mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 18:41:09 -05:00
Handler improvements, Claims type, expiry and tests
This commit is contained in:
parent
66c0c85e36
commit
a872c75843
2 changed files with 332 additions and 23 deletions
154
jwtauth.go
154
jwtauth.go
|
|
@ -12,7 +12,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errUnauthorized = errors.New("unauthorized token")
|
ErrUnauthorized = errors.New("jwtauth: unauthorized token")
|
||||||
|
ErrExpired = errors.New("jwtauth: expired token")
|
||||||
)
|
)
|
||||||
|
|
||||||
type JwtAuth struct {
|
type JwtAuth struct {
|
||||||
|
|
@ -22,7 +23,8 @@ type JwtAuth struct {
|
||||||
parser *jwt.Parser
|
parser *jwt.Parser
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyKey is only for RSA
|
// New creates a JwtAuth authenticator instance that provides middleware handlers
|
||||||
|
// and encoding/decoding functions for JWT signing.
|
||||||
func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
|
func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
|
||||||
return &JwtAuth{
|
return &JwtAuth{
|
||||||
signKey: signKey,
|
signKey: signKey,
|
||||||
|
|
@ -31,7 +33,8 @@ func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// the same as New, except it supports custom parser settings introduced in ver. 2.4.0 of jwt-go
|
// NewWithParser is the same as New, except it supports custom parser settings
|
||||||
|
// introduced in ver. 2.4.0 of jwt-go
|
||||||
func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth {
|
func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth {
|
||||||
return &JwtAuth{
|
return &JwtAuth{
|
||||||
signKey: signKey,
|
signKey: signKey,
|
||||||
|
|
@ -41,7 +44,25 @@ func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler {
|
// Verifier middleware will verify a JWT passed by a client request.
|
||||||
|
// The Verifier will look for a JWT token from:
|
||||||
|
// 1. 'jwt' URI query parameter
|
||||||
|
// 2. 'Authorization: BEARER T' request header
|
||||||
|
// 3. Cookie 'jwt' value
|
||||||
|
//
|
||||||
|
// The verification processes finishes here and sets the token and
|
||||||
|
// a error in the request context and calls the next handler.
|
||||||
|
//
|
||||||
|
// Make sure to have your own handler following the Validator that
|
||||||
|
// will check the value of the "jwt" and "jwt.err" in the context
|
||||||
|
// and respond to the client accordingly. A generic Authenticator
|
||||||
|
// middleware is provided by this package, that will return a 401
|
||||||
|
// message for all unverified tokens, see jwtauth.Authenticator.
|
||||||
|
func (ja *JwtAuth) Verifier(next chi.Handler) chi.Handler {
|
||||||
|
return ja.Verify("")(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ja *JwtAuth) Verify(paramAliases ...string) func(chi.Handler) chi.Handler {
|
||||||
return func(next chi.Handler) chi.Handler {
|
return func(next chi.Handler) chi.Handler {
|
||||||
hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
|
@ -79,39 +100,45 @@ func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler
|
||||||
|
|
||||||
// Token is required, cya
|
// Token is required, cya
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
err = errUnauthorized
|
err = ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the token
|
// Verify the token
|
||||||
token, err := ja.Decode(tokenStr)
|
token, err := ja.Decode(tokenStr)
|
||||||
if err != nil || !token.Valid || token.Method != ja.signer {
|
if err != nil || !token.Valid || token.Method != ja.signer {
|
||||||
http.Error(w, errUnauthorized.Error(), 401)
|
switch err.Error() {
|
||||||
|
case "token is expired":
|
||||||
|
err = ErrExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = ja.SetContext(ctx, token, err)
|
||||||
|
next.ServeHTTPC(ctx, w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check expiry via "exp" claim
|
// Check expiry via "exp" claim
|
||||||
if exp, ok := token.Claims["exp"].(int64); ok {
|
if ja.IsExpired(token) {
|
||||||
now := EpochNow()
|
err = ErrExpired
|
||||||
if exp < now {
|
ctx = ja.SetContext(ctx, token, err)
|
||||||
http.Error(w, errUnauthorized.Error(), 401)
|
next.ServeHTTPC(ctx, w, r)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, "jwt", token.Raw)
|
// Valid! pass it down the context to an authenticator middleware
|
||||||
ctx = context.WithValue(ctx, "jwt.token", token)
|
ctx = ja.SetContext(ctx, token, err)
|
||||||
|
|
||||||
next.ServeHTTPC(ctx, w, r)
|
next.ServeHTTPC(ctx, w, r)
|
||||||
}
|
}
|
||||||
return chi.HandlerFunc(hfn)
|
return chi.HandlerFunc(hfn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JwtAuth) Handler(next chi.Handler) chi.Handler {
|
func (ja *JwtAuth) SetContext(ctx context.Context, t *jwt.Token, err error) context.Context {
|
||||||
return ja.Handle("")(next)
|
ctx = context.WithValue(ctx, "jwt", t)
|
||||||
|
ctx = context.WithValue(ctx, "jwt.err", err)
|
||||||
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JwtAuth) Encode(claims map[string]interface{}) (t *jwt.Token, tokenString string, err error) {
|
func (ja *JwtAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) {
|
||||||
t = jwt.New(ja.signer)
|
t = jwt.New(ja.signer)
|
||||||
t.Claims = claims
|
t.Claims = claims
|
||||||
tokenString, err = t.SignedString(ja.signKey)
|
tokenString, err = t.SignedString(ja.signKey)
|
||||||
|
|
@ -119,6 +146,13 @@ func (ja *JwtAuth) Encode(claims map[string]interface{}) (t *jwt.Token, tokenStr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
||||||
|
if ja.parser != nil {
|
||||||
|
return ja.parser.Parse(tokenString, ja.keyFunc)
|
||||||
|
}
|
||||||
|
return jwt.Parse(tokenString, ja.keyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
||||||
if ja.verifyKey != nil && len(ja.verifyKey) > 0 {
|
if ja.verifyKey != nil && len(ja.verifyKey) > 0 {
|
||||||
return ja.verifyKey, nil
|
return ja.verifyKey, nil
|
||||||
|
|
@ -127,14 +161,88 @@ func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
func (ja *JwtAuth) IsExpired(t *jwt.Token) bool {
|
||||||
if ja.parser != nil {
|
if expv, ok := t.Claims["exp"]; ok {
|
||||||
return ja.parser.Parse(tokenString, ja.keyFunc)
|
var exp int64
|
||||||
|
now := EpochNow()
|
||||||
|
|
||||||
|
switch v := expv.(type) {
|
||||||
|
case float64:
|
||||||
|
exp = int64(v)
|
||||||
|
case int64:
|
||||||
|
exp = v
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp < now {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return jwt.Parse(tokenString, ja.keyFunc)
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the NumericDate time value used in conventional jwt claims
|
// Authenticator is a default authentication middleware to enforce access following
|
||||||
|
// the Verifier middleware. The Authenticator sends a 401 Unauthorized response for
|
||||||
|
// all unverified tokens and passes the good ones through. It's just fine until you
|
||||||
|
// decide to write something similar and customize your client response.
|
||||||
|
func Authenticator(next chi.Handler) chi.Handler {
|
||||||
|
return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
|
||||||
|
if jwtErr != nil {
|
||||||
|
http.Error(w, http.StatusText(401), 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
|
||||||
|
if !ok || jwtToken == nil || !jwtToken.Valid {
|
||||||
|
http.Error(w, http.StatusText(401), 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is authenticated, pass it through
|
||||||
|
next.ServeHTTPC(ctx, w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claims is a convenience type to manage a JWT claims hash.
|
||||||
|
type Claims map[string]interface{}
|
||||||
|
|
||||||
|
func (c Claims) Set(k string, v interface{}) Claims {
|
||||||
|
c[k] = v
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Claims) Get(k string) (interface{}, bool) {
|
||||||
|
v, ok := c[k]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set issues at ("iat") to specified time in the claims
|
||||||
|
func (c Claims) SetIssuedAt(tm time.Time) Claims {
|
||||||
|
c["iat"] = tm.UTC().Unix()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set issues at ("iat") to present time in the claims
|
||||||
|
func (c Claims) SetIssuedNow() Claims {
|
||||||
|
c["iat"] = EpochNow()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set expiry ("exp") in the claims and return itself so it can be chained
|
||||||
|
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
|
||||||
|
c["exp"] = ExpireIn(tm)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function that returns the NumericDate time value used by the spec
|
||||||
func EpochNow() int64 {
|
func EpochNow() int64 {
|
||||||
return time.Now().UTC().Unix()
|
return time.Now().UTC().Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to return calculated time in the future for "exp" claim.
|
||||||
|
func ExpireIn(tm time.Duration) int64 {
|
||||||
|
return EpochNow() + int64(tm.Seconds())
|
||||||
|
}
|
||||||
|
|
|
||||||
201
jwtauth_test.go
Normal file
201
jwtauth_test.go
Normal file
|
|
@ -0,0 +1,201 @@
|
||||||
|
package jwtauth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dgrijalva/jwt-go"
|
||||||
|
"github.com/goware/jwtauth"
|
||||||
|
"github.com/pressly/chi"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
TokenAuth *jwtauth.JwtAuth
|
||||||
|
TokenSecret = []byte("secretpass")
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
TokenAuth = jwtauth.New("HS256", []byte("secretpass"), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Tests
|
||||||
|
//
|
||||||
|
|
||||||
|
func TestSimple(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
r.Use(TokenAuth.Verifier)
|
||||||
|
r.Use(jwtauth.Authenticator)
|
||||||
|
|
||||||
|
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("welcome"))
|
||||||
|
})
|
||||||
|
|
||||||
|
ts := httptest.NewServer(r)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// sending unauthorized requests
|
||||||
|
if resp := testRequest(t, ts, "GET", "/", nil, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||||
|
if resp := testRequest(t, ts, "GET", "/", h, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
h.Set("Authorization", "BEARER asdf")
|
||||||
|
if resp := testRequest(t, ts, "GET", "/", h, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sending authorized requests
|
||||||
|
if resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); resp != "welcome" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMore(t *testing.T) {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
// Protected routes
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Use(TokenAuth.Verifier)
|
||||||
|
|
||||||
|
authenticator := func(next chi.Handler) chi.Handler {
|
||||||
|
return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
|
||||||
|
switch jwtErr {
|
||||||
|
default:
|
||||||
|
log.Println("...we're expired... I think..:", jwtErr)
|
||||||
|
http.Error(w, http.StatusText(401), 401)
|
||||||
|
return
|
||||||
|
case jwtauth.ErrExpired:
|
||||||
|
http.Error(w, "expired", 401)
|
||||||
|
return
|
||||||
|
case jwtauth.ErrUnauthorized:
|
||||||
|
http.Error(w, http.StatusText(401), 401)
|
||||||
|
return
|
||||||
|
case nil:
|
||||||
|
// no error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
|
||||||
|
if !ok || jwtToken == nil || !jwtToken.Valid {
|
||||||
|
log.Println("jwt token..........?", jwtToken)
|
||||||
|
http.Error(w, http.StatusText(401), 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is authenticated, pass it through
|
||||||
|
next.ServeHTTPC(ctx, w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
r.Use(authenticator)
|
||||||
|
|
||||||
|
r.Get("/admin", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("protected"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Public routes
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("welcome"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
ts := httptest.NewServer(r)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// sending unauthorized requests
|
||||||
|
if resp := testRequest(t, ts, "GET", "/admin", nil, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||||
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
h.Set("Authorization", "BEARER asdf")
|
||||||
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "Unauthorized\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
|
||||||
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "expired\n" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sending authorized requests
|
||||||
|
if resp := testRequest(t, ts, "GET", "/", nil, nil); resp != "welcome" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute))
|
||||||
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "protected" {
|
||||||
|
t.Fatalf(resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Test helper functions
|
||||||
|
//
|
||||||
|
|
||||||
|
func testRequest(t *testing.T, ts *httptest.Server, method, path string, header http.Header, body io.Reader) string {
|
||||||
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if header != nil {
|
||||||
|
for k, v := range header {
|
||||||
|
req.Header.Set(k, v[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return string(respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
|
||||||
|
token := jwt.New(jwt.GetSigningMethod("HS256"))
|
||||||
|
if len(claims) > 0 {
|
||||||
|
for k, v := range claims[0] {
|
||||||
|
token.Claims[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenStr, err := token.SignedString(secret)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return tokenStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthHeader(claims ...jwtauth.Claims) http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
||||||
|
return h
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue