isl-api/middlewares/jwtAuth.go

225 lines
5.7 KiB
Go
Raw Normal View History

package middlewares
import (
"context"
"errors"
"net/http"
"strings"
"time"
"forgejo.merr.is/annika/isl-api/helpers"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type JWTAuth struct {
jwksUri string
jwksCache *jwk.Cache
jwksContext context.Context
jwkKeySet jwk.Set
verifier jwt.ParseOption
validateOptions []jwt.ValidateOption
}
type ContextKey struct {
name string
}
// Errors!
var (
ErrUnauthorized = errors.New("token is unauthorized")
ErrExpired = errors.New("token is expired")
ErrNBFInvalid = errors.New("token nbf validation failed")
ErrIATInvalid = errors.New("token iat validation failed")
ErrNoTokenFound = errors.New("no token found")
ErrAlgoInvalid = errors.New("algorithm mismatch")
)
func ErrorReason(err error) error {
switch {
case errors.Is(err, jwt.ErrTokenExpired()), err == ErrExpired:
return ErrExpired
case errors.Is(err, jwt.ErrInvalidIssuedAt()), err == ErrIATInvalid:
return ErrIATInvalid
case errors.Is(err, jwt.ErrTokenNotYetValid()), err == ErrNBFInvalid:
return ErrNBFInvalid
default:
return ErrUnauthorized
}
}
var TokenContextKey = &ContextKey{"Token"}
var ErrorContextKey = &ContextKey{"Error"}
func New(jwksUri string, ctx context.Context) (*JWTAuth, error) {
jwtAuth := &JWTAuth{
jwksContext: ctx,
jwksUri: jwksUri,
}
if jwtAuth.jwksUri != "" {
jwtAuth.jwksCache = jwk.NewCache(jwtAuth.jwksContext)
jwtAuth.jwksCache.Register(jwtAuth.jwksUri, jwk.WithRefreshInterval(15*time.Minute))
var err error
jwtAuth.jwkKeySet, err = jwtAuth.jwksCache.Refresh(jwtAuth.jwksContext, jwtAuth.jwksUri)
if err != nil {
return nil, err
}
jwtAuth.verifier = jwt.WithKeySet(jwtAuth.jwkKeySet)
}
return jwtAuth, nil
}
func (ja *JWTAuth) Verifier() func(http.Handler) http.Handler {
return ja.Verify(TokenFromHeader, TokenFromCookie)
}
func (ja *JWTAuth) Verify(findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Refresh the JWKS keyset
var err error
ja.jwkKeySet, err = ja.jwksCache.Get(ctx, ja.jwksUri)
ja.verifier = jwt.WithKeySet(ja.jwkKeySet)
if err != nil {
ctx = context.WithValue(ctx, ErrorContextKey, err)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Now we do stuff with it
token, err := ja.VerifyRequest(r, findTokenFns...)
ctx = context.WithValue(ctx, TokenContextKey, token)
ctx = context.WithValue(ctx, ErrorContextKey, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(handlerFunc)
}
}
func (ja *JWTAuth) VerifyRequest(r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
var tokenString string
for _, fn := range findTokenFns {
tokenString = fn(r)
if tokenString != "" {
break
}
}
if tokenString == "" {
return nil, ErrNoTokenFound
}
return ja.VerifyToken(tokenString)
}
func (ja *JWTAuth) VerifyToken(tokenString string) (jwt.Token, error) {
token, err := ja.Decode(tokenString)
if err != nil {
return token, err
}
if token == nil {
return nil, ErrUnauthorized
}
if err := jwt.Validate(token, ja.validateOptions...); err != nil {
return token, err
}
return token, nil
}
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
return ja.parse([]byte(tokenString))
}
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
return jwt.Parse(payload, ja.verifier, jwt.WithValidate(false))
}
func (ja *JWTAuth) Authenticator() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
token, _, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if token == nil || jwt.Validate(token, ja.validateOptions...) != nil {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
// Token is authenticated, pass it through
next.ServeHTTP(w, r)
}
return http.HandlerFunc(handlerFunc)
}
}
func (ja *JWTAuth) AuthorizeRoles(roles []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
token := r.Context().Value(TokenContextKey).(jwt.Token)
hasAllRoles := true
privateClaims := token.PrivateClaims()
for _, role := range roles {
hasRole := helpers.JwtHasClaim(privateClaims, role)
if !hasRole {
hasAllRoles = false
break
}
}
if !hasAllRoles {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(handlerFunc)
}
}
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
token, _ := ctx.Value(TokenContextKey).(jwt.Token)
var err error
var claims map[string]interface{}
if token != nil {
claims, err = token.AsMap(context.Background())
if err != nil {
return token, nil, err
}
} else {
claims = map[string]interface{}{}
}
err, _ = ctx.Value(ErrorContextKey).(error)
return token, claims, err
}
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("jwt")
if err != nil {
return ""
}
return cookie.Value
}
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "jwt".
return r.URL.Query().Get("jwt")
}