Remove unnecessary code and clean up errors (#30)

* Initial refactor

* Fix some of the tests

* Handle more specific errors from jwt library

* Fix comments

* Rename package

* Verify signing algo after nil check
This commit is contained in:
Jonathan ES Lin 2018-09-27 01:43:54 +08:00 committed by Peter Kieltyka
parent 53a0a4877a
commit ea7d7e213f
4 changed files with 49 additions and 121 deletions

View file

@ -2,7 +2,6 @@ package jwtauth
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -12,15 +11,20 @@ import (
"github.com/dgrijalva/jwt-go"
)
// Context keys
var (
TokenCtxKey = &contextKey{"Token"}
ErrorCtxKey = &contextKey{"Error"}
)
// Library errors
var (
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
ErrExpired = errors.New("jwtauth: token is expired")
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
ErrNoTokenFound = errors.New("jwtauth: no token found")
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
)
type JWTAuth struct {
@ -38,12 +42,7 @@ func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
// NewWithParser is the same as New, except it supports custom parser settings
// introduced in jwt-go/v2.4.0.
//
// We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that
// we can control when the claims are validated - in our case, by the Verifier
// http middleware handler.
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
parser.SkipClaimsValidation = true
return &JWTAuth{
signKey: signKey,
verifyKey: verifyKey,
@ -103,34 +102,36 @@ func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Re
return nil, ErrNoTokenFound
}
// TODO: what other kinds of validations should we do / error messages?
// Verify the token
token, err := ja.Decode(tokenStr)
if err != nil {
switch err.Error() {
case "token is expired":
err = ErrExpired
if verr, ok := err.(*jwt.ValidationError); ok {
if verr.Errors&jwt.ValidationErrorExpired > 0 {
return token, ErrExpired
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrIATInvalid
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
return token, ErrNBFInvalid
}
}
return token, err
}
if token == nil || !token.Valid || token.Method != ja.signer {
if token == nil || !token.Valid {
err = ErrUnauthorized
return token, err
}
// Check expiry via "exp" claim
if IsExpired(token) {
err = ErrExpired
return token, err
// Verify signing algorithm
if token.Method != ja.signer {
return token, ErrAlgoInvalid
}
// Valid!
return token, nil
}
func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) {
func (ja *JWTAuth) Encode(claims jwt.MapClaims) (t *jwt.Token, tokenString string, err error) {
t = jwt.New(ja.signer)
t.Claims = claims
tokenString, err = t.SignedString(ja.signKey)
@ -139,10 +140,6 @@ func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err
}
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
// Decode the tokenString, but avoid using custom Claims via jwt-go's
// ParseWithClaims as the jwt-go types will cause some glitches, so easier
// to decode as MapClaims then wrap the underlying map[string]interface{}
// to our Claims type
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
if err != nil {
return nil, err
@ -187,21 +184,18 @@ func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
return ctx
}
func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
var claims Claims
var claims jwt.MapClaims
if token != nil {
switch tokenClaims := token.Claims.(type) {
case Claims:
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
claims = tokenClaims
case jwt.MapClaims:
claims = Claims(tokenClaims)
default:
} else {
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
}
} else {
claims = Claims{}
claims = jwt.MapClaims{}
}
err, _ := ctx.Value(ErrorCtxKey).(error)
@ -209,83 +203,17 @@ func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
return token, claims, err
}
func IsExpired(t *jwt.Token) bool {
claims, ok := t.Claims.(jwt.MapClaims)
if !ok {
panic("jwtauth: expecting jwt.MapClaims")
}
if expv, ok := claims["exp"]; ok {
var exp int64
switch v := expv.(type) {
case float64:
exp = int64(v)
case int64:
exp = v
case json.Number:
exp, _ = v.Int64()
default:
}
if exp < EpochNow() {
return true
}
}
return false
// UnixTime returns the given time in UTC milliseconds
func UnixTime(tm time.Time) int64 {
return tm.UTC().Unix()
}
// Claims is a convenience type to manage a JWT claims hash.
type Claims map[string]interface{}
// NOTE: as of v3.0 of jwt-go, Valid() interface method is called to verify
// the claims. However, the current design we test these claims in the
// Verifier middleware, so we skip this step.
func (c Claims) Valid() error {
return nil
}
func (c Claims) Set(k string, v interface{}) Claims {
c[k] = v
return c
}
func (c Claims) Get(k string) (interface{}, bool) {
v, ok := c[k]
return v, ok
}
// Set issued at ("iat") to specified time in the claims
func (c Claims) SetIssuedAt(tm time.Time) Claims {
c["iat"] = tm.UTC().Unix()
return c
}
// Set issued at ("iat") to present time in the claims
func (c Claims) SetIssuedNow() Claims {
c["iat"] = EpochNow()
return c
}
// Set expiry ("exp") in the claims and return itself so it can be chained
func (c Claims) SetExpiry(tm time.Time) Claims {
c["exp"] = tm.UTC().Unix()
return c
}
// Set expiry ("exp") in the claims to some duration from the present time
// and return itself so it can be chained
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
c["exp"] = ExpireIn(tm)
return c
}
// Helper function that returns the NumericDate time value used by the spec
// EpochNow is a helper function that returns the NumericDate time value used by the spec
func EpochNow() int64 {
return time.Now().UTC().Unix()
}
// Helper function to return calculated time in the future for "exp" claim.
// ExpireIn is a helper function to return calculated time in the future for "exp" claim
func ExpireIn(tm time.Duration) int64 {
return EpochNow() + int64(tm.Seconds())
}