Skip to content

Commit

Permalink
Update token endpoint for better error handling (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
DTCurrie authored Nov 12, 2024
1 parent 000ac9e commit 842479f
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 42 deletions.
136 changes: 94 additions & 42 deletions web/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ type AuthProvider struct {
stateCookieMaxAge time.Duration
}

const (
// ViamTokenCookie is the cookie name for an authenticated access token
//nolint:gosec
ViamTokenCookie string = "viam.auth.token"
// ViamRefreshCookie is the cookie name for an authenticated refresh token.
ViamRefreshCookie string = "viam.auth.refresh"
// ViamExpiryCookie is the cookie name for an authenticated token's expiry.
ViamExpiryCookie string = "viam.auth.expiry"
)

// Close called by io.Closer.
func (s *AuthProvider) Close() error {
s.httpTransport.CloseIdleConnections()
Expand Down Expand Up @@ -262,7 +272,7 @@ func (h *callbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.token",
Name: ViamTokenCookie,
Value: token.AccessToken,
Path: "/",
Expires: token.Expiry,
Expand All @@ -272,7 +282,7 @@ func (h *callbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.refresh",
Name: ViamRefreshCookie,
Value: token.RefreshToken,
Path: "/",
Expires: token.Expiry,
Expand All @@ -282,7 +292,7 @@ func (h *callbackHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.expiry",
Name: ViamExpiryCookie,
Value: token.Expiry.Format(time.RFC3339),
Path: "/",
Expires: token.Expiry,
Expand Down Expand Up @@ -413,79 +423,121 @@ type tokenResponse struct {
Expiry string `json:"expiry"`
}

func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()

_, span := trace.StartSpan(ctx, r.URL.Path)
defer span.End()

token, err := r.Cookie("viam.auth.token")
if HandleError(w, err, h.logger, "getting token cookie") {
return
func getBearerToken(req *http.Request) string {
authHeader := req.Header.Get("Authorization")
if authHeader == "" {
return ""
}

refresh, err := r.Cookie("viam.auth.refresh")
if HandleError(w, err, h.logger, "getting refresh cookie") {
return
parts := strings.Split(authHeader, " ")
if len(parts) == 2 && parts[0] == "Bearer" {
return parts[1]
}

expiry, err := r.Cookie("viam.auth.expiry")
if HandleError(w, err, h.logger, "getting expiry cookie") {
return
}
return ""
}

response := &tokenResponse{
AccessToken: token.Value,
RefreshToken: refresh.Value,
Expiry: expiry.Value,
// getAuthCookieValues reads the authentication cookie values as a /token response
// before clearing the cookies.
func getAndClearAuthCookieValues(w http.ResponseWriter, r *http.Request) *tokenResponse {
token, err := r.Cookie(ViamTokenCookie)
if err != nil || token.Value == "" {
return nil
}

w.Header().Set("Content-Type", "application/json")
data, err := json.Marshal(response)
refresh, err := r.Cookie(ViamRefreshCookie)
// TODO: Check if refresh is empty when implemented, always empty now
if err != nil {
temp := errors.New("failed to verify marshal token data: " + err.Error())
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(temp.Error()))
if err != nil {
utils.UncheckedError(err)
}
h.logger.Error(temp)
return
return nil
}

expiry, err := r.Cookie(ViamExpiryCookie)
if err != nil || expiry.Value == "" {
return nil
}

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.token",
Name: ViamTokenCookie,
Value: "",
Path: "/",
MaxAge: -1,
Secure: r.TLS != nil,
SameSite: http.SameSiteLaxMode,
HttpOnly: true,
})

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.refresh",
Name: ViamRefreshCookie,
Value: "",
Path: "/",
MaxAge: -1,
Secure: r.TLS != nil,
SameSite: http.SameSiteLaxMode,
HttpOnly: true,
})

http.SetCookie(w, &http.Cookie{
Name: "viam.auth.expiry",
Name: ViamExpiryCookie,
Value: "",
Path: "/",
MaxAge: -1,
Secure: r.TLS != nil,
SameSite: http.SameSiteLaxMode,
HttpOnly: true,
})

_, err = w.Write(data)
utils.UncheckedError(err)
return &tokenResponse{
AccessToken: token.Value,
RefreshToken: refresh.Value,
Expiry: expiry.Value,
}
}

func (h *tokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()

_, span := trace.StartSpan(ctx, r.URL.Path)
defer span.End()

data := getAndClearAuthCookieValues(w, r)

// handle incoming login request with cookies
if data != nil {
w.Header().Set("Content-Type", "application/json")
response, err := json.Marshal(data)
if err != nil {
temp := errors.New("failed to verify marshal token data: " + err.Error())
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(temp.Error()))
if err != nil {
utils.UncheckedError(err)
}
h.logger.Error(temp)
return
}

_, err = w.Write(response)
utils.UncheckedError(err)
return
}

// user calls with no token in the header, no cookies exist
// - return a bad request 400
current := getBearerToken(r)
if current == "" {
w.WriteHeader(http.StatusBadRequest)
return
}

// user calls with an invalid token in the header, no cookies exist
// - return an unauthenticated error 401
isValid := h.state.sessions.HasSessionWithAccessToken(ctx, current)
if !isValid {
w.WriteHeader(http.StatusUnauthorized)
return
}

// user calls with a valid token in the header, no cookies exist
// - return a no content 204 response
w.WriteHeader(http.StatusNoContent)
}

// --------------------------------
Expand Down
93 changes: 93 additions & 0 deletions web/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package web

import (
"net/http"
"net/http/httptest"
"testing"

"go.viam.com/test"
)

func createRequest(t *testing.T) (http.ResponseWriter, *http.Request) {
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil)
if err != nil {
t.Fatal(err)
return nil, nil
}

return w, r
}

func setCookie(r *http.Request, key, value string) {
r.AddCookie(&http.Cookie{
Name: key,
Value: value,
Path: "/",
MaxAge: 10000,
Secure: true,
SameSite: http.SameSiteLaxMode,
HttpOnly: true,
})
}

func TestWebAuth(t *testing.T) {
t.Run("should return nil when token cookie is not present", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamRefreshCookie, "")
setCookie(r, ViamExpiryCookie, "123456")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data, test.ShouldBeNil)
})

