mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 11:16:32 -05:00
Switch to github.com/lestrrat-go/jwx underlying jwt library (#52)
This commit is contained in:
parent
02fa0c511c
commit
b8af768272
8 changed files with 200 additions and 148 deletions
32
.github/workflows/ci.yml
vendored
Normal file
32
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
on: [push, pull_request]
|
||||
name: Test
|
||||
jobs:
|
||||
test:
|
||||
env:
|
||||
GOPATH: ${{ github.workspace }}
|
||||
GO111MODULE: off
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: [1.14.x, 1.15.x]
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
|
||||
- name: Test
|
||||
run: |
|
||||
go get -d -t ./...
|
||||
go test -v ./...
|
||||
16
.travis.yml
16
.travis.yml
|
|
@ -1,16 +0,0 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
- 1.13.x
|
||||
|
||||
install:
|
||||
- go get -u golang.org/x/tools/cmd/goimports
|
||||
|
||||
script:
|
||||
- go get -d -t ./...
|
||||
- go test ./...
|
||||
- >
|
||||
goimports -d -e ./ | grep '.*' && { echo; echo "Aborting due to non-empty goimports output."; exit 1; } || :
|
||||
|
|
@ -7,9 +7,7 @@ from a http request and send the result down the request context (`context.Conte
|
|||
|
||||
Please note, `jwtauth` works with any Go http router, but resides under the go-chi group
|
||||
for maintenance and organization - its only 3rd party dependency is the underlying jwt library
|
||||
"github.com/dgrijalva/jwt-go".
|
||||
|
||||
This package uses the new `context` package in Go 1.7 stdlib and [net/http#Request.Context](https://golang.org/pkg/net/http/#Request.Context) to pass values between handler chains.
|
||||
"github.com/lestrrat-go/jwx".
|
||||
|
||||
In a complete JWT-authentication flow, you'll first capture the token from a http
|
||||
request, decode it, verify it and then validate that its correctly signed and hasn't
|
||||
|
|
@ -65,7 +63,7 @@ func init() {
|
|||
|
||||
// For debugging/example purposes, we generate and print
|
||||
// a sample jwt token with claims `user_id:123` here:
|
||||
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
|
||||
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
|
||||
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/jwtauth"
|
||||
)
|
||||
|
|
@ -74,7 +73,7 @@ func init() {
|
|||
|
||||
// For debugging/example purposes, we generate and print
|
||||
// a sample jwt token with claims `user_id:123` here:
|
||||
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
|
||||
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
|
||||
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
|
||||
}
|
||||
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -3,6 +3,6 @@ module github.com/go-chi/jwtauth
|
|||
go 1.15
|
||||
|
||||
require (
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/go-chi/chi v1.5.1
|
||||
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029
|
||||
)
|
||||
|
|
|
|||
39
go.sum
39
go.sum
|
|
@ -1,4 +1,39 @@
|
|||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi v1.5.1 h1:kfTK3Cxd/dkMu/rKs5ZceWYp+t5CtiE7vmaTv3LjC6w=
|
||||
github.com/go-chi/chi v1.5.1/go.mod h1:REp24E+25iKvxgeTfHmdUoL5x15kBiDBlnIl5bCwe2k=
|
||||
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY=
|
||||
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
|
||||
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029 h1:+HTAqhgKkKqizghOYb4uEpZ7wK8tl3Y48ZbUTHF521c=
|
||||
github.com/lestrrat-go/jwx v1.0.6-0.20201127121120-26218808f029/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0=
|
||||
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d h1:aEZT3f1GGg5RIlHMAy4/4fe4ciOi3SCwYoaURphcB4k=
|
||||
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
|||
167
jwtauth.go
167
jwtauth.go
|
|
@ -1,14 +1,15 @@
|
|||
package jwtauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
|
|
@ -19,36 +20,31 @@ var (
|
|||
|
||||
// Library errors
|
||||
var (
|
||||
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
|
||||
ErrExpired = errors.New("jwtauth: token is expired")
|
||||
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("jwtauth: no token found")
|
||||
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
|
||||
ErrUnauthorized = errors.New("token is unauthorized")
|
||||
ErrExpired = errors.New("token is expired")
|
||||
ErrNBFInvalid = errors.New("token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("no token found")
|
||||
ErrAlgoInvalid = errors.New("algorithm mismatch")
|
||||
)
|
||||
|
||||
type JWTAuth struct {
|
||||
signKey interface{}
|
||||
verifyKey interface{}
|
||||
signer jwt.SigningMethod
|
||||
parser *jwt.Parser
|
||||
alg jwa.SignatureAlgorithm
|
||||
signKey interface{} // private-key
|
||||
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
|
||||
verifier jwt.ParseOption
|
||||
}
|
||||
|
||||
// New creates a JWTAuth authenticator instance that provides middleware handlers
|
||||
// and encoding/decoding functions for JWT signing.
|
||||
func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||||
return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey)
|
||||
}
|
||||
ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey}
|
||||
|
||||
// NewWithParser is the same as New, except it supports custom parser settings
|
||||
// introduced in jwt-go/v2.4.0.
|
||||
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||||
return &JWTAuth{
|
||||
signKey: signKey,
|
||||
verifyKey: verifyKey,
|
||||
signer: jwt.GetSigningMethod(alg),
|
||||
parser: parser,
|
||||
if ja.verifyKey != nil {
|
||||
ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey)
|
||||
} else {
|
||||
ja.verifier = jwt.WithVerify(ja.alg, ja.signKey)
|
||||
}
|
||||
|
||||
return ja
|
||||
}
|
||||
|
||||
// Verifier http middleware handler will verify a JWT string from a http request.
|
||||
|
|
@ -85,73 +81,81 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http
|
|||
}
|
||||
}
|
||||
|
||||
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (*jwt.Token, error) {
|
||||
var tokenStr string
|
||||
var err error
|
||||
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
|
||||
var tokenString string
|
||||
|
||||
// Extract token string from the request by calling token find functions in
|
||||
// the order they where provided. Further extraction stops if a function
|
||||
// returns a non-empty string.
|
||||
for _, fn := range findTokenFns {
|
||||
tokenStr = fn(r)
|
||||
if tokenStr != "" {
|
||||
tokenString = fn(r)
|
||||
if tokenString != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenStr == "" {
|
||||
if tokenString == "" {
|
||||
return nil, ErrNoTokenFound
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
token, err := ja.Decode(tokenStr)
|
||||
return VerifyToken(ja, tokenString)
|
||||
}
|
||||
|
||||
func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
|
||||
// Decode & verify the token
|
||||
token, err := ja.Decode(tokenString)
|
||||
if err != nil {
|
||||
if verr, ok := err.(*jwt.ValidationError); ok {
|
||||
if verr.Errors&jwt.ValidationErrorExpired > 0 {
|
||||
return token, ErrExpired
|
||||
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
|
||||
return token, ErrIATInvalid
|
||||
} else if verr.Errors&jwt.ValidationErrorNotValidYet > 0 {
|
||||
return token, ErrNBFInvalid
|
||||
}
|
||||
}
|
||||
return token, err
|
||||
return token, ErrorReason(err)
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
err = ErrUnauthorized
|
||||
return token, err
|
||||
if token == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
// Verify signing algorithm
|
||||
if token.Method != ja.signer {
|
||||
return token, ErrAlgoInvalid
|
||||
if err := jwt.Validate(token); err != nil {
|
||||
return token, ErrorReason(err)
|
||||
}
|
||||
|
||||
// Valid!
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Encode(claims jwt.Claims) (t *jwt.Token, tokenString string, err error) {
|
||||
t = jwt.New(ja.signer)
|
||||
t.Claims = claims
|
||||
tokenString, err = t.SignedString(ja.signKey)
|
||||
t.Raw = tokenString
|
||||
return
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
||||
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
|
||||
t = jwt.New()
|
||||
for k, v := range claims {
|
||||
t.Set(k, v)
|
||||
}
|
||||
payload, err := ja.sign(t)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
tokenString = string(payload)
|
||||
return
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
||||
if ja.verifyKey != nil {
|
||||
return ja.verifyKey, nil
|
||||
} else {
|
||||
return ja.signKey, nil
|
||||
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
|
||||
return ja.parse([]byte(tokenString))
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, ja.alg, ja.signKey)
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
|
||||
return jwt.Parse(bytes.NewReader(payload), ja.verifier)
|
||||
}
|
||||
|
||||
// ErrorReason will normalize the error message from the underlining
|
||||
// jwt library
|
||||
func ErrorReason(err error) error {
|
||||
switch err.Error() {
|
||||
case "exp not satisfied", ErrExpired.Error():
|
||||
return ErrExpired
|
||||
case "iat not satisfied", ErrIATInvalid.Error():
|
||||
return ErrIATInvalid
|
||||
case "nbf not satisfied", ErrNBFInvalid.Error():
|
||||
return ErrNBFInvalid
|
||||
default:
|
||||
return ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,11 +168,11 @@ func Authenticator(next http.Handler) http.Handler {
|
|||
token, _, err := FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
if token == nil || jwt.Validate(token) != nil {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
}
|
||||
|
|
@ -178,27 +182,28 @@ func Authenticator(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
|
||||
func NewContext(ctx context.Context, t jwt.Token, err error) context.Context {
|
||||
ctx = context.WithValue(ctx, TokenCtxKey, t)
|
||||
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
|
||||
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
|
||||
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
|
||||
token, _ := ctx.Value(TokenCtxKey).(jwt.Token)
|
||||
|
||||
var err error
|
||||
var claims map[string]interface{}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
if token != nil {
|
||||
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
claims = tokenClaims
|
||||
} else {
|
||||
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
|
||||
claims, err = token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return token, nil, err
|
||||
}
|
||||
} else {
|
||||
claims = jwt.MapClaims{}
|
||||
claims = map[string]interface{}{}
|
||||
}
|
||||
|
||||
err, _ := ctx.Value(ErrorCtxKey).(error)
|
||||
err, _ = ctx.Value(ErrorCtxKey).(error)
|
||||
|
||||
return token, claims, err
|
||||
}
|
||||
|
|
@ -219,22 +224,22 @@ func ExpireIn(tm time.Duration) int64 {
|
|||
}
|
||||
|
||||
// Set issued at ("iat") to specified time in the claims
|
||||
func SetIssuedAt(claims jwt.MapClaims, tm time.Time) {
|
||||
func SetIssuedAt(claims map[string]interface{}, tm time.Time) {
|
||||
claims["iat"] = tm.UTC().Unix()
|
||||
}
|
||||
|
||||
// Set issued at ("iat") to present time in the claims
|
||||
func SetIssuedNow(claims jwt.MapClaims) {
|
||||
func SetIssuedNow(claims map[string]interface{}) {
|
||||
claims["iat"] = EpochNow()
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims
|
||||
func SetExpiry(claims jwt.MapClaims, tm time.Time) {
|
||||
func SetExpiry(claims map[string]interface{}, tm time.Time) {
|
||||
claims["exp"] = tm.UTC().Unix()
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims to some duration from the present time
|
||||
func SetExpiryIn(claims jwt.MapClaims, tm time.Duration) {
|
||||
func SetExpiryIn(claims map[string]interface{}, tm time.Duration) {
|
||||
claims["exp"] = ExpireIn(tm)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package jwtauth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
|
@ -13,9 +14,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -69,7 +70,7 @@ func TestSimpleRSA(t *testing.T) {
|
|||
|
||||
TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
claims := map[string]interface{}{
|
||||
"key": "val",
|
||||
"key2": "val2",
|
||||
"key3": "val3",
|
||||
|
|
@ -87,7 +88,12 @@ func TestSimpleRSA(t *testing.T) {
|
|||
t.Fatalf("Failed to decode token string %s\n", err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) {
|
||||
tokenClaims, err := token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(claims, tokenClaims) {
|
||||
t.Fatalf("The decoded claims don't match the original ones\n")
|
||||
}
|
||||
}
|
||||
|
|
@ -105,27 +111,27 @@ func TestSimple(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
// sending unauthorized requests
|
||||
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "no token found\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
h.Set("Authorization", "BEARER asdf")
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// wrong token secret and wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// correct token secret but wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
|
|
@ -147,23 +153,12 @@ func TestMore(t *testing.T) {
|
|||
token, _, err := jwtauth.FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
default:
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
case jwtauth.ErrExpired:
|
||||
http.Error(w, "expired", 401)
|
||||
return
|
||||
case jwtauth.ErrUnauthorized:
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
case nil:
|
||||
// no error
|
||||
}
|
||||
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
if err := jwt.Validate(token); err != nil {
|
||||
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -196,32 +191,32 @@ func TestMore(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
// sending unauthorized requests
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
h.Set("Authorization", "BEARER asdf")
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// wrong token secret and wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// correct token secret but wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000})
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
|
||||
h = newAuthHeader(map[string]interface{}{"exp": jwtauth.EpochNow() - 1000})
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is expired\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +225,7 @@ func TestMore(t *testing.T) {
|
|||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
||||
h = newAuthHeader((map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
|
@ -269,32 +264,36 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
|
|||
return resp.StatusCode, string(respBody)
|
||||
}
|
||||
|
||||
func newJwtToken(secret []byte, claims ...jwt.MapClaims) string {
|
||||
token := jwt.New(jwt.GetSigningMethod("HS256"))
|
||||
func newJwtToken(secret []byte, claims ...map[string]interface{}) string {
|
||||
token := jwt.New()
|
||||
if len(claims) > 0 {
|
||||
token.Claims = claims[0]
|
||||
for k, v := range claims[0] {
|
||||
token.Set(k, v)
|
||||
}
|
||||
}
|
||||
tokenStr, err := token.SignedString(secret)
|
||||
tokenPayload, err := jwt.Sign(token, "HS256", secret)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return tokenStr
|
||||
return string(tokenPayload)
|
||||
}
|
||||
|
||||
func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string {
|
||||
func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
|
||||
// use-case: when token is signed with a different alg than expected
|
||||
token := jwt.New(jwt.GetSigningMethod("HS512"))
|
||||
token := jwt.New()
|
||||
if len(claims) > 0 {
|
||||
token.Claims = claims[0]
|
||||
for k, v := range claims[0] {
|
||||
token.Set(k, v)
|
||||
}
|
||||
}
|
||||
tokenStr, err := token.SignedString(secret)
|
||||
tokenPayload, err := jwt.Sign(token, "HS512", secret)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return tokenStr
|
||||
return string(tokenPayload)
|
||||
}
|
||||
|
||||
func newAuthHeader(claims ...jwt.MapClaims) http.Header {
|
||||
func newAuthHeader(claims ...map[string]interface{}) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
||||
return h
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue