mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 13:47:41 -05:00
Switch to github.com/lestrrat-go/jwx underlying jwt library (#52)
This commit is contained in:
parent
02fa0c511c
commit
b8af768272
8 changed files with 200 additions and 148 deletions
167
jwtauth.go
167
jwtauth.go
|
|
@ -1,14 +1,15 @@
|
|||
package jwtauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
|
|
@ -19,36 +20,31 @@ var (
|
|||
|
||||
// Library errors
|
||||
var (
|
||||
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
|
||||
ErrExpired = errors.New("jwtauth: token is expired")
|
||||
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("jwtauth: no token found")
|
||||
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
|
||||
ErrUnauthorized = errors.New("token is unauthorized")
|
||||
ErrExpired = errors.New("token is expired")
|
||||
ErrNBFInvalid = errors.New("token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("no token found")
|
||||
ErrAlgoInvalid = errors.New("algorithm mismatch")
|
||||
)
|
||||
|
||||
type JWTAuth struct {
|
||||
signKey interface{}
|
||||
verifyKey interface{}
|
||||
signer jwt.SigningMethod
|
||||
parser *jwt.Parser
|
||||
alg jwa.SignatureAlgorithm
|
||||
signKey interface{} // private-key
|
||||
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
|
||||
verifier jwt.ParseOption
|
||||
}
|
||||
|
||||
// New creates a JWTAuth authenticator instance that provides middleware handlers
|
||||
// and encoding/decoding functions for JWT signing.
|
||||
func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||||
return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey)
|
||||
}
|
||||
ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey}
|
||||
|
||||
// NewWithParser is the same as New, except it supports custom parser settings
|
||||
// introduced in jwt-go/v2.4.0.
|
||||
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||||
return &JWTAuth{
|
||||
signKey: signKey,
|
||||
verifyKey: verifyKey,
|
||||
signer: jwt.GetSigningMethod(alg),
|
||||
parser: parser,
|
||||
if ja.verifyKey != nil {
|
||||
ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey)
|
||||
} else {
|
||||
ja.verifier = jwt.WithVerify(ja.alg, ja.signKey)
|
||||
}
|
||||
|
||||
return ja
|
||||
}
|
||||
|
||||
// Verifier http middleware handler will verify a JWT string from a http request.
|
||||
|
|
@ -85,73 +81,81 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http
|
|||
}
|
||||
}
|
||||
|
||||
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (*jwt.Token, error) {
|
||||
var tokenStr string
|
||||
var err error
|
||||
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
|
||||
var tokenString string
|
||||
|
||||
// Extract token string from the request by calling token find functions in
|
||||
// the order they where provided. Further extraction stops if a function
|
||||
// returns a non-empty string.
|
||||
for _, fn := range findTokenFns {
|
||||
tokenStr = fn(r)
|
||||
if tokenStr != "" {
|
||||
tokenString = fn(r)
|
||||
if tokenString != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenStr == "" {
|
||||
if tokenString == "" {
|
||||
return nil, ErrNoTokenFound
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
token, err := ja.Decode(tokenStr)
|
||||
return VerifyToken(ja, tokenString)
|
||||
}
|
||||
|
||||
func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
|
||||
// Decode & verify the token
|
||||
token, err := ja.Decode(tokenString)
|
||||
if err != nil {
|
||||
if verr, ok := err.(*jwt.ValidationError); ok {
|
||||
if verr.Errors&jwt.ValidationErrorExpired > 0 {
|
||||
return token, ErrExpired
|
||||
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
|
||||
return token, ErrIATInvalid
|
||||
} else if verr.Errors&jwt.ValidationErrorNotValidYet > 0 {
|
||||
return token, ErrNBFInvalid
|
||||
}
|
||||
}
|
||||
return token, err
|
||||
return token, ErrorReason(err)
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
err = ErrUnauthorized
|
||||
return token, err
|
||||
if token == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
// Verify signing algorithm
|
||||
if token.Method != ja.signer {
|
||||
return token, ErrAlgoInvalid
|
||||
if err := jwt.Validate(token); err != nil {
|
||||
return token, ErrorReason(err)
|
||||
}
|
||||
|
||||
// Valid!
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Encode(claims jwt.Claims) (t *jwt.Token, tokenString string, err error) {
|
||||
t = jwt.New(ja.signer)
|
||||
t.Claims = claims
|
||||
tokenString, err = t.SignedString(ja.signKey)
|
||||
t.Raw = tokenString
|
||||
return
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
||||
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
|
||||
t = jwt.New()
|
||||
for k, v := range claims {
|
||||
t.Set(k, v)
|
||||
}
|
||||
payload, err := ja.sign(t)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
tokenString = string(payload)
|
||||
return
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
||||
if ja.verifyKey != nil {
|
||||
return ja.verifyKey, nil
|
||||
} else {
|
||||
return ja.signKey, nil
|
||||
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
|
||||
return ja.parse([]byte(tokenString))
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, ja.alg, ja.signKey)
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
|
||||
return jwt.Parse(bytes.NewReader(payload), ja.verifier)
|
||||
}
|
||||
|
||||
// ErrorReason will normalize the error message from the underlining
|
||||
// jwt library
|
||||
func ErrorReason(err error) error {
|
||||
switch err.Error() {
|
||||
case "exp not satisfied", ErrExpired.Error():
|
||||
return ErrExpired
|
||||
case "iat not satisfied", ErrIATInvalid.Error():
|
||||
return ErrIATInvalid
|
||||
case "nbf not satisfied", ErrNBFInvalid.Error():
|
||||
return ErrNBFInvalid
|
||||
default:
|
||||
return ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,11 +168,11 @@ func Authenticator(next http.Handler) http.Handler {
|
|||
token, _, err := FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
if token == nil || jwt.Validate(token) != nil {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
}
|
||||
|
|
@ -178,27 +182,28 @@ func Authenticator(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
|
||||
func NewContext(ctx context.Context, t jwt.Token, err error) context.Context {
|
||||
ctx = context.WithValue(ctx, TokenCtxKey, t)
|
||||
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
|
||||
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
|
||||
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
|
||||
token, _ := ctx.Value(TokenCtxKey).(jwt.Token)
|
||||
|
||||
var err error
|
||||
var claims map[string]interface{}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
if token != nil {
|
||||
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
claims = tokenClaims
|
||||
} else {
|
||||
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
|
||||
claims, err = token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return token, nil, err
|
||||
}
|
||||
} else {
|
||||
claims = jwt.MapClaims{}
|
||||
claims = map[string]interface{}{}
|
||||
}
|
||||
|
||||
err, _ := ctx.Value(ErrorCtxKey).(error)
|
||||
err, _ = ctx.Value(ErrorCtxKey).(error)
|
||||
|
||||
return token, claims, err
|
||||
}
|
||||
|
|
@ -219,22 +224,22 @@ func ExpireIn(tm time.Duration) int64 {
|
|||
}
|
||||
|
||||
// Set issued at ("iat") to specified time in the claims
|
||||
func SetIssuedAt(claims jwt.MapClaims, tm time.Time) {
|
||||
func SetIssuedAt(claims map[string]interface{}, tm time.Time) {
|
||||
claims["iat"] = tm.UTC().Unix()
|
||||
}
|
||||
|
||||
// Set issued at ("iat") to present time in the claims
|
||||
func SetIssuedNow(claims jwt.MapClaims) {
|
||||
func SetIssuedNow(claims map[string]interface{}) {
|
||||
claims["iat"] = EpochNow()
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims
|
||||
func SetExpiry(claims jwt.MapClaims, tm time.Time) {
|
||||
func SetExpiry(claims map[string]interface{}, tm time.Time) {
|
||||
claims["exp"] = tm.UTC().Unix()
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims to some duration from the present time
|
||||
func SetExpiryIn(claims jwt.MapClaims, tm time.Duration) {
|
||||
func SetExpiryIn(claims map[string]interface{}, tm time.Duration) {
|
||||
claims["exp"] = ExpireIn(tm)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue