Merge branch 'rsa' into 'main'
Rsa See merge request psa/uess/authn!1
This commit is contained in:
+82
-22
@@ -603,42 +603,47 @@ func checkEmailInDB(email string) (bool, error) {
|
||||
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !isValidAuthHeader(authHeader) {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Authorization header missing or invalid")
|
||||
return
|
||||
}
|
||||
clearRefreshTokenCookie(w)
|
||||
clearCSRFCookie(w)
|
||||
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
if tokenString == "" {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Token is missing or empty")
|
||||
return
|
||||
}
|
||||
if isValidAuthHeader(authHeader) {
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
if tokenString != "" {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
if rsaPrivateKey == nil {
|
||||
return nil, errors.New("RSA private key is not initialized")
|
||||
}
|
||||
return &rsaPrivateKey.PublicKey, nil
|
||||
})
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok {
|
||||
userID, err := services.GetUserIDFromEmail(claims.Email)
|
||||
if err == nil {
|
||||
if err := RevokeAllUserSessions(userID); err != nil {
|
||||
helper.LogError(err, "Failed to revoke user sessions during logout")
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok {
|
||||
userID, err := services.GetUserIDFromEmail(claims.Email)
|
||||
if err == nil {
|
||||
if err := RevokeAllUserSessions(userID); err != nil {
|
||||
helper.LogError(err, "Failed to revoke user sessions during logout")
|
||||
}
|
||||
} else {
|
||||
helper.LogError(err, "Failed to get user ID during logout")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogError(err, "Failed to get user ID during logout")
|
||||
helper.LogError(err, "Failed to parse JWT token during logout")
|
||||
}
|
||||
} else {
|
||||
helper.LogWarn("Authorization header contains empty bearer token during logout")
|
||||
}
|
||||
} else {
|
||||
helper.LogError(err, "Failed to parse JWT token during logout")
|
||||
helper.LogWarn("Authorization header missing or invalid during logout; proceeding with cookie clear only")
|
||||
}
|
||||
|
||||
if err := accessLog(r, nil, 18, nil); err != nil {
|
||||
helper.LogError(err, "Failed to write access log during logout")
|
||||
}
|
||||
|
||||
clearRefreshTokenCookie(w)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"message": "Successfully logged out",
|
||||
"action": "clear_session_storage",
|
||||
@@ -703,3 +708,58 @@ func clearRefreshTokenCookie(w http.ResponseWriter) {
|
||||
|
||||
helper.LogInfo("Refresh token cookie clearing commands sent to browser")
|
||||
}
|
||||
|
||||
func clearCSRFCookie(w http.ResponseWriter) {
|
||||
helper.LogInfo("Clearing csrf_token cookie...")
|
||||
|
||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||
|
||||
// Match middleware cookie characteristics first (host-only, SameSiteStrict)
|
||||
primaryCookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, primaryCookie)
|
||||
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #1 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
||||
primaryCookie.Name, primaryCookie.Domain, primaryCookie.Secure, primaryCookie.SameSite))
|
||||
|
||||
// Fallback for local/dev browser behavior where secure or samesite attributes differ
|
||||
fallbackCookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, fallbackCookie)
|
||||
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #2 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
||||
fallbackCookie.Name, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.SameSite))
|
||||
|
||||
if !isSecure {
|
||||
localhostCookie := &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: "localhost",
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, localhostCookie)
|
||||
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #3 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
||||
localhostCookie.Name, localhostCookie.Domain, localhostCookie.Secure, localhostCookie.SameSite))
|
||||
}
|
||||
|
||||
helper.LogInfo("CSRF token cookie clearing commands sent to browser")
|
||||
}
|
||||
|
||||
+17
-6
@@ -183,14 +183,25 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string,
|
||||
AccessTokenExpiration = "45"
|
||||
}
|
||||
|
||||
if roleID == nil {
|
||||
roleID = []int{}
|
||||
}
|
||||
|
||||
var primaryRoleID *int
|
||||
if len(roleID) > 0 {
|
||||
value := roleID[0]
|
||||
primaryRoleID = &value
|
||||
}
|
||||
|
||||
expirationTime := time.Now().Add(24 * time.Hour).Unix()
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
UsersID: userID,
|
||||
RoleID: roleID,
|
||||
SessionID: sessionID,
|
||||
Exp: expirationTime,
|
||||
Email: email,
|
||||
UsersID: userID,
|
||||
RoleID: primaryRoleID,
|
||||
AdditionalRoleID: roleID,
|
||||
SessionID: sessionID,
|
||||
Exp: expirationTime,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Unix(expirationTime, 0)),
|
||||
},
|
||||
@@ -528,7 +539,7 @@ func RevokeSession(sessionID string) error {
|
||||
func RevokeAllUserSessions(userID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID)
|
||||
rows, err := db.DB.Query("SELECT jwt_sessions_id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
||||
}
|
||||
|
||||
+6
-5
@@ -7,11 +7,12 @@ import (
|
||||
)
|
||||
|
||||
type AccessToken struct {
|
||||
Email string `json:"email"`
|
||||
UsersID string `json:"users_id"`
|
||||
RoleID []int `json:"role_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Exp int64 `json:"exp"`
|
||||
Email string `json:"email"`
|
||||
UsersID string `json:"users_id"`
|
||||
RoleID *int `json:"role_id,omitempty"`
|
||||
AdditionalRoleID []int `json:"additional_role_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Exp int64 `json:"exp"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package models
|
||||
|
||||
type User struct {
|
||||
UserID string `json:"user_id"`
|
||||
FirstName string `json:"first_name"`
|
||||
MiddleInitial *string `json:"middle_initial"`
|
||||
LastName string `json:"last_name"`
|
||||
Suffix *string `json:"suffix"`
|
||||
OfficeID *int `json:"office_id"`
|
||||
RoleID *int `json:"role_id,omitempty"`
|
||||
Projects *[]ProjectMetadata `json:"projects,omitempty"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
MIS *int `json:"mis"`
|
||||
CAPI *int `json:"capi"`
|
||||
CAWI *int `json:"cawi"`
|
||||
DPS *int `json:"dps"`
|
||||
Sex *string `json:"sex"`
|
||||
OfficeName *string `json:"office_name"`
|
||||
StatusOfEmployment *string `json:"status_of_employment"`
|
||||
UserType *int `json:"user_type"`
|
||||
HomeAddress *string `json:"home_address"`
|
||||
ContactNumber *string `json:"contact_number"`
|
||||
UpdatedBy *string `json:"updated_by"`
|
||||
}
|
||||
|
||||
type ProjectMetadata struct {
|
||||
ProjectID int `json:"project_id"`
|
||||
Alias *string `json:"alias"`
|
||||
RoleID []int `json:"role_id"`
|
||||
OfficeID *int `json:"office_id"`
|
||||
}
|
||||
+249
-15
@@ -2,7 +2,13 @@ package services
|
||||
|
||||
import (
|
||||
"authentication/db"
|
||||
"authentication/models"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetUserID(email string) (string, error) {
|
||||
@@ -53,27 +59,255 @@ func GetUserIDFromEmail(email string) (string, error) {
|
||||
|
||||
func GetRoleIDsFromEmail(email string) ([]int, error) {
|
||||
log.Print(email)
|
||||
query := `SELECT ur.role_id
|
||||
FROM uess_user_management.user_roles ur
|
||||
JOIN uess_user_management.users u ON ur.users_id = u.users_id
|
||||
|
||||
globalQuery := `SELECT DISTINCT ur.role_id
|
||||
FROM uess_user_management.users u
|
||||
JOIN uess_user_management.user_roles ur ON u.users_id = ur.users_id
|
||||
WHERE u.email_address = ?
|
||||
AND u.is_deleted = 0`
|
||||
rows, err := db.DB.Query(query, email)
|
||||
AND u.is_deleted = 0
|
||||
AND ur.role_id IS NOT NULL`
|
||||
|
||||
globalRows, err := db.DB.Query(globalQuery, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
defer globalRows.Close()
|
||||
|
||||
var roleIDs []int
|
||||
for rows.Next() {
|
||||
var roleID int
|
||||
if err := rows.Scan(&roleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roleIDs = append(roleIDs, roleID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
projectQuery := `SELECT DISTINCT pa.role_id
|
||||
FROM uess_user_management.users u
|
||||
JOIN uess_project_management.project_assignment pa ON u.users_id = pa.users_id
|
||||
WHERE u.email_address = ?
|
||||
AND u.is_deleted = 0
|
||||
AND pa.is_active = 1
|
||||
AND pa.role_id IS NOT NULL`
|
||||
|
||||
projectRows, err := db.DB.Query(projectQuery, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer projectRows.Close()
|
||||
|
||||
roleIDs := make([]int, 0)
|
||||
seen := make(map[int]struct{})
|
||||
|
||||
for globalRows.Next() {
|
||||
var roleID int
|
||||
if err := globalRows.Scan(&roleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, exists := seen[roleID]; !exists {
|
||||
seen[roleID] = struct{}{}
|
||||
roleIDs = append(roleIDs, roleID)
|
||||
}
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for projectRows.Next() {
|
||||
var roleID int
|
||||
if err := projectRows.Scan(&roleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, exists := seen[roleID]; !exists {
|
||||
seen[roleID] = struct{}{}
|
||||
roleIDs = append(roleIDs, roleID)
|
||||
}
|
||||
}
|
||||
if err := projectRows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sort.Ints(roleIDs)
|
||||
return roleIDs, nil
|
||||
}
|
||||
|
||||
func FetchUserByEmail(email string) (models.User, error) {
|
||||
query := `SELECT u.users_id, u.first_name, u.middle_initial, u.last_name, u.suffix, GROUP_CONCAT(DISTINCT ur.role_id ORDER BY ur.role_id) AS role_ids, u.office_id, u.email_address, MAX(ur.MIS), MAX(ur.CAPI), MAX(ur.CAWI), MAX(ur.DPS),
|
||||
u.sex, u.status_of_employment, u.home_address, u.contact_number, u.user_type
|
||||
FROM uess_user_management.users u
|
||||
LEFT JOIN uess_user_management.user_roles ur
|
||||
ON u.users_id = ur.users_id
|
||||
WHERE u.email_address = ?
|
||||
GROUP BY u.users_id, u.first_name, u.middle_initial, u.last_name, u.suffix, u.office_id, u.email_address, u.sex, u.status_of_employment, u.home_address, u.contact_number, u.user_type`
|
||||
|
||||
var user models.User
|
||||
var roleIDsCSV sql.NullString
|
||||
|
||||
var middleInitial sql.NullString
|
||||
var suffix sql.NullString
|
||||
var officeID sql.NullInt64
|
||||
var mis sql.NullInt64
|
||||
var capi sql.NullInt64
|
||||
var cawi sql.NullInt64
|
||||
var dps sql.NullInt64
|
||||
var sex sql.NullString
|
||||
var statusOfEmployment sql.NullString
|
||||
var homeAddress sql.NullString
|
||||
var contactNumber sql.NullString
|
||||
var userType sql.NullInt64
|
||||
|
||||
err := db.DB.QueryRow(query, email).Scan(
|
||||
&user.UserID,
|
||||
&user.FirstName,
|
||||
&middleInitial,
|
||||
&user.LastName,
|
||||
&suffix,
|
||||
&roleIDsCSV,
|
||||
&officeID,
|
||||
&user.EmailAddress,
|
||||
&mis,
|
||||
&capi,
|
||||
&cawi,
|
||||
&dps,
|
||||
&sex,
|
||||
&statusOfEmployment,
|
||||
&homeAddress,
|
||||
&contactNumber,
|
||||
&userType,
|
||||
)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
if middleInitial.Valid {
|
||||
value := middleInitial.String
|
||||
user.MiddleInitial = &value
|
||||
}
|
||||
if suffix.Valid {
|
||||
value := suffix.String
|
||||
user.Suffix = &value
|
||||
}
|
||||
if officeID.Valid {
|
||||
value := int(officeID.Int64)
|
||||
user.OfficeID = &value
|
||||
}
|
||||
if mis.Valid {
|
||||
value := int(mis.Int64)
|
||||
user.MIS = &value
|
||||
}
|
||||
if capi.Valid {
|
||||
value := int(capi.Int64)
|
||||
user.CAPI = &value
|
||||
}
|
||||
if cawi.Valid {
|
||||
value := int(cawi.Int64)
|
||||
user.CAWI = &value
|
||||
}
|
||||
if dps.Valid {
|
||||
value := int(dps.Int64)
|
||||
user.DPS = &value
|
||||
}
|
||||
if sex.Valid {
|
||||
value := sex.String
|
||||
user.Sex = &value
|
||||
}
|
||||
if statusOfEmployment.Valid {
|
||||
value := statusOfEmployment.String
|
||||
user.StatusOfEmployment = &value
|
||||
}
|
||||
if homeAddress.Valid {
|
||||
value := homeAddress.String
|
||||
user.HomeAddress = &value
|
||||
}
|
||||
if contactNumber.Valid {
|
||||
value := contactNumber.String
|
||||
user.ContactNumber = &value
|
||||
}
|
||||
if userType.Valid {
|
||||
value := int(userType.Int64)
|
||||
user.UserType = &value
|
||||
}
|
||||
|
||||
baseRoleIDs, parseErr := parseRoleIDsCSV(roleIDsCSV.String)
|
||||
if parseErr != nil {
|
||||
return user, parseErr
|
||||
}
|
||||
if len(baseRoleIDs) > 0 {
|
||||
primaryRoleID := baseRoleIDs[0]
|
||||
user.RoleID = &primaryRoleID
|
||||
}
|
||||
|
||||
projectsQuery := `SELECT pa.project_id, p.alias, GROUP_CONCAT(DISTINCT pa.role_id ORDER BY pa.role_id) AS role_ids, u.office_id
|
||||
FROM uess_user_management.users u
|
||||
LEFT JOIN uess_project_management.project_assignment pa
|
||||
ON u.users_id = pa.users_id AND pa.is_active = 1
|
||||
LEFT JOIN uess_project_management.project p
|
||||
ON pa.project_id = p.project_id
|
||||
WHERE u.email_address = ? AND pa.project_id IS NOT NULL
|
||||
GROUP BY pa.project_id, p.alias, u.office_id`
|
||||
|
||||
rows, err := db.DB.Query(projectsQuery, email)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
projects := make([]models.ProjectMetadata, 0)
|
||||
for rows.Next() {
|
||||
var project models.ProjectMetadata
|
||||
var projectAlias sql.NullString
|
||||
var projectRoleIDsCSV sql.NullString
|
||||
var projectOfficeID sql.NullInt64
|
||||
|
||||
if scanErr := rows.Scan(&project.ProjectID, &projectAlias, &projectRoleIDsCSV, &projectOfficeID); scanErr != nil {
|
||||
return user, scanErr
|
||||
}
|
||||
|
||||
if projectAlias.Valid {
|
||||
alias := projectAlias.String
|
||||
project.Alias = &alias
|
||||
}
|
||||
|
||||
if projectOfficeID.Valid {
|
||||
office := int(projectOfficeID.Int64)
|
||||
project.OfficeID = &office
|
||||
}
|
||||
|
||||
roleIDs, parseErr := parseRoleIDsCSV(projectRoleIDsCSV.String)
|
||||
if parseErr != nil {
|
||||
return user, parseErr
|
||||
}
|
||||
project.RoleID = roleIDs
|
||||
projects = append(projects, project)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
if len(projects) > 0 {
|
||||
user.Projects = &projects
|
||||
if user.RoleID == nil && len(projects[0].RoleID) > 0 {
|
||||
primaryRoleID := projects[0].RoleID[0]
|
||||
user.RoleID = &primaryRoleID
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func parseRoleIDsCSV(roleIDsCSV string) ([]int, error) {
|
||||
trimmed := strings.TrimSpace(roleIDsCSV)
|
||||
if trimmed == "" {
|
||||
return make([]int, 0), nil
|
||||
}
|
||||
|
||||
parts := strings.Split(trimmed, ",")
|
||||
roleIDs := make([]int, 0, len(parts))
|
||||
|
||||
for _, part := range parts {
|
||||
value := strings.TrimSpace(part)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid role id %q: %w", value, err)
|
||||
}
|
||||
roleIDs = append(roleIDs, parsed)
|
||||
}
|
||||
|
||||
return roleIDs, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"authentication/db"
|
||||
@@ -330,3 +331,164 @@ func TestCheckEmailInDBVariousEmails(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRoleIDsFromEmail(t *testing.T) {
|
||||
mock, cleanup := setupMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
email := "roles@example.com"
|
||||
expectedRoleIDs := []int{2, 4, 8}
|
||||
|
||||
globalRows := sqlmock.NewRows([]string{"role_id"}).
|
||||
AddRow(2).
|
||||
AddRow(8)
|
||||
|
||||
projectRows := sqlmock.NewRows([]string{"role_id"}).
|
||||
AddRow(4).
|
||||
AddRow(8).
|
||||
AddRow(2)
|
||||
|
||||
mock.ExpectQuery(`SELECT DISTINCT ur\.role_id\s+FROM uess_user_management\.users u\s+JOIN uess_user_management\.user_roles ur ON u\.users_id = ur\.users_id\s+WHERE u\.email_address = \?\s+AND u\.is_deleted = 0\s+AND ur\.role_id IS NOT NULL`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(globalRows)
|
||||
|
||||
mock.ExpectQuery(`SELECT DISTINCT pa\.role_id\s+FROM uess_user_management\.users u\s+JOIN uess_project_management\.project_assignment pa ON u\.users_id = pa\.users_id\s+WHERE u\.email_address = \?\s+AND u\.is_deleted = 0\s+AND pa\.is_active = 1\s+AND pa\.role_id IS NOT NULL`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(projectRows)
|
||||
|
||||
roleIDs, err := GetRoleIDsFromEmail(email)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(roleIDs, expectedRoleIDs) {
|
||||
t.Errorf("Expected role IDs %v, got %v", expectedRoleIDs, roleIDs)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("Unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRoleIDsFromEmailQueryError(t *testing.T) {
|
||||
mock, cleanup := setupMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
email := "roles-error@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT DISTINCT ur\.role_id\s+FROM uess_user_management\.users u\s+JOIN uess_user_management\.user_roles ur ON u\.users_id = ur\.users_id\s+WHERE u\.email_address = \?\s+AND u\.is_deleted = 0\s+AND ur\.role_id IS NOT NULL`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
roleIDs, err := GetRoleIDsFromEmail(email)
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
|
||||
if roleIDs != nil {
|
||||
t.Errorf("Expected nil role IDs on error, got %v", roleIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRoleIDsFromEmailNoRowsReturnsEmptySlice(t *testing.T) {
|
||||
mock, cleanup := setupMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
email := "no-roles@example.com"
|
||||
|
||||
globalRows := sqlmock.NewRows([]string{"role_id"})
|
||||
projectRows := sqlmock.NewRows([]string{"role_id"})
|
||||
|
||||
mock.ExpectQuery(`SELECT DISTINCT ur\.role_id\s+FROM uess_user_management\.users u\s+JOIN uess_user_management\.user_roles ur ON u\.users_id = ur\.users_id\s+WHERE u\.email_address = \?\s+AND u\.is_deleted = 0\s+AND ur\.role_id IS NOT NULL`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(globalRows)
|
||||
|
||||
mock.ExpectQuery(`SELECT DISTINCT pa\.role_id\s+FROM uess_user_management\.users u\s+JOIN uess_project_management\.project_assignment pa ON u\.users_id = pa\.users_id\s+WHERE u\.email_address = \?\s+AND u\.is_deleted = 0\s+AND pa\.is_active = 1\s+AND pa\.role_id IS NOT NULL`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(projectRows)
|
||||
|
||||
roleIDs, err := GetRoleIDsFromEmail(email)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if roleIDs == nil {
|
||||
t.Error("Expected empty slice, got nil")
|
||||
}
|
||||
|
||||
if len(roleIDs) != 0 {
|
||||
t.Errorf("Expected empty slice, got %v", roleIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRoleIDsCSV(t *testing.T) {
|
||||
roleIDs, err := parseRoleIDsCSV("1, 12,3")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
expected := []int{1, 12, 3}
|
||||
if !reflect.DeepEqual(roleIDs, expected) {
|
||||
t.Fatalf("Expected %v, got %v", expected, roleIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRoleIDsCSVEmpty(t *testing.T) {
|
||||
roleIDs, err := parseRoleIDsCSV("")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if roleIDs == nil {
|
||||
t.Fatal("Expected empty slice, got nil")
|
||||
}
|
||||
|
||||
if len(roleIDs) != 0 {
|
||||
t.Fatalf("Expected empty slice, got %v", roleIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUserByEmail(t *testing.T) {
|
||||
mock, cleanup := setupMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
email := "d.israel.psa@gmail.com"
|
||||
|
||||
userRows := sqlmock.NewRows([]string{
|
||||
"users_id", "first_name", "middle_initial", "last_name", "suffix", "role_ids", "office_id", "email_address", "MIS", "CAPI", "CAWI", "DPS", "sex", "status_of_employment", "home_address", "contact_number", "user_type",
|
||||
}).AddRow(
|
||||
"U0000000001", "AAAAAAAA", "A", "Israel", "", "1", 103, email, 1, 1, 1, 1, nil, "COSW", "Quezon City", "09171234567", nil,
|
||||
)
|
||||
|
||||
mock.ExpectQuery(`SELECT u\.users_id, u\.first_name, u\.middle_initial, u\.last_name, u\.suffix, GROUP_CONCAT\(DISTINCT ur\.role_id ORDER BY ur\.role_id\) AS role_ids, u\.office_id, u\.email_address, MAX\(ur\.MIS\), MAX\(ur\.CAPI\), MAX\(ur\.CAWI\), MAX\(ur\.DPS\),\s+u\.sex, u\.status_of_employment, u\.home_address, u\.contact_number, u\.user_type\s+FROM uess_user_management\.users u\s+LEFT JOIN uess_user_management\.user_roles ur\s+ON u\.users_id = ur\.users_id\s+WHERE u\.email_address = \?\s+GROUP BY u\.users_id, u\.first_name, u\.middle_initial, u\.last_name, u\.suffix, u\.office_id, u\.email_address, u\.sex, u\.status_of_employment, u\.home_address, u\.contact_number, u\.user_type`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(userRows)
|
||||
|
||||
projectRows := sqlmock.NewRows([]string{"project_id", "alias", "role_ids", "office_id"}).
|
||||
AddRow(1, "TTT", "1,12", 103)
|
||||
|
||||
mock.ExpectQuery(`SELECT pa\.project_id, p\.alias, GROUP_CONCAT\(DISTINCT pa\.role_id ORDER BY pa\.role_id\) AS role_ids, u\.office_id\s+FROM uess_user_management\.users u\s+LEFT JOIN uess_project_management\.project_assignment pa\s+ON u\.users_id = pa\.users_id AND pa\.is_active = 1\s+LEFT JOIN uess_project_management\.project p\s+ON pa\.project_id = p\.project_id\s+WHERE u\.email_address = \? AND pa\.project_id IS NOT NULL\s+GROUP BY pa\.project_id, p\.alias, u\.office_id`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(projectRows)
|
||||
|
||||
user, err := FetchUserByEmail(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if user.RoleID == nil || *user.RoleID != 1 {
|
||||
t.Fatalf("Expected RoleID=1, got %v", user.RoleID)
|
||||
}
|
||||
|
||||
if user.Projects == nil || len(*user.Projects) != 1 {
|
||||
t.Fatalf("Expected 1 project, got %v", user.Projects)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual((*user.Projects)[0].RoleID, []int{1, 12}) {
|
||||
t.Fatalf("Expected project roles [1 12], got %v", (*user.Projects)[0].RoleID)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("Unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user