238 lines
5.9 KiB
Go
238 lines
5.9 KiB
Go
package middleware
|
||
|
||
import (
|
||
"bytes"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"os"
|
||
"path/filepath"
|
||
"testing"
|
||
"time"
|
||
"tss-rocks-be/internal/types"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/assert"
|
||
)
|
||
|
||
func TestAccessLog(t *testing.T) {
|
||
// 设置测试临时目录
|
||
tmpDir := t.TempDir()
|
||
logPath := filepath.Join(tmpDir, "test.log")
|
||
|
||
testCases := []struct {
|
||
name string
|
||
config *types.AccessLogConfig
|
||
expectedError bool
|
||
setupRequest func(*http.Request)
|
||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||
}{
|
||
{
|
||
name: "Console logging only",
|
||
config: &types.AccessLogConfig{
|
||
EnableConsole: true,
|
||
EnableFile: false,
|
||
Format: "json",
|
||
Level: "info",
|
||
},
|
||
expectedError: false,
|
||
setupRequest: func(req *http.Request) {
|
||
req.Header.Set("User-Agent", "test-agent")
|
||
},
|
||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||
assert.Equal(t, http.StatusOK, w.Code)
|
||
assert.Contains(t, logOutput, "GET /test")
|
||
assert.Contains(t, logOutput, "test-agent")
|
||
},
|
||
},
|
||
{
|
||
name: "File logging only",
|
||
config: &types.AccessLogConfig{
|
||
EnableConsole: false,
|
||
EnableFile: true,
|
||
FilePath: logPath,
|
||
Format: "json",
|
||
Level: "info",
|
||
Rotation: struct {
|
||
MaxSize int `yaml:"max_size"`
|
||
MaxAge int `yaml:"max_age"`
|
||
MaxBackups int `yaml:"max_backups"`
|
||
Compress bool `yaml:"compress"`
|
||
LocalTime bool `yaml:"local_time"`
|
||
}{
|
||
MaxSize: 1,
|
||
MaxAge: 1,
|
||
MaxBackups: 1,
|
||
Compress: false,
|
||
LocalTime: true,
|
||
},
|
||
},
|
||
expectedError: false,
|
||
setupRequest: func(req *http.Request) {
|
||
req.Header.Set("User-Agent", "test-agent")
|
||
},
|
||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||
assert.Equal(t, http.StatusOK, w.Code)
|
||
|
||
// 读取日志文件内容
|
||
content, err := os.ReadFile(logPath)
|
||
assert.NoError(t, err)
|
||
assert.Contains(t, string(content), "GET /test")
|
||
assert.Contains(t, string(content), "test-agent")
|
||
},
|
||
},
|
||
{
|
||
name: "Both console and file logging",
|
||
config: &types.AccessLogConfig{
|
||
EnableConsole: true,
|
||
EnableFile: true,
|
||
FilePath: logPath,
|
||
Format: "json",
|
||
Level: "info",
|
||
Rotation: struct {
|
||
MaxSize int `yaml:"max_size"`
|
||
MaxAge int `yaml:"max_age"`
|
||
MaxBackups int `yaml:"max_backups"`
|
||
Compress bool `yaml:"compress"`
|
||
LocalTime bool `yaml:"local_time"`
|
||
}{
|
||
MaxSize: 1,
|
||
MaxAge: 1,
|
||
MaxBackups: 1,
|
||
Compress: false,
|
||
LocalTime: true,
|
||
},
|
||
},
|
||
expectedError: false,
|
||
setupRequest: func(req *http.Request) {
|
||
req.Header.Set("User-Agent", "test-agent")
|
||
},
|
||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||
assert.Equal(t, http.StatusOK, w.Code)
|
||
assert.Contains(t, logOutput, "GET /test")
|
||
assert.Contains(t, logOutput, "test-agent")
|
||
|
||
// 读取日志文件内容
|
||
content, err := os.ReadFile(logPath)
|
||
assert.NoError(t, err)
|
||
assert.Contains(t, string(content), "GET /test")
|
||
assert.Contains(t, string(content), "test-agent")
|
||
},
|
||
},
|
||
{
|
||
name: "With authenticated user",
|
||
config: &types.AccessLogConfig{
|
||
EnableConsole: true,
|
||
EnableFile: false,
|
||
Format: "json",
|
||
Level: "info",
|
||
},
|
||
expectedError: false,
|
||
setupRequest: func(req *http.Request) {
|
||
req.Header.Set("User-Agent", "test-agent")
|
||
},
|
||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||
assert.Equal(t, http.StatusOK, w.Code)
|
||
assert.Contains(t, logOutput, "GET /test")
|
||
assert.Contains(t, logOutput, "test-agent")
|
||
assert.Contains(t, logOutput, "test-user")
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
// 捕获标准输出
|
||
oldStdout := os.Stdout
|
||
r, w, _ := os.Pipe()
|
||
os.Stdout = w
|
||
|
||
// 创建一个新的 gin 引擎
|
||
gin.SetMode(gin.TestMode)
|
||
router := gin.New()
|
||
|
||
// 创建访问日志中间件
|
||
middleware, err := AccessLog(tc.config)
|
||
if tc.expectedError {
|
||
assert.Error(t, err)
|
||
return
|
||
}
|
||
assert.NoError(t, err)
|
||
|
||
// 添加测试路由
|
||
router.Use(middleware)
|
||
router.GET("/test", func(c *gin.Context) {
|
||
// 如果是测试认证用户的情况,设置用户ID
|
||
if tc.name == "With authenticated user" {
|
||
c.Set("user_id", "test-user")
|
||
}
|
||
c.Status(http.StatusOK)
|
||
})
|
||
|
||
// 创建测试请求
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
if tc.setupRequest != nil {
|
||
tc.setupRequest(req)
|
||
}
|
||
rec := httptest.NewRecorder()
|
||
|
||
// 执行请求
|
||
router.ServeHTTP(rec, req)
|
||
|
||
// 恢复标准输出并获取输出内容
|
||
w.Close()
|
||
var buf bytes.Buffer
|
||
io.Copy(&buf, r)
|
||
os.Stdout = oldStdout
|
||
|
||
// 验证输出
|
||
if tc.validateOutput != nil {
|
||
tc.validateOutput(t, rec, buf.String())
|
||
}
|
||
|
||
// 关闭日志文件
|
||
if tc.config.EnableFile {
|
||
// 调用中间件函数来关闭日志文件
|
||
middleware(nil)
|
||
// 等待一小段时间确保文件完全关闭
|
||
time.Sleep(100 * time.Millisecond)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||
testCases := []struct {
|
||
name string
|
||
config *types.AccessLogConfig
|
||
expectedError bool
|
||
}{
|
||
{
|
||
name: "Invalid log level",
|
||
config: &types.AccessLogConfig{
|
||
EnableConsole: true,
|
||
Level: "invalid_level",
|
||
},
|
||
expectedError: false, // 应该使用默认的 info 级别
|
||
},
|
||
{
|
||
name: "Invalid file path",
|
||
config: &types.AccessLogConfig{
|
||
EnableFile: true,
|
||
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
|
||
},
|
||
expectedError: true,
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
_, err := AccessLog(tc.config)
|
||
if tc.expectedError {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|