diff --git a/README.md b/README.md index 4b609cb..cd42dfe 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -jwtauth - JWT authentication middleware for Go HTTP services -============================================================ -[![GoDoc Widget]][GoDoc] +# jwtauth - JWT authentication middleware for Go HTTP services + +[![GoDoc Widget]][godoc] The `jwtauth` http middleware package provides a simple way to verify a JWT token from a http request and send the result down the request context (`context.Context`). @@ -13,11 +13,11 @@ This package uses the new `context` package in Go 1.7 stdlib and [net/http#Reque In a complete JWT-authentication flow, you'll first capture the token from a http request, decode it, verify it and then validate that its correctly signed and hasn't -expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The +expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The `jwtauth.Verifier` will set the context values on keys `jwtauth.TokenCtxKey` and `jwtauth.ErrorCtxKey`. -Next, it's up to an authentication handler to respond or continue processing after the +Next, it's up to an authentication handler to respond or continue processing after the `jwtauth.Verifier`. The `jwtauth.Authenticator` middleware responds with a 401 Unauthorized plain-text payload for all unverified tokens and passes the good ones through. You can also copy the Authenticator and customize it to handle invalid tokens to better fit @@ -25,12 +25,12 @@ your flow (ie. with a JSON error response body). By default, the `Verifier` will search for a JWT token in a http request, in the order: -1. 'jwt' URI query parameter -2. 'Authorization: BEARER T' request header -3. 'jwt' Cookie value +1. 'jwt' URI query parameter +2. 'Authorization: BEARER T' request header +3. 'jwt' Cookie value The first JWT string that is found as a query parameter, authorization header -or cookie header is then decoded by the `jwt-go` library and a *jwt.Token +or cookie header is then decoded by the `jwt-go` library and a \*jwt.Token object is set on the request context. In the case of a signature decoding error the Verifier will also set the error on the request context. @@ -43,7 +43,6 @@ Note: jwtauth supports custom verification sequences for finding a token from a request by using the `Verify` middleware instantiator directly. The default `Verifier` is instantiated by calling `Verify(ja, TokenFromQuery, TokenFromHeader, TokenFromCookie)`. - # Usage See the full [example](https://github.com/go-chi/jwtauth/blob/master/_example/main.go). @@ -66,7 +65,7 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: - _, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123}) + _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123}) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) } @@ -111,5 +110,5 @@ func router() http.Handler { [MIT](/LICENSE) -[GoDoc]: https://godoc.org/github.com/go-chi/jwtauth -[GoDoc Widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg +[godoc]: https://godoc.org/github.com/go-chi/jwtauth +[godoc widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg diff --git a/_example/main.go b/_example/main.go index 5e94d6b..8210926 100644 --- a/_example/main.go +++ b/_example/main.go @@ -62,6 +62,7 @@ import ( "fmt" "net/http" + jwt "github.com/dgrijalva/jwt-go" "github.com/go-chi/chi" "github.com/go-chi/jwtauth" ) @@ -73,7 +74,7 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: - _, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123}) + _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123}) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) } diff --git a/jwtauth.go b/jwtauth.go index 0e8df98..4f1130d 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -2,7 +2,6 @@ package jwtauth import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -12,15 +11,20 @@ import ( "github.com/dgrijalva/jwt-go" ) +// Context keys var ( TokenCtxKey = &contextKey{"Token"} ErrorCtxKey = &contextKey{"Error"} ) +// Library errors var ( ErrUnauthorized = errors.New("jwtauth: token is unauthorized") ErrExpired = errors.New("jwtauth: token is expired") + ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed") + ErrIATInvalid = errors.New("jwtauth: token iat validation failed") ErrNoTokenFound = errors.New("jwtauth: no token found") + ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch") ) type JWTAuth struct { @@ -38,12 +42,7 @@ func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth { // NewWithParser is the same as New, except it supports custom parser settings // introduced in jwt-go/v2.4.0. -// -// We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that -// we can control when the claims are validated - in our case, by the Verifier -// http middleware handler. func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth { - parser.SkipClaimsValidation = true return &JWTAuth{ signKey: signKey, verifyKey: verifyKey, @@ -103,34 +102,36 @@ func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Re return nil, ErrNoTokenFound } - // TODO: what other kinds of validations should we do / error messages? - // Verify the token token, err := ja.Decode(tokenStr) if err != nil { - switch err.Error() { - case "token is expired": - err = ErrExpired + if verr, ok := err.(*jwt.ValidationError); ok { + if verr.Errors&jwt.ValidationErrorExpired > 0 { + return token, ErrExpired + } else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 { + return token, ErrIATInvalid + } else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 { + return token, ErrNBFInvalid + } } return token, err } - if token == nil || !token.Valid || token.Method != ja.signer { + if token == nil || !token.Valid { err = ErrUnauthorized return token, err } - // Check expiry via "exp" claim - if IsExpired(token) { - err = ErrExpired - return token, err + // Verify signing algorithm + if token.Method != ja.signer { + return token, ErrAlgoInvalid } // Valid! return token, nil } -func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) { +func (ja *JWTAuth) Encode(claims jwt.MapClaims) (t *jwt.Token, tokenString string, err error) { t = jwt.New(ja.signer) t.Claims = claims tokenString, err = t.SignedString(ja.signKey) @@ -139,10 +140,6 @@ func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err } func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) { - // Decode the tokenString, but avoid using custom Claims via jwt-go's - // ParseWithClaims as the jwt-go types will cause some glitches, so easier - // to decode as MapClaims then wrap the underlying map[string]interface{} - // to our Claims type t, err = ja.parser.Parse(tokenString, ja.keyFunc) if err != nil { return nil, err @@ -187,21 +184,18 @@ func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context { return ctx } -func FromContext(ctx context.Context) (*jwt.Token, Claims, error) { +func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) { token, _ := ctx.Value(TokenCtxKey).(*jwt.Token) - var claims Claims + var claims jwt.MapClaims if token != nil { - switch tokenClaims := token.Claims.(type) { - case Claims: + if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok { claims = tokenClaims - case jwt.MapClaims: - claims = Claims(tokenClaims) - default: + } else { panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims)) } } else { - claims = Claims{} + claims = jwt.MapClaims{} } err, _ := ctx.Value(ErrorCtxKey).(error) @@ -209,83 +203,17 @@ func FromContext(ctx context.Context) (*jwt.Token, Claims, error) { return token, claims, err } -func IsExpired(t *jwt.Token) bool { - claims, ok := t.Claims.(jwt.MapClaims) - if !ok { - panic("jwtauth: expecting jwt.MapClaims") - } - - if expv, ok := claims["exp"]; ok { - var exp int64 - switch v := expv.(type) { - case float64: - exp = int64(v) - case int64: - exp = v - case json.Number: - exp, _ = v.Int64() - default: - } - - if exp < EpochNow() { - return true - } - } - - return false +// UnixTime returns the given time in UTC milliseconds +func UnixTime(tm time.Time) int64 { + return tm.UTC().Unix() } -// Claims is a convenience type to manage a JWT claims hash. -type Claims map[string]interface{} - -// NOTE: as of v3.0 of jwt-go, Valid() interface method is called to verify -// the claims. However, the current design we test these claims in the -// Verifier middleware, so we skip this step. -func (c Claims) Valid() error { - return nil -} - -func (c Claims) Set(k string, v interface{}) Claims { - c[k] = v - return c -} - -func (c Claims) Get(k string) (interface{}, bool) { - v, ok := c[k] - return v, ok -} - -// Set issued at ("iat") to specified time in the claims -func (c Claims) SetIssuedAt(tm time.Time) Claims { - c["iat"] = tm.UTC().Unix() - return c -} - -// Set issued at ("iat") to present time in the claims -func (c Claims) SetIssuedNow() Claims { - c["iat"] = EpochNow() - return c -} - -// Set expiry ("exp") in the claims and return itself so it can be chained -func (c Claims) SetExpiry(tm time.Time) Claims { - c["exp"] = tm.UTC().Unix() - return c -} - -// Set expiry ("exp") in the claims to some duration from the present time -// and return itself so it can be chained -func (c Claims) SetExpiryIn(tm time.Duration) Claims { - c["exp"] = ExpireIn(tm) - return c -} - -// Helper function that returns the NumericDate time value used by the spec +// EpochNow is a helper function that returns the NumericDate time value used by the spec func EpochNow() int64 { return time.Now().UTC().Unix() } -// Helper function to return calculated time in the future for "exp" claim. +// ExpireIn is a helper function to return calculated time in the future for "exp" claim func ExpireIn(tm time.Duration) int64 { return EpochNow() + int64(tm.Seconds()) } diff --git a/jwtauth_test.go b/jwtauth_test.go index cc1dc19..67a8218 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -69,7 +69,7 @@ func TestSimpleRSA(t *testing.T) { TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey) - claims := jwtauth.Claims{ + claims := jwt.MapClaims{ "key": "val", "key2": "val2", "key3": "val3", @@ -87,7 +87,7 @@ func TestSimpleRSA(t *testing.T) { t.Fatalf("Failed to decode token string %s\n", err.Error()) } - if !reflect.DeepEqual(claims, jwtauth.Claims(token.Claims.(jwt.MapClaims))) { + if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) { t.Fatalf("The decoded claims don't match the original ones\n") } } @@ -220,7 +220,7 @@ func TestMore(t *testing.T) { t.Fatalf(resp) } - h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000)) + h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000}) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" { t.Fatalf(resp) } @@ -230,7 +230,7 @@ func TestMore(t *testing.T) { t.Fatalf(resp) } - h = newAuthHeader((jwtauth.Claims{"user_id": 31337}).SetExpiryIn(5 * time.Minute)) + h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { t.Fatalf(resp) } @@ -269,7 +269,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header return resp.StatusCode, string(respBody) } -func newJwtToken(secret []byte, claims ...jwtauth.Claims) string { +func newJwtToken(secret []byte, claims ...jwt.MapClaims) string { token := jwt.New(jwt.GetSigningMethod("HS256")) if len(claims) > 0 { token.Claims = claims[0] @@ -281,7 +281,7 @@ func newJwtToken(secret []byte, claims ...jwtauth.Claims) string { return tokenStr } -func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string { +func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string { // use-case: when token is signed with a different alg than expected token := jwt.New(jwt.GetSigningMethod("HS512")) if len(claims) > 0 { @@ -294,7 +294,7 @@ func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string { return tokenStr } -func newAuthHeader(claims ...jwtauth.Claims) http.Header { +func newAuthHeader(claims ...jwt.MapClaims) http.Header { h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...)) return h