From 137d669fb7769f7c7b470957d4d8447e3e5989db Mon Sep 17 00:00:00 2001 From: kvii <56432636+kvii@users.noreply.github.com> Date: Wed, 9 Feb 2022 19:21:49 +0800 Subject: [PATCH] remove unnecessary code. (#69) --- jwtauth.go | 8 +++----- jwtauth_test.go | 10 ++++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/jwtauth.go b/jwtauth.go index 38104fe..e6534e8 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -61,9 +61,7 @@ func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth { // which checks the request context jwt token and error to prepare a custom // http response. func Verifier(ja *JWTAuth) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return Verify(ja, TokenFromHeader, TokenFromCookie)(next) - } + return Verify(ja, TokenFromHeader, TokenFromCookie) } func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { @@ -165,12 +163,12 @@ func Authenticator(next http.Handler) http.Handler { token, _, err := FromContext(r.Context()) if err != nil { - http.Error(w, err.Error(), 401) + http.Error(w, err.Error(), http.StatusUnauthorized) return } if token == nil || jwt.Validate(token) != nil { - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } diff --git a/jwtauth_test.go b/jwtauth_test.go index c8cd84c..dd2bc57 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -153,12 +153,12 @@ func TestMore(t *testing.T) { token, _, err := jwtauth.FromContext(r.Context()) if err != nil { - http.Error(w, jwtauth.ErrorReason(err).Error(), 401) + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) return } if err := jwt.Validate(token); err != nil { - http.Error(w, jwtauth.ErrorReason(err).Error(), 401) + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) return } @@ -242,10 +242,8 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header return 0, "" } - if header != nil { - for k, v := range header { - req.Header.Set(k, v[0]) - } + for k, v := range header { + req.Header.Set(k, v[0]) } resp, err := http.DefaultClient.Do(req)