mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-10 12:25:27 -05:00
347 lines
9.7 KiB
Go
347 lines
9.7 KiB
Go
package jwtauth_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/jwtauth/v5"
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
)
|
|
|
|
var (
|
|
TokenAuthHS256 *jwtauth.JWTAuth
|
|
TokenSecret = []byte("secretpass")
|
|
|
|
TokenAuthRS256 *jwtauth.JWTAuth
|
|
|
|
PrivateKeyRS256String = `-----BEGIN RSA PRIVATE KEY-----
|
|
MIIBOwIBAAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYwDLxxa5/7QyH6y77nCRQy
|
|
J3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQJARjFLHtuj2zmPrwcBcjja
|
|
IS0Q3LKV8pA0LoCS+CdD+4QwCxeKFq0yEMZtMvcQOfqo9x9oAywFClMSlLRyl7ng
|
|
gQIhAOyerGbcdQxxwjwGpLS61Mprf4n2HzjwISg20cEEH1tfAiEAy9dXmgQpDPir
|
|
C6Q9QdLXpNgSB+o5CDqfor7TTyTCovsCIQDNCfpu795luDYN+dvD2JoIBfrwu9v2
|
|
ZO72f/pm/YGGlQIgUdRXyW9kH13wJFNBeBwxD27iBiVj0cbe8NFUONBUBmMCIQCN
|
|
jVK4eujt1lm/m60TlEhaWBC3p+3aPT2TqFPUigJ3RQ==
|
|
-----END RSA PRIVATE KEY-----
|
|
`
|
|
|
|
PublicKeyRS256String = `-----BEGIN PUBLIC KEY-----
|
|
MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
|
|
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
|
|
-----END PUBLIC KEY-----
|
|
`
|
|
)
|
|
|
|
func init() {
|
|
TokenAuthHS256 = jwtauth.New(jwa.HS256.String(), TokenSecret, nil, jwt.WithAcceptableSkew(30*time.Second))
|
|
}
|
|
|
|
//
|
|
// Tests
|
|
//
|
|
|
|
func TestSimple(t *testing.T) {
|
|
r := chi.NewRouter()
|
|
|
|
r.Use(
|
|
jwtauth.Verifier(TokenAuthHS256),
|
|
jwtauth.Authenticator(TokenAuthHS256),
|
|
)
|
|
|
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("welcome"))
|
|
})
|
|
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// sending unauthorized requests
|
|
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "no token found\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
h.Set("Authorization", "BEARER asdf")
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
// wrong token secret and wrong alg
|
|
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
// correct token secret but wrong alg
|
|
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// correct token, but has expired within the skew time
|
|
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, map[string]interface{}{"exp": time.Now().Unix() - 29}))
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 200 || resp != "welcome" {
|
|
fmt.Println("status", status, "resp", resp)
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// correct token, but has expired outside of the skew time
|
|
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, map[string]interface{}{"exp": time.Now().Unix() - 31}))
|
|
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is expired\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// sending authorized requests
|
|
if status, resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); status != 200 || resp != "welcome" {
|
|
t.Fatalf(resp)
|
|
}
|
|
}
|
|
|
|
func TestSimpleRSA(t *testing.T) {
|
|
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
|
|
|
|
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
publicKeyBlock, _ := pem.Decode([]byte(PublicKeyRS256String))
|
|
|
|
publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
TokenAuthRS256 = jwtauth.New(jwa.RS256.String(), privateKey, publicKey)
|
|
|
|
claims := map[string]interface{}{
|
|
"key": "val",
|
|
"key2": "val2",
|
|
"key3": "val3",
|
|
}
|
|
|
|
_, tokenString, err := TokenAuthRS256.Encode(claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to encode claims %s\n", err.Error())
|
|
}
|
|
|
|
token, err := TokenAuthRS256.Decode(tokenString)
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode token string %s\n", err.Error())
|
|
}
|
|
|
|
tokenClaims, err := token.AsMap(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err.Error())
|
|
}
|
|
|
|
if !reflect.DeepEqual(claims, tokenClaims) {
|
|
t.Fatalf("The decoded claims don't match the original ones\n")
|
|
}
|
|
}
|
|
|
|
func TestSimpleRSAVerifyOnly(t *testing.T) {
|
|
tokenString := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkiOiJ2YWwiLCJrZXkyIjoidmFsMiIsImtleTMiOiJ2YWwzIn0.kLEK3FZZPsAlQNKR5yHyjRyrlCJFhvKmrh7o-GqDT_zaGQgvb0Dufp8uNSMeOFAlLGK5FbKX7BckjJqfvEyrTQ"
|
|
claims := map[string]interface{}{
|
|
"key": "val",
|
|
"key2": "val2",
|
|
"key3": "val3",
|
|
}
|
|
|
|
publicKeyBlock, _ := pem.Decode([]byte(PublicKeyRS256String))
|
|
publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes)
|
|
if err != nil {
|
|
t.Fatalf(err.Error())
|
|
}
|
|
|
|
TokenAuthRS256 = jwtauth.New(jwa.RS256.String(), nil, publicKey)
|
|
|
|
_, _, err = TokenAuthRS256.Encode(claims)
|
|
if err == nil {
|
|
t.Fatalf("Expecting error when encoding claims without signing key")
|
|
}
|
|
|
|
token, err := TokenAuthRS256.Decode(tokenString)
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode token string %s\n", err.Error())
|
|
}
|
|
|
|
tokenClaims, err := token.AsMap(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err.Error())
|
|
}
|
|
|
|
if !reflect.DeepEqual(claims, tokenClaims) {
|
|
t.Fatalf("The decoded claims don't match the original ones\n")
|
|
}
|
|
}
|
|
|
|
func TestMore(t *testing.T) {
|
|
r := chi.NewRouter()
|
|
|
|
// Protected routes
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(jwtauth.Verifier(TokenAuthHS256))
|
|
|
|
authenticator := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
token, _, err := jwtauth.FromContext(r.Context())
|
|
|
|
if err != nil {
|
|
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if err := jwt.Validate(token); err != nil {
|
|
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Token is authenticated, pass it through
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
r.Use(authenticator)
|
|
|
|
r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
|
|
_, claims, err := jwtauth.FromContext(r.Context())
|
|
|
|
if err != nil {
|
|
w.Write([]byte(fmt.Sprintf("error! %v", err)))
|
|
return
|
|
}
|
|
|
|
w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
|
|
})
|
|
})
|
|
|
|
// Public routes
|
|
r.Group(func(r chi.Router) {
|
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("welcome"))
|
|
})
|
|
})
|
|
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// sending unauthorized requests
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
h.Set("Authorization", "BEARER asdf")
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
// wrong token secret and wrong alg
|
|
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
// correct token secret but wrong alg
|
|
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h = newAuthHeader(map[string]interface{}{"exp": jwtauth.EpochNow() - 1000})
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is expired\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// sending authorized requests
|
|
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 || resp != "welcome" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h = newAuthHeader((map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
|
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
|
|
t.Fatalf(resp)
|
|
}
|
|
}
|
|
|
|
//
|
|
// Test helper functions
|
|
//
|
|
|
|
func testRequest(t *testing.T, ts *httptest.Server, method, path string, header http.Header, body io.Reader) (int, string) {
|
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return 0, ""
|
|
}
|
|
|
|
for k, v := range header {
|
|
req.Header.Set(k, v[0])
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return 0, ""
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return 0, ""
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp.StatusCode, string(respBody)
|
|
}
|
|
|
|
func newJwtToken(secret []byte, claims ...map[string]interface{}) string {
|
|
token := jwt.New()
|
|
if len(claims) > 0 {
|
|
for k, v := range claims[0] {
|
|
token.Set(k, v)
|
|
}
|
|
}
|
|
|
|
tokenPayload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, secret))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return string(tokenPayload)
|
|
}
|
|
|
|
func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
|
|
// use-case: when token is signed with a different alg than expected
|
|
token := jwt.New()
|
|
if len(claims) > 0 {
|
|
for k, v := range claims[0] {
|
|
token.Set(k, v)
|
|
}
|
|
}
|
|
tokenPayload, err := jwt.Sign(token, jwt.WithKey(jwa.HS512, secret))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return string(tokenPayload)
|
|
}
|
|
|
|
func newAuthHeader(claims ...map[string]interface{}) http.Header {
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
|
return h
|
|
}
|