fix tests
This commit is contained in:
+22
-4
@@ -9,6 +9,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -30,6 +31,10 @@ var oauthStateString = generateRandomState()
|
||||
var AuthorizationURL string
|
||||
var FetchedRedirectURI *string
|
||||
|
||||
func isTestEnvironment() bool {
|
||||
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
|
||||
}
|
||||
|
||||
// init initializes the Google OAuth2 configuration by loading environment variables
|
||||
// from a .env file. If the .env file cannot be loaded, it logs a fatal error.
|
||||
// Note: This init runs AFTER .env is loaded in main() init
|
||||
@@ -51,6 +56,19 @@ func init() {
|
||||
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
backendURL := os.Getenv("BACKEND_URL")
|
||||
|
||||
if (clientID == "" || clientSecret == "" || backendURL == "") && isTestEnvironment() {
|
||||
if clientID == "" {
|
||||
clientID = "test-google-client-id"
|
||||
}
|
||||
if clientSecret == "" {
|
||||
clientSecret = "test-google-client-secret"
|
||||
}
|
||||
if backendURL == "" {
|
||||
backendURL = "http://localhost:8080"
|
||||
}
|
||||
log.Print("[google_auth.init] Using test fallback values for Google OAuth configuration")
|
||||
}
|
||||
|
||||
log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID))
|
||||
log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret))
|
||||
log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL)
|
||||
@@ -164,7 +182,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
helper.LogError(err, "Failed to fetch Google user info")
|
||||
|
||||
|
||||
// Provide user-friendly error messages for different scenarios
|
||||
if strings.Contains(errMsg, "TLS handshake timeout") {
|
||||
helper.RespondWithError(w, http.StatusGatewayTimeout, "Connection to Google failed due to network issues. Please try again in a moment.")
|
||||
@@ -186,7 +204,7 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
helper.RespondWithError(w, http.StatusForbidden, "Access to Google authentication was denied. Please try again later.")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google. Please try again.")
|
||||
return
|
||||
}
|
||||
@@ -343,7 +361,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to fetch user info from Google: %v", err)
|
||||
helper.LogError(err, errMsg)
|
||||
|
||||
|
||||
// Provide more specific error messages for common issues
|
||||
if os.IsTimeout(err) {
|
||||
return models.UserGoogleInfo{}, fmt.Errorf("request timed out: Google userinfo endpoint took too long to respond (timeout: 30s)")
|
||||
@@ -360,7 +378,7 @@ func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoo
|
||||
if strings.Contains(err.Error(), "no such host") {
|
||||
return models.UserGoogleInfo{}, fmt.Errorf("DNS resolution failed: Cannot resolve googleapis.com")
|
||||
}
|
||||
|
||||
|
||||
return models.UserGoogleInfo{}, fmt.Errorf("network error while fetching user info: %w", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
|
||||
+34
-17
@@ -22,6 +22,28 @@ import (
|
||||
|
||||
var rsaPrivateKey *rsa.PrivateKey
|
||||
|
||||
func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(keyData)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode PEM block containing private key")
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not an RSA private key")
|
||||
}
|
||||
|
||||
return rsaKey, nil
|
||||
}
|
||||
|
||||
// Note: .env file is loaded in main() before this init runs
|
||||
func init() {
|
||||
keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH")
|
||||
@@ -31,29 +53,24 @@ func init() {
|
||||
|
||||
keyData, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
if isTestEnvironment() {
|
||||
log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err)
|
||||
generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if genErr != nil {
|
||||
log.Fatalf("Failed to generate test RSA private key: %v", genErr)
|
||||
}
|
||||
rsaPrivateKey = generatedKey
|
||||
return
|
||||
}
|
||||
log.Fatalf("Failed to read RSA private key file: %v", err)
|
||||
}
|
||||
|
||||
log.Print("Key Data: ", string(keyData))
|
||||
block, _ := pem.Decode(keyData)
|
||||
if block == nil {
|
||||
log.Fatal("Failed to decode PEM block containing private key")
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
parsedKey, err := parseRSAPrivateKey(keyData)
|
||||
if err != nil {
|
||||
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse RSA private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var ok bool
|
||||
rsaPrivateKey, ok = key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
log.Fatal("Not an RSA private key")
|
||||
log.Fatalf("%v", err)
|
||||
}
|
||||
|
||||
rsaPrivateKey = parsedKey
|
||||
log.Println("RSA private key loaded successfully for JWT signing")
|
||||
}
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ func TestInsertAccessLogLogin(t *testing.T) {
|
||||
Time: currentTime,
|
||||
}
|
||||
|
||||
mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`).
|
||||
WithArgs(&userID, &participantID, activityType, ipAddress, &fieldData, currentTime).
|
||||
mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`).
|
||||
WithArgs(&userID, activityType, ipAddress, &fieldData, currentTime).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
err := InsertAccessLogLogin(accessLog)
|
||||
@@ -62,8 +62,8 @@ func TestInsertAccessLogLoginNullFields(t *testing.T) {
|
||||
Time: currentTime,
|
||||
}
|
||||
|
||||
mock.ExpectExec(`INSERT INTO access_log \( user_id, participant_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?, \?\)`).
|
||||
WithArgs(nil, nil, activityType, ipAddress, nil, currentTime).
|
||||
mock.ExpectExec(`INSERT INTO access_log \( user_id, activity_type, ip_address, field_updated, time\) VALUES \(\?, \?, \?, \?, \?\)`).
|
||||
WithArgs(nil, activityType, ipAddress, nil, currentTime).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
err := InsertAccessLogLogin(accessLog)
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@ func GetUserID(email string) (string, error) {
|
||||
func CheckEmailInDB(email string) (bool, error) {
|
||||
var exists bool
|
||||
query := `SELECT EXISTS (
|
||||
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0);`
|
||||
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`
|
||||
err := db.DB.QueryRow(query, email).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
+15
-28
@@ -32,14 +32,11 @@ func TestGetUser(t *testing.T) {
|
||||
|
||||
email := "test@example.com"
|
||||
expectedID := "user123"
|
||||
expectedFirstName := "John"
|
||||
expectedLastName := "Doe"
|
||||
expectedEmail := "test@example.com"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
|
||||
AddRow(expectedID, expectedFirstName, expectedLastName, expectedEmail)
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -51,7 +48,7 @@ func TestGetUserNotFound(t *testing.T) {
|
||||
|
||||
email := "nonexistent@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
@@ -73,10 +70,10 @@ func TestGetUserNullNames(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
expectedID := "user456"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
|
||||
AddRow(expectedID, nil, nil, email)
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -99,11 +96,10 @@ func TestGetUserID(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
expectedID := "user789"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id"}).
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
// Note: The query has a typo "SELECT id, FROM" but we match it as-is
|
||||
mock.ExpectQuery(`SELECT id, FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -201,7 +197,7 @@ func TestGetUserIDFromEmail(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -226,7 +222,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) {
|
||||
|
||||
email := "notfound@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
@@ -247,7 +243,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) {
|
||||
|
||||
email := "error@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT id FROM \( SELECT id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
@@ -282,19 +278,10 @@ func TestGetUserMultipleEmails(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.email, func(t *testing.T) {
|
||||
var firstName, lastName interface{}
|
||||
if tc.hasNames {
|
||||
firstName = "First"
|
||||
lastName = "Last"
|
||||
} else {
|
||||
firstName = nil
|
||||
lastName = nil
|
||||
}
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
AddRow(tc.userID)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "first_name", "last_name", "email_address"}).
|
||||
AddRow(tc.userID, firstName, lastName, tc.email)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, first_name, last_name, email_address FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1`).
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(tc.email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user