mirror of
https://forgejo.merr.is/annika/isl-api.git
synced 2025-12-11 11:02:03 -05:00
Wrote my own JWT auth middleware, since I could not get the go-chi middleware to accept a JWKS instead of a certificate.
224 lines
5.7 KiB
Go
224 lines
5.7 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"forgejo.merr.is/annika/isl-api/helpers"
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
)
|
|
|
|
type JWTAuth struct {
|
|
jwksUri string
|
|
jwksCache *jwk.Cache
|
|
jwksContext context.Context
|
|
jwkKeySet jwk.Set
|
|
verifier jwt.ParseOption
|
|
validateOptions []jwt.ValidateOption
|
|
}
|
|
|
|
type ContextKey struct {
|
|
name string
|
|
}
|
|
|
|
// Errors!
|
|
var (
|
|
ErrUnauthorized = errors.New("token is unauthorized")
|
|
ErrExpired = errors.New("token is expired")
|
|
ErrNBFInvalid = errors.New("token nbf validation failed")
|
|
ErrIATInvalid = errors.New("token iat validation failed")
|
|
ErrNoTokenFound = errors.New("no token found")
|
|
ErrAlgoInvalid = errors.New("algorithm mismatch")
|
|
)
|
|
|
|
func ErrorReason(err error) error {
|
|
switch {
|
|
case errors.Is(err, jwt.ErrTokenExpired()), err == ErrExpired:
|
|
return ErrExpired
|
|
case errors.Is(err, jwt.ErrInvalidIssuedAt()), err == ErrIATInvalid:
|
|
return ErrIATInvalid
|
|
case errors.Is(err, jwt.ErrTokenNotYetValid()), err == ErrNBFInvalid:
|
|
return ErrNBFInvalid
|
|
default:
|
|
return ErrUnauthorized
|
|
}
|
|
}
|
|
|
|
var TokenContextKey = &ContextKey{"Token"}
|
|
var ErrorContextKey = &ContextKey{"Error"}
|
|
|
|
func New(jwksUri string, ctx context.Context) (*JWTAuth, error) {
|
|
jwtAuth := &JWTAuth{
|
|
jwksContext: ctx,
|
|
jwksUri: jwksUri,
|
|
}
|
|
|
|
if jwtAuth.jwksUri != "" {
|
|
jwtAuth.jwksCache = jwk.NewCache(jwtAuth.jwksContext)
|
|
jwtAuth.jwksCache.Register(jwtAuth.jwksUri, jwk.WithRefreshInterval(15*time.Minute))
|
|
var err error
|
|
jwtAuth.jwkKeySet, err = jwtAuth.jwksCache.Refresh(jwtAuth.jwksContext, jwtAuth.jwksUri)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
jwtAuth.verifier = jwt.WithKeySet(jwtAuth.jwkKeySet)
|
|
}
|
|
|
|
return jwtAuth, nil
|
|
}
|
|
|
|
func (ja *JWTAuth) Verifier() func(http.Handler) http.Handler {
|
|
return ja.Verify(TokenFromHeader, TokenFromCookie)
|
|
}
|
|
|
|
func (ja *JWTAuth) Verify(findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
// Refresh the JWKS keyset
|
|
var err error
|
|
ja.jwkKeySet, err = ja.jwksCache.Get(ctx, ja.jwksUri)
|
|
ja.verifier = jwt.WithKeySet(ja.jwkKeySet)
|
|
if err != nil {
|
|
ctx = context.WithValue(ctx, ErrorContextKey, err)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
// Now we do stuff with it
|
|
token, err := ja.VerifyRequest(r, findTokenFns...)
|
|
ctx = context.WithValue(ctx, TokenContextKey, token)
|
|
ctx = context.WithValue(ctx, ErrorContextKey, err)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
return http.HandlerFunc(handlerFunc)
|
|
}
|
|
}
|
|
|
|
func (ja *JWTAuth) VerifyRequest(r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
|
|
var tokenString string
|
|
|
|
for _, fn := range findTokenFns {
|
|
tokenString = fn(r)
|
|
if tokenString != "" {
|
|
break
|
|
}
|
|
}
|
|
if tokenString == "" {
|
|
return nil, ErrNoTokenFound
|
|
}
|
|
|
|
return ja.VerifyToken(tokenString)
|
|
}
|
|
|
|
func (ja *JWTAuth) VerifyToken(tokenString string) (jwt.Token, error) {
|
|
token, err := ja.Decode(tokenString)
|
|
if err != nil {
|
|
return token, err
|
|
}
|
|
if token == nil {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
if err := jwt.Validate(token, ja.validateOptions...); err != nil {
|
|
return token, err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
|
|
return ja.parse([]byte(tokenString))
|
|
}
|
|
|
|
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
|
|
return jwt.Parse(payload, ja.verifier, jwt.WithValidate(false))
|
|
}
|
|
|
|
func (ja *JWTAuth) Authenticator() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
|
token, _, err := FromContext(r.Context())
|
|
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if token == nil || jwt.Validate(token, ja.validateOptions...) != nil {
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Token is authenticated, pass it through
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
return http.HandlerFunc(handlerFunc)
|
|
}
|
|
}
|
|
|
|
func (ja *JWTAuth) AuthorizeRoles(roles []string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
|
token := r.Context().Value(TokenContextKey).(jwt.Token)
|
|
hasAllRoles := true
|
|
privateClaims := token.PrivateClaims()
|
|
for _, role := range roles {
|
|
hasRole := helpers.JwtHasClaim(privateClaims, role)
|
|
if !hasRole {
|
|
hasAllRoles = false
|
|
break
|
|
}
|
|
}
|
|
if !hasAllRoles {
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
return http.HandlerFunc(handlerFunc)
|
|
}
|
|
}
|
|
|
|
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
|
|
token, _ := ctx.Value(TokenContextKey).(jwt.Token)
|
|
|
|
var err error
|
|
var claims map[string]interface{}
|
|
|
|
if token != nil {
|
|
claims, err = token.AsMap(context.Background())
|
|
if err != nil {
|
|
return token, nil, err
|
|
}
|
|
} else {
|
|
claims = map[string]interface{}{}
|
|
}
|
|
|
|
err, _ = ctx.Value(ErrorContextKey).(error)
|
|
return token, claims, err
|
|
}
|
|
|
|
func TokenFromCookie(r *http.Request) string {
|
|
cookie, err := r.Cookie("jwt")
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return cookie.Value
|
|
}
|
|
|
|
func TokenFromHeader(r *http.Request) string {
|
|
// Get token from authorization header.
|
|
bearer := r.Header.Get("Authorization")
|
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
|
return bearer[7:]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func TokenFromQuery(r *http.Request) string {
|
|
// Get token from query param named "jwt".
|
|
return r.URL.Query().Get("jwt")
|
|
}
|