tss-rocks/backend/internal/middleware/accesslog_test.go
CDN 05ddc1f783
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s
[feature] migrate to monorepo
2025-02-21 00:49:20 +08:00

238 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}