added more comprehensive unit test cases
This commit is contained in:
@@ -345,3 +345,368 @@ func TestJWTAuth_ValidToken(t *testing.T) {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional comprehensive test cases
|
||||
|
||||
func TestExtractBearerToken_EdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
header string
|
||||
wantToken string
|
||||
wantOk bool
|
||||
}{
|
||||
{"Multiple spaces", "Bearer token123", " token123", true}, // Extracts everything after "Bearer "
|
||||
{"No space after Bearer", "Bearertoken123", "", false},
|
||||
{"Lowercase bearer", "bearer token123", "", false},
|
||||
{"Mixed case", "BeArEr token123", "", false},
|
||||
{"Extra whitespace", " Bearer token123", "", false}, // Must start with "Bearer "
|
||||
{"Token with spaces", "Bearer token with spaces", "token with spaces", true},
|
||||
{"Very long token", "Bearer " + string(make([]byte, 5000)), string(make([]byte, 5000)), true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotToken, gotOk := extractBearerToken(tc.header)
|
||||
if gotToken != tc.wantToken || gotOk != tc.wantOk {
|
||||
t.Errorf("extractBearerToken(%q) = (%q, %v), want (%q, %v)", tc.header, gotToken, gotOk, tc.wantToken, tc.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAndValidateToken_MalformedTokens(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"Empty string", ""},
|
||||
{"Random string", "not.a.jwt.token"},
|
||||
{"Only dots", "..."},
|
||||
{"Two parts only", "header.payload"},
|
||||
{"Four parts", "part1.part2.part3.part4"},
|
||||
{"Invalid base64", "!@#$.!@#$.!@#$"},
|
||||
{"Spaces in token", "part1 .part2 .part3"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := parseAndValidateToken(tc.token)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for malformed token %q", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContext_WithDifferentRoles(t *testing.T) {
|
||||
roles := []string{"admin", "user", "guest", "superadmin", "", "role-with-dash"}
|
||||
|
||||
for _, role := range roles {
|
||||
t.Run("Role: "+role, func(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: role,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
newReq := buildContext(req.Context(), claims)
|
||||
reqWithCtx := req.WithContext(newReq)
|
||||
|
||||
retrievedClaims, ok := GetClaims(reqWithCtx)
|
||||
if !ok {
|
||||
t.Error("Claims not found in context")
|
||||
}
|
||||
if retrievedClaims.Role != role {
|
||||
t.Errorf("Role = %q, want %q", retrievedClaims.Role, role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClaims_WithoutClaims(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
claims, ok := GetClaims(req)
|
||||
if ok {
|
||||
t.Error("Expected ok=false when claims not in context")
|
||||
}
|
||||
if claims != nil {
|
||||
t.Error("Expected nil claims when not in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClaims_WithWrongType(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), models.ContextKey("claims"), "wrong type")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
claims, ok := GetClaims(req)
|
||||
if ok {
|
||||
t.Error("Expected ok=false when claims are wrong type")
|
||||
}
|
||||
if claims != nil {
|
||||
t.Error("Expected nil claims when wrong type in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserID_WithNoClaims(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID, ok := GetUserID(req)
|
||||
if ok {
|
||||
t.Error("Expected ok=false when no claims")
|
||||
}
|
||||
if userID != "" {
|
||||
t.Errorf("Expected empty string, got %q", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsername_WithNoClaims(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
username, ok := GetUsername(req)
|
||||
if ok {
|
||||
t.Error("Expected ok=false when no claims")
|
||||
}
|
||||
if username != "" {
|
||||
t.Errorf("Expected empty string, got %q", username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRole_WithNoClaims(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
role, ok := GetRole(req)
|
||||
if ok {
|
||||
t.Error("Expected ok=false when no claims")
|
||||
}
|
||||
if role != "" {
|
||||
t.Errorf("Expected empty string, got %q", role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_MissingBearerPrefix(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "InvalidToken")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_ExpiredToken(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
// Create token that's already expired
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("test-secret"))
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d for expired token, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_TokenWithMissingClaims(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
claims *models.Claims
|
||||
}{
|
||||
{
|
||||
"Missing UserID",
|
||||
&models.Claims{
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Missing Username",
|
||||
&models.Claims{
|
||||
UserID: "user123",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"Missing Role",
|
||||
&models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, tc.claims)
|
||||
tokenString, _ := token.SignedString([]byte("test-secret"))
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := GetClaims(r)
|
||||
if !ok {
|
||||
t.Error("Claims should still be in context even if some fields are empty")
|
||||
}
|
||||
// Verify the missing field is empty
|
||||
_ = claims
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
// Token is valid, just missing some claim fields
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_ConcurrentRequests(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("test-secret"))
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := GetClaims(r); !ok {
|
||||
t.Error("Claims not found in concurrent request")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
concurrency := 50
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d in concurrent request, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_TokenSignedWithWrongKey(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "correct-secret")
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
// Create token with wrong key
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("wrong-secret"))
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d for wrong signature, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user