mirror of
https://forgejo.merr.is/annika/isl-api.git
synced 2025-12-13 11:28:58 -05:00
Added JWT Auth
Wrote my own JWT auth middleware, since I could not get the go-chi middleware to accept a JWKS instead of a certificate.
This commit is contained in:
parent
ac18b94a86
commit
b5ea01729b
12 changed files with 336 additions and 132 deletions
224
middlewares/jwtAuth.go
Normal file
224
middlewares/jwtAuth.go
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forgejo.merr.is/annika/isl-api/helpers"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
type JWTAuth struct {
|
||||
jwksUri string
|
||||
jwksCache *jwk.Cache
|
||||
jwksContext context.Context
|
||||
jwkKeySet jwk.Set
|
||||
verifier jwt.ParseOption
|
||||
validateOptions []jwt.ValidateOption
|
||||
}
|
||||
|
||||
type ContextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Errors!
|
||||
var (
|
||||
ErrUnauthorized = errors.New("token is unauthorized")
|
||||
ErrExpired = errors.New("token is expired")
|
||||
ErrNBFInvalid = errors.New("token nbf validation failed")
|
||||
ErrIATInvalid = errors.New("token iat validation failed")
|
||||
ErrNoTokenFound = errors.New("no token found")
|
||||
ErrAlgoInvalid = errors.New("algorithm mismatch")
|
||||
)
|
||||
|
||||
func ErrorReason(err error) error {
|
||||
switch {
|
||||
case errors.Is(err, jwt.ErrTokenExpired()), err == ErrExpired:
|
||||
return ErrExpired
|
||||
case errors.Is(err, jwt.ErrInvalidIssuedAt()), err == ErrIATInvalid:
|
||||
return ErrIATInvalid
|
||||
case errors.Is(err, jwt.ErrTokenNotYetValid()), err == ErrNBFInvalid:
|
||||
return ErrNBFInvalid
|
||||
default:
|
||||
return ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
var TokenContextKey = &ContextKey{"Token"}
|
||||
var ErrorContextKey = &ContextKey{"Error"}
|
||||
|
||||
func New(jwksUri string, ctx context.Context) (*JWTAuth, error) {
|
||||
jwtAuth := &JWTAuth{
|
||||
jwksContext: ctx,
|
||||
jwksUri: jwksUri,
|
||||
}
|
||||
|
||||
if jwtAuth.jwksUri != "" {
|
||||
jwtAuth.jwksCache = jwk.NewCache(jwtAuth.jwksContext)
|
||||
jwtAuth.jwksCache.Register(jwtAuth.jwksUri, jwk.WithRefreshInterval(15*time.Minute))
|
||||
var err error
|
||||
jwtAuth.jwkKeySet, err = jwtAuth.jwksCache.Refresh(jwtAuth.jwksContext, jwtAuth.jwksUri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwtAuth.verifier = jwt.WithKeySet(jwtAuth.jwkKeySet)
|
||||
}
|
||||
|
||||
return jwtAuth, nil
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Verifier() func(http.Handler) http.Handler {
|
||||
return ja.Verify(TokenFromHeader, TokenFromCookie)
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Verify(findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// Refresh the JWKS keyset
|
||||
var err error
|
||||
ja.jwkKeySet, err = ja.jwksCache.Get(ctx, ja.jwksUri)
|
||||
ja.verifier = jwt.WithKeySet(ja.jwkKeySet)
|
||||
if err != nil {
|
||||
ctx = context.WithValue(ctx, ErrorContextKey, err)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
// Now we do stuff with it
|
||||
token, err := ja.VerifyRequest(r, findTokenFns...)
|
||||
ctx = context.WithValue(ctx, TokenContextKey, token)
|
||||
ctx = context.WithValue(ctx, ErrorContextKey, err)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(handlerFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) VerifyRequest(r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
|
||||
var tokenString string
|
||||
|
||||
for _, fn := range findTokenFns {
|
||||
tokenString = fn(r)
|
||||
if tokenString != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenString == "" {
|
||||
return nil, ErrNoTokenFound
|
||||
}
|
||||
|
||||
return ja.VerifyToken(tokenString)
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) VerifyToken(tokenString string) (jwt.Token, error) {
|
||||
token, err := ja.Decode(tokenString)
|
||||
if err != nil {
|
||||
return token, err
|
||||
}
|
||||
if token == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
if err := jwt.Validate(token, ja.validateOptions...); err != nil {
|
||||
return token, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
|
||||
return ja.parse([]byte(tokenString))
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
|
||||
return jwt.Parse(payload, ja.verifier, jwt.WithValidate(false))
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) Authenticator() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
||||
token, _, err := FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if token == nil || jwt.Validate(token, ja.validateOptions...) != nil {
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Token is authenticated, pass it through
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(handlerFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func (ja *JWTAuth) AuthorizeRoles(roles []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
handlerFunc := func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Context().Value(TokenContextKey).(jwt.Token)
|
||||
hasAllRoles := true
|
||||
privateClaims := token.PrivateClaims()
|
||||
for _, role := range roles {
|
||||
hasRole := helpers.JwtHasClaim(privateClaims, role)
|
||||
if !hasRole {
|
||||
hasAllRoles = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasAllRoles {
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(handlerFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
|
||||
token, _ := ctx.Value(TokenContextKey).(jwt.Token)
|
||||
|
||||
var err error
|
||||
var claims map[string]interface{}
|
||||
|
||||
if token != nil {
|
||||
claims, err = token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return token, nil, err
|
||||
}
|
||||
} else {
|
||||
claims = map[string]interface{}{}
|
||||
}
|
||||
|
||||
err, _ = ctx.Value(ErrorContextKey).(error)
|
||||
return token, claims, err
|
||||
}
|
||||
|
||||
func TokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("jwt")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
func TokenFromHeader(r *http.Request) string {
|
||||
// Get token from authorization header.
|
||||
bearer := r.Header.Get("Authorization")
|
||||
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
||||
return bearer[7:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TokenFromQuery(r *http.Request) string {
|
||||
// Get token from query param named "jwt".
|
||||
return r.URL.Query().Get("jwt")
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue