2015-10-29 13:44:09 -04:00
|
|
|
package jwtauth
|
|
|
|
|
|
|
|
|
|
import (
|
2016-07-11 17:51:29 -04:00
|
|
|
"context"
|
2016-01-21 14:45:08 -05:00
|
|
|
"encoding/json"
|
2015-10-29 13:44:09 -04:00
|
|
|
"errors"
|
|
|
|
|
"net/http"
|
|
|
|
|
"strings"
|
2016-01-19 17:43:58 -05:00
|
|
|
"time"
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-05 17:05:27 -04:00
|
|
|
"github.com/dgrijalva/jwt-go"
|
2015-10-29 13:44:09 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
2017-07-05 17:05:27 -04:00
|
|
|
TokenCtxKey = &contextKey{"Token"}
|
|
|
|
|
ErrorCtxKey = &contextKey{"Error"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
ErrUnauthorized = errors.New("jwtauth: token is unauthorized")
|
|
|
|
|
ErrExpired = errors.New("jwtauth: token is expired")
|
2015-10-29 13:44:09 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type JwtAuth struct {
|
|
|
|
|
signKey []byte
|
|
|
|
|
verifyKey []byte
|
|
|
|
|
signer jwt.SigningMethod
|
2016-01-19 12:05:48 -05:00
|
|
|
parser *jwt.Parser
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:00:36 -05:00
|
|
|
// New creates a JwtAuth authenticator instance that provides middleware handlers
|
|
|
|
|
// and encoding/decoding functions for JWT signing.
|
2015-10-29 13:44:09 -04:00
|
|
|
func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth {
|
2017-07-05 17:05:27 -04:00
|
|
|
return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey)
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:00:36 -05:00
|
|
|
// NewWithParser is the same as New, except it supports custom parser settings
|
2017-07-05 17:05:27 -04:00
|
|
|
// introduced in jwt-go/v2.4.0.
|
|
|
|
|
//
|
|
|
|
|
// We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that
|
|
|
|
|
// we can control when the claims are validated - in our case, by the Verifier
|
|
|
|
|
// http middleware handler.
|
2016-01-19 12:05:48 -05:00
|
|
|
func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth {
|
2017-07-05 17:05:27 -04:00
|
|
|
parser.SkipClaimsValidation = true
|
2016-01-19 12:05:48 -05:00
|
|
|
return &JwtAuth{
|
|
|
|
|
signKey: signKey,
|
|
|
|
|
verifyKey: verifyKey,
|
|
|
|
|
signer: jwt.GetSigningMethod(alg),
|
|
|
|
|
parser: parser,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-05 17:05:27 -04:00
|
|
|
// Verifier http middleware handler will verify a JWT string from a http request.
|
|
|
|
|
//
|
|
|
|
|
// Verifier will search for a JWT token in a http request, in the order:
|
|
|
|
|
// 1. 'jwt' URI query parameter
|
|
|
|
|
// 2. 'Authorization: BEARER T' request header
|
|
|
|
|
// 3. Cookie 'jwt' value
|
2016-01-21 14:00:36 -05:00
|
|
|
//
|
2017-07-05 17:05:27 -04:00
|
|
|
// The first JWT string that is found as a query parameter, authorization header
|
|
|
|
|
// or cookie header is then decoded by the `jwt-go` library and a *jwt.Token
|
|
|
|
|
// object is set on the request context. In the case of a signature decoding error
|
|
|
|
|
// the Verifier will also set the error on the request context.
|
2016-01-21 14:00:36 -05:00
|
|
|
//
|
2017-07-05 17:05:27 -04:00
|
|
|
// The Verifier always calls the next http handler in sequence, which can either
|
|
|
|
|
// be the generic `jwtauth.Authenticator` middleware or your own custom handler
|
|
|
|
|
// which checks the request context jwt token and error to prepare a custom
|
|
|
|
|
// http response.
|
2017-07-06 18:09:09 -04:00
|
|
|
func Verifier(ja *JwtAuth) func(http.Handler) http.Handler {
|
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
|
return Verify(ja, "")(next)
|
|
|
|
|
}
|
2016-01-21 14:00:36 -05:00
|
|
|
}
|
|
|
|
|
|
2017-07-06 18:09:09 -04:00
|
|
|
func Verify(ja *JwtAuth, paramAliases ...string) func(http.Handler) http.Handler {
|
2016-07-11 17:51:29 -04:00
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
|
hfn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
ctx := r.Context()
|
2017-07-12 21:25:34 -04:00
|
|
|
token, err := VerifyRequest(ja, r, paramAliases...)
|
|
|
|
|
ctx = NewContext(ctx, token, err)
|
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
|
}
|
|
|
|
|
return http.HandlerFunc(hfn)
|
|
|
|
|
}
|
|
|
|
|
}
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
func VerifyRequest(ja *JwtAuth, r *http.Request, paramAliases ...string) (*jwt.Token, error) {
|
|
|
|
|
var tokenStr string
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
// Get token from query params
|
|
|
|
|
tokenStr = r.URL.Query().Get("jwt")
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
// Get token from other param aliases
|
|
|
|
|
if tokenStr == "" && paramAliases != nil && len(paramAliases) > 0 {
|
|
|
|
|
for _, p := range paramAliases {
|
|
|
|
|
tokenStr = r.URL.Query().Get(p)
|
|
|
|
|
if tokenStr != "" {
|
|
|
|
|
break
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
2017-07-12 21:25:34 -04:00
|
|
|
}
|
|
|
|
|
}
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
// Get token from authorization header
|
|
|
|
|
if tokenStr == "" {
|
|
|
|
|
bearer := r.Header.Get("Authorization")
|
|
|
|
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
|
|
|
|
tokenStr = bearer[7:]
|
|
|
|
|
}
|
|
|
|
|
}
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
// Get token from cookie
|
|
|
|
|
if tokenStr == "" {
|
|
|
|
|
// TODO: paramAliases should apply to cookies too..
|
|
|
|
|
cookie, err := r.Cookie("jwt")
|
|
|
|
|
if err == nil {
|
|
|
|
|
tokenStr = cookie.Value
|
|
|
|
|
}
|
|
|
|
|
}
|
2016-01-21 14:00:36 -05:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
// TODO: what other kinds of validations should we do / error messages?
|
2015-10-29 13:44:09 -04:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
// Verify the token
|
|
|
|
|
token, err := ja.Decode(tokenStr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
switch err.Error() {
|
|
|
|
|
case "token is expired":
|
|
|
|
|
err = ErrExpired
|
|
|
|
|
}
|
|
|
|
|
return token, err
|
|
|
|
|
}
|
2016-01-19 17:43:58 -05:00
|
|
|
|
2017-07-12 21:25:34 -04:00
|
|
|
if token == nil || !token.Valid || token.Method != ja.signer {
|
|
|
|
|
err = ErrUnauthorized
|
|
|
|
|
return token, err
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
2017-07-12 21:25:34 -04:00
|
|
|
|
|
|
|
|
// Check expiry via "exp" claim
|
|
|
|
|
if IsExpired(token) {
|
|
|
|
|
err = ErrExpired
|
|
|
|
|
return token, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Valid!
|
|
|
|
|
return token, nil
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
|
|
|
|
|
2017-07-06 16:49:11 -04:00
|
|
|
func (ja *JwtAuth) Encode(claims Claims) (t *jwt.Token, tokenString string, err error) {
|
|
|
|
|
t = jwt.New(ja.signer)
|
|
|
|
|
t.Claims = claims
|
|
|
|
|
tokenString, err = t.SignedString(ja.signKey)
|
|
|
|
|
t.Raw = tokenString
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) {
|
|
|
|
|
// Decode the tokenString, but avoid using custom Claims via jwt-go's
|
|
|
|
|
// ParseWithClaims as the jwt-go types will cause some glitches, so easier
|
|
|
|
|
// to decode as MapClaims then wrap the underlying map[string]interface{}
|
|
|
|
|
// to our Claims type
|
|
|
|
|
t, err = ja.parser.Parse(tokenString, ja.keyFunc)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) {
|
|
|
|
|
if ja.verifyKey != nil && len(ja.verifyKey) > 0 {
|
|
|
|
|
return ja.verifyKey, nil
|
|
|
|
|
} else {
|
|
|
|
|
return ja.signKey, nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-05 17:05:27 -04:00
|
|
|
// Authenticator is a default authentication middleware to enforce access from the
|
|
|
|
|
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
|
|
|
|
|
// response for any unverified tokens and passes the good ones through. It's just fine
|
|
|
|
|
// until you decide to write something similar and customize your client response.
|
|
|
|
|
func Authenticator(next http.Handler) http.Handler {
|
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
2017-07-06 15:23:11 -04:00
|
|
|
token, _, err := FromContext(r.Context())
|
2017-07-05 17:05:27 -04:00
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
http.Error(w, http.StatusText(401), 401)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if token == nil || !token.Valid {
|
|
|
|
|
http.Error(w, http.StatusText(401), 401)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Token is authenticated, pass it through
|
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-06 16:49:11 -04:00
|
|
|
func NewContext(ctx context.Context, t *jwt.Token, err error) context.Context {
|
|
|
|
|
ctx = context.WithValue(ctx, TokenCtxKey, t)
|
|
|
|
|
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
|
|
|
|
return ctx
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-06 15:23:11 -04:00
|
|
|
func FromContext(ctx context.Context) (*jwt.Token, Claims, error) {
|
2017-07-05 17:05:27 -04:00
|
|
|
token, _ := ctx.Value(TokenCtxKey).(*jwt.Token)
|
|
|
|
|
|
|
|
|
|
var claims Claims
|
|
|
|
|
if token != nil {
|
2017-07-10 21:30:13 -04:00
|
|
|
tokenClaims, ok := token.Claims.(jwt.MapClaims)
|
|
|
|
|
if !ok {
|
|
|
|
|
panic("jwtauth: expecting jwt.MapClaims")
|
|
|
|
|
}
|
|
|
|
|
claims = Claims(tokenClaims)
|
2017-07-05 17:05:27 -04:00
|
|
|
} else {
|
|
|
|
|
claims = Claims{}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err, _ := ctx.Value(ErrorCtxKey).(error)
|
|
|
|
|
|
|
|
|
|
return token, claims, err
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-06 16:49:11 -04:00
|
|
|
func IsExpired(t *jwt.Token) bool {
|
2017-07-10 21:30:13 -04:00
|
|
|
claims, ok := t.Claims.(jwt.MapClaims)
|
|
|
|
|
if !ok {
|
|
|
|
|
panic("jwtauth: expecting jwt.MapClaims")
|
|
|
|
|
}
|
2017-07-05 17:05:27 -04:00
|
|
|
|
|
|
|
|
if expv, ok := claims["exp"]; ok {
|
2016-01-21 14:00:36 -05:00
|
|
|
var exp int64
|
|
|
|
|
switch v := expv.(type) {
|
|
|
|
|
case float64:
|
|
|
|
|
exp = int64(v)
|
|
|
|
|
case int64:
|
|
|
|
|
exp = v
|
2016-01-21 14:45:08 -05:00
|
|
|
case json.Number:
|
|
|
|
|
exp, _ = v.Int64()
|
2016-01-21 14:00:36 -05:00
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:07:01 -05:00
|
|
|
if exp < EpochNow() {
|
2016-01-21 14:00:36 -05:00
|
|
|
return true
|
|
|
|
|
}
|
2016-01-19 12:05:48 -05:00
|
|
|
}
|
2016-01-21 14:00:36 -05:00
|
|
|
|
|
|
|
|
return false
|
2015-10-29 13:44:09 -04:00
|
|
|
}
|
2016-01-19 17:43:58 -05:00
|
|
|
|
2016-01-21 14:00:36 -05:00
|
|
|
// Claims is a convenience type to manage a JWT claims hash.
|
|
|
|
|
type Claims map[string]interface{}
|
|
|
|
|
|
2017-07-05 17:05:27 -04:00
|
|
|
// NOTE: as of v3.0 of jwt-go, Valid() interface method is called to verify
|
|
|
|
|
// the claims. However, the current design we test these claims in the
|
|
|
|
|
// Verifier middleware, so we skip this step.
|
|
|
|
|
func (c Claims) Valid() error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:00:36 -05:00
|
|
|
func (c Claims) Set(k string, v interface{}) Claims {
|
|
|
|
|
c[k] = v
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c Claims) Get(k string) (interface{}, bool) {
|
|
|
|
|
v, ok := c[k]
|
|
|
|
|
return v, ok
|
|
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:08:11 -05:00
|
|
|
// Set issued at ("iat") to specified time in the claims
|
2016-01-21 14:00:36 -05:00
|
|
|
func (c Claims) SetIssuedAt(tm time.Time) Claims {
|
|
|
|
|
c["iat"] = tm.UTC().Unix()
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
2016-01-21 14:08:11 -05:00
|
|
|
// Set issued at ("iat") to present time in the claims
|
2016-01-21 14:00:36 -05:00
|
|
|
func (c Claims) SetIssuedNow() Claims {
|
|
|
|
|
c["iat"] = EpochNow()
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Set expiry ("exp") in the claims and return itself so it can be chained
|
2016-01-21 15:49:07 -05:00
|
|
|
func (c Claims) SetExpiry(tm time.Time) Claims {
|
|
|
|
|
c["exp"] = tm.UTC().Unix()
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Set expiry ("exp") in the claims to some duration from the present time
|
|
|
|
|
// and return itself so it can be chained
|
2016-01-21 14:00:36 -05:00
|
|
|
func (c Claims) SetExpiryIn(tm time.Duration) Claims {
|
|
|
|
|
c["exp"] = ExpireIn(tm)
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Helper function that returns the NumericDate time value used by the spec
|
2016-01-19 17:43:58 -05:00
|
|
|
func EpochNow() int64 {
|
|
|
|
|
return time.Now().UTC().Unix()
|
|
|
|
|
}
|
2016-01-21 14:00:36 -05:00
|
|
|
|
|
|
|
|
// Helper function to return calculated time in the future for "exp" claim.
|
|
|
|
|
func ExpireIn(tm time.Duration) int64 {
|
|
|
|
|
return EpochNow() + int64(tm.Seconds())
|
|
|
|
|
}
|
2017-07-05 17:05:27 -04:00
|
|
|
|
|
|
|
|
// contextKey is a value for use with context.WithValue. It's used as
|
|
|
|
|
// a pointer so it fits in an interface{} without allocation. This technique
|
|
|
|
|
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
|
|
|
|
type contextKey struct {
|
|
|
|
|
name string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (k *contextKey) String() string {
|
|
|
|
|
return "jwtauth context value " + k.name
|
|
|
|
|
}
|