477 lines
12 KiB
Go
477 lines
12 KiB
Go
package db
|
|
|
|
import (
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
)
|
|
|
|
func TestInitDB_Success(t *testing.T) {
|
|
// Setup environment variables
|
|
os.Setenv("DB_USER", "testuser")
|
|
os.Setenv("DB_PASSWORD", "testpass")
|
|
os.Setenv("DB_HOST", "localhost")
|
|
os.Setenv("DB_PORT", "3306")
|
|
os.Setenv("DB_NAME", "testdb")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
// Note: This test would require a real database connection or more sophisticated mocking
|
|
// For unit tests, we'll verify the connection string format instead
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
expectedConnStr := "testuser:testpass@tcp(localhost:3306)/testdb?parseTime=true"
|
|
|
|
if connStr != expectedConnStr {
|
|
t.Errorf("Expected connection string '%s', got '%s'", expectedConnStr, connStr)
|
|
}
|
|
}
|
|
|
|
func TestInitDB_ConnectionParameters(t *testing.T) {
|
|
// Test that InitDB would set correct connection pool parameters
|
|
// We can't easily test this without a real DB, so we document expected values
|
|
|
|
expectedMaxOpenConns := 25
|
|
expectedMaxIdleConns := 10
|
|
|
|
// These values should match what's in db.go
|
|
if expectedMaxOpenConns != 25 {
|
|
t.Errorf("Expected MaxOpenConns 25, configuration might have changed")
|
|
}
|
|
if expectedMaxIdleConns != 10 {
|
|
t.Errorf("Expected MaxIdleConns 10, configuration might have changed")
|
|
}
|
|
}
|
|
|
|
func TestInitDB_EnvironmentVariables(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
dbUser string
|
|
dbPass string
|
|
dbHost string
|
|
dbPort string
|
|
dbName string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "Standard configuration",
|
|
dbUser: "user",
|
|
dbPass: "pass",
|
|
dbHost: "localhost",
|
|
dbPort: "3306",
|
|
dbName: "mydb",
|
|
expected: "user:pass@tcp(localhost:3306)/mydb?parseTime=true",
|
|
},
|
|
{
|
|
name: "Remote database",
|
|
dbUser: "admin",
|
|
dbPass: "secret",
|
|
dbHost: "db.example.com",
|
|
dbPort: "3307",
|
|
dbName: "production",
|
|
expected: "admin:secret@tcp(db.example.com:3307)/production?parseTime=true",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
os.Setenv("DB_USER", tt.dbUser)
|
|
os.Setenv("DB_PASSWORD", tt.dbPass)
|
|
os.Setenv("DB_HOST", tt.dbHost)
|
|
os.Setenv("DB_PORT", tt.dbPort)
|
|
os.Setenv("DB_NAME", tt.dbName)
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
if connStr != tt.expected {
|
|
t.Errorf("Expected '%s', got '%s'", tt.expected, connStr)
|
|
}
|
|
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDBConnection_MockPing(t *testing.T) {
|
|
// Create a mock database
|
|
mockDB, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mock database: %v", err)
|
|
}
|
|
defer mockDB.Close()
|
|
|
|
// Store original DB and replace with mock
|
|
originalDB := DB
|
|
DB = mockDB
|
|
defer func() { DB = originalDB }()
|
|
|
|
// Expect a ping
|
|
mock.ExpectPing()
|
|
|
|
// Test ping
|
|
err = DB.Ping()
|
|
if err != nil {
|
|
t.Errorf("Expected no error on ping, got %v", err)
|
|
}
|
|
|
|
// Verify expectations
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("Unfulfilled expectations: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDBConnection_PingFailure(t *testing.T) {
|
|
// Create a mock database
|
|
mockDB, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mock database: %v", err)
|
|
}
|
|
defer mockDB.Close()
|
|
|
|
// Store original DB and replace with mock
|
|
originalDB := DB
|
|
DB = mockDB
|
|
defer func() { DB = originalDB }()
|
|
|
|
// Expect a ping to fail
|
|
mock.ExpectPing().WillReturnError(sqlmock.ErrCancelled)
|
|
|
|
// Test ping
|
|
err = DB.Ping()
|
|
if err == nil {
|
|
t.Error("Expected error on failed ping")
|
|
}
|
|
}
|
|
|
|
func TestDBGlobalVariable(t *testing.T) {
|
|
// Test that the global DB variable exists and can be set
|
|
originalDB := DB
|
|
defer func() { DB = originalDB }()
|
|
|
|
// Create a mock database
|
|
mockDB, _, err := sqlmock.New()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mock database: %v", err)
|
|
}
|
|
defer mockDB.Close()
|
|
|
|
// Set the global variable
|
|
DB = mockDB
|
|
|
|
if DB != mockDB {
|
|
t.Error("Failed to set global DB variable")
|
|
}
|
|
}
|
|
|
|
func TestConnectionString_ParseTime(t *testing.T) {
|
|
// Verify parseTime is included in connection string
|
|
os.Setenv("DB_USER", "user")
|
|
os.Setenv("DB_PASSWORD", "pass")
|
|
os.Setenv("DB_HOST", "localhost")
|
|
os.Setenv("DB_PORT", "3306")
|
|
os.Setenv("DB_NAME", "db")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
if connStr[len(connStr)-15:] != "?parseTime=true" {
|
|
t.Error("Connection string should end with '?parseTime=true'")
|
|
}
|
|
}
|
|
|
|
// Additional comprehensive test cases
|
|
|
|
func TestConnectionString_SpecialCharacters(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
user string
|
|
pass string
|
|
expected string
|
|
}{
|
|
{
|
|
"Password with special chars",
|
|
"user",
|
|
"p@ss!word",
|
|
"user:p@ss!word@tcp(localhost:3306)/testdb?parseTime=true",
|
|
},
|
|
{
|
|
"Username with underscore",
|
|
"test_user",
|
|
"password",
|
|
"test_user:password@tcp(localhost:3306)/testdb?parseTime=true",
|
|
},
|
|
{
|
|
"Complex password",
|
|
"admin",
|
|
"P@ssw0rd!#$",
|
|
"admin:P@ssw0rd!#$@tcp(localhost:3306)/testdb?parseTime=true",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
os.Setenv("DB_USER", tc.user)
|
|
os.Setenv("DB_PASSWORD", tc.pass)
|
|
os.Setenv("DB_HOST", "localhost")
|
|
os.Setenv("DB_PORT", "3306")
|
|
os.Setenv("DB_NAME", "testdb")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
if connStr != tc.expected {
|
|
t.Errorf("Expected %q, got %q", tc.expected, connStr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectionString_EmptyValues(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
vars map[string]string
|
|
}{
|
|
{
|
|
"Empty user",
|
|
map[string]string{
|
|
"DB_USER": "",
|
|
"DB_PASSWORD": "pass",
|
|
"DB_HOST": "localhost",
|
|
"DB_PORT": "3306",
|
|
"DB_NAME": "testdb",
|
|
},
|
|
},
|
|
{
|
|
"Empty password",
|
|
map[string]string{
|
|
"DB_USER": "user",
|
|
"DB_PASSWORD": "",
|
|
"DB_HOST": "localhost",
|
|
"DB_PORT": "3306",
|
|
"DB_NAME": "testdb",
|
|
},
|
|
},
|
|
{
|
|
"Empty database name",
|
|
map[string]string{
|
|
"DB_USER": "user",
|
|
"DB_PASSWORD": "pass",
|
|
"DB_HOST": "localhost",
|
|
"DB_PORT": "3306",
|
|
"DB_NAME": "",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
for key, val := range tc.vars {
|
|
os.Setenv(key, val)
|
|
}
|
|
defer func() {
|
|
for key := range tc.vars {
|
|
os.Unsetenv(key)
|
|
}
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
// Connection string should still be formed, even if invalid
|
|
if len(connStr) == 0 {
|
|
t.Error("Connection string should not be empty")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectionString_DifferentPorts(t *testing.T) {
|
|
ports := []string{"3306", "3307", "13306", "33060"}
|
|
|
|
for _, port := range ports {
|
|
t.Run("Port: "+port, func(t *testing.T) {
|
|
os.Setenv("DB_USER", "user")
|
|
os.Setenv("DB_PASSWORD", "pass")
|
|
os.Setenv("DB_HOST", "localhost")
|
|
os.Setenv("DB_PORT", port)
|
|
os.Setenv("DB_NAME", "testdb")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
expected := "user:pass@tcp(localhost:" + port + ")/testdb?parseTime=true"
|
|
if connStr != expected {
|
|
t.Errorf("Expected %q, got %q", expected, connStr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectionString_DifferentHosts(t *testing.T) {
|
|
hosts := []string{
|
|
"localhost",
|
|
"127.0.0.1",
|
|
"db.example.com",
|
|
"192.168.1.100",
|
|
"mysql-server.local",
|
|
}
|
|
|
|
for _, host := range hosts {
|
|
t.Run("Host: "+host, func(t *testing.T) {
|
|
os.Setenv("DB_USER", "user")
|
|
os.Setenv("DB_PASSWORD", "pass")
|
|
os.Setenv("DB_HOST", host)
|
|
os.Setenv("DB_PORT", "3306")
|
|
os.Setenv("DB_NAME", "testdb")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
expected := "user:pass@tcp(" + host + ":3306)/testdb?parseTime=true"
|
|
if connStr != expected {
|
|
t.Errorf("Expected %q, got %q", expected, connStr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMockDB_BasicOperations(t *testing.T) {
|
|
mockDB, mock, err := sqlmock.New()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mock DB: %v", err)
|
|
}
|
|
defer mockDB.Close()
|
|
|
|
// Test ping
|
|
mock.ExpectPing()
|
|
if err := mockDB.Ping(); err != nil {
|
|
t.Errorf("Ping failed: %v", err)
|
|
}
|
|
|
|
// Verify expectations
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("Expectations not met: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestMockDB_QueryExecution(t *testing.T) {
|
|
mockDB, mock, err := sqlmock.New()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mock DB: %v", err)
|
|
}
|
|
defer mockDB.Close()
|
|
|
|
rows := sqlmock.NewRows([]string{"id", "name"}).
|
|
AddRow(1, "test")
|
|
|
|
mock.ExpectQuery("SELECT id, name FROM test_table").
|
|
WillReturnRows(rows)
|
|
|
|
rows2, err := mockDB.Query("SELECT id, name FROM test_table")
|
|
if err != nil {
|
|
t.Errorf("Query failed: %v", err)
|
|
}
|
|
defer rows2.Close()
|
|
|
|
if !rows2.Next() {
|
|
t.Error("Expected at least one row")
|
|
}
|
|
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("Expectations not met: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectionString_VeryLongValues(t *testing.T) {
|
|
longString := string(make([]byte, 1000))
|
|
for i := range longString {
|
|
longString = longString[:i] + "a" + longString[i+1:]
|
|
}
|
|
|
|
os.Setenv("DB_USER", longString)
|
|
os.Setenv("DB_PASSWORD", "pass")
|
|
os.Setenv("DB_HOST", "localhost")
|
|
os.Setenv("DB_PORT", "3306")
|
|
os.Setenv("DB_NAME", "testdb")
|
|
defer func() {
|
|
os.Unsetenv("DB_USER")
|
|
os.Unsetenv("DB_PASSWORD")
|
|
os.Unsetenv("DB_HOST")
|
|
os.Unsetenv("DB_PORT")
|
|
os.Unsetenv("DB_NAME")
|
|
}()
|
|
|
|
connStr := os.Getenv("DB_USER") + ":" + os.Getenv("DB_PASSWORD") +
|
|
"@tcp(" + os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT") + ")/" +
|
|
os.Getenv("DB_NAME") + "?parseTime=true"
|
|
|
|
if len(connStr) < 1000 {
|
|
t.Error("Connection string should include long username")
|
|
}
|
|
}
|
|
|
|
func TestConnectionPoolSettings(t *testing.T) {
|
|
// Test that expected pool settings are documented
|
|
expectedSettings := map[string]int{
|
|
"MaxOpenConns": 25,
|
|
"MaxIdleConns": 10,
|
|
}
|
|
|
|
for setting, expected := range expectedSettings {
|
|
t.Run(setting, func(t *testing.T) {
|
|
// This is a documentation test to ensure we're aware of pool settings
|
|
if expected <= 0 {
|
|
t.Errorf("%s should be positive, got %d", setting, expected)
|
|
}
|
|
})
|
|
}
|
|
}
|