Switch to github.com/lestrrat-go/jwx underlying jwt library (#52)

This commit is contained in:
Peter Kieltyka 2020-12-12 09:40:27 -05:00 committed by GitHub
parent 02fa0c511c
commit b8af768272
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 200 additions and 148 deletions

View file

@ -1,14 +1,15 @@
package jwtauth
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
)
// Context keys
@ -19,36 +20,31 @@ var (
// Library errors
var (
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
ErrExpired = errors.New("jwtauth: token is expired")
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
ErrNoTokenFound = errors.New("jwtauth: no token found")
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
ErrUnauthorized = errors.New("token is unauthorized")
ErrExpired = errors.New("token is expired")
ErrNBFInvalid = errors.New("token nbf validation failed")
ErrIATInvalid = errors.New("token iat validation failed")
ErrNoTokenFound = errors.New("no token found")
ErrAlgoInvalid = errors.New("algorithm mismatch")
)
type JWTAuth struct {
signKey interface{}
verifyKey interface{}
signer jwt.SigningMethod
parser *jwt.Parser
alg jwa.SignatureAlgorithm
signKey interface{} // private-key
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
}
// New creates a JWTAuth authenticator instance that provides middleware handlers
// and encoding/decoding functions for JWT signing.
func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey)
}
ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey}
// NewWithParser is the same as New, except it supports custom parser settings
// introduced in jwt-go/v2.4.0.
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
return &JWTAuth{
signKey: signKey,
verifyKey: verifyKey,
signer: jwt.GetSigningMethod(alg),
parser: parser,
if ja.verifyKey != nil {
ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey)
} else {
ja.verifier = jwt.WithVerify(ja.alg, ja.signKey)
}
return ja
}
// Verifier http middleware handler will verify a JWT string from a http request.
@ -85,73 +81,81 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http
}
}
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (*jwt.Token, error) {
var tokenStr string
var err error
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
var tokenString string
// Extract token string from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
tokenString = fn(r)
if tokenString != "" {
break
}
}
if tokenStr == "" {
if tokenString == "" {
return nil, ErrNoTokenFound
}
// Verify the token
token, err := ja.Decode(tokenStr)
return VerifyToken(ja, tokenString)
}
func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
// Decode & verify the token
token, err := ja.Decode(tokenString)
if err != nil {
if verr, ok := err.(*jwt.ValidationError); ok {
if verr.Errors&jwt.ValidationErrorExpired > 0 {
return token, ErrExpired
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrIATInvalid
} else if verr.Errors&jwt.ValidationErrorNotValidYet > 0 {
return token, ErrNBFInvalid
}
}
return token, err
return token, ErrorReason(err)
}
if token == nil || !token.Valid {
err = ErrUnauthorized
return token, err
if token == nil {
return nil, ErrUnauthorized
}
// Verify signing algorithm
if token.Method != ja.signer {
return token, ErrAlgoInvalid
if err := jwt.Validate(token); err != nil {
return token, ErrorReason(err)
}
// Valid!
return token, nil
}
func (ja *JWTAuth) Encode(claims jwt.Claims) (t *jwt.Token, tokenString string, err error) {
t = jwt.New(ja.signer)
t.Claims = claims
tokenString, err = t.SignedString(ja.signKey)
t.Raw = tokenString
return
}
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
if err != nil {
return nil, err
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
t = jwt.New()
for k, v := range claims {
t.Set(k, v)
}
payload, err := ja.sign(t)
if err != nil {
return nil, "", err
}
tokenString = string(payload)
return
}
func (ja *JWTAuth) keyFunc(t *jwt.Token) (interface{}, error) {
if ja.verifyKey != nil {
return ja.verifyKey, nil
} else {
return ja.signKey, nil
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
return ja.parse([]byte(tokenString))
}
func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, ja.alg, ja.signKey)
}
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
return jwt.Parse(bytes.NewReader(payload), ja.verifier)
}
// ErrorReason will normalize the error message from the underlining
// jwt library
func ErrorReason(err error) error {
switch err.Error() {
case "exp not satisfied", ErrExpired.Error():
return ErrExpired
case "iat not satisfied", ErrIATInvalid.Error():
return ErrIATInvalid
case "nbf not satisfied", ErrNBFInvalid.Error():
return ErrNBFInvalid
default:
return ErrUnauthorized
}
}
@ -164,11 +168,11 @@ func Authenticator(next http.Handler) http.Handler {
token, _, err := FromContext(r.Context())
if err != nil {
http.Error(w, http.StatusText(401), 401)
http.Error(w, err.Error(), 401)
return
}
if token == nil || !token.Valid {
if token == nil || jwt.Validate(token) != nil {
http.Error(w, http.StatusText(401), 401)
return
}
@ -178,27 +182,28 @@ func Authenticator(next http.Handler) http.Handler {
})
}
func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
func NewContext(ctx context.Context, t jwt.Token, err error) context.Context {
ctx = context.WithValue(ctx, TokenCtxKey, t)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
token, _ := ctx.Value(TokenCtxKey).(jwt.Token)
var err error
var claims map[string]interface{}
var claims jwt.MapClaims
if token != nil {
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
claims = tokenClaims
} else {
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
claims, err = token.AsMap(context.Background())
if err != nil {
return token, nil, err
}
} else {
claims = jwt.MapClaims{}
claims = map[string]interface{}{}
}
err, _ := ctx.Value(ErrorCtxKey).(error)
err, _ = ctx.Value(ErrorCtxKey).(error)
return token, claims, err
}
@ -219,22 +224,22 @@ func ExpireIn(tm time.Duration) int64 {
}
// Set issued at ("iat") to specified time in the claims
func SetIssuedAt(claims jwt.MapClaims, tm time.Time) {
func SetIssuedAt(claims map[string]interface{}, tm time.Time) {
claims["iat"] = tm.UTC().Unix()
}
// Set issued at ("iat") to present time in the claims
func SetIssuedNow(claims jwt.MapClaims) {
func SetIssuedNow(claims map[string]interface{}) {
claims["iat"] = EpochNow()
}
// Set expiry ("exp") in the claims
func SetExpiry(claims jwt.MapClaims, tm time.Time) {
func SetExpiry(claims map[string]interface{}, tm time.Time) {
claims["exp"] = tm.UTC().Unix()
}
// Set expiry ("exp") in the claims to some duration from the present time
func SetExpiryIn(claims jwt.MapClaims, tm time.Duration) {
func SetExpiryIn(claims map[string]interface{}, tm time.Duration) {
claims["exp"] = ExpireIn(tm)
}