t.Run("should return nil when token cookie is empty", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamTokenCookie, "")
setCookie(r, ViamRefreshCookie, "")
setCookie(r, ViamExpiryCookie, "123456")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data, test.ShouldBeNil)
})

t.Run("should return nil when refresh cookies is not present", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamTokenCookie, "abc123")
setCookie(r, ViamExpiryCookie, "123456")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data, test.ShouldBeNil)
})

t.Run("should return nil when expiry cookie is not present", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamTokenCookie, "abc123")
setCookie(r, ViamRefreshCookie, "")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data, test.ShouldBeNil)
})

t.Run("should return nil when expiry cookie is empty", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamTokenCookie, "abc123")
setCookie(r, ViamRefreshCookie, "")
setCookie(r, ViamExpiryCookie, "")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data, test.ShouldBeNil)
})

t.Run("should return token response data when cookies are set and clear the cookies", func(t *testing.T) {
w, r := createRequest(t)
setCookie(r, ViamTokenCookie, "abc123")
setCookie(r, ViamRefreshCookie, "")
setCookie(r, ViamExpiryCookie, "123456")

data := getAndClearAuthCookieValues(w, r)
test.That(t, data.AccessToken, test.ShouldEqual, "abc123")
test.That(t, data.RefreshToken, test.ShouldEqual, "")
test.That(t, data.Expiry, test.ShouldEqual, "123456")
})
}

0 comments on commit 842479f

Please sign in to comment.