mirror of
https://forgejo.merr.is/annika/jwtauth.git
synced 2025-12-11 17:07:44 -05:00
201 lines
4.6 KiB
Go
201 lines
4.6 KiB
Go
package jwtauth_test
|
|
|
|
import (
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
"github.com/goware/jwtauth"
|
|
"github.com/pressly/chi"
|
|
"golang.org/x/net/context"
|
|
)
|
|
|
|
var (
|
|
TokenAuth *jwtauth.JwtAuth
|
|
TokenSecret = []byte("secretpass")
|
|
)
|
|
|
|
func init() {
|
|
TokenAuth = jwtauth.New("HS256", []byte("secretpass"), nil)
|
|
}
|
|
|
|
//
|
|
// Tests
|
|
//
|
|
|
|
func TestSimple(t *testing.T) {
|
|
r := chi.NewRouter()
|
|
|
|
r.Use(TokenAuth.Verifier)
|
|
r.Use(jwtauth.Authenticator)
|
|
|
|
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("welcome"))
|
|
})
|
|
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// sending unauthorized requests
|
|
if resp := testRequest(t, ts, "GET", "/", nil, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
|
if resp := testRequest(t, ts, "GET", "/", h, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
h.Set("Authorization", "BEARER asdf")
|
|
if resp := testRequest(t, ts, "GET", "/", h, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// sending authorized requests
|
|
if resp := testRequest(t, ts, "GET", "/", newAuthHeader(), nil); resp != "welcome" {
|
|
t.Fatalf(resp)
|
|
}
|
|
}
|
|
|
|
func TestMore(t *testing.T) {
|
|
r := chi.NewRouter()
|
|
|
|
// Protected routes
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(TokenAuth.Verifier)
|
|
|
|
authenticator := func(next chi.Handler) chi.Handler {
|
|
return chi.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
if jwtErr, ok := ctx.Value("jwt.err").(error); ok {
|
|
switch jwtErr {
|
|
default:
|
|
log.Println("...we're expired... I think..:", jwtErr)
|
|
http.Error(w, http.StatusText(401), 401)
|
|
return
|
|
case jwtauth.ErrExpired:
|
|
http.Error(w, "expired", 401)
|
|
return
|
|
case jwtauth.ErrUnauthorized:
|
|
http.Error(w, http.StatusText(401), 401)
|
|
return
|
|
case nil:
|
|
// no error
|
|
}
|
|
}
|
|
|
|
jwtToken, ok := ctx.Value("jwt").(*jwt.Token)
|
|
if !ok || jwtToken == nil || !jwtToken.Valid {
|
|
log.Println("jwt token..........?", jwtToken)
|
|
http.Error(w, http.StatusText(401), 401)
|
|
return
|
|
}
|
|
|
|
// Token is authenticated, pass it through
|
|
next.ServeHTTPC(ctx, w, r)
|
|
})
|
|
}
|
|
r.Use(authenticator)
|
|
|
|
r.Get("/admin", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("protected"))
|
|
})
|
|
})
|
|
|
|
// Public routes
|
|
r.Group(func(r chi.Router) {
|
|
r.Get("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("welcome"))
|
|
})
|
|
})
|
|
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// sending unauthorized requests
|
|
if resp := testRequest(t, ts, "GET", "/admin", nil, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{}))
|
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
h.Set("Authorization", "BEARER asdf")
|
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "Unauthorized\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h = newAuthHeader((jwtauth.Claims{}).Set("exp", jwtauth.EpochNow()-1000))
|
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "expired\n" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
// sending authorized requests
|
|
if resp := testRequest(t, ts, "GET", "/", nil, nil); resp != "welcome" {
|
|
t.Fatalf(resp)
|
|
}
|
|
|
|
h = newAuthHeader((jwtauth.Claims{}).SetExpiryIn(5 * time.Minute))
|
|
if resp := testRequest(t, ts, "GET", "/admin", h, nil); resp != "protected" {
|
|
t.Fatalf(resp)
|
|
}
|
|
}
|
|
|
|
//
|
|
// Test helper functions
|
|
//
|
|
|
|
func testRequest(t *testing.T, ts *httptest.Server, method, path string, header http.Header, body io.Reader) string {
|
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return ""
|
|
}
|
|
|
|
if header != nil {
|
|
for k, v := range header {
|
|
req.Header.Set(k, v[0])
|
|
}
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return ""
|
|
}
|
|
|
|
respBody, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return ""
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return string(respBody)
|
|
}
|
|
|
|
func newJwtToken(secret []byte, claims ...jwtauth.Claims) string {
|
|
token := jwt.New(jwt.GetSigningMethod("HS256"))
|
|
if len(claims) > 0 {
|
|
for k, v := range claims[0] {
|
|
token.Claims[k] = v
|
|
}
|
|
}
|
|
tokenStr, err := token.SignedString(secret)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return tokenStr
|
|
}
|
|
|
|
func newAuthHeader(claims ...jwtauth.Claims) http.Header {
|
|
h := http.Header{}
|
|
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
|
|
return h
|
|
}
|