package middleware import ( "bytes" "io" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "time" "tss-rocks-be/internal/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestAccessLog(t *testing.T) { // 设置测试临时目录 tmpDir := t.TempDir() logPath := filepath.Join(tmpDir, "test.log") testCases := []struct { name string config *types.AccessLogConfig expectedError bool setupRequest func(*http.Request) validateOutput func(*testing.T, *httptest.ResponseRecorder, string) }{ { name: "Console logging only", config: &types.AccessLogConfig{ EnableConsole: true, EnableFile: false, Format: "json", Level: "info", }, expectedError: false, setupRequest: func(req *http.Request) { req.Header.Set("User-Agent", "test-agent") }, validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, logOutput, "GET /test") assert.Contains(t, logOutput, "test-agent") }, }, { name: "File logging only", config: &types.AccessLogConfig{ EnableConsole: false, EnableFile: true, FilePath: logPath, Format: "json", Level: "info", Rotation: struct { MaxSize int `yaml:"max_size"` MaxAge int `yaml:"max_age"` MaxBackups int `yaml:"max_backups"` Compress bool `yaml:"compress"` LocalTime bool `yaml:"local_time"` }{ MaxSize: 1, MaxAge: 1, MaxBackups: 1, Compress: false, LocalTime: true, }, }, expectedError: false, setupRequest: func(req *http.Request) { req.Header.Set("User-Agent", "test-agent") }, validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { assert.Equal(t, http.StatusOK, w.Code) // 读取日志文件内容 content, err := os.ReadFile(logPath) assert.NoError(t, err) assert.Contains(t, string(content), "GET /test") assert.Contains(t, string(content), "test-agent") }, }, { name: "Both console and file logging", config: &types.AccessLogConfig{ EnableConsole: true, EnableFile: true, FilePath: logPath, Format: "json", Level: "info", Rotation: struct { MaxSize int `yaml:"max_size"` MaxAge int `yaml:"max_age"` MaxBackups int `yaml:"max_backups"` Compress bool `yaml:"compress"` LocalTime bool `yaml:"local_time"` }{ MaxSize: 1, MaxAge: 1, MaxBackups: 1, Compress: false, LocalTime: true, }, }, expectedError: false, setupRequest: func(req *http.Request) { req.Header.Set("User-Agent", "test-agent") }, validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, logOutput, "GET /test") assert.Contains(t, logOutput, "test-agent") // 读取日志文件内容 content, err := os.ReadFile(logPath) assert.NoError(t, err) assert.Contains(t, string(content), "GET /test") assert.Contains(t, string(content), "test-agent") }, }, { name: "With authenticated user", config: &types.AccessLogConfig{ EnableConsole: true, EnableFile: false, Format: "json", Level: "info", }, expectedError: false, setupRequest: func(req *http.Request) { req.Header.Set("User-Agent", "test-agent") }, validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, logOutput, "GET /test") assert.Contains(t, logOutput, "test-agent") assert.Contains(t, logOutput, "test-user") }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // 捕获标准输出 oldStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w // 创建一个新的 gin 引擎 gin.SetMode(gin.TestMode) router := gin.New() // 创建访问日志中间件 middleware, err := AccessLog(tc.config) if tc.expectedError { assert.Error(t, err) return } assert.NoError(t, err) // 添加测试路由 router.Use(middleware) router.GET("/test", func(c *gin.Context) { // 如果是测试认证用户的情况,设置用户ID if tc.name == "With authenticated user" { c.Set("user_id", "test-user") } c.Status(http.StatusOK) }) // 创建测试请求 req := httptest.NewRequest("GET", "/test", nil) if tc.setupRequest != nil { tc.setupRequest(req) } rec := httptest.NewRecorder() // 执行请求 router.ServeHTTP(rec, req) // 恢复标准输出并获取输出内容 w.Close() var buf bytes.Buffer io.Copy(&buf, r) os.Stdout = oldStdout // 验证输出 if tc.validateOutput != nil { tc.validateOutput(t, rec, buf.String()) } // 关闭日志文件 if tc.config.EnableFile { // 调用中间件函数来关闭日志文件 middleware(nil) // 等待一小段时间确保文件完全关闭 time.Sleep(100 * time.Millisecond) } }) } } func TestAccessLogInvalidConfig(t *testing.T) { testCases := []struct { name string config *types.AccessLogConfig expectedError bool }{ { name: "Invalid log level", config: &types.AccessLogConfig{ EnableConsole: true, Level: "invalid_level", }, expectedError: false, // 应该使用默认的 info 级别 }, { name: "Invalid file path", config: &types.AccessLogConfig{ EnableFile: true, FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径 }, expectedError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := AccessLog(tc.config) if tc.expectedError { assert.Error(t, err) } else { assert.NoError(t, err) } }) } }