diff --git a/handlers/access_log.go b/handlers/access_log.go index 5aa01db..b9d89fc 100644 --- a/handlers/access_log.go +++ b/handlers/access_log.go @@ -5,7 +5,6 @@ import ( "authentication/services" "fmt" "net/http" - "net/url" ) func accessLog(w http.ResponseWriter, r *http.Request, user *string, actType int, fieldUpdated interface{}) { @@ -27,7 +26,7 @@ func accessLog(w http.ResponseWriter, r *http.Request, user *string, actType int if err == nil { errMsg = "Perform Action" } - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(fmt.Sprintf("Failed to %s", errMsg))), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to %s", errMsg)) return } } diff --git a/handlers/google_auth.go b/handlers/google_auth.go index 815c13e..9953c3c 100644 --- a/handlers/google_auth.go +++ b/handlers/google_auth.go @@ -14,7 +14,6 @@ import ( "log" "net" "net/http" - "net/url" "os" "strings" "time" @@ -28,7 +27,6 @@ import ( var googleOauthConfig oauth2.Config var oauthStateString = generateRandomState() -var DashboardBaseURL string var AuthorizationURL string // init initializes the Google OAuth2 configuration by loading environment variables @@ -75,7 +73,6 @@ func init() { log.Fatal("GOOGLE_CLIENT_SECRET is not set in environment variables") } - DashboardBaseURL = os.Getenv("DASHBOARD_URL") AuthorizationURL = os.Getenv("AUTHORIZATION_URL") } @@ -90,10 +87,10 @@ func generateRandomState() string { } // checkUserAuthorization calls the authorization microservice to verify user permissions -func checkUserAuthorization(userID, accessToken string) (bool, string, error) { +func checkUserAuthorization(userID, accessToken string) (bool, error) { if AuthorizationURL == "" { helper.LogWarn("AUTHORIZATION_URL not configured, skipping authorization check") - return true, "", nil // Allow access if authorization service is not configured + return false, nil // Allow access if authorization service is not configured } // Prepare request to authorization microservice @@ -108,13 +105,13 @@ func checkUserAuthorization(userID, accessToken string) (bool, string, error) { jsonData, err := json.Marshal(reqBody) if err != nil { helper.LogError(err, "Failed to marshal authorization request") - return false, "", err + return false, err } req, err := http.NewRequest("POST", authCheckURL, strings.NewReader(string(jsonData))) if err != nil { helper.LogError(err, "Failed to create authorization request") - return false, "", err + return false, err } log.Print("JSON Data Sent to AuthZ Service: ", string(jsonData)) @@ -125,7 +122,7 @@ func checkUserAuthorization(userID, accessToken string) (bool, string, error) { resp, err := client.Do(req) if err != nil { helper.LogError(err, "Failed to call authorization microservice") - return false, "", err + return false, err } defer resp.Body.Close() @@ -133,7 +130,7 @@ func checkUserAuthorization(userID, accessToken string) (bool, string, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { helper.LogError(err, "Failed to read authorization response body") - return false, "", err + return false, err } log.Printf("AUTHZ RAW RESPONSE Status: %d, Body: %s", resp.StatusCode, string(bodyBytes)) @@ -143,14 +140,14 @@ func checkUserAuthorization(userID, accessToken string) (bool, string, error) { if err := json.Unmarshal(bodyBytes, &authResp); err != nil { helper.LogError(err, "Failed to decode authorization response") log.Printf("Failed to unmarshal response body: %s", string(bodyBytes)) - return false, "", err + return false, err } log.Printf("AUTHZ RESPONSE for user %s: %+v", userID, authResp) helper.LogInfo(fmt.Sprintf("Authorization check for user %s: allowed=%v, redirect=%s, message=%s", userID, authResp.Allowed, authResp.RedirectRoute, authResp.Message)) - return authResp.Allowed, authResp.RedirectRoute, nil + return authResp.Allowed, nil } func GoogleLogin(w http.ResponseWriter, r *http.Request) { @@ -219,24 +216,23 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { userInfo, err := FetchGoogleUserInfo(w, r) if err != nil { - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google") return } email := userInfo.Email - profilePicture := userInfo.Picture emailExists, err := checkEmailInDB(email) if err != nil { helper.LogError(err, "Error checking email") - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Error checking email in database")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusBadGateway, "Error checking email in database") return } helper.LogError(fmt.Errorf("%v", emailExists), "Email exists in DB") accessToken, refreshToken, err := GenerateTokens(email, userAgent, ipAddress) if err != nil { helper.LogError(err, "Error generating access token") - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token generation failed")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusInternalServerError, "Token generation failed") return } @@ -273,80 +269,55 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly, cookieConfig.SameSite)) if !emailExists { - helper.LogWarn(fmt.Sprintf("Email %s does not exist in the database", email)) - registrationURL := fmt.Sprintf("%s/callback?error=%s&token=%s", DashboardBaseURL, url.QueryEscape("Please register first"), accessToken) - http.Redirect(w, r, registrationURL, http.StatusSeeOther) + helper.RespondWithError(w, http.StatusUnauthorized, "Email not registered. Please contact the administrator.") return } - var firstName string - helper.LogInfo("Fetching first name for email: " + email) helper.LogInfo("Userinfo Email: " + userInfo.Email) - userID, firstNamePtr, lastNamePtr, emailAddressPtr, err := services.GetUser(email) + userID, err := services.GetUserID(email) if err != nil { helper.LogError(err, "Error fetching user") - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("User not found")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information") return } // Dereference pointers to get actual string values - if firstNamePtr != nil { - firstName = *firstNamePtr - } - lastName := "" - if lastNamePtr != nil { - lastName = *lastNamePtr - } - emailAddress := emailAddressPtr helper.LogInfo("Access Token Generated Copy this: " + accessToken) err = helper.LogLoginEventV2(userID, ipAddress) if err != nil { - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Failed to log login event")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusBadGateway, "Failed to Log Login Event") return } // Check user authorization via authorization microservice - allowed, redirectRoute, err := checkUserAuthorization(userID, accessToken) + allowed, err := checkUserAuthorization(userID, accessToken) if err != nil { helper.LogError(err, "Authorization check failed") - // Continue with default flow if authorization service is unavailable - helper.LogWarn("Proceeding without authorization check due to error") + helper.RespondWithError(w, http.StatusBadGateway, "Authorization check failed") } if !allowed { helper.LogWarn(fmt.Sprintf("User %s denied access by authorization service", userID)) - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Access denied: Insufficient permissions")), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusForbidden, "Access denied: Insufficient permissions") return } helper.LogInfo("Copy this access token: " + accessToken) - - // Determine redirect URL based on authorization response - var redirectURL string - if redirectRoute != "" { - // Authorization service provided a specific route - redirectURL = fmt.Sprintf("%s%s?token=%s&user_id=%s&first_name=%s&last_name=%s&email_address=%s&profile_picture=%s", - DashboardBaseURL, redirectRoute, accessToken, userID, firstName, lastName, emailAddress, profilePicture) - helper.LogInfo(fmt.Sprintf("Redirecting user to authorized route: %s", redirectRoute)) - } else { - // Default to dashboard callback - redirectURL = fmt.Sprintf("%s/callback?token=%s&user_id=%s&first_name=%s&last_name=%s&email_address=%s&profile_picture=%s", - DashboardBaseURL, accessToken, userID, firstName, lastName, emailAddress, profilePicture) - helper.LogInfo("Redirecting user to default dashboard") - } - - http.Redirect(w, r, redirectURL, http.StatusSeeOther) + helper.RespondWithJSON(w, http.StatusOK, map[string]string{ + "message": "Authentication successful", + "access_token": accessToken, + }) } func validateState(w http.ResponseWriter, r *http.Request) bool { cookie, err := r.Cookie("oauth_state") if err != nil || r.URL.Query().Get("state") != cookie.Value { helper.LogWarn(errorInvalidState) - http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(errorInvalidState)), http.StatusSeeOther) + helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState) return false } helper.LogInfo(fmt.Sprintf("Cookie state: %s, Callback state: %s", cookie.Value, r.URL.Query().Get("state"))) diff --git a/services/users.go b/services/users.go index 79a120c..df845dc 100644 --- a/services/users.go +++ b/services/users.go @@ -5,20 +5,6 @@ import ( "log" ) -func GetUser(email string) (string, *string, *string, string, error) { - log.Print(email) - query := `SELECT user_id, first_name, last_name, email_address FROM users WHERE email_address = ? AND is_deleted = 0 LIMIT 1;` - var id string - var firstName *string - var lastName *string - var emailAddress string - err := db.DB.QueryRow(query, email).Scan(&id, &firstName, &lastName, &emailAddress) - if err != nil { - return "", nil, nil, "", err - } - return id, firstName, lastName, emailAddress, nil -} - func GetUserID(email string) (string, error) { log.Print(email) query := `SELECT user_id, FROM users WHERE email_address = ? AND is_deleted = 0 LIMIT 1;` diff --git a/services/users_test.go b/services/users_test.go index 1897cb0..c8c73d6 100644 --- a/services/users_test.go +++ b/services/users_test.go @@ -43,31 +43,6 @@ func TestGetUser(t *testing.T) { WithArgs(email). WillReturnRows(rows) - id, firstName, lastName, emailAddress, err := GetUser(email) - - if err != nil { - t.Errorf("Expected no error, got: %v", err) - } - - if id != expectedID { - t.Errorf("Expected ID %s, got %s", expectedID, id) - } - - if firstName == nil || *firstName != expectedFirstName { - t.Errorf("Expected first name %s, got %v", expectedFirstName, firstName) - } - - if lastName == nil || *lastName != expectedLastName { - t.Errorf("Expected last name %s, got %v", expectedLastName, lastName) - } - - if emailAddress != expectedEmail { - t.Errorf("Expected email %s, got %s", expectedEmail, emailAddress) - } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("Unfulfilled expectations: %v", err) - } } func TestGetUserNotFound(t *testing.T) { @@ -80,7 +55,7 @@ func TestGetUserNotFound(t *testing.T) { WithArgs(email). WillReturnError(sql.ErrNoRows) - id, firstName, lastName, emailAddress, err := GetUser(email) + id, err := GetUserID(email) if err != sql.ErrNoRows { t.Errorf("Expected sql.ErrNoRows, got: %v", err) @@ -89,18 +64,6 @@ func TestGetUserNotFound(t *testing.T) { if id != "" { t.Errorf("Expected empty ID, got %s", id) } - - if firstName != nil { - t.Errorf("Expected nil firstName, got %v", firstName) - } - - if lastName != nil { - t.Errorf("Expected nil lastName, got %v", lastName) - } - - if emailAddress != "" { - t.Errorf("Expected empty email, got %s", emailAddress) - } } func TestGetUserNullNames(t *testing.T) { @@ -117,7 +80,7 @@ func TestGetUserNullNames(t *testing.T) { WithArgs(email). WillReturnRows(rows) - id, firstName, lastName, emailAddress, err := GetUser(email) + id, err := GetUserID(email) if err != nil { t.Errorf("Expected no error, got: %v", err) @@ -127,17 +90,6 @@ func TestGetUserNullNames(t *testing.T) { t.Errorf("Expected ID %s, got %s", expectedID, id) } - if firstName != nil { - t.Errorf("Expected nil firstName for NULL value, got %v", firstName) - } - - if lastName != nil { - t.Errorf("Expected nil lastName for NULL value, got %v", lastName) - } - - if emailAddress != email { - t.Errorf("Expected email %s, got %s", email, emailAddress) - } } func TestGetUserID(t *testing.T) { @@ -346,7 +298,7 @@ func TestGetUserMultipleEmails(t *testing.T) { WithArgs(tc.email). WillReturnRows(rows) - id, fn, ln, email, err := GetUser(tc.email) + id, err := GetUserID(tc.email) if err != nil { t.Errorf("Expected no error, got: %v", err) @@ -355,20 +307,6 @@ func TestGetUserMultipleEmails(t *testing.T) { if id != tc.userID { t.Errorf("Expected ID %s, got %s", tc.userID, id) } - - if tc.hasNames { - if fn == nil || ln == nil { - t.Error("Expected names to be present") - } - } else { - if fn != nil || ln != nil { - t.Error("Expected names to be nil") - } - } - - if email != tc.email { - t.Errorf("Expected email %s, got %s", tc.email, email) - } }) } }