[chore/backend] remove all test for now

This commit is contained in:
CDN 2025-02-22 02:11:27 +08:00
parent 3d19ef05b3
commit 1c9628124f
Signed by: CDN
GPG key ID: 0C656827F9F80080
28 changed files with 0 additions and 6780 deletions

View file

@ -1,238 +0,0 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View file

@ -1,227 +0,0 @@
package middleware
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/service"
)
func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
panic(fmt.Sprintf("Failed to sign token: %v", err))
}
return signedToken
}
func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret"
tokenBlacklist := service.NewTokenBlacklist()
testCases := []struct {
name string
setupAuth func(*http.Request)
expectedStatus int
expectedBody map[string]string
checkUserData bool
expectedUserID string
expectedRoles []string
}{
{
name: "No Authorization header",
setupAuth: func(req *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header is required"},
},
{
name: "Invalid Authorization format",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidFormat")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
},
{
name: "Invalid token",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer invalid.token.here")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
{
name: "Valid token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"admin", "editor"},
"exp": time.Now().Add(time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusOK,
checkUserData: true,
expectedUserID: "123",
expectedRoles: []string{"admin", "editor"},
},
{
name: "Expired token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"user"},
"exp": time.Now().Add(-time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加认证中间件
router.Use(func(c *gin.Context) {
// 设置日志级别为 debug
gin.SetMode(gin.DebugMode)
c.Next()
}, AuthMiddleware(jwtSecret, tokenBlacklist))
// 测试路由
router.GET("/test", func(c *gin.Context) {
if tc.checkUserData {
userID, exists := c.Get("user_id")
assert.True(t, exists, "user_id should exist in context")
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
roles, exists := c.Get("user_roles")
assert.True(t, exists, "user_roles should exist in context")
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
}
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
tc.setupAuth(req)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}
func TestRoleMiddleware(t *testing.T) {
testCases := []struct {
name string
setupContext func(*gin.Context)
allowedRoles []string
expectedStatus int
expectedBody map[string]string
}{
{
name: "No user roles",
setupContext: func(c *gin.Context) {
// 不设置用户角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User roles not found"},
},
{
name: "Invalid roles type",
setupContext: func(c *gin.Context) {
c.Set("user_roles", 123) // 设置错误类型的角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user roles type"},
},
{
name: "Insufficient permissions",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden,
expectedBody: map[string]string{"error": "Insufficient permissions"},
},
{
name: "Allowed role",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"admin"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK,
},
{
name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user", "editor"})
},
allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加角色中间件
router.Use(func(c *gin.Context) {
tc.setupContext(c)
c.Next()
}, RoleMiddleware(tc.allowedRoles...))
// 测试路由
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}

View file

@ -1,76 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCORS(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
checkHeaders bool
}{
{
name: "Normal GET request",
method: "GET",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
{
name: "OPTIONS request",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
checkHeaders: true,
},
{
name: "POST request",
method: "POST",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加 CORS 中间件
router.Use(CORS())
// 添加测试路由
router.Any("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest(tc.method, "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证状态码
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.checkHeaders {
// 验证 CORS 头部
headers := rec.Header()
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
}
})
}
}

View file

@ -1,207 +0,0 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}

View file

@ -1,262 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}