From ac0ff0088045ff8ca25e46a575647c2439ebf1da Mon Sep 17 00:00:00 2001 From: F04C Date: Tue, 27 Jan 2026 10:58:24 +0800 Subject: [PATCH] fix tests --- handlers/google_auth.go | 26 ++++++++++++++++--- handlers/jwt.go | 51 ++++++++++++++++++++++++------------- services/access_log_test.go | 8 +++--- services/users.go | 2 +- services/users_test.go | 43 +++++++++++-------------------- 5 files changed, 76 insertions(+), 54 deletions(-) diff --git a/handlers/google_auth.go b/handlers/google_auth.go index 74bf08d..688fdc9 100644 --- a/handlers/google_auth.go +++ b/handlers/google_auth.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "encoding/json" "errors" + "flag" "fmt" "io" "log" @@ -30,6 +31,10 @@ var oauthStateString = generateRandomState() var AuthorizationURL string var FetchedRedirectURI *string +func isTestEnvironment() bool { + return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test") +} + // init initializes the Google OAuth2 configuration by loading environment variables // from a .env file. If the .env file cannot be loaded, it logs a fatal error. // Note: This init runs AFTER .env is loaded in main() init @@ -51,6 +56,19 @@ func init() { clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET") backendURL := os.Getenv("BACKEND_URL") + if (clientID == "" || clientSecret == "" || backendURL == "") && isTestEnvironment() { + if clientID == "" { + clientID = "test-google-client-id" + } + if clientSecret == "" { + clientSecret = "test-google-client-secret" + } + if backendURL == "" { + backendURL = "http://localhost:8080" + } + log.Print("[google_auth.init] Using test fallback values for Google OAuth configuration") + } + log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID)) log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret)) log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL) @@ -164,7 +182,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { if err != nil { errMsg := err.Error() helper.LogError(err, "Failed to fetch Google user info") - + // Provide user-friendly error messages for different scenarios if strings.Contains(errMsg, "TLS handshake timeout") { helper.RespondWithError(w, http.StatusGatewayTimeout, "Connection to Google failed due to network issues. Please try again in a moment.") @@ -186,7 +204,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { helper.RespondWithError(w, http.StatusForbidden, "Access to Google authentication was denied. Please try again later.") return } - + helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google. Please try again.") return } @@ -343,7 +361,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo if err != nil { errMsg := fmt.Sprintf("Failed to fetch user info from Google: %v", err) helper.LogError(err, errMsg) - + // Provide more specific error messages for common issues if os.IsTimeout(err) { return models.UserGoogleInfo{}, fmt.Errorf("request timed out: Google userinfo endpoint took too long to respond (timeout: 30s)") @@ -360,7 +378,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo if strings.Contains(err.Error(), "no such host") { return models.UserGoogleInfo{}, fmt.Errorf("DNS resolution failed: Cannot resolve googleapis.com") } - + return models.UserGoogleInfo{}, fmt.Errorf("network error while fetching user info: %w", err) } defer func(Body io.ReadCloser) { diff --git a/handlers/jwt.go b/handlers/jwt.go index 422acc4..13afa31 100644 --- a/handlers/jwt.go +++ b/handlers/jwt.go @@ -22,6 +22,28 @@ import ( var rsaPrivateKey *rsa.PrivateKey +func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(keyData) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block containing private key") + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + key, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse RSA private key: %w", err) + } + } + + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("not an RSA private key") + } + + return rsaKey, nil +} + // Note: .env file is loaded in main() before this init runs func init() { keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH") @@ -31,29 +53,24 @@ func init() { keyData, err := os.ReadFile(keyPath) if err != nil { + if isTestEnvironment() { + log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err) + generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048) + if genErr != nil { + log.Fatalf("Failed to generate test RSA private key: %v", genErr) + } + rsaPrivateKey = generatedKey + return + } log.Fatalf("Failed to read RSA private key file: %v", err) } - log.Print("Key Data: ", string(keyData)) - block, _ := pem.Decode(keyData) - if block == nil { - log.Fatal("Failed to decode PEM block containing private key") - } - - key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + parsedKey, err := parseRSAPrivateKey(keyData) if err != nil { - key, err = x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - log.Fatalf("Failed to parse RSA private key: %v", err) - } - } - - var ok bool - rsaPrivateKey, ok = key.(*rsa.PrivateKey) - if !ok { - log.Fatal("Not an RSA private key") + log.Fatalf("%v", err) } + rsaPrivateKey = parsedKey log.Println("RSA private key loaded successfully for JWT signing") } diff --git a/services/access_log_test.go b/services/access_log_test.go index 3c61d05..b3a245c 100644 --- a/services/access_log_test.go +++ b/services/access_log_test.go @@ -30,8 +30,8 @@ func TestInsertAccessLogLogin(t *testing.T) { Time: currentTime, } - mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`). - WithArgs(&userID, &participantID, activityType, ipAddress, &fieldData, currentTime). + mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`). + WithArgs(&userID, activityType, ipAddress, &fieldData, currentTime). WillReturnResult(sqlmock.NewResult(1, 1)) err := InsertAccessLogLogin(accessLog) @@ -62,8 +62,8 @@ func TestInsertAccessLogLoginNullFields(t *testing.T) { Time: currentTime, } - mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`). - WithArgs(nil, nil, activityType, ipAddress, nil, currentTime). + mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`). + WithArgs(nil, activityType, ipAddress, nil, currentTime). WillReturnResult(sqlmock.NewResult(1, 1)) err := InsertAccessLogLogin(accessLog) diff --git a/services/users.go b/services/users.go index f5b1387..e84d436 100644 --- a/services/users.go +++ b/services/users.go @@ -19,7 +19,7 @@ func GetUserID(email string) (string, error) { func CheckEmailInDB(email string) (bool, error) { var exists bool query := `SELECT EXISTS ( - SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0);` + SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)` err := db.DB.QueryRow(query, email).Scan(&exists) if err != nil { return false, err diff --git a/services/users_test.go b/services/users_test.go index c8c73d6..6770aa7 100644 --- a/services/users_test.go +++ b/services/users_test.go @@ -32,14 +32,11 @@ func TestGetUser(t *testing.T) { email := "test@example.com" expectedID := "user123" - expectedFirstName := "John" - expectedLastName := "Doe" - expectedEmail := "test@example.com" - rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}). - AddRow(expectedID, expectedFirstName, expectedLastName, expectedEmail) + rows := sqlmock.NewRows([]string{"user_id"}). + AddRow(expectedID) - mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -51,7 +48,7 @@ func TestGetUserNotFound(t *testing.T) { email := "nonexistent@example.com" - mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrNoRows) @@ -73,10 +70,10 @@ func TestGetUserNullNames(t *testing.T) { email := "test@example.com" expectedID := "user456" - rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}). - AddRow(expectedID, nil, nil, email) + rows := sqlmock.NewRows([]string{"user_id"}). + AddRow(expectedID) - mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -99,11 +96,10 @@ func TestGetUserID(t *testing.T) { email := "test@example.com" expectedID := "user789" - rows := sqlmock.NewRows([]string{"id"}). + rows := sqlmock.NewRows([]string{"user_id"}). AddRow(expectedID) - // Note: The query has a typo "SELECT id, FROM" but we match it as-is - mock.ExpectQuery(`SELECT id, FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -201,7 +197,7 @@ func TestGetUserIDFromEmail(t *testing.T) { rows := sqlmock.NewRows([]string{"id"}). AddRow(expectedID) - mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -226,7 +222,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) { email := "notfound@example.com" - mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrNoRows) @@ -247,7 +243,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) { email := "error@example.com" - mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrConnDone) @@ -282,19 +278,10 @@ func TestGetUserMultipleEmails(t *testing.T) { for _, tc := range testCases { t.Run(tc.email, func(t *testing.T) { - var firstName, lastName interface{} - if tc.hasNames { - firstName = "First" - lastName = "Last" - } else { - firstName = nil - lastName = nil - } + rows := sqlmock.NewRows([]string{"user_id"}). + AddRow(tc.userID) - rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}). - AddRow(tc.userID, firstName, lastName, tc.email) - - mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`). + mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(tc.email). WillReturnRows(rows)