mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 13:47:41 -05:00
Switch to github.com/lestrrat-go/jwx underlying jwt library (#52)
This commit is contained in:
parent
02fa0c511c
commit
b8af768272
8 changed files with 200 additions and 148 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package jwtauth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
|
@ -13,9 +14,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/jwtauth"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -69,7 +70,7 @@ func TestSimpleRSA(t *testing.T) {
|
|||
|
||||
TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
claims := map[string]interface{}{
|
||||
"key": "val",
|
||||
"key2": "val2",
|
||||
"key3": "val3",
|
||||
|
|
@ -87,7 +88,12 @@ func TestSimpleRSA(t *testing.T) {
|
|||
t.Fatalf("Failed to decode token string %s\n", err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(claims, jwt.MapClaims(token.Claims.(jwt.MapClaims))) {
|
||||
tokenClaims, err := token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(claims, tokenClaims) {
|
||||
t.Fatalf("The decoded claims don't match the original ones\n")
|
||||
}
|
||||
}
|
||||
|
|
@ -105,27 +111,27 @@ func TestSimple(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
// sending unauthorized requests
|
||||
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "no token found\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
h.Set("Authorization", "BEARER asdf")
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// wrong token secret and wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// correct token secret but wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
|
|
@ -147,23 +153,12 @@ func TestMore(t *testing.T) {
|
|||
token, _, err := jwtauth.FromContext(r.Context())
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
default:
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
case jwtauth.ErrExpired:
|
||||
http.Error(w, "expired", 401)
|
||||
return
|
||||
case jwtauth.ErrUnauthorized:
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
return
|
||||
case nil:
|
||||
// no error
|
||||
}
|
||||
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
if token == nil || !token.Valid {
|
||||
http.Error(w, http.StatusText(401), 401)
|
||||
if err := jwt.Validate(token); err != nil {
|
||||
http.Error(w, jwtauth.ErrorReason(err).Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -196,32 +191,32 @@ func TestMore(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
// sending unauthorized requests
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
h.Set("Authorization", "BEARER asdf")
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// wrong token secret and wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
// correct token secret but wrong alg
|
||||
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader(jwt.MapClaims{"exp": jwtauth.EpochNow() - 1000})
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
|
||||
h = newAuthHeader(map[string]interface{}{"exp": jwtauth.EpochNow() - 1000})
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is expired\n" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
|
|
@ -230,7 +225,7 @@ func TestMore(t *testing.T) {
|
|||
t.Fatalf(resp)
|
||||
}
|
||||
|
||||
h = newAuthHeader((jwt.MapClaims{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
||||
h = newAuthHeader((map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
|
||||
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
|
||||
t.Fatalf(resp)
|
||||
}
|
||||
|
|
@ -269,32 +264,36 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
|
|||
return resp.StatusCode, string(respBody)
|
||||
}
|
||||
|
||||
func newJwtToken(secret []byte, claims ...jwt.MapClaims) string {
|
||||
token := jwt.New(jwt.GetSigningMethod("HS256"))
|
||||
func newJwtToken(secret []byte, claims ...map[string]interface{}) string {
|
||||
token := jwt.New()
|
||||
if len(claims) > 0 {
|
||||
token.Claims = claims[0]
|
||||
for k, v := range claims[0] {
|
||||
token.Set(k, v)
|
||||
}
|
||||
}
|
||||
tokenStr, err := token.SignedString(secret)
|
||||
tokenPayload, err := jwt.Sign(token, "HS256", secret)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return tokenStr
|
||||
return string(tokenPayload)
|
||||
}
|
||||
|
||||
func newJwt512Token(secret []byte, claims ...jwt.MapClaims) string {
|
||||
func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
|
||||
// use-case: when token is signed with a different alg than expected
|
||||
token := jwt.New(jwt.GetSigningMethod("HS512"))
|
||||
token := jwt.New()
|
||||
if len(claims) > 0 {
|
||||
token.Claims = claims[0]
|
||||
for k, v := range claims[0] {
|
||||
token.Set(k, v)
|
||||
}
|
||||
}
|
||||
tokenStr, err := token.SignedString(secret)
|
||||
tokenPayload, err := jwt.Sign(token, "HS512", secret)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return tokenStr
|
||||
return string(tokenPayload)
|
||||
}
|
||||
|
||||
func newAuthHeader(claims ...jwt.MapClaims) http.Header {
|
||||
func newAuthHeader(claims ...map[string]interface{}) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
||||
return h
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue