[chore/backend] remove all test for now
This commit is contained in:
parent
3d19ef05b3
commit
1c9628124f
28 changed files with 0 additions and 6780 deletions
|
@ -1,238 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
// 设置测试临时目录
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
setupRequest func(*http.Request)
|
||||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||||
}{
|
||||
{
|
||||
name: "Console logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "File logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: false,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both console and file logging",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With authenticated user",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
assert.Contains(t, logOutput, "test-user")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 捕获标准输出
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建访问日志中间件
|
||||
middleware, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 添加测试路由
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 如果是测试认证用户的情况,设置用户ID
|
||||
if tc.name == "With authenticated user" {
|
||||
c.Set("user_id", "test-user")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tc.setupRequest != nil {
|
||||
tc.setupRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 恢复标准输出并获取输出内容
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
os.Stdout = oldStdout
|
||||
|
||||
// 验证输出
|
||||
if tc.validateOutput != nil {
|
||||
tc.validateOutput(t, rec, buf.String())
|
||||
}
|
||||
|
||||
// 关闭日志文件
|
||||
if tc.config.EnableFile {
|
||||
// 调用中间件函数来关闭日志文件
|
||||
middleware(nil)
|
||||
// 等待一小段时间确保文件完全关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid log level",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Level: "invalid_level",
|
||||
},
|
||||
expectedError: false, // 应该使用默认的 info 级别
|
||||
},
|
||||
{
|
||||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tss-rocks-be/internal/service"
|
||||
)
|
||||
|
||||
func createTestToken(secret string, claims jwt.MapClaims) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to sign token: %v", err))
|
||||
}
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
jwtSecret := "test-secret"
|
||||
tokenBlacklist := service.NewTokenBlacklist()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(*http.Request)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
checkUserData bool
|
||||
expectedUserID string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "No Authorization header",
|
||||
setupAuth: func(req *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header is required"},
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization format",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "InvalidFormat")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
{
|
||||
name: "Valid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"admin", "editor"},
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
checkUserData: true,
|
||||
expectedUserID: "123",
|
||||
expectedRoles: []string{"admin", "editor"},
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"user"},
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
// 设置日志级别为 debug
|
||||
gin.SetMode(gin.DebugMode)
|
||||
c.Next()
|
||||
}, AuthMiddleware(jwtSecret, tokenBlacklist))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if tc.checkUserData {
|
||||
userID, exists := c.Get("user_id")
|
||||
assert.True(t, exists, "user_id should exist in context")
|
||||
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
|
||||
|
||||
roles, exists := c.Get("user_roles")
|
||||
assert.True(t, exists, "user_roles should exist in context")
|
||||
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tc.setupAuth(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
allowedRoles []string
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "No user roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置用户角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "User roles not found"},
|
||||
},
|
||||
{
|
||||
name: "Invalid roles type",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", 123) // 设置错误类型的角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: map[string]string{"error": "Invalid user roles type"},
|
||||
},
|
||||
{
|
||||
name: "Insufficient permissions",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
||||
},
|
||||
{
|
||||
name: "Allowed role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"admin"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "One of multiple allowed roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user", "editor"})
|
||||
},
|
||||
allowedRoles: []string{"admin", "editor", "moderator"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加角色中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
tc.setupContext(c)
|
||||
c.Next()
|
||||
}, RoleMiddleware(tc.allowedRoles...))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
checkHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "Normal GET request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加 CORS 中间件
|
||||
router.Use(CORS())
|
||||
|
||||
// 添加测试路由
|
||||
router.Any("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(tc.method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证状态码
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.checkHeaders {
|
||||
// 验证 CORS 头部
|
||||
headers := rec.Header()
|
||||
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,207 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.RateLimitConfig
|
||||
setupTest func(*gin.Engine)
|
||||
runTest func(*testing.T, *gin.Engine)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IP rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1, // 每秒1个请求
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 第一个请求应该成功
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 第二个请求应该被限制
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests from this IP", response["error"])
|
||||
|
||||
// 等待限流器重置
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// 第三个请求应该成功
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Route rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
||||
IPBurst: 10,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/limited": {
|
||||
Rate: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/limited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
router.GET("/unlimited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 测试限流路由
|
||||
req := httptest.NewRequest("GET", "/limited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests for this route", response["error"])
|
||||
|
||||
// 测试未限流路由
|
||||
req = httptest.NewRequest("GET", "/unlimited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// IP1 的请求
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.3:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 的请求应该不受 IP1 的限制影响
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.4:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req2)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加限流中间件
|
||||
router.Use(RateLimit(tc.config))
|
||||
|
||||
// 设置测试路由
|
||||
tc.setupTest(router)
|
||||
|
||||
// 运行测试
|
||||
tc.runTest(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
config := &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
}
|
||||
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
// 添加一些IP限流器
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
rl.getLimiter(ip)
|
||||
}
|
||||
|
||||
// 验证IP限流器已创建
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, len(ips), len(rl.ips))
|
||||
rl.mu.RUnlock()
|
||||
|
||||
// 修改一些IP的最后访问时间为1小时前
|
||||
rl.mu.Lock()
|
||||
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 手动触发清理
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 验证过期的IP限流器已被删除
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, 1, len(rl.ips))
|
||||
_, exists := rl.ips["192.168.1.3"]
|
||||
assert.True(t, exists)
|
||||
rl.mu.RUnlock()
|
||||
}
|
|
@ -1,262 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, bytes.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/upload", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.UploadConfig
|
||||
filename string
|
||||
content []byte
|
||||
setupRequest func(*testing.T) *http.Request
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Valid image upload",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5, // 5MB
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.jpg",
|
||||
content: []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
|
||||
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invalid file extension",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.txt",
|
||||
content: []byte("test content"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File extension .txt is not allowed",
|
||||
},
|
||||
{
|
||||
name: "File too large",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 1, // 1MB
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "large.jpg",
|
||||
content: make([]byte, 2<<20), // 2MB
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File large.jpg exceeds maximum size of 1 MB",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "fake.jpg",
|
||||
content: []byte("not a real image"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File type text/plain; charset=utf-8 is not allowed",
|
||||
},
|
||||
{
|
||||
name: "Missing file",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Failed to parse form",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type header",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest("POST", "/upload", nil)
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Content-Type must be multipart/form-data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if tt.setupRequest != nil {
|
||||
req = tt.setupRequest(t)
|
||||
} else {
|
||||
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
middleware := ValidateUpload(tt.config)
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
if tt.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(w.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response["error"], tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidatedFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
filename string
|
||||
expectedFound bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Get existing file",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 创建测试文件内容
|
||||
content := []byte("test content")
|
||||
buf := bytes.NewBuffer(content)
|
||||
|
||||
// 设置验证过的文件和内容类型
|
||||
c.Set("validated_file_test.txt", buf)
|
||||
c.Set("validated_content_type_test.txt", "text/plain")
|
||||
},
|
||||
filename: "test.txt",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "File not found",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置任何文件
|
||||
},
|
||||
filename: "nonexistent.txt",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setupContext != nil {
|
||||
tt.setupContext(c)
|
||||
}
|
||||
|
||||
buffer, contentType, found := GetValidatedFile(c, tt.filename)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
if tt.expectedFound {
|
||||
assert.NotNil(t, buffer)
|
||||
assert.NotEmpty(t, contentType)
|
||||
} else {
|
||||
assert.Nil(t, buffer)
|
||||
assert.Empty(t, contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "String found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "String not found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
str: "a",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.str)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue