package middlewares import ( "context" "errors" "net/http" "strings" "time" "forgejo.merr.is/annika/isl-api/helpers" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" ) type JWTAuth struct { jwksUri string jwksCache *jwk.Cache jwksContext context.Context jwkKeySet jwk.Set verifier jwt.ParseOption validateOptions []jwt.ValidateOption } type ContextKey struct { name string } // Errors! var ( ErrUnauthorized = errors.New("token is unauthorized") ErrExpired = errors.New("token is expired") ErrNBFInvalid = errors.New("token nbf validation failed") ErrIATInvalid = errors.New("token iat validation failed") ErrNoTokenFound = errors.New("no token found") ErrAlgoInvalid = errors.New("algorithm mismatch") ) func ErrorReason(err error) error { switch { case errors.Is(err, jwt.ErrTokenExpired()), err == ErrExpired: return ErrExpired case errors.Is(err, jwt.ErrInvalidIssuedAt()), err == ErrIATInvalid: return ErrIATInvalid case errors.Is(err, jwt.ErrTokenNotYetValid()), err == ErrNBFInvalid: return ErrNBFInvalid default: return ErrUnauthorized } } var TokenContextKey = &ContextKey{"Token"} var ErrorContextKey = &ContextKey{"Error"} func New(jwksUri string, ctx context.Context) (*JWTAuth, error) { jwtAuth := &JWTAuth{ jwksContext: ctx, jwksUri: jwksUri, } if jwtAuth.jwksUri != "" { jwtAuth.jwksCache = jwk.NewCache(jwtAuth.jwksContext) jwtAuth.jwksCache.Register(jwtAuth.jwksUri, jwk.WithRefreshInterval(15*time.Minute)) var err error jwtAuth.jwkKeySet, err = jwtAuth.jwksCache.Refresh(jwtAuth.jwksContext, jwtAuth.jwksUri) if err != nil { return nil, err } jwtAuth.verifier = jwt.WithKeySet(jwtAuth.jwkKeySet) } return jwtAuth, nil } func (ja *JWTAuth) Verifier() func(http.Handler) http.Handler { return ja.Verify(TokenFromHeader, TokenFromCookie) } func (ja *JWTAuth) Verify(findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { handlerFunc := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Refresh the JWKS keyset var err error ja.jwkKeySet, err = ja.jwksCache.Get(ctx, ja.jwksUri) ja.verifier = jwt.WithKeySet(ja.jwkKeySet) if err != nil { ctx = context.WithValue(ctx, ErrorContextKey, err) next.ServeHTTP(w, r.WithContext(ctx)) return } // Now we do stuff with it token, err := ja.VerifyRequest(r, findTokenFns...) ctx = context.WithValue(ctx, TokenContextKey, token) ctx = context.WithValue(ctx, ErrorContextKey, err) next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(handlerFunc) } } func (ja *JWTAuth) VerifyRequest(r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) { var tokenString string for _, fn := range findTokenFns { tokenString = fn(r) if tokenString != "" { break } } if tokenString == "" { return nil, ErrNoTokenFound } return ja.VerifyToken(tokenString) } func (ja *JWTAuth) VerifyToken(tokenString string) (jwt.Token, error) { token, err := ja.Decode(tokenString) if err != nil { return token, err } if token == nil { return nil, ErrUnauthorized } if err := jwt.Validate(token, ja.validateOptions...); err != nil { return token, err } return token, nil } func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) { return ja.parse([]byte(tokenString)) } func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) { return jwt.Parse(payload, ja.verifier, jwt.WithValidate(false)) } func (ja *JWTAuth) Authenticator() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { handlerFunc := func(w http.ResponseWriter, r *http.Request) { token, _, err := FromContext(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } if token == nil || jwt.Validate(token, ja.validateOptions...) != nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } // Token is authenticated, pass it through next.ServeHTTP(w, r) } return http.HandlerFunc(handlerFunc) } } func (ja *JWTAuth) AuthorizeRoles(roles []string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { handlerFunc := func(w http.ResponseWriter, r *http.Request) { token := r.Context().Value(TokenContextKey).(jwt.Token) hasAllRoles := true privateClaims := token.PrivateClaims() for _, role := range roles { hasRole := helpers.JwtHasClaim(privateClaims, role) if !hasRole { hasAllRoles = false break } } if !hasAllRoles { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } next.ServeHTTP(w, r) } return http.HandlerFunc(handlerFunc) } } func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) { token, _ := ctx.Value(TokenContextKey).(jwt.Token) var err error var claims map[string]interface{} if token != nil { claims, err = token.AsMap(context.Background()) if err != nil { return token, nil, err } } else { claims = map[string]interface{}{} } err, _ = ctx.Value(ErrorContextKey).(error) return token, claims, err } func TokenFromCookie(r *http.Request) string { cookie, err := r.Cookie("jwt") if err != nil { return "" } return cookie.Value } func TokenFromHeader(r *http.Request) string { // Get token from authorization header. bearer := r.Header.Get("Authorization") if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { return bearer[7:] } return "" } func TokenFromQuery(r *http.Request) string { // Get token from query param named "jwt". return r.URL.Query().Get("jwt") }