diff --git a/.gitignore b/.gitignore index 251645c..498781e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ vendor/ Gopkg.lock +.idea/ \ No newline at end of file diff --git a/jwtauth.go b/jwtauth.go index d271554..a9d1b2f 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -22,15 +22,15 @@ var ( ) type JwtAuth struct { - signKey []byte - verifyKey []byte + signKey interface{} + verifyKey interface{} signer jwt.SigningMethod parser *jwt.Parser } // New creates a JwtAuth authenticator instance that provides middleware handlers // and encoding/decoding functions for JWT signing. -func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth { +func New(alg string, signKey interface{}, verifyKey interface{}) *JwtAuth { return NewWithParser(alg, &jwt.Parser{}, signKey, verifyKey) } @@ -40,7 +40,7 @@ func New(alg string, signKey []byte, verifyKey []byte) *JwtAuth { // We explicitly toggle `SkipClaimsValidation` in the `jwt-go` parser so that // we can control when the claims are validated - in our case, by the Verifier // http middleware handler. -func NewWithParser(alg string, parser *jwt.Parser, signKey []byte, verifyKey []byte) *JwtAuth { +func NewWithParser(alg string, parser *jwt.Parser, signKey interface{}, verifyKey interface{}) *JwtAuth { parser.SkipClaimsValidation = true return &JwtAuth{ signKey: signKey, @@ -166,7 +166,7 @@ func (ja *JwtAuth) Decode(tokenString string) (t *jwt.Token, err error) { } func (ja *JwtAuth) keyFunc(t *jwt.Token) (interface{}, error) { - if ja.verifyKey != nil && len(ja.verifyKey) > 0 { + if ja.verifyKey != nil { return ja.verifyKey, nil } else { return ja.signKey, nil diff --git a/jwtauth_test.go b/jwtauth_test.go index c6a489b..094d722 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -1,12 +1,15 @@ package jwtauth_test import ( + "crypto/x509" + "encoding/pem" "fmt" "io" "io/ioutil" "log" "net/http" "net/http/httptest" + "reflect" "testing" "time" @@ -16,22 +19,83 @@ import ( ) var ( - TokenAuth *jwtauth.JwtAuth - TokenSecret = []byte("secretpass") + TokenAuthHS256 *jwtauth.JwtAuth + TokenSecret = []byte("secretpass") + + TokenAuthRS256 *jwtauth.JwtAuth + + PrivateKeyRS256String = `-----BEGIN RSA PRIVATE KEY----- +MIIBOwIBAAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYwDLxxa5/7QyH6y77nCRQy +J3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQJARjFLHtuj2zmPrwcBcjja +IS0Q3LKV8pA0LoCS+CdD+4QwCxeKFq0yEMZtMvcQOfqo9x9oAywFClMSlLRyl7ng +gQIhAOyerGbcdQxxwjwGpLS61Mprf4n2HzjwISg20cEEH1tfAiEAy9dXmgQpDPir +C6Q9QdLXpNgSB+o5CDqfor7TTyTCovsCIQDNCfpu795luDYN+dvD2JoIBfrwu9v2 +ZO72f/pm/YGGlQIgUdRXyW9kH13wJFNBeBwxD27iBiVj0cbe8NFUONBUBmMCIQCN +jVK4eujt1lm/m60TlEhaWBC3p+3aPT2TqFPUigJ3RQ== +-----END RSA PRIVATE KEY----- +` + + PublicKeyRS256String = `-----BEGIN PUBLIC KEY----- +MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw +DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ== +-----END PUBLIC KEY----- +` ) func init() { - TokenAuth = jwtauth.New("HS256", TokenSecret, nil) + TokenAuthHS256 = jwtauth.New("HS256", TokenSecret, nil) } // // Tests // +func TestSimpleRSA(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + + if err != nil { + t.Fatalf(err.Error()) + } + + publicKeyBlock, _ := pem.Decode([]byte(PublicKeyRS256String)) + + publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes) + + if err != nil { + t.Fatalf(err.Error()) + } + + TokenAuthRS256 = jwtauth.New("RS256", privateKey, publicKey) + + claims := jwtauth.Claims{ + "key": "val", + "key2": "val2", + "key3": "val3", + } + + _, tokenString, err := TokenAuthRS256.Encode(claims) + + if err != nil { + t.Fatalf("Failed to encode claims %s\n", err.Error()) + } + + token, err := TokenAuthRS256.Decode(tokenString) + + if err != nil { + t.Fatalf("Failed to decode token string %s\n", err.Error()) + } + + if !reflect.DeepEqual(claims, jwtauth.Claims(token.Claims.(jwt.MapClaims))) { + t.Fatalf("The decoded claims don't match the original ones\n") + } +} + func TestSimple(t *testing.T) { r := chi.NewRouter() - r.Use(jwtauth.Verifier(TokenAuth), jwtauth.Authenticator) + r.Use(jwtauth.Verifier(TokenAuthHS256), jwtauth.Authenticator) r.Get("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("welcome")) @@ -76,7 +140,7 @@ func TestMore(t *testing.T) { // Protected routes r.Group(func(r chi.Router) { - r.Use(jwtauth.Verifier(TokenAuth)) + r.Use(jwtauth.Verifier(TokenAuthHS256)) authenticator := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {