Remove unnecessary code and clean up errors (#30)

* Initial refactor

* Fix some of the tests

* Handle more specific errors from jwt library

* Fix comments

* Rename package

* Verify signing algo after nil check
This commit is contained in:
Jonathan ES Lin 2018-09-27 01:43:54 +08:00 committed by Peter Kieltyka
parent 53a0a4877a
commit ea7d7e213f
4 changed files with 49 additions and 121 deletions

View file

@ -1,6 +1,6 @@
jwtauth - JWT authentication middleware for Go HTTP services # jwtauth - JWT authentication middleware for Go HTTP services
============================================================
[![GoDoc Widget]][GoDoc] [![GoDoc Widget]][godoc]
The `jwtauth` http middleware package provides a simple way to verify a JWT token The `jwtauth` http middleware package provides a simple way to verify a JWT token
from a http request and send the result down the request context (`context.Context`). from a http request and send the result down the request context (`context.Context`).
@ -13,11 +13,11 @@ This package uses the new `context` package in Go 1.7 stdlib and [net/http#Reque
In a complete JWT-authentication flow, you'll first capture the token from a http In a complete JWT-authentication flow, you'll first capture the token from a http
request, decode it, verify it and then validate that its correctly signed and hasn't request, decode it, verify it and then validate that its correctly signed and hasn't
expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The
`jwtauth.Verifier` will set the context values on keys `jwtauth.TokenCtxKey` and `jwtauth.Verifier` will set the context values on keys `jwtauth.TokenCtxKey` and
`jwtauth.ErrorCtxKey`. `jwtauth.ErrorCtxKey`.
Next, it's up to an authentication handler to respond or continue processing after the Next, it's up to an authentication handler to respond or continue processing after the
`jwtauth.Verifier`. The `jwtauth.Authenticator` middleware responds with a 401 Unauthorized `jwtauth.Verifier`. The `jwtauth.Authenticator` middleware responds with a 401 Unauthorized
plain-text payload for all unverified tokens and passes the good ones through. You can plain-text payload for all unverified tokens and passes the good ones through. You can
also copy the Authenticator and customize it to handle invalid tokens to better fit also copy the Authenticator and customize it to handle invalid tokens to better fit
@ -25,12 +25,12 @@ your flow (ie. with a JSON error response body).
By default, the `Verifier` will search for a JWT token in a http request, in the order: By default, the `Verifier` will search for a JWT token in a http request, in the order:
1. 'jwt' URI query parameter 1. 'jwt' URI query parameter
2. 'Authorization: BEARER T' request header 2. 'Authorization: BEARER T' request header
3. 'jwt' Cookie value 3. 'jwt' Cookie value
The first JWT string that is found as a query parameter, authorization header The first JWT string that is found as a query parameter, authorization header
or cookie header is then decoded by the `jwt-go` library and a *jwt.Token or cookie header is then decoded by the `jwt-go` library and a \*jwt.Token
object is set on the request context. In the case of a signature decoding error object is set on the request context. In the case of a signature decoding error
the Verifier will also set the error on the request context. the Verifier will also set the error on the request context.
@ -43,7 +43,6 @@ Note: jwtauth supports custom verification sequences for finding a token
from a request by using the `Verify` middleware instantiator directly. The default from a request by using the `Verify` middleware instantiator directly. The default
`Verifier` is instantiated by calling `Verify(ja, TokenFromQuery, TokenFromHeader, TokenFromCookie)`. `Verifier` is instantiated by calling `Verify(ja, TokenFromQuery, TokenFromHeader, TokenFromCookie)`.
# Usage # Usage
See the full [example](https://github.com/go-chi/jwtauth/blob/master/_example/main.go). See the full [example](https://github.com/go-chi/jwtauth/blob/master/_example/main.go).
@ -66,7 +65,7 @@ func init() {
// For debugging/example purposes, we generate and print // For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here: // a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123}) _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
} }
@ -111,5 +110,5 @@ func router() http.Handler {
[MIT](/LICENSE) [MIT](/LICENSE)
[GoDoc]: https://godoc.org/github.com/go-chi/jwtauth [godoc]: https://godoc.org/github.com/go-chi/jwtauth
[GoDoc Widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg [godoc widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg

View file

@ -62,6 +62,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
jwt "github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/jwtauth" "github.com/go-chi/jwtauth"
) )
@ -73,7 +74,7 @@ func init() {
// For debugging/example purposes, we generate and print // For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here: // a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123}) _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
} }

View file

@ -2,7 +2,6 @@ package jwtauth
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -12,15 +11,20 @@ import (
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
) )
// Context keys
var ( var (
TokenCtxKey = &contextKey{"Token"} TokenCtxKey = &contextKey{"Token"}
ErrorCtxKey = &contextKey{"Error"} ErrorCtxKey = &contextKey{"Error"}
) )
// Library errors
var ( var (
ErrUnauthorized = errors.New("jwtauth: token is unauthorized") ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
ErrExpired = errors.New("jwtauth: token is expired") ErrExpired = errors.New("jwtauth: token is expired")
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
ErrNoTokenFound = errors.New("jwtauth: no token found") ErrNoTokenFound = errors.New("jwtauth: no token found")
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
) )
type JWTAuth struct { type JWTAuth struct {
@ -38,12 +42,7 @@ func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
// NewWithParser is the same as New, except it supports custom parser settings // NewWithParser is the same as New, except it supports custom parser settings
// introduced in jwt-go/v2.4.0. // introduced in jwt-go/v2.4.0.
//
// We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that
// we can control when the claims are validated - in our case, by the Verifier
// http middleware handler.
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth { func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
parser.SkipClaimsValidation = true
return &JWTAuth{ return &JWTAuth{
signKey: signKey, signKey: signKey,
verifyKey: verifyKey, verifyKey: verifyKey,
@ -103,34 +102,36 @@ func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Re
return nil, ErrNoTokenFound return nil, ErrNoTokenFound
} }
// TODO: what other kinds of validations should we do / error messages?
// Verify the token // Verify the token
token, err := ja.Decode(tokenStr) token, err := ja.Decode(tokenStr)
if err != nil { if err != nil {
switch err.Error() { if verr, ok := err.(*jwt.ValidationError); ok {
case "token is expired": if verr.Errors&jwt.ValidationErrorExpired > 0 {
err = ErrExpired return token, ErrExpired
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrIATInvalid
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrNBFInvalid
}
} }
return token, err return token, err
} }
if token == nil || !token.Valid || token.Method != ja.signer { if token == nil || !token.Valid {
err = ErrUnauthorized err = ErrUnauthorized
return token, err return token, err
} }
// Check expiry via "exp" claim // Verify signing algorithm
if IsExpired(token) { if token.Method != ja.signer {
err = ErrExpired return token, ErrAlgoInvalid
return token, err
} }
// Valid! // Valid!
return token, nil return token, nil
} }
func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) { func (ja *JWTAuth) Encode(claims jwt.MapClaims) (t *jwt.Token, tokenString string, err error) {
t = jwt.New(ja.signer) t = jwt.New(ja.signer)
t.Claims = claims t.Claims = claims
tokenString, err = t.SignedString(ja.signKey) tokenString, err = t.SignedString(ja.signKey)
@ -139,10 +140,6 @@ func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err
} }
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) { func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
// Decode the tokenString, but avoid using custom Claims via jwt-go's
// ParseWithClaims as the jwt-go types will cause some glitches, so easier
// to decode as MapClaims then wrap the underlying map[string]interface{}
// to our Claims type
t, err = ja.parser.Parse(tokenString, ja.keyFunc) t, err = ja.parser.Parse(tokenString, ja.keyFunc)
if err != nil { if err != nil {
return nil, err return nil, err
@ -187,21 +184,18 @@ func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
return ctx return ctx
} }
func FromContext(ctx context.Context) (*jwt.Token, Claims, error) { func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token) token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
var claims Claims var claims jwt.MapClaims
if token != nil { if token != nil {
switch tokenClaims := token.Claims.(type) { if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
case Claims:
claims = tokenClaims claims = tokenClaims
case jwt.MapClaims: } else {
claims = Claims(tokenClaims)
default:
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims)) panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
} }
} else { } else {
claims = Claims{} claims = jwt.MapClaims{}
} }
err, _ := ctx.Value(ErrorCtxKey).(error) err, _ := ctx.Value(ErrorCtxKey).(error)
@ -209,83 +203,17 @@ func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
return token, claims, err return token, claims, err
} }
func IsExpired(t *jwt.Token) bool { // UnixTime returns the given time in UTC milliseconds
claims, ok := t.Claims.(jwt.MapClaims) func UnixTime(tm time.Time) int64 {
if !ok { return tm.UTC().Unix()
panic("jwtauth: expecting jwt.MapClaims")
}
if expv, ok := claims["exp"]; ok {
var exp int64
switch v := expv.(type) {
case float64:
exp = int64(v)
case int64:
exp = v
case json.Number:
exp, _ = v.Int64()
default:
}
if exp < EpochNow() {
return true
}
}
return false
} }
// Claims is a convenience type to manage a JWT claims hash. // EpochNow is a helper function that returns the NumericDate time value used by the spec
type Claims map[string]interface{}
// NOTE: as of v3.0 of jwt-go, Valid() interface method is called to verify
// the claims. However, the current design we test these claims in the
// Verifier middleware, so we skip this step.
func (c Claims) Valid() error {
return nil
}
func (c Claims) Set(k string, v interface{}) Claims {
c[k] = v
return c
}
func (c Claims) Get(k string) (interface{}, bool) {
v, ok := c[k]
return v, ok
}
// Set issued at ("iat") to specified time in the claims
func (c Claims) SetIssuedAt(tm time.Time) Claims {
c["iat"] = tm.UTC().Unix()
return c
}
// Set issued at ("iat") to present time in the claims
func (c Claims) SetIssuedNow() Claims {
c["iat"] = EpochNow()
return c
}
// Set expiry ("exp") in the claims and return itself so it can be chained
func (c Claims) SetExpiry(tm time.Time) Claims {
c["exp"] = tm.UTC().Unix()
return c
}
// Set expiry ("exp") in the claims to some duration from the present time
// and return itself so it can be chained
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
c["exp"] = ExpireIn(tm)
return c
}
// Helper function that returns the NumericDate time value used by the spec
func EpochNow() int64 { func EpochNow() int64 {
return time.Now().UTC().Unix() return time.Now().UTC().Unix()
} }
// Helper function to return calculated time in the future for "exp" claim. // ExpireIn is a helper function to return calculated time in the future for "exp" claim
func ExpireIn(tm time.Duration) int64 { func ExpireIn(tm time.Duration) int64 {
return EpochNow() + int64(tm.Seconds()) return EpochNow() + int64(tm.Seconds())
} }

