ae1831e61f
- Rename user_id → users_id across all models, handlers, services, and tests
- Add custom RoleIDs type supporting string/int/array unmarshaling (e.g., "1", 1, [1])
- Implement flexible JSON unmarshaling for JWT Claims to handle field name variants
- Support both user_id/users_id and email/email_address field names
- Enable role_id as string ("1"), int (1), or array ([1,2])
- Update AuthorizationContext to handle role_id type flexibility
- Add comprehensive logging to repository, service, and handler layers
- Entry/exit logs with full context
- Success (✓) and failure (✗) indicators
- Step-by-step authorization flow tracking
- Add containsRole helper for multi-role membership checks
- Fix database queries: user_id → users_id, id → permissions_id
- Update all tests to use models.RoleIDs{} syntax
- Change GetRole middleware return type: string → []int
- Maintain backward compatibility with legacy JWT tokens
This change improves integration with external services (MIS) that may send
role_id in different formats and standardizes field naming conventions
throughout the authorization microservice.
568 lines
14 KiB
Go
568 lines
14 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/models"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"math/big"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
os.Setenv("GO_ENV", "development")
|
|
code := m.Run()
|
|
os.Unsetenv("GO_ENV")
|
|
os.Exit(code)
|
|
}
|
|
|
|
// Test helper to generate RSA key pair and certificate
|
|
func generateTestRSACertificate(t *testing.T) (privateKey *rsa.PrivateKey, certPEM []byte) {
|
|
t.Helper()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Org"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create certificate: %v", err)
|
|
}
|
|
|
|
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
return privateKey, certPEM
|
|
}
|
|
|
|
func TestGetRSAPublicKey(t *testing.T) {
|
|
t.Run("Valid PEM file", func(t *testing.T) {
|
|
// Reset state for testing
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
key, err := getRSAPublicKey()
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if key == nil {
|
|
t.Error("Expected RSA public key, got nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestExtractBearerToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
authHeader string
|
|
wantToken string
|
|
wantOK bool
|
|
}{
|
|
{
|
|
name: "Valid Bearer token",
|
|
authHeader: "Bearer token123",
|
|
wantToken: "token123",
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Empty header",
|
|
authHeader: "",
|
|
wantToken: "",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "Too short",
|
|
authHeader: "Bearer",
|
|
wantToken: "",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "Wrong prefix",
|
|
authHeader: "Basic token123",
|
|
wantToken: "",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "Missing space",
|
|
authHeader: "Bearertoken123",
|
|
wantToken: "",
|
|
wantOK: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
token, ok := extractBearerToken(tt.authHeader)
|
|
if token != tt.wantToken {
|
|
t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token)
|
|
}
|
|
if ok != tt.wantOK {
|
|
t.Errorf("Expected ok %v, got %v", tt.wantOK, ok)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseAndValidateToken(t *testing.T) {
|
|
// Reset state
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
t.Run("Valid token", func(t *testing.T) {
|
|
// Generate test RSA key and certificate
|
|
_, certPEM := generateTestRSACertificate(t)
|
|
|
|
// Create temporary test certificate file
|
|
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
if _, err := tmpFile.Write(certPEM); err != nil {
|
|
t.Fatalf("Failed to write cert: %v", err)
|
|
}
|
|
tmpFile.Close()
|
|
|
|
// Note: This test would need the actual PEM file to be present
|
|
// For now, we'll skip the full validation
|
|
t.Skip("Requires actual RSA certificate file")
|
|
})
|
|
|
|
t.Run("Invalid token", func(t *testing.T) {
|
|
// Reset state
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
_, err := parseAndValidateToken("invalid.token.string")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid token")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBuildContext(t *testing.T) {
|
|
claims := &models.Claims{
|
|
UsersID: "user123",
|
|
RoleID: models.RoleIDs{3},
|
|
}
|
|
|
|
parent := context.Background()
|
|
ctx := buildContext(parent, claims)
|
|
|
|
// Check claims
|
|
if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UsersID != "user123" {
|
|
t.Error("Claims not properly set in context")
|
|
}
|
|
|
|
// Check userID
|
|
if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" {
|
|
t.Error("UserID not properly set in context")
|
|
}
|
|
|
|
// Check role
|
|
if val, ok := ctx.Value(roleIDKey).([]int); !ok || len(val) == 0 || val[0] != 3 {
|
|
t.Error("Role not properly set in context")
|
|
}
|
|
}
|
|
|
|
func TestGetClaims(t *testing.T) {
|
|
claims := &models.Claims{
|
|
UsersID: "user123",
|
|
RoleID: models.RoleIDs{3},
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
ctx := context.WithValue(req.Context(), claimsKey, claims)
|
|
req = req.WithContext(ctx)
|
|
|
|
retrievedClaims, ok := GetClaims(req)
|
|
if !ok {
|
|
t.Error("Expected claims to be found")
|
|
}
|
|
if retrievedClaims.UsersID != "user123" {
|
|
t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UsersID)
|
|
}
|
|
}
|
|
|
|
func TestGetUserID(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
ctx := context.WithValue(req.Context(), userIDKey, "user123")
|
|
req = req.WithContext(ctx)
|
|
|
|
userID, ok := GetUserID(req)
|
|
if !ok {
|
|
t.Error("Expected userID to be found")
|
|
}
|
|
if userID != "user123" {
|
|
t.Errorf("Expected 'user123', got '%s'", userID)
|
|
}
|
|
}
|
|
|
|
func TestGetRole(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
ctx := context.WithValue(req.Context(), roleIDKey, []int{3})
|
|
req = req.WithContext(ctx)
|
|
|
|
role, ok := GetRole(req)
|
|
if !ok {
|
|
t.Error("Expected role to be found")
|
|
}
|
|
if len(role) == 0 {
|
|
t.Errorf("Expected at least one role, got '%v'", role)
|
|
} else if role[0] != 3 {
|
|
t.Errorf("Expected first role to be 3, got '%v'", role[0])
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthNoAuthHeader(t *testing.T) {
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthInvalidToken(t *testing.T) {
|
|
// Reset RSA state
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthValidToken(t *testing.T) {
|
|
t.Skip("Requires RSA certificate setup - integration test")
|
|
}
|
|
|
|
// Additional comprehensive test cases
|
|
|
|
func TestExtractBearerTokenEdgeCases(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
header string
|
|
wantToken string
|
|
wantOk bool
|
|
}{
|
|
{"Multiple spaces", "Bearer token123", " token123", true}, // Extracts everything after "Bearer "
|
|
{"No space after Bearer", "Bearertoken123", "", false},
|
|
{"Lowercase bearer", "bearer token123", "", false},
|
|
{"Mixed case", "BeArEr token123", "", false},
|
|
{"Extra whitespace", " Bearer token123", "", false}, // Must start with "Bearer "
|
|
{"Token with spaces", "Bearer token with spaces", "token with spaces", true},
|
|
{"Very long token", "Bearer " + string(make([]byte, 5000)), string(make([]byte, 5000)), true},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
gotToken, gotOk := extractBearerToken(tc.header)
|
|
if gotToken != tc.wantToken || gotOk != tc.wantOk {
|
|
t.Errorf("extractBearerToken(%q) = (%q, %v), want (%q, %v)", tc.header, gotToken, gotOk, tc.wantToken, tc.wantOk)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseAndValidateTokenMalformedTokens(t *testing.T) {
|
|
// Reset RSA state
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
testCases := []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"Empty string", ""},
|
|
{"Random string", "not.a.jwt.token"},
|
|
{"Only dots", "..."},
|
|
{"Two parts only", "header.payload"},
|
|
{"Four parts", "part1.part2.part3.part4"},
|
|
{"Invalid base64", "!@#$.!@#$.!@#$"},
|
|
{"Spaces in token", "part1 .part2 .part3"},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := parseAndValidateToken(tc.token)
|
|
if err == nil {
|
|
t.Errorf("Expected error for malformed token %q", tc.name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildContextWithDifferentRoles(t *testing.T) {
|
|
roles := []int{3, 4, 5, 6, 7, 8}
|
|
|
|
for _, role := range roles {
|
|
t.Run("Role: "+string(rune(role)), func(t *testing.T) {
|
|
claims := &models.Claims{
|
|
UsersID: "user123",
|
|
RoleID: models.RoleIDs{role},
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
newReq := buildContext(req.Context(), claims)
|
|
reqWithCtx := req.WithContext(newReq)
|
|
|
|
retrievedClaims, ok := GetClaims(reqWithCtx)
|
|
if !ok {
|
|
t.Error("Claims not found in context")
|
|
}
|
|
if len(retrievedClaims.RoleID) == 0 || retrievedClaims.RoleID[0] != role {
|
|
t.Errorf("Role = %v, want %v", retrievedClaims.RoleID, role)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetClaimsWithoutClaims(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
claims, ok := GetClaims(req)
|
|
if ok {
|
|
t.Error("Expected ok=false when claims not in context")
|
|
}
|
|
if claims != nil {
|
|
t.Error("Expected nil claims when not in context")
|
|
}
|
|
}
|
|
|
|
func TestGetClaimsWithWrongType(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
ctx := context.WithValue(req.Context(), models.ContextKey("claims"), "wrong type")
|
|
req = req.WithContext(ctx)
|
|
|
|
claims, ok := GetClaims(req)
|
|
if ok {
|
|
t.Error("Expected ok=false when claims are wrong type")
|
|
}
|
|
if claims != nil {
|
|
t.Error("Expected nil claims when wrong type in context")
|
|
}
|
|
}
|
|
|
|
func TestGetUserIDWithNoClaims(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
userID, ok := GetUserID(req)
|
|
if ok {
|
|
t.Error("Expected ok=false when no claims")
|
|
}
|
|
if userID != "" {
|
|
t.Errorf("Expected empty string, got %q", userID)
|
|
}
|
|
}
|
|
|
|
func TestGetRoleWithNoClaims(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
role, ok := GetRole(req)
|
|
if ok {
|
|
t.Error("Expected ok=false when no claims")
|
|
}
|
|
if role != nil && len(role) != 0 {
|
|
t.Errorf("Expected no roles, got %v", role)
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthMissingBearerPrefix(t *testing.T) {
|
|
// Reset RSA state
|
|
rsaPublicKeyOnce = sync.Once{}
|
|
rsaPublicKeyError = nil
|
|
rsaPublicKeyCached = nil
|
|
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Authorization", "InvalidToken")
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthExpiredToken(t *testing.T) {
|
|
t.Skip("Requires RSA certificate setup - integration test")
|
|
}
|
|
|
|
func TestJWTAuthTokenWithMissingClaims(t *testing.T) {
|
|
t.Skip("Requires RSA certificate setup - integration test")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
claims *models.Claims
|
|
}{
|
|
{
|
|
"Missing UserID",
|
|
&models.Claims{
|
|
RoleID: models.RoleIDs{3},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"Missing Role",
|
|
&models.Claims{
|
|
UsersID: "user123",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, tc.claims)
|
|
tokenString, _ := token.SignedString([]byte("test-secret"))
|
|
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := GetClaims(r)
|
|
if !ok {
|
|
t.Error("Claims should still be in context even if some fields are empty")
|
|
}
|
|
// Verify the missing field is empty
|
|
_ = claims
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenString)
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
// Token is valid, just missing some claim fields
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthConcurrentRequests(t *testing.T) {
|
|
t.Skip("Requires RSA certificate setup - integration test")
|
|
|
|
claims := &models.Claims{
|
|
UsersID: "user123",
|
|
RoleID: models.RoleIDs{3},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, _ := token.SignedString([]byte("test-secret"))
|
|
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
if _, ok := GetClaims(r); !ok {
|
|
t.Error("Claims not found in concurrent request")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
concurrency := 50
|
|
done := make(chan bool, concurrency)
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
go func() {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenString)
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status %d in concurrent request, got %d", http.StatusOK, w.Code)
|
|
}
|
|
done <- true
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
<-done
|
|
}
|
|
}
|
|
|
|
func TestJWTAuthTokenSignedWithWrongKey(t *testing.T) {
|
|
t.Skip("Requires RSA certificate setup - integration test")
|
|
|
|
// Create token with wrong key
|
|
claims := &models.Claims{
|
|
UsersID: "user123",
|
|
RoleID: models.RoleIDs{3},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, _ := token.SignedString([]byte("wrong-secret"))
|
|
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenString)
|
|
w := httptest.NewRecorder()
|
|
|
|
JWTAuth(handler)(w, req)
|
|
|
|
if w.Code != http.StatusUnauthorized {
|
|
t.Errorf("Expected status %d for wrong signature, got %d", http.StatusUnauthorized, w.Code)
|
|
}
|
|
}
|