Merge pull request #4 from goware/exp

Refactor+expiry
This commit is contained in:
Peter Kieltyka 2016-01-21 14:20:12 -05:00
commit 35183c04d2
4 changed files with 354 additions and 16 deletions

20
LICENSE Normal file
View file

@ -0,0 +1,20 @@
Copyright (c) 2015-2016 Peter Kieltyka (https://twitter.com/peterk)
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,2 +1 @@
# jwtauth # jwtauth

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/pressly/chi" "github.com/pressly/chi"
@ -11,7 +12,8 @@ import (
) )
var ( var (
errUnauthorized = errors.New("unauthorized token") ErrUnauthorized = errors.New("jwtauth: unauthorized token")
ErrExpired = errors.New("jwtauth: expired token")
) )
type JwtAuth struct { type JwtAuth struct {
@ -21,7 +23,8 @@ type JwtAuth struct {
parser *jwt.Parser parser *jwt.Parser
} }
// verifyKey is only for RSA // New creates a JwtAuth authenticator instance that provides middleware handlers
// and encoding/decoding functions for JWT signing.
func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth { func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
return &JwtAuth{ return &JwtAuth{
signKey: signKey, signKey: signKey,
@ -30,7 +33,8 @@ func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
} }
} }
// the same as New, except it supports custom parser settings introduced in ver. 2.4.0 of jwt-go // NewWithParser is the same as New, except it supports custom parser settings
// introduced in ver. 2.4.0 of jwt-go
func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth { func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth {
return &JwtAuth{ return &JwtAuth{
signKey: signKey, signKey: signKey,
@ -40,7 +44,25 @@ func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []b
} }
} }
func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler { // Verifier middleware will verify a JWT passed by a client request.
// The Verifier will look for a JWT token from:
// 1. 'jwt' URI query parameter
// 2. 'Authorization: BEARER T' request header
// 3. Cookie 'jwt' value
//
// The verification processes finishes here and sets the token and
// a error in the request context and calls the next handler.
//
// Make sure to have your own handler following the Validator that
// will check the value of the "jwt" and "jwt.err" in the context
// and respond to the client accordingly. A generic Authenticator
// middleware is provided by this package, that will return a 401
// message for all unverified tokens, see jwtauth.Authenticator.
func (ja *JwtAuth) Verifier(next chi.Handler) chi.Handler {
return ja.Verify("")(next)
}
func (ja *JwtAuth) Verify(paramAliases ...string) func(chi.Handler) chi.Handler {
return func(next chi.Handler) chi.Handler { return func(next chi.Handler) chi.Handler {
hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) { hfn := func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
@ -78,30 +100,45 @@ func (ja *JwtAuth) Handle(paramAliases ...string) func(chi.Handler) chi.Handler
// Token is required, cya // Token is required, cya
if tokenStr == "" { if tokenStr == "" {
err = errUnauthorized err = ErrUnauthorized
} }
// Verify the token // Verify the token
token, err := ja.Decode(tokenStr) token, err := ja.Decode(tokenStr)
if err != nil || !token.Valid || token.Method != ja.signer { if err != nil || !token.Valid || token.Method != ja.signer {
http.Error(w, errUnauthorized.Error(), 401) switch err.Error() {
case "token is expired":
err = ErrExpired
}
ctx = ja.SetContext(ctx, token, err)
next.ServeHTTPC(ctx, w, r)
return return
} }
ctx = context.WithValue(ctx, "jwt", token.Raw) // Check expiry via "exp" claim
ctx = context.WithValue(ctx, "jwt.token", token) if ja.IsExpired(token) {
err = ErrExpired
ctx = ja.SetContext(ctx, token, err)
next.ServeHTTPC(ctx, w, r)
return
}
// Valid! pass it down the context to an authenticator middleware
ctx = ja.SetContext(ctx, token, err)
next.ServeHTTPC(ctx, w, r) next.ServeHTTPC(ctx, w, r)
} }
return chi.HandlerFunc(hfn) return chi.HandlerFunc(hfn)
} }
} }
func (ja *JwtAuth) Handler(next chi.Handler) chi.Handler { func (ja *JwtAuth) SetContext(ctx context.Context, t *jwt.Token, err error) context.Context {
return ja.Handle("")(next) ctx = context.WithValue(ctx, "jwt", t)
ctx = context.WithValue(ctx, "jwt.err", err)
return ctx
} }
func (ja *JwtAuth) Encode(claims map[string]interface{}) (t *jwt.Token, tokenString string, err error) { func (ja *JwtAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) {
t = jwt.New(ja.signer) t = jwt.New(ja.signer)
t.Claims = claims t.Claims = claims
tokenString, err = t.SignedString(ja.signKey) tokenString, err = t.SignedString(ja.signKey)
@ -109,6 +146,13 @@ func (ja *JwtAuth) Encode(claims map[string]interface{}) (t *jwt.Token, tokenStr
return return
} }
func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) {
if ja.parser != nil {
return ja.parser.Parse(tokenString, ja.keyFunc)
}
return jwt.Parse(tokenString, ja.keyFunc)
}
func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) { func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
if ja.verifyKey != nil && len(ja.verifyKey) > 0 { if ja.verifyKey != nil && len(ja.verifyKey) > 0 {
return ja.verifyKey, nil return ja.verifyKey, nil
@ -117,9 +161,86 @@ func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
} }
} }
func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) { func (ja *JwtAuth) IsExpired(t *jwt.Token) bool {
if ja.parser != nil { if expv, ok := t.Claims["exp"]; ok {
return ja.parser.Parse(tokenString, ja.keyFunc) var exp int64
switch v := expv.(type) {
case float64:
exp = int64(v)
case int64:
exp = v
default:
}
if exp < EpochNow() {
return true
}
} }
return jwt.Parse(tokenString, ja.keyFunc)
return false
}
// Authenticator is a default authentication middleware to enforce access following
// the Verifier middleware. The Authenticator sends a 401 Unauthorized response for
// all unverified tokens and passes the good ones through. It's just fine until you
// decide to write something similar and customize your client response.
func Authenticator(next chi.Handler) chi.Handler {
return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
if jwtErr != nil {
http.Error(w, http.StatusText(401), 401)
return
}
}
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
if !ok || jwtToken == nil || !jwtToken.Valid {
http.Error(w, http.StatusText(401), 401)
return
}
// Token is authenticated, pass it through
next.ServeHTTPC(ctx, w, r)
})
}
// Claims is a convenience type to manage a JWT claims hash.
type Claims map[string]interface{}
func (c Claims) Set(k string, v interface{}) Claims {
c[k] = v
return c
}
func (c Claims) Get(k string) (interface{}, bool) {
v, ok := c[k]
return v, ok
}
// Set issued at ("iat") to specified time in the claims
func (c Claims) SetIssuedAt(tm time.Time) Claims {
c["iat"] = tm.UTC().Unix()
return c
}
// Set issued at ("iat") to present time in the claims
func (c Claims) SetIssuedNow() Claims {
c["iat"] = EpochNow()
return c
}
// Set expiry ("exp") in the claims and return itself so it can be chained
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
c["exp"] = ExpireIn(tm)
return c
}
// Helper function that returns the NumericDate time value used by the spec
func EpochNow() int64 {
return time.Now().UTC().Unix()
}
// Helper function to return calculated time in the future for "exp" claim.
func ExpireIn(tm time.Duration) int64 {
return EpochNow() + int64(tm.Seconds())
} }

