package server import ( "context" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tss-rocks-be/internal/config" "tss-rocks-be/internal/types" "tss-rocks-be/ent/enttest" ) func TestNew(t *testing.T) { // 创建测试配置 cfg := &config.Config{ Server: config.ServerConfig{ Host: "localhost", Port: 8080, }, Storage: config.StorageConfig{ Type: "local", Local: config.LocalStorage{ RootDir: "testdata", }, Upload: types.UploadConfig{ MaxSize: 10, AllowedTypes: []string{"image/jpeg", "image/png"}, AllowedExtensions: []string{".jpg", ".png"}, }, }, RateLimit: types.RateLimitConfig{ IPRate: 100, IPBurst: 200, RouteRates: map[string]struct { Rate int `yaml:"rate"` Burst int `yaml:"burst"` }{ "/api/v1/upload": {Rate: 10, Burst: 20}, }, }, AccessLog: types.AccessLogConfig{ EnableConsole: true, EnableFile: true, FilePath: "testdata/access.log", Format: "json", Level: "info", Rotation: struct { MaxSize int `yaml:"max_size"` MaxAge int `yaml:"max_age"` MaxBackups int `yaml:"max_backups"` Compress bool `yaml:"compress"` LocalTime bool `yaml:"local_time"` }{ MaxSize: 100, MaxAge: 7, MaxBackups: 3, Compress: true, LocalTime: true, }, }, } // 创建测试数据库客户端 client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() // 测试服务器初始化 s, err := New(cfg, client) require.NoError(t, err) assert.NotNil(t, s) assert.NotNil(t, s.router) assert.NotNil(t, s.handler) assert.Equal(t, cfg, s.config) } func TestNew_StorageError(t *testing.T) { // 创建一个无效的存储配置 cfg := &config.Config{ Storage: config.StorageConfig{ Type: "invalid_type", // 使用无效的存储类型 }, } client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() s, err := New(cfg, client) assert.Error(t, err) assert.Nil(t, s) assert.Contains(t, err.Error(), "failed to initialize storage") } func TestServer_StartAndShutdown(t *testing.T) { // 创建测试配置 cfg := &config.Config{ Server: config.ServerConfig{ Host: "localhost", Port: 0, // 使用随机端口 }, Storage: config.StorageConfig{ Type: "local", Local: config.LocalStorage{ RootDir: "testdata", }, }, RateLimit: types.RateLimitConfig{ IPRate: 100, IPBurst: 200, }, AccessLog: types.AccessLogConfig{ EnableConsole: true, Format: "json", Level: "info", }, } // 创建测试数据库客户端 client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() // 初始化服务器 s, err := New(cfg, client) require.NoError(t, err) // 创建一个通道来接收服务器错误 errChan := make(chan error, 1) // 在 goroutine 中启动服务器 go func() { err := s.Start() if err != nil && err != http.ErrServerClosed { errChan <- err } close(errChan) }() // 给服务器一些时间启动 time.Sleep(100 * time.Millisecond) // 测试关闭服务器 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err = s.Shutdown(ctx) assert.NoError(t, err) // 检查服务器是否有错误发生 err = <-errChan assert.NoError(t, err) } func TestServer_StartError(t *testing.T) { // 创建一个配置,使用已经被占用的端口来触发错误 cfg := &config.Config{ Server: config.ServerConfig{ Host: "localhost", Port: 8899, // 使用固定端口以便测试 }, Storage: config.StorageConfig{ Type: "local", Local: config.LocalStorage{ RootDir: "testdata", }, }, } client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() // 创建第一个服务器实例 s1, err := New(cfg, client) require.NoError(t, err) // 创建一个通道来接收服务器错误 errChan := make(chan error, 1) // 启动第一个服务器 go func() { err := s1.Start() if err != nil && err != http.ErrServerClosed { errChan <- err } close(errChan) }() // 给服务器一些时间启动 time.Sleep(100 * time.Millisecond) // 尝试在同一端口启动第二个服务器,应该会失败 s2, err := New(cfg, client) require.NoError(t, err) err = s2.Start() assert.Error(t, err) // 清理 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // 关闭第一个服务器 err = s1.Shutdown(ctx) assert.NoError(t, err) // 检查第一个服务器是否有错误发生 err = <-errChan assert.NoError(t, err) // 关闭第二个服务器 err = s2.Shutdown(ctx) assert.NoError(t, err) } func TestServer_ShutdownWithNilServer(t *testing.T) { s := &Server{} err := s.Shutdown(context.Background()) assert.NoError(t, err) }