View file

@ -69,7 +69,7 @@ func TestSimpleRSA(t *testing.T) {
TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey) TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey)
claims := jwtauth.Claims{ claims := jwt.MapClaims{
"key": "val", "key": "val",
"key2": "val2", "key2": "val2",
"key3": "val3", "key3": "val3",
@ -87,7 +87,7 @@ func TestSimpleRSA(t *testing.T) {
t.Fatalf("Failed to decode token string %s\n", err.Error()) t.Fatalf("Failed to decode token string %s\n", err.Error())
} }
if !reflect.DeepEqual(claims, jwtauth.Claims(token.Claims.(jwt.MapClaims))) { if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) {
t.Fatalf("The decoded claims don't match the original ones\n") t.Fatalf("The decoded claims don't match the original ones\n")
} }
} }
@ -220,7 +220,7 @@ func TestMore(t *testing.T) {
t.Fatalf(resp) t.Fatalf(resp)
} }
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000)) h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000})
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" { if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
t.Fatalf(resp) t.Fatalf(resp)
} }
@ -230,7 +230,7 @@ func TestMore(t *testing.T) {
t.Fatalf(resp) t.Fatalf(resp)
} }
h = newAuthHeader((jwtauth.Claims{"user_id": 31337}).SetExpiryIn(5 * time.Minute)) h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp) t.Fatalf(resp)
} }
@ -269,7 +269,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
return resp.StatusCode, string(respBody) return resp.StatusCode, string(respBody)
} }
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string { func newJwtToken(secret []byte, claims ...jwt.MapClaims) string {
token := jwt.New(jwt.GetSigningMethod("HS256")) token := jwt.New(jwt.GetSigningMethod("HS256"))
if len(claims) > 0 { if len(claims) > 0 {
token.Claims = claims[0] token.Claims = claims[0]
@ -281,7 +281,7 @@ func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
return tokenStr return tokenStr
} }
func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string { func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string {
// use-case: when token is signed with a different alg than expected // use-case: when token is signed with a different alg than expected
token := jwt.New(jwt.GetSigningMethod("HS512")) token := jwt.New(jwt.GetSigningMethod("HS512"))
if len(claims) > 0 { if len(claims) > 0 {
@ -294,7 +294,7 @@ func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string {
return tokenStr return tokenStr
} }
func newAuthHeader(claims ...jwtauth.Claims) http.Header { func newAuthHeader(claims ...jwt.MapClaims) http.Header {
h := http.Header{} h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...)) h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
return h return h