fix tests

This commit is contained in:
2026-01-27 10:58:24 +08:00
parent 8af97e970a
commit ac0ff00880
5 changed files with 76 additions and 54 deletions
+22 -4
View File
@@ -9,6 +9,7 @@ import (
"crypto/rand"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
@@ -30,6 +31,10 @@ var oauthStateString = generateRandomState()
var AuthorizationURL string
var FetchedRedirectURI *string
func isTestEnvironment() bool {
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
}
// init initializes the Google OAuth2 configuration by loading environment variables
// from a .env file. If the .env file cannot be loaded, it logs a fatal error.
// Note: This init runs AFTER .env is loaded in main() init
@@ -51,6 +56,19 @@ func init() {
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
backendURL := os.Getenv("BACKEND_URL")
if (clientID == "" || clientSecret == "" || backendURL == "") && isTestEnvironment() {
if clientID == "" {
clientID = "test-google-client-id"
}
if clientSecret == "" {
clientSecret = "test-google-client-secret"
}
if backendURL == "" {
backendURL = "http://localhost:8080"
}
log.Print("[google_auth.init] Using test fallback values for Google OAuth configuration")
}
log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID))
log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret))
log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL)
@@ -164,7 +182,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
if err != nil {
errMsg := err.Error()
helper.LogError(err, "Failed to fetch Google user info")
// Provide user-friendly error messages for different scenarios
if strings.Contains(errMsg, "TLS handshake timeout") {
helper.RespondWithError(w, http.StatusGatewayTimeout, "Connection to Google failed due to network issues. Please try again in a moment.")
@@ -186,7 +204,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
helper.RespondWithError(w, http.StatusForbidden, "Access to Google authentication was denied. Please try again later.")
return
}
helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google. Please try again.")
return
}
@@ -343,7 +361,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo
if err != nil {
errMsg := fmt.Sprintf("Failed to fetch user info from Google: %v", err)
helper.LogError(err, errMsg)
// Provide more specific error messages for common issues
if os.IsTimeout(err) {
return models.UserGoogleInfo{}, fmt.Errorf("request timed out: Google userinfo endpoint took too long to respond (timeout: 30s)")
@@ -360,7 +378,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo
if strings.Contains(err.Error(), "no such host") {
return models.UserGoogleInfo{}, fmt.Errorf("DNS resolution failed: Cannot resolve googleapis.com")
}
return models.UserGoogleInfo{}, fmt.Errorf("network error while fetching user info: %w", err)
}
defer func(Body io.ReadCloser) {
+34 -17
View File
@@ -22,6 +22,28 @@ import (
var rsaPrivateKey *rsa.PrivateKey
func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(keyData)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
}
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key")
}
return rsaKey, nil
}
// Note: .env file is loaded in main() before this init runs
func init() {
keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH")
@@ -31,29 +53,24 @@ func init() {
keyData, err := os.ReadFile(keyPath)
if err != nil {
if isTestEnvironment() {
log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err)
generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048)
if genErr != nil {
log.Fatalf("Failed to generate test RSA private key: %v", genErr)
}
rsaPrivateKey = generatedKey
return
}
log.Fatalf("Failed to read RSA private key file: %v", err)
}
log.Print("Key Data: ", string(keyData))
block, _ := pem.Decode(keyData)
if block == nil {
log.Fatal("Failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
parsedKey, err := parseRSAPrivateKey(keyData)
if err != nil {
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
log.Fatalf("Failed to parse RSA private key: %v", err)
}
}
var ok bool
rsaPrivateKey, ok = key.(*rsa.PrivateKey)
if !ok {
log.Fatal("Not an RSA private key")
log.Fatalf("%v", err)
}
rsaPrivateKey = parsedKey
log.Println("RSA private key loaded successfully for JWT signing")
}
+4 -4
View File
@@ -30,8 +30,8 @@ func TestInsertAccessLogLogin(t *testing.T) {
Time: currentTime,
}
mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`).
WithArgs(&userID, &participantID, activityType, ipAddress, &fieldData, currentTime).
mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`).
WithArgs(&userID, activityType, ipAddress, &fieldData, currentTime).
WillReturnResult(sqlmock.NewResult(1, 1))
err := InsertAccessLogLogin(accessLog)
@@ -62,8 +62,8 @@ func TestInsertAccessLogLoginNullFields(t *testing.T) {
Time: currentTime,
}
mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`).
WithArgs(nil, nil, activityType, ipAddress, nil, currentTime).
mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`).
WithArgs(nil, activityType, ipAddress, nil, currentTime).
WillReturnResult(sqlmock.NewResult(1, 1))
err := InsertAccessLogLogin(accessLog)
+1 -1
View File
@@ -19,7 +19,7 @@ func GetUserID(email string) (string, error) {
func CheckEmailInDB(email string) (bool, error) {
var exists bool
query := `SELECT EXISTS (
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0);`
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`
err := db.DB.QueryRow(query, email).Scan(&exists)
if err != nil {
return false, err
+15 -28
View File
@@ -32,14 +32,11 @@ func TestGetUser(t *testing.T) {
email := "test@example.com"
expectedID := "user123"
expectedFirstName := "John"
expectedLastName := "Doe"
expectedEmail := "test@example.com"
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
AddRow(expectedID, expectedFirstName, expectedLastName, expectedEmail)
rows := sqlmock.NewRows([]string{"user_id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -51,7 +48,7 @@ func TestGetUserNotFound(t *testing.T) {
email := "nonexistent@example.com"
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrNoRows)
@@ -73,10 +70,10 @@ func TestGetUserNullNames(t *testing.T) {
email := "test@example.com"
expectedID := "user456"
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
AddRow(expectedID, nil, nil, email)
rows := sqlmock.NewRows([]string{"user_id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -99,11 +96,10 @@ func TestGetUserID(t *testing.T) {
email := "test@example.com"
expectedID := "user789"
rows := sqlmock.NewRows([]string{"id"}).
rows := sqlmock.NewRows([]string{"user_id"}).
AddRow(expectedID)
// Note: The query has a typo "SELECT id, FROM" but we match it as-is
mock.ExpectQuery(`SELECT id, FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -201,7 +197,7 @@ func TestGetUserIDFromEmail(t *testing.T) {
rows := sqlmock.NewRows([]string{"id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -226,7 +222,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) {
email := "notfound@example.com"
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrNoRows)
@@ -247,7 +243,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) {
email := "error@example.com"
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrConnDone)
@@ -282,19 +278,10 @@ func TestGetUserMultipleEmails(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.email, func(t *testing.T) {
var firstName, lastName interface{}
if tc.hasNames {
firstName = "First"
lastName = "Last"
} else {
firstName = nil
lastName = nil
}
rows := sqlmock.NewRows([]string{"user_id"}).
AddRow(tc.userID)
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
AddRow(tc.userID, firstName, lastName, tc.email)
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(tc.email).
WillReturnRows(rows)