Switch to github.com/lestrrat-go/jwx underlying jwt library (#52)

This commit is contained in:
Peter Kieltyka 2020-12-12 09:40:27 -05:00 committed by GitHub
parent 02fa0c511c
commit b8af768272
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 200 additions and 148 deletions

32
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,32 @@
on: [push, pull_request]
name: Test
jobs:
test:
env:
GOPATH: ${{ github.workspace }}
GO111MODULE: off
defaults:
run:
working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
strategy:
matrix:
go-version: [1.14.x, 1.15.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v2
with:
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
- name: Test
run: |
go get -d -t ./...
go test -v ./...

View file

@ -1,16 +0,0 @@
language: go
go:
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
install:
- go get -u golang.org/x/tools/cmd/goimports
script:
- go get -d -t ./...
- go test ./...
- >
goimports -d -e ./ | grep '.*' && { echo; echo "Aborting due to non-empty goimports output."; exit 1; } || :

View file

@ -7,9 +7,7 @@ from a http request and send the result down the request context (`context.Conte
Please note, `jwtauth` works with any Go http router, but resides under the go-chi group
for maintenance and organization - its only 3rd party dependency is the underlying jwt library
"github.com/dgrijalva/jwt-go".
This package uses the new `context` package in Go 1.7 stdlib and [net/http#Request.Context](https://golang.org/pkg/net/http/#Request.Context) to pass values between handler chains.
"github.com/lestrrat-go/jwx".
In a complete JWT-authentication flow, you'll first capture the token from a http
request, decode it, verify it and then validate that its correctly signed and hasn't
@ -65,7 +63,7 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
}

View file

@ -62,7 +62,6 @@ import (
"fmt"
"net/http"
jwt "github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi"
"github.com/go-chi/jwtauth"
)
@ -74,7 +73,7 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
}

2
go.mod
View file

@ -3,6 +3,6 @@ module github.com/go-chi/jwtauth
go 1.15
require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/go-chi/chi v1.5.1
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029
)

39
go.sum
View file

@ -1,4 +1,39 @@
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi v1.5.1 h1:kfTK3Cxd/dkMu/rKs5ZceWYp+t5CtiE7vmaTv3LjC6w=
github.com/go-chi/chi v1.5.1/go.mod h1:REp24E+25iKvxgeTfHmdUoL5x15kBiDBlnIl5bCwe2k=
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY=
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 h1:+HTAqhgKkKqizghOYb4uEpZ7wK8tl3Y48ZbUTHF521c=
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0=
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d h1:aEZT3f1GGg5RIlHMAy4/4fe4ciOi3SCwYoaURphcB4k=
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -1,14 +1,15 @@
package jwtauth
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
)
// Context keys
@ -19,36 +20,31 @@ var (
// Library errors
var (
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
ErrExpired = errors.New("jwtauth: token is expired")
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
ErrNoTokenFound = errors.New("jwtauth: no token found")
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
ErrUnauthorized = errors.New("token is unauthorized")
ErrExpired = errors.New("token is expired")
ErrNBFInvalid = errors.New("token nbf validation failed")
ErrIATInvalid = errors.New("token iat validation failed")
ErrNoTokenFound = errors.New("no token found")
ErrAlgoInvalid = errors.New("algorithm mismatch")
)
type JWTAuth struct {
signKey interface{}
verifyKey interface{}
signer jwt.SigningMethod
parser *jwt.Parser
alg jwa.SignatureAlgorithm
signKey interface{} // private-key
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
}
// New creates a JWTAuth authenticator instance that provides middleware handlers
// and encoding/decoding functions for JWT signing.
func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey)
}
ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey}
// NewWithParser is the same as New, except it supports custom parser settings
// introduced in jwt-go/v2.4.0.
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
return &JWTAuth{
signKey: signKey,
verifyKey: verifyKey,
signer: jwt.GetSigningMethod(alg),
parser: parser,
if ja.verifyKey != nil {
ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey)
} else {
ja.verifier = jwt.WithVerify(ja.alg, ja.signKey)
}
return ja
}
// Verifier http middleware handler will verify a JWT string from a http request.
@ -85,73 +81,81 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http
}
}
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (*jwt.Token, error) {
var tokenStr string
var err error
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
var tokenString string
// Extract token string from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
tokenString = fn(r)
if tokenString != "" {
break
}
}
if tokenStr == "" {
if tokenString == "" {
return nil, ErrNoTokenFound
}
// Verify the token
token, err := ja.Decode(tokenStr)
return VerifyToken(ja, tokenString)
}
func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
// Decode & verify the token
token, err := ja.Decode(tokenString)
if err != nil {
if verr, ok := err.(*jwt.ValidationError); ok {
if verr.Errors&jwt.ValidationErrorExpired > 0 {
return token, ErrExpired
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrIATInvalid
} else if verr.Errors&jwt.ValidationErrorNotValidYet > 0 {
return token, ErrNBFInvalid
}
}
return token, err
return token, ErrorReason(err)
}
if token == nil || !token.Valid {
err = ErrUnauthorized
return token, err
if token == nil {
return nil, ErrUnauthorized
}
// Verify signing algorithm
if token.Method != ja.signer {
return token, ErrAlgoInvalid
if err := jwt.Validate(token); err != nil {
return token, ErrorReason(err)
}
// Valid!
return token, nil
}
func (ja *JWTAuth) Encode(claims jwt.Claims) (t *jwt.Token, tokenString string, err error) {
t = jwt.New(ja.signer)
t.Claims = claims
tokenString, err = t.SignedString(ja.signKey)
t.Raw = tokenString
return
}
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
if err != nil {
return nil, err
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
t = jwt.New()
for k, v := range claims {
t.Set(k, v)
}
payload, err := ja.sign(t)
if err != nil {
return nil, "", err
}
tokenString = string(payload)
return
}
func (ja *JWTAuth) keyFunc(t *jwt.Token) (interface{}, error) {
if ja.verifyKey != nil {
return ja.verifyKey, nil
} else {
return ja.signKey, nil
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
return ja.parse([]byte(tokenString))
}
func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, ja.alg, ja.signKey)
}
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
return jwt.Parse(bytes.NewReader(payload), ja.verifier)
}
// ErrorReason will normalize the error message from the underlining
// jwt library
func ErrorReason(err error) error {
switch err.Error() {
case "exp not satisfied", ErrExpired.Error():
return ErrExpired
case "iat not satisfied", ErrIATInvalid.Error():
return ErrIATInvalid
case "nbf not satisfied", ErrNBFInvalid.Error():
return ErrNBFInvalid
default:
return ErrUnauthorized
}
}
@ -164,11 +168,11 @@ func Authenticator(next http.Handler) http.Handler {
token, _, err := FromContext(r.Context())
if err != nil {
http.Error(w, http.StatusText(401), 401)
http.Error(w, err.Error(), 401)
return
}
if token == nil || !token.Valid {
if token == nil || jwt.Validate(token) != nil {
http.Error(w, http.StatusText(401), 401)
return
}
@ -178,27 +182,28 @@ func Authenticator(next http.Handler) http.Handler {
})
}
func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
func NewContext(ctx context.Context, t jwt.Token, err error) context.Context {
ctx = context.WithValue(ctx, TokenCtxKey, t)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
token, _ := ctx.Value(TokenCtxKey).(jwt.Token)
var err error
var claims map[string]interface{}
var claims jwt.MapClaims
if token != nil {
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
claims = tokenClaims
} else {
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
claims, err = token.AsMap(context.Background())
if err != nil {
return token, nil, err
}
} else {
claims = jwt.MapClaims{}
claims = map[string]interface{}{}
}
err, _ := ctx.Value(ErrorCtxKey).(error)
err, _ = ctx.Value(ErrorCtxKey).(error)
return token, claims, err
}
@ -219,22 +224,22 @@ func ExpireIn(tm time.Duration) int64 {
}
// Set issued at ("iat") to specified time in the claims
func SetIssuedAt(claims jwt.MapClaims, tm time.Time) {
func SetIssuedAt(claims map[string]interface{}, tm time.Time) {
claims["iat"] = tm.UTC().Unix()
}
// Set issued at ("iat") to present time in the claims
func SetIssuedNow(claims jwt.MapClaims) {
func SetIssuedNow(claims map[string]interface{}) {
claims["iat"] = EpochNow()
}
// Set expiry ("exp") in the claims
func SetExpiry(claims jwt.MapClaims, tm time.Time) {
func SetExpiry(claims map[string]interface{}, tm time.Time) {
claims["exp"] = tm.UTC().Unix()
}
// Set expiry ("exp") in the claims to some duration from the present time
func SetExpiryIn(claims jwt.MapClaims, tm time.Duration) {
func SetExpiryIn(claims map[string]interface{}, tm time.Duration) {
claims["exp"] = ExpireIn(tm)
}

View file

@ -1,6 +1,7 @@
package jwtauth_test
import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
@ -13,9 +14,9 @@ import (
"testing"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi"
"github.com/go-chi/jwtauth"
"github.com/lestrrat-go/jwx/jwt"
)
var (
@ -69,7 +70,7 @@ func TestSimpleRSA(t *testing.T) {
TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey)
claims := jwt.MapClaims{
claims := map[string]interface{}{
"key": "val",
"key2": "val2",
"key3": "val3",
@ -87,7 +88,12 @@ func TestSimpleRSA(t *testing.T) {
t.Fatalf("Failed to decode token string %s\n", err.Error())
}
if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) {
tokenClaims, err := token.AsMap(context.Background())
if err != nil {
t.Fatal(err.Error())
}
if !reflect.DeepEqual(claims, tokenClaims) {
t.Fatalf("The decoded claims don't match the original ones\n")
}
}
@ -105,27 +111,27 @@ func TestSimple(t *testing.T) {
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "no token found\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
// wrong token secret and wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
// correct token secret but wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
@ -147,23 +153,12 @@ func TestMore(t *testing.T) {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil {
switch err {
default:
http.Error(w, http.StatusText(401), 401)
return
case jwtauth.ErrExpired:
http.Error(w, "expired", 401)
return
case jwtauth.ErrUnauthorized:
http.Error(w, http.StatusText(401), 401)
return
case nil:
// no error
}
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
return
}
if token == nil || !token.Valid {
http.Error(w, http.StatusText(401), 401)
if err := jwt.Validate(token); err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
return
}
@ -196,32 +191,32 @@ func TestMore(t *testing.T) {
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
// wrong token secret and wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
// correct token secret but wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
t.Fatalf(resp)
}
h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000})
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
h = newAuthHeader(map[string]interface{}{"exp": jwtauth.EpochNow() - 1000})
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is expired\n" {
t.Fatalf(resp)
}
@ -230,7 +225,7 @@ func TestMore(t *testing.T) {
t.Fatalf(resp)
}
h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
h = newAuthHeader((map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
@ -269,32 +264,36 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
return resp.StatusCode, string(respBody)
}
func newJwtToken(secret []byte, claims ...jwt.MapClaims) string {
token := jwt.New(jwt.GetSigningMethod("HS256"))
func newJwtToken(secret []byte, claims ...map[string]interface{}) string {
token := jwt.New()
if len(claims) > 0 {
token.Claims = claims[0]
for k, v := range claims[0] {
token.Set(k, v)
}
}
tokenStr, err := token.SignedString(secret)
tokenPayload, err := jwt.Sign(token, "HS256", secret)
if err != nil {
log.Fatal(err)
}
return tokenStr
return string(tokenPayload)
}
func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string {
func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
// use-case: when token is signed with a different alg than expected
token := jwt.New(jwt.GetSigningMethod("HS512"))
token := jwt.New()
if len(claims) > 0 {
token.Claims = claims[0]
for k, v := range claims[0] {
token.Set(k, v)
}
}
tokenStr, err := token.SignedString(secret)
tokenPayload, err := jwt.Sign(token, "HS512", secret)
if err != nil {
log.Fatal(err)
}
return tokenStr
return string(tokenPayload)
}
func newAuthHeader(claims ...jwt.MapClaims) http.Header {
func newAuthHeader(claims ...map[string]interface{}) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
return h