220 lines
4.9 KiB
Go
220 lines
4.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"tss-rocks-be/internal/config"
|
|
"tss-rocks-be/internal/types"
|
|
"tss-rocks-be/ent/enttest"
|
|
)
|
|
|
|
func TestNew(t *testing.T) {
|
|
// 创建测试配置
|
|
cfg := &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
Port: 8080,
|
|
},
|
|
Storage: config.StorageConfig{
|
|
Type: "local",
|
|
Local: config.LocalStorage{
|
|
RootDir: "testdata",
|
|
},
|
|
Upload: types.UploadConfig{
|
|
MaxSize: 10,
|
|
AllowedTypes: []string{"image/jpeg", "image/png"},
|
|
AllowedExtensions: []string{".jpg", ".png"},
|
|
},
|
|
},
|
|
RateLimit: types.RateLimitConfig{
|
|
IPRate: 100,
|
|
IPBurst: 200,
|
|
RouteRates: map[string]struct {
|
|
Rate int `yaml:"rate"`
|
|
Burst int `yaml:"burst"`
|
|
}{
|
|
"/api/v1/upload": {Rate: 10, Burst: 20},
|
|
},
|
|
},
|
|
AccessLog: types.AccessLogConfig{
|
|
EnableConsole: true,
|
|
EnableFile: true,
|
|
FilePath: "testdata/access.log",
|
|
Format: "json",
|
|
Level: "info",
|
|
Rotation: struct {
|
|
MaxSize int `yaml:"max_size"`
|
|
MaxAge int `yaml:"max_age"`
|
|
MaxBackups int `yaml:"max_backups"`
|
|
Compress bool `yaml:"compress"`
|
|
LocalTime bool `yaml:"local_time"`
|
|
}{
|
|
MaxSize: 100,
|
|
MaxAge: 7,
|
|
MaxBackups: 3,
|
|
Compress: true,
|
|
LocalTime: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
// 创建测试数据库客户端
|
|
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
|
defer client.Close()
|
|
|
|
// 测试服务器初始化
|
|
s, err := New(cfg, client)
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, s)
|
|
assert.NotNil(t, s.router)
|
|
assert.NotNil(t, s.handler)
|
|
assert.Equal(t, cfg, s.config)
|
|
}
|
|
|
|
func TestNew_StorageError(t *testing.T) {
|
|
// 创建一个无效的存储配置
|
|
cfg := &config.Config{
|
|
Storage: config.StorageConfig{
|
|
Type: "invalid_type", // 使用无效的存储类型
|
|
},
|
|
}
|
|
|
|
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
|
defer client.Close()
|
|
|
|
s, err := New(cfg, client)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, s)
|
|
assert.Contains(t, err.Error(), "failed to initialize storage")
|
|
}
|
|
|
|
func TestServer_StartAndShutdown(t *testing.T) {
|
|
// 创建测试配置
|
|
cfg := &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
Port: 0, // 使用随机端口
|
|
},
|
|
Storage: config.StorageConfig{
|
|
Type: "local",
|
|
Local: config.LocalStorage{
|
|
RootDir: "testdata",
|
|
},
|
|
},
|
|
RateLimit: types.RateLimitConfig{
|
|
IPRate: 100,
|
|
IPBurst: 200,
|
|
},
|
|
AccessLog: types.AccessLogConfig{
|
|
EnableConsole: true,
|
|
Format: "json",
|
|
Level: "info",
|
|
},
|
|
}
|
|
|
|
// 创建测试数据库客户端
|
|
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
|
defer client.Close()
|
|
|
|
// 初始化服务器
|
|
s, err := New(cfg, client)
|
|
require.NoError(t, err)
|
|
|
|
// 创建一个通道来接收服务器错误
|
|
errChan := make(chan error, 1)
|
|
|
|
// 在 goroutine 中启动服务器
|
|
go func() {
|
|
err := s.Start()
|
|
if err != nil && err != http.ErrServerClosed {
|
|
errChan <- err
|
|
}
|
|
close(errChan)
|
|
}()
|
|
|
|
// 给服务器一些时间启动
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// 测试关闭服务器
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
err = s.Shutdown(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// 检查服务器是否有错误发生
|
|
err = <-errChan
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestServer_StartError(t *testing.T) {
|
|
// 创建一个配置,使用已经被占用的端口来触发错误
|
|
cfg := &config.Config{
|
|
Server: config.ServerConfig{
|
|
Host: "localhost",
|
|
Port: 8899, // 使用固定端口以便测试
|
|
},
|
|
Storage: config.StorageConfig{
|
|
Type: "local",
|
|
Local: config.LocalStorage{
|
|
RootDir: "testdata",
|
|
},
|
|
},
|
|
}
|
|
|
|
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
|
defer client.Close()
|
|
|
|
// 创建第一个服务器实例
|
|
s1, err := New(cfg, client)
|
|
require.NoError(t, err)
|
|
|
|
// 创建一个通道来接收服务器错误
|
|
errChan := make(chan error, 1)
|
|
|
|
// 启动第一个服务器
|
|
go func() {
|
|
err := s1.Start()
|
|
if err != nil && err != http.ErrServerClosed {
|
|
errChan <- err
|
|
}
|
|
close(errChan)
|
|
}()
|
|
|
|
// 给服务器一些时间启动
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// 尝试在同一端口启动第二个服务器,应该会失败
|
|
s2, err := New(cfg, client)
|
|
require.NoError(t, err)
|
|
|
|
err = s2.Start()
|
|
assert.Error(t, err)
|
|
|
|
// 清理
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// 关闭第一个服务器
|
|
err = s1.Shutdown(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// 检查第一个服务器是否有错误发生
|
|
err = <-errChan
|
|
assert.NoError(t, err)
|
|
|
|
// 关闭第二个服务器
|
|
err = s2.Shutdown(ctx)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestServer_ShutdownWithNilServer(t *testing.T) {
|
|
s := &Server{}
|
|
err := s.Shutdown(context.Background())
|
|
assert.NoError(t, err)
|
|
}
|