mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 11:16:32 -05:00
Remove unnecessary code and clean up errors (#30)
* Initial refactor * Fix some of the tests * Handle more specific errors from jwt library * Fix comments * Rename package * Verify signing algo after nil check
This commit is contained in:
parent
53a0a4877a
commit
ea7d7e213f
4 changed files with 49 additions and 121 deletions
25
README.md
25
README.md
|
|
@ -1,6 +1,6 @@
|
|||
jwtauth - JWT authentication middleware for Go HTTP services
|
||||
============================================================
|
||||
[![GoDoc Widget]][GoDoc]
|
||||
# jwtauth - JWT authentication middleware for Go HTTP services
|
||||
|
||||
[![GoDoc Widget]][godoc]
|
||||
|
||||
The `jwtauth` http middleware package provides a simple way to verify a JWT token
|
||||
from a http request and send the result down the request context (`context.Context`).
|
||||
|
|
@ -13,11 +13,11 @@ This package uses the new `context` package in Go 1.7 stdlib and [net/http#Reque
|
|||
|
||||
In a complete JWT-authentication flow, you'll first capture the token from a http
|
||||
request, decode it, verify it and then validate that its correctly signed and hasn't
|
||||
expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The
|
||||
expired - the `jwtauth.Verifier` middleware handler takes care of all of that. The
|
||||
`jwtauth.Verifier` will set the context values on keys `jwtauth.TokenCtxKey` and
|
||||
`jwtauth.ErrorCtxKey`.
|
||||
|
||||
Next, it's up to an authentication handler to respond or continue processing after the
|
||||
Next, it's up to an authentication handler to respond or continue processing after the
|
||||
`jwtauth.Verifier`. The `jwtauth.Authenticator` middleware responds with a 401 Unauthorized
|
||||
plain-text payload for all unverified tokens and passes the good ones through. You can
|
||||
also copy the Authenticator and customize it to handle invalid tokens to better fit
|
||||
|
|
@ -25,12 +25,12 @@ your flow (ie. with a JSON error response body).
|
|||
|
||||
By default, the `Verifier` will search for a JWT token in a http request, in the order:
|
||||
|
||||
1. 'jwt' URI query parameter
|
||||
2. 'Authorization: BEARER T' request header
|
||||
3. 'jwt' Cookie value
|
||||
1. 'jwt' URI query parameter
|
||||
2. 'Authorization: BEARER T' request header
|
||||
3. 'jwt' Cookie value
|
||||
|
||||
The first JWT string that is found as a query parameter, authorization header
|
||||
or cookie header is then decoded by the `jwt-go` library and a *jwt.Token
|
||||
or cookie header is then decoded by the `jwt-go` library and a \*jwt.Token
|
||||
object is set on the request context. In the case of a signature decoding error
|
||||
the Verifier will also set the error on the request context.
|
||||
|
||||
|
|
@ -43,7 +43,6 @@ Note: jwtauth supports custom verification sequences for finding a token
|
|||
from a request by using the `Verify` middleware instantiator directly. The default
|
||||
`Verifier` is instantiated by calling `Verify(ja, TokenFromQuery, TokenFromHeader, TokenFromCookie)`.
|
||||
|
||||
|
||||
# Usage
|
||||
|
||||
See the full [example](https://github.com/go-chi/jwtauth/blob/master/_example/main.go).
|
||||
|
|
@ -66,7 +65,7 @@ func init() {
|
|||
|
||||
// For debugging/example purposes, we generate and print
|
||||
// a sample jwt token with claims `user_id:123` here:
|
||||
_, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123})
|
||||
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
|
||||
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
|
||||
}
|
||||
|
||||
|
|
@ -111,5 +110,5 @@ func router() http.Handler {
|
|||
|
||||
[MIT](/LICENSE)
|
||||
|
||||
[GoDoc]: https://godoc.org/github.com/go-chi/jwtauth
|
||||
[GoDoc Widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg
|
||||
[godoc]: https://godoc.org/github.com/go-chi/jwtauth
|
||||
[godoc widget]: https://godoc.org/github.com/go-chi/jwtauth?status.svg
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/jwtauth"
|
||||
)
|
||||
|
|
@ -73,7 +74,7 @@ func init() {
|
|||
|
||||
// For debugging/example purposes, we generate and print
|
||||
// a sample jwt token with claims `user_id:123` here:
|
||||
_, tokenString, _ := tokenAuth.Encode(jwtauth.Claims{"user_id": 123})
|
||||
_, tokenString, _ := tokenAuth.Encode(jwt.MapClaims{"user_id": 123})
|
||||
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
|
||||
}
|
||||
|
||||
|
|
|
|||
128
jwtauth.go
128
jwtauth.go
|
|
@ -2,7 +2,6 @@ package jwtauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
|
@ -12,15 +11,20 @@ import (
|
|||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
var (
|
||||
TokenCtxKey = &contextKey{"Token"}
|
||||
ErrorCtxKey = &contextKey{"Error"}
|
||||
)
|
||||
|
||||
// Library errors
|
||||
var (
|
||||
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
|
||||
ErrExpired = errors.New("jwtauth: token is expired")
|
||||
ErrNBFInvalid = errors.New("jwtauth: token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("jwtauth: token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("jwtauth: no token found")
|
||||
ErrAlgoInvalid = errors.New("jwtauth: algorithm mismatch")
|
||||
)
|
||||
|
||||
type JWTAuth struct {
|
||||
|
|
@ -38,12 +42,7 @@ func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
|||
|
||||
// NewWithParser is the same as New, except it supports custom parser settings
|
||||
// introduced in jwt-go/v2.4.0.
|
||||
//
|
||||
// We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that
|
||||
// we can control when the claims are validated - in our case, by the Verifier
|
||||
// http middleware handler.
|
||||
func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||||
parser.SkipClaimsValidation = true
|
||||
return &JWTAuth{
|
||||
signKey: signKey,
|
||||
verifyKey: verifyKey,
|
||||
|
|
@ -103,34 +102,36 @@ func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Re
|
|||
return nil, ErrNoTokenFound
|
||||
}
|
||||
|
||||
// TODO: what other kinds of validations should we do / error messages?
|
||||
|
||||
// Verify the token
|
||||
token, err := ja.Decode(tokenStr)
|
||||
if err != nil {
|
||||
switch err.Error() {
|
||||
case "token is expired":
|
||||
err = ErrExpired
|
||||
if verr, ok := err.(*jwt.ValidationError); ok {
|
||||
if verr.Errors&jwt.ValidationErrorExpired > 0 {
|
||||
return token, ErrExpired
|
||||
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
|
||||
return token, ErrIATInvalid
|
||||
} else if verr.Errors&jwt.ValidationErrorIssuedAt > 0 {
|
||||
return token, ErrNBFInvalid
|
||||
}
|
||||
}
|
||||
return token, err
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid || token.Method != ja.signer {
|
||||
if token == nil || !token.Valid {
|
||||
err = ErrUnauthorized
|
||||
return token, err
|
||||
}
|
||||
|
||||
// Check expiry via "exp" claim
|
||||
if IsExpired(token) {
|
||||
err = ErrExpired
|
||||
return token, err
|
||||
// Verify signing algorithm
|
||||
if token.Method != ja.signer {
|
||||
return token, ErrAlgoInvalid
|
||||
}
|
||||
|
||||
// Valid!
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) {
|
||||
func (ja *JWTAuth) Encode(claims jwt.MapClaims) (t *jwt.Token, tokenString string, err error) {
|
||||
t = jwt.New(ja.signer)
|
||||
t.Claims = claims
|
||||
tokenString, err = t.SignedString(ja.signKey)
|
||||
|
|
@ -139,10 +140,6 @@ func (ja *JWTAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err
|
|||
}
|
||||
|
||||
func (ja *JWTAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
||||
// Decode the tokenString, but avoid using custom Claims via jwt-go's
|
||||
// ParseWithClaims as the jwt-go types will cause some glitches, so easier
|
||||
// to decode as MapClaims then wrap the underlying map[string]interface{}
|
||||
// to our Claims type
|
||||
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -187,21 +184,18 @@ func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
|
|||
return ctx
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
|
||||
func FromContext(ctx context.Context) (*jwt.Token, jwt.MapClaims, error) {
|
||||
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
|
||||
|
||||
var claims Claims
|
||||
var claims jwt.MapClaims
|
||||
if token != nil {
|
||||
switch tokenClaims := token.Claims.(type) {
|
||||
case Claims:
|
||||
if tokenClaims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
claims = tokenClaims
|
||||
case jwt.MapClaims:
|
||||
claims = Claims(tokenClaims)
|
||||
default:
|
||||
} else {
|
||||
panic(fmt.Sprintf("jwtauth: unknown type of Claims: %T", token.Claims))
|
||||
}
|
||||
} else {
|
||||
claims = Claims{}
|
||||
claims = jwt.MapClaims{}
|
||||
}
|
||||
|
||||
err, _ := ctx.Value(ErrorCtxKey).(error)
|
||||
|
|
@ -209,83 +203,17 @@ func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
|
|||
return token, claims, err
|
||||
}
|
||||
|
||||
func IsExpired(t *jwt.Token) bool {
|
||||
claims, ok := t.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
panic("jwtauth: expecting jwt.MapClaims")
|
||||
}
|
||||
|
||||
if expv, ok := claims["exp"]; ok {
|
||||
var exp int64
|
||||
switch v := expv.(type) {
|
||||
case float64:
|
||||
exp = int64(v)
|
||||
case int64:
|
||||
exp = v
|
||||
case json.Number:
|
||||
exp, _ = v.Int64()
|
||||
default:
|
||||
}
|
||||
|
||||
if exp < EpochNow() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
// UnixTime returns the given time in UTC milliseconds
|
||||
func UnixTime(tm time.Time) int64 {
|
||||
return tm.UTC().Unix()
|
||||
}
|
||||
|
||||
// Claims is a convenience type to manage a JWT claims hash.
|
||||
type Claims map[string]interface{}
|
||||
|
||||
// NOTE: as of v3.0 of jwt-go, Valid() interface method is called to verify
|
||||
// the claims. However, the current design we test these claims in the
|
||||
// Verifier middleware, so we skip this step.
|
||||
func (c Claims) Valid() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Claims) Set(k string, v interface{}) Claims {
|
||||
c[k] = v
|
||||
return c
|
||||
}
|
||||
|
||||
func (c Claims) Get(k string) (interface{}, bool) {
|
||||
v, ok := c[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Set issued at ("iat") to specified time in the claims
|
||||
func (c Claims) SetIssuedAt(tm time.Time) Claims {
|
||||
c["iat"] = tm.UTC().Unix()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set issued at ("iat") to present time in the claims
|
||||
func (c Claims) SetIssuedNow() Claims {
|
||||
c["iat"] = EpochNow()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims and return itself so it can be chained
|
||||
func (c Claims) SetExpiry(tm time.Time) Claims {
|
||||
c["exp"] = tm.UTC().Unix()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set expiry ("exp") in the claims to some duration from the present time
|
||||
// and return itself so it can be chained
|
||||
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
|
||||
c["exp"] = ExpireIn(tm)
|
||||
return c
|
||||
}
|
||||
|
||||
// Helper function that returns the NumericDate time value used by the spec
|
||||
// EpochNow is a helper function that returns the NumericDate time value used by the spec
|
||||
func EpochNow() int64 {
|
||||
return time.Now().UTC().Unix()
|
||||
}
|
||||
|
||||
// Helper function to return calculated time in the future for "exp" claim.
|
||||
// ExpireIn is a helper function to return calculated time in the future for "exp" claim
|
||||
func ExpireIn(tm time.Duration) int64 {
|
||||
return EpochNow() + int64(tm.Seconds())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ func TestSimpleRSA(t *testing.T) {
|
|||
|
||||
TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey)
|
||||
|
||||
claims := jwtauth.Claims{
|
||||
claims := jwt.MapClaims{
|
||||
"key": "val",
|
||||
"key2": "val2",
|
||||
"key3": "val3",
|
||||
|
|
@ -87,7 +87,7 @@ func TestSimpleRSA(t *testing.T) {
|
|||
t.Fatalf("Failed to decode token string %s\n", err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(claims, jwtauth.Claims(token.Claims.(jwt.MapClaims))) {
|
||||
if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) {
|
||||
t.Fatalf("The decoded claims don't match the original ones\n")
|
||||
}
|
||||
}
|
||||
|
|
@ -220,7 +220,7 @@ func TestMore(t *testing.T) {
|
|||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
|
||||
h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000})
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
|
@ -230,7 +230,7 @@ func TestMore(t *testing.T) {
|
|||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader((jwtauth.Claims{"user_id": 31337}).SetExpiryIn(5 * time.Minute))
|
||||
h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
|
@ -269,7 +269,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
|
|||
return resp.StatusCode, string(respBody)
|
||||
}
|
||||
|
||||
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
|
||||
func newJwtToken(secret []byte, claims ...jwt.MapClaims) string {
|
||||
token := jwt.New(jwt.GetSigningMethod("HS256"))
|
||||
if len(claims) > 0 {
|
||||
token.Claims = claims[0]
|
||||
|
|
@ -281,7 +281,7 @@ func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
|
|||
return tokenStr
|
||||
}
|
||||
|
||||
func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string {
|
||||
func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string {
|
||||
// use-case: when token is signed with a different alg than expected
|
||||
token := jwt.New(jwt.GetSigningMethod("HS512"))
|
||||
if len(claims) > 0 {
|
||||
|
|
@ -294,7 +294,7 @@ func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string {
|
|||
return tokenStr
|
||||
}
|
||||
|
||||
func newAuthHeader(claims ...jwtauth.Claims) http.Header {
|
||||
func newAuthHeader(claims ...jwt.MapClaims) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
||||
return h
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue