init commit
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"authentication/models"
|
||||
"authentication/services"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
func LogEvent(id string, user *string, ipAddress string, actType int, fieldUpdate interface{}) error {
|
||||
|
||||
fieldUpdated := new(json.RawMessage)
|
||||
if fieldUpdate != nil {
|
||||
data, err := json.Marshal(fieldUpdate)
|
||||
if err != nil {
|
||||
LogError(err, "Error marshalling field update")
|
||||
return err
|
||||
}
|
||||
fieldUpdated = (*json.RawMessage)(&data)
|
||||
}
|
||||
|
||||
params := models.LogEventParams{
|
||||
ActivityType: actType,
|
||||
IPAddress: ipAddress,
|
||||
FieldUpdated: fieldUpdated,
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
return LogLoginEventParams(params, ipAddress)
|
||||
}
|
||||
|
||||
func LogLoginEventV2(id string, ipAddress string) error {
|
||||
|
||||
params := models.LogEventParams{
|
||||
ActivityType: 17,
|
||||
IPAddress: ipAddress,
|
||||
FieldUpdated: new(json.RawMessage),
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
return LogLoginEventParams(params, ipAddress)
|
||||
}
|
||||
|
||||
func LogLoginEventParams(params models.LogEventParams, ipAddress string) error {
|
||||
location, err := LoadAsiaManilaLocation()
|
||||
if err != nil {
|
||||
LogError(err, "Failed to load Asia/Manila timezone, using UTC+8 offset")
|
||||
}
|
||||
currentTime := time.Now().In(location)
|
||||
accessLog := models.UserAccessLog{
|
||||
UserID: params.UserID,
|
||||
ParticipantID: params.ParticipantID,
|
||||
ActivityType: params.ActivityType,
|
||||
IPAddress: ipAddress,
|
||||
FieldUpdated: params.FieldUpdated,
|
||||
Time: currentTime,
|
||||
}
|
||||
err = services.InsertAccessLogLogin(accessLog)
|
||||
if err != nil {
|
||||
LogError(err, params.ErrorMessage)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"authentication/models"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogEvent(t *testing.T) {
|
||||
// Note: This test requires database and Redis connections
|
||||
// In a real test environment, you'd use mocks or test databases
|
||||
// For now, we'll test the structure and basic validation
|
||||
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
userID := "user123"
|
||||
ipAddress := "192.168.1.1"
|
||||
actType := 17
|
||||
fieldUpdate := map[string]string{"field": "value"}
|
||||
|
||||
err := LogEvent("test-id", &userID, ipAddress, actType, fieldUpdate)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEventNilUser(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
ipAddress := "192.168.1.1"
|
||||
actType := 17
|
||||
|
||||
err := LogEvent("test-id", nil, ipAddress, actType, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with nil user, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEventNilFieldUpdate(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
userID := "user456"
|
||||
ipAddress := "10.0.0.1"
|
||||
actType := 5
|
||||
|
||||
err := LogEvent("test-id", &userID, ipAddress, actType, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with nil field update, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLoginEventV2(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
err := LogLoginEventV2("user789", "172.16.0.1")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLoginEventV2EmptyIP(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
err := LogLoginEventV2("user999", "")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with empty IP, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLoginEventParams(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
fieldUpdated := new(json.RawMessage)
|
||||
data := []byte(`{"key": "value"}`)
|
||||
fieldUpdated = (*json.RawMessage)(&data)
|
||||
|
||||
params := models.LogEventParams{
|
||||
UserID: stringPtr("user123"),
|
||||
ActivityType: 17,
|
||||
IPAddress: "192.168.1.100",
|
||||
FieldUpdated: fieldUpdated,
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
|
||||
err := LogLoginEventParams(params, "192.168.1.100")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLoginEventParamsActivityTypes(t *testing.T) {
|
||||
t.Skip("Integration test - requires database and Redis")
|
||||
|
||||
activityTypes := []int{1, 5, 10, 17, 20}
|
||||
|
||||
for _, actType := range activityTypes {
|
||||
params := models.LogEventParams{
|
||||
ActivityType: actType,
|
||||
IPAddress: "192.168.1.1",
|
||||
FieldUpdated: new(json.RawMessage),
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
|
||||
err := LogLoginEventParams(params, "192.168.1.1")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for activity type %d, got: %v", actType, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEventJSONMarshalling(t *testing.T) {
|
||||
// Test that field updates can be marshalled correctly
|
||||
testCases := []struct {
|
||||
name string
|
||||
fieldUpdate interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Simple map",
|
||||
fieldUpdate: map[string]string{"field": "value"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Complex object",
|
||||
fieldUpdate: map[string]interface{}{"nested": map[string]string{"key": "value"}},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Array",
|
||||
fieldUpdate: []string{"item1", "item2"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nil",
|
||||
fieldUpdate: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Unmarshalable (channel)",
|
||||
fieldUpdate: make(chan int),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var fieldUpdated *json.RawMessage
|
||||
|
||||
if tc.fieldUpdate != nil {
|
||||
data, err := json.Marshal(tc.fieldUpdate)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected marshalling error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected marshalling error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg := json.RawMessage(data)
|
||||
fieldUpdated = &rawMsg
|
||||
} else {
|
||||
fieldUpdated = new(json.RawMessage)
|
||||
}
|
||||
|
||||
if fieldUpdated == nil {
|
||||
t.Error("Field updated should not be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEventParamsStructValidation(t *testing.T) {
|
||||
// Test LogEventParams struct can be properly constructed
|
||||
params := models.LogEventParams{
|
||||
UserID: stringPtr("user123"),
|
||||
ParticipantID: stringPtr("part456"),
|
||||
ActivityType: 17,
|
||||
IPAddress: "192.168.1.1",
|
||||
FieldUpdated: new(json.RawMessage),
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
|
||||
if params.UserID == nil || *params.UserID != "user123" {
|
||||
t.Error("UserID not set correctly")
|
||||
}
|
||||
|
||||
if params.ParticipantID == nil || *params.ParticipantID != "part456" {
|
||||
t.Error("ParticipantID not set correctly")
|
||||
}
|
||||
|
||||
if params.ActivityType != 17 {
|
||||
t.Errorf("Expected activity type 17, got %d", params.ActivityType)
|
||||
}
|
||||
|
||||
if params.IPAddress != "192.168.1.1" {
|
||||
t.Errorf("Expected IP 192.168.1.1, got %s", params.IPAddress)
|
||||
}
|
||||
|
||||
if params.ErrorMessage != ErrorFailedtoLogLoginEvent {
|
||||
t.Errorf("Expected error message '%s', got '%s'", ErrorFailedtoLogLoginEvent, params.ErrorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAccessLogStructValidation(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
now := timeNow().In(location)
|
||||
|
||||
fieldData := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
log := models.UserAccessLog{
|
||||
UserID: stringPtr("user123"),
|
||||
ParticipantID: stringPtr("part456"),
|
||||
ActivityType: 17,
|
||||
IPAddress: "192.168.1.1",
|
||||
FieldUpdated: &fieldData,
|
||||
Time: now,
|
||||
}
|
||||
|
||||
if log.UserID == nil || *log.UserID != "user123" {
|
||||
t.Error("UserID not set correctly")
|
||||
}
|
||||
|
||||
if log.ActivityType != 17 {
|
||||
t.Errorf("Expected activity type 17, got %d", log.ActivityType)
|
||||
}
|
||||
|
||||
if log.IPAddress != "192.168.1.1" {
|
||||
t.Errorf("Expected IP 192.168.1.1, got %s", log.IPAddress)
|
||||
}
|
||||
|
||||
if log.FieldUpdated == nil {
|
||||
t.Error("FieldUpdated should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEventIPAddressFormats(t *testing.T) {
|
||||
// Test various IP address formats
|
||||
ipAddresses := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"2001:0db8:85a3:0000:0000:8a2e:0370:7334", // IPv6
|
||||
"::1", // IPv6 loopback
|
||||
"127.0.0.1", // localhost
|
||||
}
|
||||
|
||||
for _, ip := range ipAddresses {
|
||||
t.Run(ip, func(t *testing.T) {
|
||||
// Just test that the IP format is accepted
|
||||
params := models.LogEventParams{
|
||||
ActivityType: 17,
|
||||
IPAddress: ip,
|
||||
FieldUpdated: new(json.RawMessage),
|
||||
ErrorMessage: ErrorFailedtoLogLoginEvent,
|
||||
}
|
||||
|
||||
if params.IPAddress != ip {
|
||||
t.Errorf("Expected IP %s, got %s", ip, params.IPAddress)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func timeNow() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func TestLogLoginEventV2ActivityType(t *testing.T) {
|
||||
// Verify that LogLoginEventV2 uses activity type 17
|
||||
// This is verified by checking the function implementation
|
||||
|
||||
expectedActivityType := 17
|
||||
|
||||
// The function hardcodes activity type 17
|
||||
// We can't directly test this without integration tests,
|
||||
// but we can document the expected behavior
|
||||
|
||||
if expectedActivityType != 17 {
|
||||
t.Errorf("LogLoginEventV2 should use activity type 17")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFailedtoLogLoginEventConstant(t *testing.T) {
|
||||
if ErrorFailedtoLogLoginEvent == "" {
|
||||
t.Error("ErrorFailedtoLogLoginEvent constant should not be empty")
|
||||
}
|
||||
|
||||
expectedMsg := "Failed to log login event"
|
||||
if ErrorFailedtoLogLoginEvent != expectedMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedMsg, ErrorFailedtoLogLoginEvent)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package helper
|
||||
|
||||
const (
|
||||
ContentTypeHeader = "Content-Type"
|
||||
ApplicationJSON = "application/json"
|
||||
ErrorLabel = "error"
|
||||
MessageLabel = "message"
|
||||
ErrorEncodingResponse = "Error encoding response"
|
||||
ErrorFailedtoLogLoginEvent = "Failed to log login event"
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
)
|
||||
|
||||
// LogInfo logs an info message to both the local log and Sentry based on the environment.
|
||||
func LogInfo(message string) {
|
||||
goEnv := os.Getenv("GO_ENV")
|
||||
|
||||
if goEnv == "" {
|
||||
log.Fatal("GO_ENV is not set in error_logging LogInfo. Please set the GO_ENV environment variable.")
|
||||
}
|
||||
|
||||
if goEnv == "development" || goEnv == "debug" {
|
||||
log.Println("INFO:", message)
|
||||
}
|
||||
if goEnv == "production" || goEnv == "canary" {
|
||||
log.Println("INFO:", message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWarn logs a warning message to both the local log and Sentry based on the environment.
|
||||
func LogWarn(message string) {
|
||||
goEnv := os.Getenv("GO_ENV")
|
||||
|
||||
if goEnv == "" {
|
||||
log.Fatal("GO_ENV is not set in error_logging LogWarn. Please set the GO_ENV environment variable.")
|
||||
}
|
||||
if goEnv == "production" || goEnv == "canary" {
|
||||
sentry.CaptureMessage("WARNING: " + message)
|
||||
} else if goEnv == "development" || goEnv == "debug" {
|
||||
log.Println("WARNING:", message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error message to both the local log and Sentry based on the environment.
|
||||
func LogError(err error, message string) {
|
||||
goEnv := os.Getenv("GO_ENV")
|
||||
|
||||
if goEnv == "" {
|
||||
log.Fatal("GO_ENV is not set in error_logging LogError. Please set the GO_ENV environment variable.")
|
||||
}
|
||||
|
||||
if goEnv == "production" || goEnv == "canary" {
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
} else {
|
||||
sentry.CaptureMessage("ERROR: " + message)
|
||||
}
|
||||
log.Printf("ERROR: %s: %v", message, err)
|
||||
} else if goEnv == "development" || goEnv == "debug" {
|
||||
if err != nil {
|
||||
log.Printf("ERROR: %s: %v", message, err)
|
||||
} else {
|
||||
log.Println("ERROR:", message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogFatal logs a fatal error message to both the local log and Sentry based on the environment and then exits the application.
|
||||
func LogFatal(err error, message string) {
|
||||
goEnv := os.Getenv("GO_ENV")
|
||||
|
||||
if goEnv == "" {
|
||||
log.Fatal("GO_ENV is not set in error_logging LogFatal. Please set the GO_ENV environment variable.")
|
||||
}
|
||||
|
||||
if goEnv == "production" || goEnv == "canary" {
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
} else {
|
||||
sentry.CaptureMessage("FATAL: " + message)
|
||||
}
|
||||
log.Fatalf("FATAL: %s: %v", message, err)
|
||||
} else if goEnv == "development" || goEnv == "debug" {
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: %s: %v", message, err)
|
||||
} else {
|
||||
log.Fatalf("FATAL: %s", message)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func captureLogOutput(f func()) string {
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
f()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestLogInfo_Development(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("Test info message")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "INFO:") {
|
||||
t.Error("Expected INFO prefix in log output")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Test info message") {
|
||||
t.Error("Expected message to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogInfo_Debug(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "debug")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("Debug info message")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "INFO:") {
|
||||
t.Error("Expected INFO prefix in log output")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Debug info message") {
|
||||
t.Error("Expected message to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogInfo_Production(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "production")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("Production info message")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "INFO:") {
|
||||
t.Error("Expected INFO prefix in log output")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Production info message") {
|
||||
t.Error("Expected message to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogInfo_NoEnv(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Unsetenv("GO_ENV")
|
||||
|
||||
// LogInfo calls log.Fatal if GO_ENV not set, which exits the process
|
||||
// We can't easily test this without subprocess, so we'll skip this specific case
|
||||
// or test that it panics/exits
|
||||
t.Skip("Cannot test log.Fatal without subprocess")
|
||||
}
|
||||
|
||||
func TestLogWarn_Development(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogWarn("Test warning")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "WARNING:") {
|
||||
t.Error("Expected WARNING prefix in log output")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Test warning") {
|
||||
t.Error("Expected warning message to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWarn_Debug(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "debug")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogWarn("Debug warning")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "WARNING:") {
|
||||
t.Error("Expected WARNING prefix in log output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_Development(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
testErr := &testError{"test error"}
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogError(testErr, "Error occurred")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR:") {
|
||||
t.Error("Expected ERROR prefix in log output")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Error occurred") {
|
||||
t.Error("Expected error message to be logged")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "test error") {
|
||||
t.Error("Expected error details to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_NilError(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogError(nil, "Error message only")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR:") {
|
||||
t.Error("Expected ERROR prefix")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Error message only") {
|
||||
t.Error("Expected message to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_Debug(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "debug")
|
||||
|
||||
testErr := &testError{"debug error"}
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogError(testErr, "Debug error occurred")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR:") {
|
||||
t.Error("Expected ERROR prefix")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Debug error occurred") {
|
||||
t.Error("Expected error message")
|
||||
}
|
||||
}
|
||||
|
||||
type testError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *testError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func TestLogInfo_EmptyMessage(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "INFO:") {
|
||||
t.Error("Expected INFO prefix even with empty message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWarn_EmptyMessage(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogWarn("")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "WARNING:") {
|
||||
t.Error("Expected WARNING prefix even with empty message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_EmptyMessage(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogError(nil, "")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR:") {
|
||||
t.Error("Expected ERROR prefix even with empty message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogInfo_LongMessage(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
longMessage := strings.Repeat("A", 1000)
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo(longMessage)
|
||||
})
|
||||
|
||||
if !strings.Contains(output, longMessage) {
|
||||
t.Error("Expected long message to be logged completely")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWarn_SpecialCharacters(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
specialMsg := "Warning: \n\t special characters & symbols!"
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogWarn(specialMsg)
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "WARNING:") {
|
||||
t.Error("Expected WARNING prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_MultilineMessage(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
multilineMsg := "Line 1\nLine 2\nLine 3"
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogError(nil, multilineMsg)
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR:") {
|
||||
t.Error("Expected ERROR prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogInfo_Canary(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "canary")
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("Canary info message")
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "INFO:") {
|
||||
t.Error("Expected INFO prefix in canary environment")
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "Canary info message") {
|
||||
t.Error("Expected message to be logged in canary environment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogEnvironmentCheck(t *testing.T) {
|
||||
validEnvironments := []string{"development", "debug", "production", "canary"}
|
||||
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
for _, env := range validEnvironments {
|
||||
t.Run(env, func(t *testing.T) {
|
||||
os.Setenv("GO_ENV", env)
|
||||
|
||||
output := captureLogOutput(func() {
|
||||
LogInfo("Test message")
|
||||
})
|
||||
|
||||
if output == "" {
|
||||
t.Errorf("Expected log output for environment: %s", env)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogError_WithAndWithoutError(t *testing.T) {
|
||||
originalEnv := os.Getenv("GO_ENV")
|
||||
defer os.Setenv("GO_ENV", originalEnv)
|
||||
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
// With error
|
||||
output1 := captureLogOutput(func() {
|
||||
LogError(&testError{"actual error"}, "Context message")
|
||||
})
|
||||
|
||||
if !strings.Contains(output1, "actual error") {
|
||||
t.Error("Expected error details when error provided")
|
||||
}
|
||||
|
||||
// Without error
|
||||
output2 := captureLogOutput(func() {
|
||||
LogError(nil, "Context message")
|
||||
})
|
||||
|
||||
if strings.Contains(output2, "<nil>") {
|
||||
t.Log("nil error is logged as <nil>")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogInfo(b *testing.B) {
|
||||
os.Setenv("GO_ENV", "development")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
// Discard log output for benchmarking
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
LogInfo("Benchmark message")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogWarn(b *testing.B) {
|
||||
os.Setenv("GO_ENV", "development")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
LogWarn("Benchmark warning")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogError(b *testing.B) {
|
||||
os.Setenv("GO_ENV", "development")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
testErr := &testError{"benchmark error"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
LogError(testErr, "Benchmark error message")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"authentication/models"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func ExtractEmailFromToken(tokenString string) (string, error) {
|
||||
// Remove "Bearer " prefix if it exists
|
||||
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
|
||||
|
||||
// Handle null/empty token cases
|
||||
if tokenString == "" || tokenString == "null" || tokenString == "undefined" {
|
||||
return "", errors.New("no valid token provided")
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
secretKey := os.Getenv("JWT_SECRET_KEY")
|
||||
if secretKey == "" {
|
||||
return nil, errors.New("JWT secret key not set")
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err == nil && token.Valid {
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok {
|
||||
if claims.Email != "" && strings.Contains(claims.Email, "@") {
|
||||
log.Printf("Successfully extracted email from AccessToken: %s", claims.Email)
|
||||
return claims.Email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If AccessToken parsing failed, try MapClaims for backward compatibility
|
||||
log.Printf("AccessToken parsing failed: %v, trying MapClaims fallback", err)
|
||||
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
secretKey := os.Getenv("JWT_SECRET_KEY")
|
||||
if secretKey == "" {
|
||||
return nil, errors.New("JWT secret key not set")
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("MapClaims parsing also failed: %v", err)
|
||||
return "", errors.New("invalid token signature")
|
||||
}
|
||||
|
||||
// Extract claims from MapClaims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", errors.New("invalid token claims")
|
||||
}
|
||||
|
||||
if email, ok := claims["email"].(string); ok && strings.Contains(email, "@") {
|
||||
log.Printf("Successfully extracted email from MapClaims: %s", email)
|
||||
return email, nil
|
||||
}
|
||||
|
||||
return "", errors.New("email not found in token")
|
||||
}
|
||||
@@ -0,0 +1,393 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"authentication/models"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func TestExtractEmailFromToken_ValidAccessToken(t *testing.T) {
|
||||
// Set up test environment
|
||||
secretKey := "test-secret-key-123"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
email := "test@example.com"
|
||||
|
||||
// Create valid AccessToken
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(secretKey))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
// Test extraction
|
||||
extractedEmail, err := ExtractEmailFromToken(tokenString)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_ValidMapClaims(t *testing.T) {
|
||||
secretKey := "test-secret-key-456"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
email := "mapuser@example.com"
|
||||
|
||||
// Create token with MapClaims
|
||||
claims := jwt.MapClaims{
|
||||
"email": email,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(secretKey))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
// Test extraction
|
||||
extractedEmail, err := ExtractEmailFromToken(tokenString)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_BearerPrefix(t *testing.T) {
|
||||
secretKey := "test-secret-bearer"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
email := "bearer@example.com"
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
SessionID: "session789",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(secretKey))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
// Test with Bearer prefix
|
||||
bearerToken := "Bearer " + tokenString
|
||||
extractedEmail, err := ExtractEmailFromToken(bearerToken)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_EmptyToken(t *testing.T) {
|
||||
os.Setenv("JWT_SECRET_KEY", "test-key")
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
testCases := []string{"", "null", "undefined"}
|
||||
|
||||
for _, tokenString := range testCases {
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for token '%s', got nil", tokenString)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no valid token provided") {
|
||||
t.Errorf("Expected 'no valid token provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_InvalidSignature(t *testing.T) {
|
||||
os.Setenv("JWT_SECRET_KEY", "correct-secret")
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Create token with wrong secret
|
||||
wrongSecret := "wrong-secret"
|
||||
claims := &models.AccessToken{
|
||||
Email: "test@example.com",
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(wrongSecret))
|
||||
|
||||
// Try to extract with different secret
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid signature")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid token signature") {
|
||||
t.Errorf("Expected 'invalid token signature' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_ExpiredToken(t *testing.T) {
|
||||
secretKey := "test-expired-key"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Create expired token
|
||||
claims := &models.AccessToken{
|
||||
Email: "expired@example.com",
|
||||
SessionID: "session999",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)), // Expired 1 hour ago
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
// Note: The ExtractEmailFromToken function doesn't validate expiration,
|
||||
// it relies on ParseWithClaims which may or may not enforce expiration
|
||||
// depending on jwt library version. We'll just verify it can extract the email.
|
||||
extractedEmail, _ := ExtractEmailFromToken(tokenString)
|
||||
if extractedEmail != "expired@example.com" {
|
||||
t.Logf("Extracted email from expired token: %s", extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_NoEmailInClaims(t *testing.T) {
|
||||
secretKey := "test-no-email"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Create token without email
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": "user123",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for token without email")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "email not found in token") {
|
||||
t.Errorf("Expected 'email not found in token' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_InvalidEmailFormat(t *testing.T) {
|
||||
secretKey := "test-invalid-email"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Create token with invalid email (no @ symbol)
|
||||
claims := &models.AccessToken{
|
||||
Email: "notanemail",
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid email format")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_NoSecretKey(t *testing.T) {
|
||||
// Ensure no secret key is set
|
||||
os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: "test@example.com",
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("any-key"))
|
||||
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error when secret key not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_WrongSigningMethod(t *testing.T) {
|
||||
secretKey := "test-wrong-method"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Try to create token with non-HMAC signing method (would need RSA keys in real scenario)
|
||||
// For simplicity, we'll create a malformed token string
|
||||
malformedToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.invalid"
|
||||
|
||||
_, err := ExtractEmailFromToken(malformedToken)
|
||||
if err == nil {
|
||||
t.Error("Expected error for wrong signing method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_MalformedToken(t *testing.T) {
|
||||
os.Setenv("JWT_SECRET_KEY", "test-key")
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
malformedTokens := []string{
|
||||
"not.a.token",
|
||||
"invalid",
|
||||
"Bearer invalid",
|
||||
"...",
|
||||
"a.b",
|
||||
}
|
||||
|
||||
for _, tokenString := range malformedTokens {
|
||||
_, err := ExtractEmailFromToken(tokenString)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for malformed token '%s'", tokenString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_MultipleAtSymbols(t *testing.T) {
|
||||
secretKey := "test-multiple-at"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
email := "user@sub@example.com"
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
// Should extract successfully (just checks for @ presence)
|
||||
extractedEmail, err := ExtractEmailFromToken(tokenString)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_WhitespaceEmail(t *testing.T) {
|
||||
secretKey := "test-whitespace"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
// Email with spaces (should still work if it has @)
|
||||
email := " user@example.com "
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"email": email,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
extractedEmail, err := ExtractEmailFromToken(tokenString)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmailFromToken_CaseInsensitiveBearer(t *testing.T) {
|
||||
secretKey := "test-case-bearer"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
email := "case@example.com"
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
SessionID: "session123",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
// Test with standard "Bearer " prefix
|
||||
bearerToken := "Bearer " + tokenString
|
||||
extractedEmail, err := ExtractEmailFromToken(bearerToken)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for Bearer prefix, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedEmail != email {
|
||||
t.Errorf("Expected email '%s', got '%s'", email, extractedEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractEmailFromToken(b *testing.B) {
|
||||
secretKey := "benchmark-secret"
|
||||
os.Setenv("JWT_SECRET_KEY", secretKey)
|
||||
defer os.Unsetenv("JWT_SECRET_KEY")
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: "bench@example.com",
|
||||
SessionID: "sessionBench",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secretKey))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ExtractEmailFromToken(tokenString)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package helper
|
||||
|
||||
// Role caching removed - authorization is handled by separate authz microservice
|
||||
@@ -0,0 +1,12 @@
|
||||
package helper
|
||||
|
||||
import "time"
|
||||
|
||||
func LoadAsiaManilaLocation() (*time.Location, error) {
|
||||
const AsiaManila = "Asia/Manila"
|
||||
location, err := time.LoadLocation(AsiaManila)
|
||||
if err != nil {
|
||||
location = time.FixedZone("Asia/Manila", 8*60*60)
|
||||
}
|
||||
return location, nil
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadAsiaManilaLocation(t *testing.T) {
|
||||
location, err := LoadAsiaManilaLocation()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if location == nil {
|
||||
t.Fatal("Expected location to not be nil")
|
||||
}
|
||||
|
||||
// Check location name
|
||||
locationName := location.String()
|
||||
if locationName != "Asia/Manila" && locationName != "Local" {
|
||||
// "Local" is acceptable as fallback uses FixedZone
|
||||
t.Logf("Location name: %s (expected 'Asia/Manila' or 'Local')", locationName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationOffset(t *testing.T) {
|
||||
location, err := LoadAsiaManilaLocation()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Get current time in Asia/Manila
|
||||
now := time.Now().In(location)
|
||||
|
||||
// Asia/Manila is UTC+8 (28800 seconds)
|
||||
_, offset := now.Zone()
|
||||
|
||||
expectedOffset := 8 * 60 * 60 // 28800 seconds
|
||||
|
||||
if offset != expectedOffset {
|
||||
t.Errorf("Expected offset %d seconds (UTC+8), got %d seconds", expectedOffset, offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationNotNil(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
|
||||
if location == nil {
|
||||
t.Error("Location should never be nil due to fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationTimezone(t *testing.T) {
|
||||
location, err := LoadAsiaManilaLocation()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Create a specific time and check its formatting in Manila timezone
|
||||
testTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
manilaTime := testTime.In(location)
|
||||
|
||||
// Manila is UTC+8, so 12:00 UTC should be 20:00 in Manila
|
||||
expectedHour := 20
|
||||
if manilaTime.Hour() != expectedHour {
|
||||
t.Errorf("Expected hour %d in Manila time, got %d", expectedHour, manilaTime.Hour())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationFallback(t *testing.T) {
|
||||
// Even if timezone database is not available, function should not panic
|
||||
location, err := LoadAsiaManilaLocation()
|
||||
|
||||
if location == nil {
|
||||
t.Error("Location should not be nil even with fallback")
|
||||
}
|
||||
|
||||
// Error can be nil if LoadLocation succeeds
|
||||
// Error is not returned from FixedZone fallback
|
||||
if err != nil {
|
||||
t.Logf("Note: LoadLocation failed, using fallback: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationConsistency(t *testing.T) {
|
||||
// Call multiple times to ensure consistency
|
||||
location1, err1 := LoadAsiaManilaLocation()
|
||||
location2, err2 := LoadAsiaManilaLocation()
|
||||
|
||||
if (err1 == nil) != (err2 == nil) {
|
||||
t.Error("Inconsistent error returns")
|
||||
}
|
||||
|
||||
// Both should have same offset
|
||||
now := time.Now()
|
||||
time1 := now.In(location1)
|
||||
time2 := now.In(location2)
|
||||
|
||||
_, offset1 := time1.Zone()
|
||||
_, offset2 := time2.Zone()
|
||||
|
||||
if offset1 != offset2 {
|
||||
t.Errorf("Inconsistent offsets: %d vs %d", offset1, offset2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationUTCConversion(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
|
||||
utcTime := time.Date(2025, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
manilaTime := utcTime.In(location)
|
||||
|
||||
// Manila is UTC+8
|
||||
expectedHour := 18 // 10 + 8
|
||||
if manilaTime.Hour() != expectedHour {
|
||||
t.Errorf("Expected hour %d, got %d", expectedHour, manilaTime.Hour())
|
||||
}
|
||||
|
||||
// Minute and second should be the same
|
||||
if manilaTime.Minute() != 30 {
|
||||
t.Errorf("Expected minute 30, got %d", manilaTime.Minute())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationDSTHandling(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
|
||||
// Philippines doesn't observe DST, so offset should be constant throughout the year
|
||||
|
||||
// Test summer time
|
||||
summerTime := time.Date(2025, 7, 1, 12, 0, 0, 0, location)
|
||||
_, summerOffset := summerTime.Zone()
|
||||
|
||||
// Test winter time
|
||||
winterTime := time.Date(2025, 1, 1, 12, 0, 0, 0, location)
|
||||
_, winterOffset := winterTime.Zone()
|
||||
|
||||
if summerOffset != winterOffset {
|
||||
t.Errorf("Philippines should not have DST. Summer offset %d != Winter offset %d", summerOffset, winterOffset)
|
||||
}
|
||||
|
||||
// Both should be UTC+8
|
||||
expectedOffset := 8 * 60 * 60
|
||||
if summerOffset != expectedOffset {
|
||||
t.Errorf("Expected offset %d, got %d", expectedOffset, summerOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationFormatting(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
|
||||
now := time.Now().In(location)
|
||||
formatted := now.Format("2006-01-02 15:04:05 MST")
|
||||
|
||||
if formatted == "" {
|
||||
t.Error("Formatted time should not be empty")
|
||||
}
|
||||
|
||||
// Should contain timezone information
|
||||
if !containsTimeZone(formatted) {
|
||||
t.Logf("Formatted time: %s (timezone info may vary)", formatted)
|
||||
}
|
||||
}
|
||||
|
||||
func containsTimeZone(s string) bool {
|
||||
// Simple check for common timezone formats
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
func BenchmarkLoadAsiaManilaLocation(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
LoadAsiaManilaLocation()
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAsiaManilaLocationReusability(t *testing.T) {
|
||||
location, _ := LoadAsiaManilaLocation()
|
||||
|
||||
// Use the location multiple times
|
||||
for i := 0; i < 100; i++ {
|
||||
now := time.Now().In(location)
|
||||
if now.Location() != location {
|
||||
t.Error("Time location should match provided location")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"authentication/redisclient"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultRedisTTLSeconds = 60
|
||||
|
||||
func SetJSON(ctx context.Context, key string, value interface{}, ttlSeconds *int) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl := time.Duration(defaultRedisTTLSeconds) * time.Second
|
||||
if ttlSeconds != nil {
|
||||
ttl = time.Duration(*ttlSeconds) * time.Second
|
||||
}
|
||||
return redisclient.RDB.Set(ctx, key, data, ttl).Err()
|
||||
}
|
||||
func SlotSetJSON(ctx context.Context, key string, value interface{}, ttlSeconds *int) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl := time.Duration(0)
|
||||
if ttlSeconds != nil {
|
||||
ttl = time.Duration(*ttlSeconds) * time.Second
|
||||
}
|
||||
return redisclient.RDB.Set(ctx, key, data, ttl).Err()
|
||||
}
|
||||
|
||||
func GetJSON(ctx context.Context, key string, dest interface{}) error {
|
||||
val, err := redisclient.RDB.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal([]byte(val), dest)
|
||||
}
|
||||
|
||||
func SetTTL(ctx context.Context, key string, ttlSeconds *int) error {
|
||||
ttl := time.Duration(defaultRedisTTLSeconds) * time.Second
|
||||
if ttlSeconds != nil {
|
||||
ttl = time.Duration(*ttlSeconds) * time.Second
|
||||
}
|
||||
res, err := redisclient.RDB.Expire(ctx, key, ttl).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
return errors.New("failed to set TTL: key does not exist")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"authentication/redisclient"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a mock Redis server for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, func()) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start miniredis: %v", err)
|
||||
}
|
||||
|
||||
// Save original client
|
||||
originalRDB := redisclient.RDB
|
||||
|
||||
// Create test client
|
||||
redisclient.RDB = redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
redisclient.RDB.Close()
|
||||
redisclient.RDB = originalRDB
|
||||
mr.Close()
|
||||
}
|
||||
|
||||
return mr, cleanup
|
||||
}
|
||||
|
||||
func TestSetJSON(t *testing.T) {
|
||||
mr, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testData := map[string]interface{}{
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"age": 30,
|
||||
}
|
||||
|
||||
// Test with default TTL
|
||||
err := SetJSON(ctx, "test:user:1", testData, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was set
|
||||
val, err := redisclient.RDB.Get(ctx, "test:user:1").Result()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get value: %v", err)
|
||||
}
|
||||
|
||||
var retrieved map[string]interface{}
|
||||
err = json.Unmarshal([]byte(val), &retrieved)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if retrieved["name"] != testData["name"] {
|
||||
t.Errorf("Expected name '%v', got '%v'", testData["name"], retrieved["name"])
|
||||
}
|
||||
|
||||
// Verify TTL was set (miniredis returns TTL)
|
||||
ttl := mr.TTL("test:user:1")
|
||||
if ttl <= 0 {
|
||||
t.Error("Expected TTL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetJSON_CustomTTL(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testData := map[string]string{"key": "value"}
|
||||
customTTL := 120
|
||||
|
||||
err := SetJSON(ctx, "test:custom:ttl", testData, &customTTL)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was set
|
||||
val, err := redisclient.RDB.Get(ctx, "test:custom:ttl").Result()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get value: %v", err)
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
t.Error("Expected value to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotSetJSON(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testData := map[string]int{"count": 42}
|
||||
|
||||
// Test with no TTL
|
||||
err := SlotSetJSON(ctx, "test:slot:1", testData, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was set
|
||||
val, err := redisclient.RDB.Get(ctx, "test:slot:1").Result()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get value: %v", err)
|
||||
}
|
||||
|
||||
var retrieved map[string]int
|
||||
err = json.Unmarshal([]byte(val), &retrieved)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if retrieved["count"] != testData["count"] {
|
||||
t.Errorf("Expected count %d, got %d", testData["count"], retrieved["count"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlotSetJSON_WithTTL(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
testData := []string{"item1", "item2"}
|
||||
ttl := 300
|
||||
|
||||
err := SlotSetJSON(ctx, "test:slot:ttl", testData, &ttl)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify data exists
|
||||
exists, err := redisclient.RDB.Exists(ctx, "test:slot:ttl").Result()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to check existence: %v", err)
|
||||
}
|
||||
|
||||
if exists != 1 {
|
||||
t.Error("Expected key to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJSON(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
type TestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
original := TestStruct{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Age: 25,
|
||||
}
|
||||
|
||||
// Set data
|
||||
err := SetJSON(ctx, "test:user:get", original, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set JSON: %v", err)
|
||||
}
|
||||
|
||||
// Get data
|
||||
var retrieved TestStruct
|
||||
err = GetJSON(ctx, "test:user:get", &retrieved)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.Name != original.Name {
|
||||
t.Errorf("Expected name '%s', got '%s'", original.Name, retrieved.Name)
|
||||
}
|
||||
|
||||
if retrieved.Email != original.Email {
|
||||
t.Errorf("Expected email '%s', got '%s'", original.Email, retrieved.Email)
|
||||
}
|
||||
|
||||
if retrieved.Age != original.Age {
|
||||
t.Errorf("Expected age %d, got %d", original.Age, retrieved.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJSON_NonExistentKey(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var result map[string]string
|
||||
err := GetJSON(ctx, "test:nonexistent", &result)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJSON_InvalidJSON(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set invalid JSON
|
||||
err := redisclient.RDB.Set(ctx, "test:invalid", "not valid json", time.Minute).Err()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set invalid data: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
err = GetJSON(ctx, "test:invalid", &result)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTTL(t *testing.T) {
|
||||
mr, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set initial data
|
||||
err := redisclient.RDB.Set(ctx, "test:ttl:key", "value", 0).Err()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set initial data: %v", err)
|
||||
}
|
||||
|
||||
// Update TTL
|
||||
ttl := 300
|
||||
err = SetTTL(ctx, "test:ttl:key", &ttl)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify TTL was set
|
||||
actualTTL := mr.TTL("test:ttl:key")
|
||||
if actualTTL <= 0 {
|
||||
t.Error("Expected TTL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTTL_DefaultTTL(t *testing.T) {
|
||||
mr, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set initial data
|
||||
err := redisclient.RDB.Set(ctx, "test:ttl:default", "value", 0).Err()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set initial data: %v", err)
|
||||
}
|
||||
|
||||
// Set default TTL
|
||||
err = SetTTL(ctx, "test:ttl:default", nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify TTL was set
|
||||
actualTTL := mr.TTL("test:ttl:default")
|
||||
if actualTTL <= 0 {
|
||||
t.Error("Expected default TTL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTTL_NonExistentKey(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := SetTTL(ctx, "test:ttl:nonexistent", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent key")
|
||||
}
|
||||
|
||||
expectedMsg := "failed to set TTL: key does not exist"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetJSON_MarshalError(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Channel cannot be marshaled to JSON
|
||||
invalidData := make(chan int)
|
||||
|
||||
err := SetJSON(ctx, "test:invalid", invalidData, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unmarshalable data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRedisTTLSeconds(t *testing.T) {
|
||||
if defaultRedisTTLSeconds != 60 {
|
||||
t.Errorf("Expected default TTL 60 seconds, got %d", defaultRedisTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetJSON_RoundTrip(t *testing.T) {
|
||||
_, cleanup := setupTestRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
type ComplexStruct struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
original := ComplexStruct{
|
||||
ID: "123",
|
||||
Name: "Test",
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
Metadata: map[string]interface{}{"key": "value", "count": 5},
|
||||
Active: true,
|
||||
}
|
||||
|
||||
// Set
|
||||
err := SetJSON(ctx, "test:complex", original, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set: %v", err)
|
||||
}
|
||||
|
||||
// Get
|
||||
var retrieved ComplexStruct
|
||||
err = GetJSON(ctx, "test:complex", &retrieved)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get: %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
if retrieved.ID != original.ID {
|
||||
t.Errorf("ID mismatch")
|
||||
}
|
||||
if retrieved.Name != original.Name {
|
||||
t.Errorf("Name mismatch")
|
||||
}
|
||||
if len(retrieved.Tags) != len(original.Tags) {
|
||||
t.Errorf("Tags length mismatch")
|
||||
}
|
||||
if retrieved.Active != original.Active {
|
||||
t.Errorf("Active mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSetJSON(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to start miniredis: %v", err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
originalRDB := redisclient.RDB
|
||||
redisclient.RDB = redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
defer func() {
|
||||
redisclient.RDB.Close()
|
||||
redisclient.RDB = originalRDB
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := map[string]string{"key": "value"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SetJSON(ctx, "bench:key", testData, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetJSON(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to start miniredis: %v", err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
originalRDB := redisclient.RDB
|
||||
redisclient.RDB = redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
defer func() {
|
||||
redisclient.RDB.Close()
|
||||
redisclient.RDB = originalRDB
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := map[string]string{"key": "value"}
|
||||
SetJSON(ctx, "bench:key", testData, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
var result map[string]string
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetJSON(ctx, "bench:key", &result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func RespondWithError(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set(ContentTypeHeader, ApplicationJSON)
|
||||
w.WriteHeader(statusCode)
|
||||
if encodeErr := json.NewEncoder(w).Encode(map[string]string{ErrorLabel: message}); encodeErr != nil {
|
||||
LogError(encodeErr, ErrorEncodingResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func RespondWithMessage(w http.ResponseWriter, message string) {
|
||||
if encodeErr := json.NewEncoder(w).Encode(map[string]string{MessageLabel: message}); encodeErr != nil {
|
||||
LogError(encodeErr, ErrorEncodingResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func RespondWithJSON(w http.ResponseWriter, statusCode int, data interface{}) {
|
||||
w.Header().Set(ContentTypeHeader, ApplicationJSON)
|
||||
w.WriteHeader(statusCode)
|
||||
if encodeErr := json.NewEncoder(w).Encode(data); encodeErr != nil {
|
||||
LogError(encodeErr, ErrorEncodingResponse)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRespondWithError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
}{
|
||||
{"Bad Request", http.StatusBadRequest, "Invalid input"},
|
||||
{"Unauthorized", http.StatusUnauthorized, "Not authenticated"},
|
||||
{"Forbidden", http.StatusForbidden, "Access denied"},
|
||||
{"Not Found", http.StatusNotFound, "Resource not found"},
|
||||
{"Internal Error", http.StatusInternalServerError, "Server error"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithError(recorder, tc.statusCode, tc.message)
|
||||
|
||||
// Check status code
|
||||
if recorder.Code != tc.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", tc.statusCode, recorder.Code)
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
|
||||
}
|
||||
|
||||
// Parse response body
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
// Check error message
|
||||
if response["error"] != tc.message {
|
||||
t.Errorf("Expected error message '%s', got '%s'", tc.message, response["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithMessage(t *testing.T) {
|
||||
testCases := []string{
|
||||
"Operation successful",
|
||||
"User created",
|
||||
"Email sent",
|
||||
"Task completed",
|
||||
}
|
||||
|
||||
for _, message := range testCases {
|
||||
t.Run(message, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithMessage(recorder, message)
|
||||
|
||||
// Check status code (should default to 200)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code)
|
||||
}
|
||||
|
||||
// Parse response body
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
// Check message
|
||||
if response["message"] != message {
|
||||
t.Errorf("Expected message '%s', got '%s'", message, response["message"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
data interface{}
|
||||
}{
|
||||
{
|
||||
name: "Simple object",
|
||||
statusCode: http.StatusOK,
|
||||
data: map[string]string{"key": "value"},
|
||||
},
|
||||
{
|
||||
name: "Array",
|
||||
statusCode: http.StatusOK,
|
||||
data: []string{"item1", "item2", "item3"},
|
||||
},
|
||||
{
|
||||
name: "Nested object",
|
||||
statusCode: http.StatusCreated,
|
||||
data: map[string]interface{}{
|
||||
"user": map[string]string{
|
||||
"name": "John",
|
||||
"email": "john@example.com",
|
||||
},
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Number",
|
||||
statusCode: http.StatusOK,
|
||||
data: map[string]int{"count": 42},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, tc.statusCode, tc.data)
|
||||
|
||||
// Check status code
|
||||
if recorder.Code != tc.statusCode {
|
||||
t.Errorf("Expected status code %d, got %d", tc.statusCode, recorder.Code)
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
|
||||
}
|
||||
|
||||
// Verify response can be parsed as JSON
|
||||
var response interface{}
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response as JSON: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithErrorEmptyMessage(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithError(recorder, http.StatusBadRequest, "")
|
||||
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if _, exists := response["error"]; !exists {
|
||||
t.Error("Response should contain 'error' key even with empty message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithJSONNilData(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, http.StatusOK, nil)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, recorder.Code)
|
||||
}
|
||||
|
||||
body := recorder.Body.String()
|
||||
if body != "null\n" {
|
||||
t.Errorf("Expected 'null', got '%s'", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithErrorStatusCodes(t *testing.T) {
|
||||
statusCodes := []int{
|
||||
http.StatusBadRequest,
|
||||
http.StatusUnauthorized,
|
||||
http.StatusForbidden,
|
||||
http.StatusNotFound,
|
||||
http.StatusMethodNotAllowed,
|
||||
http.StatusConflict,
|
||||
http.StatusUnprocessableEntity,
|
||||
http.StatusTooManyRequests,
|
||||
http.StatusInternalServerError,
|
||||
http.StatusServiceUnavailable,
|
||||
}
|
||||
|
||||
for _, code := range statusCodes {
|
||||
t.Run(http.StatusText(code), func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithError(recorder, code, "Test error")
|
||||
|
||||
if recorder.Code != code {
|
||||
t.Errorf("Expected status code %d, got %d", code, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithJSONComplex(t *testing.T) {
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Roles []string `json:"roles"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
user := User{
|
||||
ID: 123,
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
Roles: []string{"admin", "user"},
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, http.StatusOK, user)
|
||||
|
||||
var decoded User
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if decoded.ID != user.ID {
|
||||
t.Errorf("Expected ID %d, got %d", user.ID, decoded.ID)
|
||||
}
|
||||
|
||||
if decoded.Name != user.Name {
|
||||
t.Errorf("Expected Name '%s', got '%s'", user.Name, decoded.Name)
|
||||
}
|
||||
|
||||
if decoded.Email != user.Email {
|
||||
t.Errorf("Expected Email '%s', got '%s'", user.Email, decoded.Email)
|
||||
}
|
||||
|
||||
if len(decoded.Roles) != len(user.Roles) {
|
||||
t.Errorf("Expected %d roles, got %d", len(user.Roles), len(decoded.Roles))
|
||||
}
|
||||
|
||||
if decoded.IsActive != user.IsActive {
|
||||
t.Errorf("Expected IsActive %v, got %v", user.IsActive, decoded.IsActive)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithJSONArray(t *testing.T) {
|
||||
data := []map[string]string{
|
||||
{"id": "1", "name": "Item 1"},
|
||||
{"id": "2", "name": "Item 2"},
|
||||
{"id": "3", "name": "Item 3"},
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, http.StatusOK, data)
|
||||
|
||||
var decoded []map[string]string
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded) != len(data) {
|
||||
t.Errorf("Expected %d items, got %d", len(data), len(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseHeadersSet(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, http.StatusOK, map[string]string{"test": "data"})
|
||||
|
||||
// Verify Content-Type is set
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType == "" {
|
||||
t.Error("Content-Type header should be set")
|
||||
}
|
||||
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRespondWithError(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithError(recorder, http.StatusBadRequest, "Test error")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRespondWithMessage(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithMessage(recorder, "Test message")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRespondWithJSON(b *testing.B) {
|
||||
data := map[string]interface{}{
|
||||
"id": 123,
|
||||
"name": "Test",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
recorder := httptest.NewRecorder()
|
||||
RespondWithJSON(recorder, http.StatusOK, data)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func CalculateSHA256(data string) string {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(data))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// ChecksumFormFile calculates the SHA256 checksum of a multipart form file.
|
||||
func CalculateSHA256FromBytes(data []byte) string {
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Sha256 returns the SHA256 hash of the input string as a hex string
|
||||
func Sha256(s string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(s))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCalculateSHA256(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple string",
|
||||
input: "hello",
|
||||
expected: "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824",
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
},
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "hello world",
|
||||
expected: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9",
|
||||
},
|
||||
{
|
||||
name: "Numeric string",
|
||||
input: "12345",
|
||||
expected: "5994471abb01112afcc18159f6cc74b4f511b99806da59b3caf5a9c173cacfc5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := CalculateSHA256(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||
}
|
||||
|
||||
// Verify it's always 64 characters (SHA256 hex)
|
||||
if len(result) != 64 {
|
||||
t.Errorf("Expected 64 character hash, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateSHA256FromBytes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Byte array",
|
||||
input: []byte("test data"),
|
||||
expected: "916f0027a575074ce72a331777c3478d6513f786a591bd892da1a577bf2335f9",
|
||||
},
|
||||
{
|
||||
name: "Empty byte array",
|
||||
input: []byte{},
|
||||
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
},
|
||||
{
|
||||
name: "Binary data",
|
||||
input: []byte{0x00, 0x01, 0x02, 0xFF},
|
||||
expected: "3d1f57c984978ef98a18378c8166c1cb8ede02c03eeb6aee7e2f121dfeee3e56",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := CalculateSHA256FromBytes(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||
}
|
||||
|
||||
if len(result) != 64 {
|
||||
t.Errorf("Expected 64 character hash, got %d", len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSha256(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"Simple", "password123"},
|
||||
{"Empty", ""},
|
||||
{"Complex", "P@ssw0rd!#$%^&*()"},
|
||||
{"Unicode", "こんにちは"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := Sha256(tc.input)
|
||||
|
||||
// Should return 64 character hex string
|
||||
if len(result) != 64 {
|
||||
t.Errorf("Expected 64 character hash, got %d", len(result))
|
||||
}
|
||||
|
||||
// Should be lowercase hex
|
||||
if result != strings.ToLower(result) {
|
||||
t.Error("Expected lowercase hex string")
|
||||
}
|
||||
|
||||
// Should be deterministic
|
||||
result2 := Sha256(tc.input)
|
||||
if result != result2 {
|
||||
t.Error("Hash should be deterministic")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256Consistency(t *testing.T) {
|
||||
input := "test consistency"
|
||||
|
||||
// All three functions should produce the same hash for the same input
|
||||
hash1 := CalculateSHA256(input)
|
||||
hash2 := CalculateSHA256FromBytes([]byte(input))
|
||||
hash3 := Sha256(input)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("CalculateSHA256 and CalculateSHA256FromBytes produced different results: %s vs %s", hash1, hash2)
|
||||
}
|
||||
|
||||
if hash1 != hash3 {
|
||||
t.Errorf("CalculateSHA256 and Sha256 produced different results: %s vs %s", hash1, hash3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256Uniqueness(t *testing.T) {
|
||||
inputs := []string{
|
||||
"password1",
|
||||
"password2",
|
||||
"password3",
|
||||
"different",
|
||||
"unique",
|
||||
}
|
||||
|
||||
hashes := make(map[string]bool)
|
||||
|
||||
for _, input := range inputs {
|
||||
hash := CalculateSHA256(input)
|
||||
|
||||
if hashes[hash] {
|
||||
t.Errorf("Collision detected for input: %s", input)
|
||||
}
|
||||
|
||||
hashes[hash] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256LongInput(t *testing.T) {
|
||||
// Test with very long input
|
||||
longInput := strings.Repeat("a", 10000)
|
||||
hash := CalculateSHA256(longInput)
|
||||
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("Expected 64 character hash for long input, got %d", len(hash))
|
||||
}
|
||||
|
||||
// Hash should be different from short input
|
||||
shortHash := CalculateSHA256("a")
|
||||
if hash == shortHash {
|
||||
t.Error("Long and short inputs should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256SpecialCharacters(t *testing.T) {
|
||||
specialInputs := []string{
|
||||
"\n\r\t",
|
||||
"spaces everywhere",
|
||||
"!@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"emoji 🔐🔑",
|
||||
}
|
||||
|
||||
for _, input := range specialInputs {
|
||||
hash := CalculateSHA256(input)
|
||||
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("Expected 64 character hash for input %q, got %d", input, len(hash))
|
||||
}
|
||||
|
||||
// Should be valid hex
|
||||
for _, char := range hash {
|
||||
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f')) {
|
||||
t.Errorf("Invalid hex character %c in hash", char)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCalculateSHA256(b *testing.B) {
|
||||
input := "benchmark test string"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
CalculateSHA256(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCalculateSHA256FromBytes(b *testing.B) {
|
||||
input := []byte("benchmark test string")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
CalculateSHA256FromBytes(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSha256(b *testing.B) {
|
||||
input := "benchmark test string"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
Sha256(input)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const IDCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
const IDLength = 11
|
||||
|
||||
func UUIDGenerator() string {
|
||||
ID := make([]byte, IDLength)
|
||||
for i := 0; i < IDLength; i++ {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(IDCharset))))
|
||||
if err != nil {
|
||||
panic(err) // Handle error appropriately in production code
|
||||
}
|
||||
ID[i] = IDCharset[num.Int64()]
|
||||
}
|
||||
return string(ID)
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUUIDGenerator(t *testing.T) {
|
||||
uuid := UUIDGenerator()
|
||||
|
||||
// Check length
|
||||
if len(uuid) != IDLength {
|
||||
t.Errorf("Expected UUID length %d, got %d", IDLength, len(uuid))
|
||||
}
|
||||
|
||||
// Check that it only contains valid characters
|
||||
for _, char := range uuid {
|
||||
if !strings.ContainsRune(IDCharset, char) {
|
||||
t.Errorf("Invalid character %c in UUID", char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorUniqueness(t *testing.T) {
|
||||
iterations := 1000
|
||||
uuids := make(map[string]bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
|
||||
if uuids[uuid] {
|
||||
t.Errorf("Duplicate UUID generated: %s", uuid)
|
||||
}
|
||||
|
||||
uuids[uuid] = true
|
||||
}
|
||||
|
||||
if len(uuids) != iterations {
|
||||
t.Errorf("Expected %d unique UUIDs, got %d", iterations, len(uuids))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorCharset(t *testing.T) {
|
||||
// Generate many UUIDs and verify all characters in charset are used
|
||||
iterations := 10000
|
||||
charCount := make(map[rune]int)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
for _, char := range uuid {
|
||||
charCount[char]++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we have a good distribution (not comprehensive but basic check)
|
||||
if len(charCount) < len(IDCharset)/2 {
|
||||
t.Errorf("Expected more character variety. Only %d out of %d characters used", len(charCount), len(IDCharset))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorNotEmpty(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
if uuid == "" {
|
||||
t.Error("Generated UUID should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorLength(t *testing.T) {
|
||||
// Verify length constant
|
||||
if IDLength != 11 {
|
||||
t.Errorf("Expected IDLength to be 11, got %d", IDLength)
|
||||
}
|
||||
|
||||
// Generate multiple and check they all have correct length
|
||||
for i := 0; i < 100; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
if len(uuid) != 11 {
|
||||
t.Errorf("Expected UUID length 11, got %d for UUID: %s", len(uuid), uuid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorCharsetContents(t *testing.T) {
|
||||
expected := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if IDCharset != expected {
|
||||
t.Errorf("IDCharset changed. Expected: %s, Got: %s", expected, IDCharset)
|
||||
}
|
||||
|
||||
if len(IDCharset) != 62 {
|
||||
t.Errorf("Expected IDCharset length 62, got %d", len(IDCharset))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorConcurrency(t *testing.T) {
|
||||
// Test concurrent UUID generation
|
||||
count := 1000
|
||||
uuids := make(chan string, count)
|
||||
|
||||
// Generate UUIDs concurrently
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
uuids <- UUIDGenerator()
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
results := make(map[string]bool)
|
||||
for i := 0; i < count; i++ {
|
||||
uuid := <-uuids
|
||||
if results[uuid] {
|
||||
t.Errorf("Duplicate UUID in concurrent generation: %s", uuid)
|
||||
}
|
||||
results[uuid] = true
|
||||
}
|
||||
|
||||
if len(results) != count {
|
||||
t.Errorf("Expected %d unique UUIDs, got %d", count, len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorNoSpecialCharacters(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
|
||||
// Check for common special characters that shouldn't be there
|
||||
specialChars := []string{"-", "_", ".", " ", "!", "@", "#", "$", "%", "^", "&", "*", "(", ")", "+", "="}
|
||||
for _, special := range specialChars {
|
||||
if strings.Contains(uuid, special) {
|
||||
t.Errorf("UUID contains special character %s: %s", special, uuid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorDistribution(t *testing.T) {
|
||||
// Generate many UUIDs and check character distribution is reasonable
|
||||
iterations := 10000
|
||||
positionCounts := make([]map[rune]int, IDLength)
|
||||
for i := range positionCounts {
|
||||
positionCounts[i] = make(map[rune]int)
|
||||
}
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
uuid := UUIDGenerator()
|
||||
for pos, char := range uuid {
|
||||
positionCounts[pos][char]++
|
||||
}
|
||||
}
|
||||
|
||||
// Each position should have multiple different characters
|
||||
for pos, counts := range positionCounts {
|
||||
if len(counts) < 10 {
|
||||
t.Errorf("Position %d has poor character variety: only %d different characters", pos, len(counts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUUIDGenerator(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
UUIDGenerator()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUUIDGeneratorParallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
UUIDGenerator()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUUIDGeneratorFormat(t *testing.T) {
|
||||
uuid := UUIDGenerator()
|
||||
|
||||
// Should not start or end with special characters
|
||||
firstChar := uuid[0]
|
||||
lastChar := uuid[len(uuid)-1]
|
||||
|
||||
if !strings.ContainsRune(IDCharset, rune(firstChar)) {
|
||||
t.Errorf("First character %c not in charset", firstChar)
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(IDCharset, rune(lastChar)) {
|
||||
t.Errorf("Last character %c not in charset", lastChar)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user