198
jwtauth_test.go Normal file
View file

@ -0,0 +1,198 @@
package jwtauth_test
import (
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/goware/jwtauth"
"github.com/pressly/chi"
"golang.org/x/net/context"
)
var (
TokenAuth *jwtauth.JwtAuth
TokenSecret = []byte("secretpass")
)
func init() {
TokenAuth = jwtauth.New("HS256", TokenSecret, nil)
}
//
// Tests
//
func TestSimple(t *testing.T) {
r := chi.NewRouter()
r.Use(TokenAuth.Verifier, jwtauth.Authenticator)
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
ts := httptest.NewServer(r)
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// sending authorized requests
if status, resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); status != 200 && resp != "welcome" {
t.Fatalf(resp)
}
}
func TestMore(t *testing.T) {
r := chi.NewRouter()
// Protected routes
r.Group(func(r chi.Router) {
r.Use(TokenAuth.Verifier)
authenticator := func(next chi.Handler) chi.Handler {
return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
switch jwtErr {
default:
http.Error(w, http.StatusText(401), 401)
return
case jwtauth.ErrExpired:
http.Error(w, "expired", 401)
return
case jwtauth.ErrUnauthorized:
http.Error(w, http.StatusText(401), 401)
return
case nil:
// no error
}
}
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
if !ok || jwtToken == nil || !jwtToken.Valid {
http.Error(w, http.StatusText(401), 401)
return
}
// Token is authenticated, pass it through
next.ServeHTTPC(ctx, w, r)
})
}
r.Use(authenticator)
r.Get("/admin", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
w.Write([]byte("protected"))
})
})
// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
})
ts := httptest.NewServer(r)
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "expired\n" {
t.Fatalf(resp)
}
// sending authorized requests
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 && resp != "welcome" {
t.Fatalf(resp)
}
h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 && resp != "protected" {
t.Fatalf(resp)
}
}
//
// Test helper functions
//
func testRequest(t *testing.T, ts *httptest.Server, method, path string, header http.Header, body io.Reader) (int, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return 0, ""
}
if header != nil {
for k, v := range header {
req.Header.Set(k, v[0])
}
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return 0, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return 0, ""
}
defer resp.Body.Close()
return resp.StatusCode, string(respBody)
}
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
token := jwt.New(jwt.GetSigningMethod("HS256"))
if len(claims) > 0 {
for k, v := range claims[0] {
token.Claims[k] = v
}
}
tokenStr, err := token.SignedString(secret)
if err != nil {
log.Fatal(err)
}
return tokenStr
}
func newAuthHeader(claims ...jwtauth.Claims) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
return h
}