Refactor jwtauth pkg, move to go-chi org, and support dgrijalva/jwt-go v3

This commit is contained in:
Peter Kieltyka 2017-07-05 17:05:27 -04:00
parent d5f87ca5c3
commit 4c654d77d5
20 changed files with 180 additions and 1350 deletions

View file

@ -1,6 +1,7 @@
package jwtauth_test
import (
"fmt"
"io"
"io/ioutil"
"log"
@ -10,8 +11,8 @@ import (
"time"
"github.com/dgrijalva/jwt-go"
"github.com/goware/jwtauth"
"github.com/pressly/chi"
"github.com/go-chi/chi"
"github.com/go-chi/jwtauth"
)
var (
@ -40,32 +41,32 @@ func TestSimple(t *testing.T) {
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// wrong token secret and wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// correct token secret but wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// sending authorized requests
if status, resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); status != 200 && resp != "welcome" {
if status, resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); status != 200 || resp != "welcome" {
t.Fatalf(resp)
}
}
@ -79,10 +80,10 @@ func TestMore(t *testing.T) {
authenticator := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, _, err := jwtauth.TokenContext(r.Context())
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
switch jwtErr {
if err != nil {
switch err {
default:
http.Error(w, http.StatusText(401), 401)
return
@ -97,8 +98,7 @@ func TestMore(t *testing.T) {
}
}
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
if !ok || jwtToken == nil || !jwtToken.Valid {
if token == nil || !token.Valid {
http.Error(w, http.StatusText(401), 401)
return
}
@ -110,7 +110,14 @@ func TestMore(t *testing.T) {
r.Use(authenticator)
r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("protected"))
_, claims, err := jwtauth.TokenContext(r.Context())
if err != nil {
w.Write([]byte(fmt.Sprintf("error! %v", err)))
return
}
w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
})
})
@ -125,42 +132,42 @@ func TestMore(t *testing.T) {
defer ts.Close()
// sending unauthorized requests
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", nil, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h.Set("Authorization", "BEARER asdf")
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// wrong token secret and wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
// correct token secret but wrong alg
h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "Unauthorized\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "Unauthorized\n" {
t.Fatalf(resp)
}
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 && resp != "expired\n" {
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "expired\n" {
t.Fatalf(resp)
}
// sending authorized requests
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 && resp != "welcome" {
if status, resp := testRequest(t, ts, "GET", "/", nil, nil); status != 200 || resp != "welcome" {
t.Fatalf(resp)
}
h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 && resp != "protected" {
h = newAuthHeader((jwtauth.Claims{"user_id": 31337}).SetExpiryIn(5 * time.Minute))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
}
@ -201,9 +208,7 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
token := jwt.New(jwt.GetSigningMethod("HS256"))
if len(claims) > 0 {
for k, v := range claims[0] {
token.Claims[k] = v
}
token.Claims = claims[0]
}
tokenStr, err := token.SignedString(secret)
if err != nil {
@ -216,9 +221,7 @@ func newJwt512Token(secret []byte, claims ...jwtauth.Claims) string {
// use-case: when token is signed with a different alg than expected
token := jwt.New(jwt.GetSigningMethod("HS512"))
if len(claims) > 0 {
for k, v := range claims[0] {
token.Claims[k] = v
}
token.Claims = claims[0]
}
tokenStr, err := token.SignedString(secret)
if err != nil {