diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d6c1e9f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +on: [push, pull_request] +name: Test +jobs: + test: + env: + GOPATH: ${{ github.workspace }} + GO111MODULE: off + + defaults: + run: + working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} + + strategy: + matrix: + go-version: [1.14.x, 1.15.x] + os: [ubuntu-latest, macos-latest, windows-latest] + + runs-on: ${{ matrix.os }} + + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + - name: Checkout code + uses: actions/checkout@v2 + with: + path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} + - name: Test + run: | + go get -d -t ./... + go test -v ./... diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index ec62162..0000000 --- a/.travis.yml +++ /dev/null @@ -1,16 +0,0 @@ -language: go - -go: - - 1.10.x - - 1.11.x - - 1.12.x - - 1.13.x - -install: - - go get -u golang.org/x/tools/cmd/goimports - -script: - - go get -d -t ./... - - go test ./... - - > - goimports -d -e ./ | grep '.*' && { echo; echo "Aborting due to non-empty goimports output."; exit 1; } || : diff --git a/README.md b/README.md index cd42dfe..b69b438 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,7 @@ from a http request and send the result down the request context (`context.Conte Please note, `jwtauth` works with any Go http router, but resides under the go-chi group for maintenance and organization - its only 3rd party dependency is the underlying jwt library -"github.com/dgrijalva/jwt-go". - -This package uses the new `context` package in Go 1.7 stdlib and [net/http#Request.Context](https://golang.org/pkg/net/http/#Request.Context) to pass values between handler chains. +"github.com/lestrrat-go/jwx". In a complete JWT-authentication flow, you'll first capture the token from a http request, decode it, verify it and then validate that its correctly signed and hasn't @@ -65,7 +63,7 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: - _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123}) + _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) } diff --git a/_example/main.go b/_example/main.go index 8210926..4c18cf1 100644 --- a/_example/main.go +++ b/_example/main.go @@ -62,7 +62,6 @@ import ( "fmt" "net/http" - jwt "github.com/dgrijalva/jwt-go" "github.com/go-chi/chi" "github.com/go-chi/jwtauth" ) @@ -74,7 +73,7 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: - _, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123}) + _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) } diff --git a/go.mod b/go.mod index 0e75fd1..7e54503 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/go-chi/jwtauth go 1.15 require ( - github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/go-chi/chi v1.5.1 + github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 ) diff --git a/go.sum b/go.sum index f45f929..0b170e8 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,39 @@ -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi v1.5.1 h1:kfTK3Cxd/dkMu/rKs5ZceWYp+t5CtiE7vmaTv3LjC6w= github.com/go-chi/chi v1.5.1/go.mod h1:REp24E+25iKvxgeTfHmdUoL5x15kBiDBlnIl5bCwe2k= +github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY= +github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= +github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 h1:+HTAqhgKkKqizghOYb4uEpZ7wK8tl3Y48ZbUTHF521c= +github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0= +github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d h1:aEZT3f1GGg5RIlHMAy4/4fe4ciOi3SCwYoaURphcB4k= +github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/jwtauth.go b/jwtauth.go index 0e7641e..dc6be6e 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -1,14 +1,15 @@ package jwtauth import ( + "bytes" "context" "errors" - "fmt" "net/http" "strings" "time" - jwt "github.com/dgrijalva/jwt-go" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" ) // Context keys @@ -19,36 +20,31 @@ var ( // Library errors var ( - ErrUnauthorized = errors.New("jwtauth: token is unauthorized") - ErrExpired = errors.New("jwtauth: token is expired") - ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed") - ErrIATInvalid = errors.New("jwtauth: token iat validation failed") - ErrNoTokenFound = errors.New("jwtauth: no token found") - ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch") + ErrUnauthorized = errors.New("token is unauthorized") + ErrExpired = errors.New("token is expired") + ErrNBFInvalid = errors.New("token nbf validation failed") + ErrIATInvalid = errors.New("token iat validation failed") + ErrNoTokenFound = errors.New("no token found") + ErrAlgoInvalid = errors.New("algorithm mismatch") ) type JWTAuth struct { - signKey interface{} - verifyKey interface{} - signer jwt.SigningMethod - parser *jwt.Parser + alg jwa.SignatureAlgorithm + signKey interface{} // private-key + verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms + verifier jwt.ParseOption } -// New creates a JWTAuth authenticator instance that provides middleware handlers -// and encoding/decoding functions for JWT signing. func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth { - return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey) -} + ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey} -// NewWithParser is the same as New, except it supports custom parser settings -// introduced in jwt-go/v2.4.0. -func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth { - return &JWTAuth{ - signKey: signKey, - verifyKey: verifyKey, - signer: jwt.GetSigningMethod(alg), - parser: parser, + if ja.verifyKey != nil { + ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey) + } else { + ja.verifier = jwt.WithVerify(ja.alg, ja.signKey) } + + return ja } // Verifier http middleware handler will verify a JWT string from a http request. @@ -85,73 +81,81 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http } } -func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (*jwt.Token, error) { - var tokenStr string - var err error +func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) { + var tokenString string // Extract token string from the request by calling token find functions in // the order they where provided. Further extraction stops if a function // returns a non-empty string. for _, fn := range findTokenFns { - tokenStr = fn(r) - if tokenStr != "" { + tokenString = fn(r) + if tokenString != "" { break } } - if tokenStr == "" { + if tokenString == "" { return nil, ErrNoTokenFound } - // Verify the token - token, err := ja.Decode(tokenStr) + return VerifyToken(ja, tokenString) +} + +func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) { + // Decode & verify the token + token, err := ja.Decode(tokenString) if err != nil { - if verr, ok := err.(*jwt.ValidationError); ok { - if verr.Errors&jwt.ValidationErrorExpired > 0 { - return token, ErrExpired - } else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 { - return token, ErrIATInvalid - } else if verr.Errors&jwt.ValidationErrorNotValidYet > 0 { - return token, ErrNBFInvalid - } - } - return token, err + return token, ErrorReason(err) } - if token == nil || !token.Valid { - err = ErrUnauthorized - return token, err + if token == nil { + return nil, ErrUnauthorized } - // Verify signing algorithm - if token.Method != ja.signer { - return token, ErrAlgoInvalid + if err := jwt.Validate(token); err != nil { + return token, ErrorReason(err) } // Valid! return token, nil } -func (ja *JWTAuth) Encode(claims jwt.Claims) (t *jwt.Token, tokenString string, err error) { - t = jwt.New(ja.signer) - t.Claims = claims - tokenString, err = t.SignedString(ja.signKey) - t.Raw = tokenString - return -} - -func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) { - t, err = ja.parser.Parse(tokenString, ja.keyFunc) - if err != nil { - return nil, err +func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) { + t = jwt.New() + for k, v := range claims { + t.Set(k, v) } + payload, err := ja.sign(t) + if err != nil { + return nil, "", err + } + tokenString = string(payload) return } -func (ja *JWTAuth) keyFunc(t *jwt.Token) (interface{}, error) { - if ja.verifyKey != nil { - return ja.verifyKey, nil - } else { - return ja.signKey, nil +func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) { + return ja.parse([]byte(tokenString)) +} + +func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) { + return jwt.Sign(token, ja.alg, ja.signKey) +} + +func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) { + return jwt.Parse(bytes.NewReader(payload), ja.verifier) +} + +// ErrorReason will normalize the error message from the underlining +// jwt library +func ErrorReason(err error) error { + switch err.Error() { + case "exp not satisfied", ErrExpired.Error(): + return ErrExpired + case "iat not satisfied", ErrIATInvalid.Error(): + return ErrIATInvalid + case "nbf not satisfied", ErrNBFInvalid.Error(): + return ErrNBFInvalid + default: + return ErrUnauthorized } } @@ -164,11 +168,11 @@ func Authenticator(next http.Handler) http.Handler { token, _, err := FromContext(r.Context()) if err != nil { - http.Error(w, http.StatusText(401), 401) + http.Error(w, err.Error(), 401) return } - if token == nil || !token.Valid { + if token == nil || jwt.Validate(token) != nil { http.Error(w, http.StatusText(401), 401) return } @@ -178,27 +182,28 @@ func Authenticator(next http.Handler) http.Handler { }) } -func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context { +func NewContext(ctx context.Context, t jwt.Token, err error) context.Context { ctx = context.WithValue(ctx, TokenCtxKey, t) ctx = context.WithValue(ctx, ErrorCtxKey, err) return ctx } -func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) { - token, _ := ctx.Value(TokenCtxKey).(*jwt.Token) +func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) { + token, _ := ctx.Value(TokenCtxKey).(jwt.Token) + + var err error + var claims map[string]interface{} - var claims jwt.MapClaims if token != nil { - if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok { - claims = tokenClaims - } else { - panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims)) + claims, err = token.AsMap(context.Background()) + if err != nil { + return token, nil, err } } else { - claims = jwt.MapClaims{} + claims = map[string]interface{}{} } - err, _ := ctx.Value(ErrorCtxKey).(error) + err, _ = ctx.Value(ErrorCtxKey).(error) return token, claims, err } @@ -219,22 +224,22 @@ func ExpireIn(tm time.Duration) int64 { } // Set issued at ("iat") to specified time in the claims -func SetIssuedAt(claims jwt.MapClaims, tm time.Time) { +func SetIssuedAt(claims map[string]interface{}, tm time.Time) { claims["iat"] = tm.UTC().Unix() } // Set issued at ("iat") to present time in the claims -func SetIssuedNow(claims jwt.MapClaims) { +func SetIssuedNow(claims map[string]interface{}) { claims["iat"] = EpochNow() } // Set expiry ("exp") in the claims -func SetExpiry(claims jwt.MapClaims, tm time.Time) { +func SetExpiry(claims map[string]interface{}, tm time.Time) { claims["exp"] = tm.UTC().Unix() } // Set expiry ("exp") in the claims to some duration from the present time -func SetExpiryIn(claims jwt.MapClaims, tm time.Duration) { +func SetExpiryIn(claims map[string]interface{}, tm time.Duration) { claims["exp"] = ExpireIn(tm) } diff --git a/jwtauth_test.go b/jwtauth_test.go index 709aac5..38583b7 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -1,6 +1,7 @@ package jwtauth_test import ( + "context" "crypto/x509" "encoding/pem" "fmt" @@ -13,9 +14,9 @@ import ( "testing" "time" - jwt "github.com/dgrijalva/jwt-go" "github.com/go-chi/chi" "github.com/go-chi/jwtauth" + "github.com/lestrrat-go/jwx/jwt" ) var ( @@ -69,7 +70,7 @@ func TestSimpleRSA(t *testing.T) { TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey) - claims := jwt.MapClaims{ + claims := map[string]interface{}{ "key": "val", "key2": "val2", "key3": "val3", @@ -87,7 +88,12 @@ func TestSimpleRSA(t *testing.T) { t.Fatalf("Failed to decode token string %s\n", err.Error()) } - if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) { + tokenClaims, err := token.AsMap(context.Background()) + if err != nil { + t.Fatal(err.Error()) + } + + if !reflect.DeepEqual(claims, tokenClaims) { t.Fatalf("The decoded claims don't match the original ones\n") } } @@ -105,27 +111,27 @@ func TestSimple(t *testing.T) { defer ts.Close() // sending unauthorized requests - if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "no token found\n" { t.Fatalf(resp) } h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } h.Set("Authorization", "BEARER asdf") - if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // wrong token secret and wrong alg h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // correct token secret but wrong alg h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } @@ -147,23 +153,12 @@ func TestMore(t *testing.T) { token, _, err := jwtauth.FromContext(r.Context()) if err != nil { - switch err { - default: - http.Error(w, http.StatusText(401), 401) - return - case jwtauth.ErrExpired: - http.Error(w, "expired", 401) - return - case jwtauth.ErrUnauthorized: - http.Error(w, http.StatusText(401), 401) - return - case nil: - // no error - } + http.Error(w, jwtauth.ErrorReason(err).Error(), 401) + return } - if token == nil || !token.Valid { - http.Error(w, http.StatusText(401), 401) + if err := jwt.Validate(token); err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), 401) return } @@ -196,32 +191,32 @@ func TestMore(t *testing.T) { defer ts.Close() // sending unauthorized requests - if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } h.Set("Authorization", "BEARER asdf") - if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // wrong token secret and wrong alg h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // correct token secret but wrong alg h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{})) - if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" { + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } - h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000}) - if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" { + h = newAuthHeader(map[string]interface{}{"exp": jwtauth.EpochNow() - 1000}) + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is expired\n" { t.Fatalf(resp) } @@ -230,7 +225,7 @@ func TestMore(t *testing.T) { t.Fatalf(resp) } - h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) + h = newAuthHeader((map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { t.Fatalf(resp) } @@ -269,32 +264,36 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header return resp.StatusCode, string(respBody) } -func newJwtToken(secret []byte, claims ...jwt.MapClaims) string { - token := jwt.New(jwt.GetSigningMethod("HS256")) +func newJwtToken(secret []byte, claims ...map[string]interface{}) string { + token := jwt.New() if len(claims) > 0 { - token.Claims = claims[0] + for k, v := range claims[0] { + token.Set(k, v) + } } - tokenStr, err := token.SignedString(secret) + tokenPayload, err := jwt.Sign(token, "HS256", secret) if err != nil { log.Fatal(err) } - return tokenStr + return string(tokenPayload) } -func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string { +func newJwt512Token(secret []byte, claims ...map[string]interface{}) string { // use-case: when token is signed with a different alg than expected - token := jwt.New(jwt.GetSigningMethod("HS512")) + token := jwt.New() if len(claims) > 0 { - token.Claims = claims[0] + for k, v := range claims[0] { + token.Set(k, v) + } } - tokenStr, err := token.SignedString(secret) + tokenPayload, err := jwt.Sign(token, "HS512", secret) if err != nil { log.Fatal(err) } - return tokenStr + return string(tokenPayload) } -func newAuthHeader(claims ...jwt.MapClaims) http.Header { +func newAuthHeader(claims ...map[string]interface{}) http.Header { h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...)) return h