Files
Authorization/middleware/jwt_test.go
T
2025-12-16 10:57:26 +08:00

348 lines
8.1 KiB
Go

package middleware
import (
"authorization/models"
"context"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func TestGetJWTSecret(t *testing.T) {
// Save original and restore after test
originalSecret := os.Getenv("JWT_KEY")
defer func() {
if originalSecret != "" {
os.Setenv("JWT_KEY", originalSecret)
} else {
os.Unsetenv("JWT_KEY")
}
}()
t.Run("JWT_KEY not set", func(t *testing.T) {
// Reset state for testing
oldCached := jwtSecretCached
oldError := jwtSecretError
defer func() {
jwtSecretCached = oldCached
jwtSecretError = oldError
}()
os.Unsetenv("JWT_KEY")
jwtSecretOnce = sync.Once{}
jwtSecretError = nil
jwtSecretCached = nil
_, err := getJWTSecret()
if err == nil {
t.Error("Expected error when JWT_KEY is not set")
}
})
t.Run("JWT_KEY set", func(t *testing.T) {
// Reset state for testing
oldCached := jwtSecretCached
oldError := jwtSecretError
defer func() {
jwtSecretCached = oldCached
jwtSecretError = oldError
}()
os.Setenv("JWT_KEY", "test-secret-key")
jwtSecretOnce = sync.Once{}
jwtSecretError = nil
jwtSecretCached = nil
secret, err := getJWTSecret()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if string(secret) != "test-secret-key" {
t.Errorf("Expected 'test-secret-key', got '%s'", string(secret))
}
})
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
name string
authHeader string
wantToken string
wantOK bool
}{
{
name: "Valid Bearer token",
authHeader: "Bearer token123",
wantToken: "token123",
wantOK: true,
},
{
name: "Empty header",
authHeader: "",
wantToken: "",
wantOK: false,
},
{
name: "Too short",
authHeader: "Bearer",
wantToken: "",
wantOK: false,
},
{
name: "Wrong prefix",
authHeader: "Basic token123",
wantToken: "",
wantOK: false,
},
{
name: "Missing space",
authHeader: "Bearertoken123",
wantToken: "",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, ok := extractBearerToken(tt.authHeader)
if token != tt.wantToken {
t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token)
}
if ok != tt.wantOK {
t.Errorf("Expected ok %v, got %v", tt.wantOK, ok)
}
})
}
}
func TestParseAndValidateToken(t *testing.T) {
// Setup
os.Setenv("JWT_KEY", "test-secret-key")
jwtSecretOnce = sync.Once{}
jwtSecretError = nil
defer os.Unsetenv("JWT_KEY")
t.Run("Valid token", func(t *testing.T) {
// Create a valid token
claims := &models.Claims{
UserID: "user123",
Username: "testuser",
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte("test-secret-key"))
if err != nil {
t.Fatalf("Failed to create token: %v", err)
}
parsedClaims, err := parseAndValidateToken(tokenString)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if parsedClaims.UserID != "user123" {
t.Errorf("Expected UserID 'user123', got '%s'", parsedClaims.UserID)
}
})
t.Run("Invalid token", func(t *testing.T) {
_, err := parseAndValidateToken("invalid.token.string")
if err == nil {
t.Error("Expected error for invalid token")
}
})
t.Run("Expired token", func(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
Username: "testuser",
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte("test-secret-key"))
_, err := parseAndValidateToken(tokenString)
if err == nil {
t.Error("Expected error for expired token")
}
})
}
func TestBuildContext(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
Username: "testuser",
Role: "admin",
}
parent := context.Background()
ctx := buildContext(parent, claims)
// Check claims
if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UserID != "user123" {
t.Error("Claims not properly set in context")
}
// Check userID
if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" {
t.Error("UserID not properly set in context")
}
// Check username
if val, ok := ctx.Value(usernameKey).(string); !ok || val != "testuser" {
t.Error("Username not properly set in context")
}
// Check role
if val, ok := ctx.Value(roleKey).(string); !ok || val != "admin" {
t.Error("Role not properly set in context")
}
}
func TestGetClaims(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
Username: "testuser",
Role: "admin",
}
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), claimsKey, claims)
req = req.WithContext(ctx)
retrievedClaims, ok := GetClaims(req)
if !ok {
t.Error("Expected claims to be found")
}
if retrievedClaims.UserID != "user123" {
t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UserID)
}
}
func TestGetUserID(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), userIDKey, "user123")
req = req.WithContext(ctx)
userID, ok := GetUserID(req)
if !ok {
t.Error("Expected userID to be found")
}
if userID != "user123" {
t.Errorf("Expected 'user123', got '%s'", userID)
}
}
func TestGetUsername(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), usernameKey, "testuser")
req = req.WithContext(ctx)
username, ok := GetUsername(req)
if !ok {
t.Error("Expected username to be found")
}
if username != "testuser" {
t.Errorf("Expected 'testuser', got '%s'", username)
}
}
func TestGetRole(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), roleKey, "admin")
req = req.WithContext(ctx)
role, ok := GetRole(req)
if !ok {
t.Error("Expected role to be found")
}
if role != "admin" {
t.Errorf("Expected 'admin', got '%s'", role)
}
}
func TestJWTAuth_NoAuthHeader(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestJWTAuth_InvalidToken(t *testing.T) {
os.Setenv("JWT_KEY", "test-secret-key")
jwtSecretOnce = sync.Once{}
jwtSecretError = nil
defer os.Unsetenv("JWT_KEY")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer invalid.token.here")
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestJWTAuth_ValidToken(t *testing.T) {
os.Setenv("JWT_KEY", "test-secret-key")
jwtSecretOnce = sync.Once{}
jwtSecretError = nil
defer os.Unsetenv("JWT_KEY")
claims := &models.Claims{
UserID: "user123",
Username: "testuser",
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte("test-secret-key"))
handler := func(w http.ResponseWriter, r *http.Request) {
// Verify claims are in context
if retrievedClaims, ok := GetClaims(r); !ok || retrievedClaims.UserID != "user123" {
t.Error("Claims not found or incorrect in context")
}
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
}