[feature] migrate to monorepo
This commit is contained in:
commit
05ddc1f783
267 changed files with 75165 additions and 0 deletions
192
backend/internal/middleware/accesslog.go
Normal file
192
backend/internal/middleware/accesslog.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// AccessLogConfig 访问日志配置
|
||||
type AccessLogConfig struct {
|
||||
// 是否启用控制台输出
|
||||
EnableConsole bool `yaml:"enable_console"`
|
||||
// 是否启用文件日志
|
||||
EnableFile bool `yaml:"enable_file"`
|
||||
// 日志文件路径
|
||||
FilePath string `yaml:"file_path"`
|
||||
// 日志格式 (json 或 text)
|
||||
Format string `yaml:"format"`
|
||||
// 日志级别
|
||||
Level string `yaml:"level"`
|
||||
// 日志轮转配置
|
||||
Rotation struct {
|
||||
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小(MB)
|
||||
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
|
||||
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
|
||||
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
|
||||
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
|
||||
} `yaml:"rotation"`
|
||||
}
|
||||
|
||||
// accessLogger 访问日志记录器
|
||||
type accessLogger struct {
|
||||
consoleLogger *zerolog.Logger
|
||||
fileLogger *zerolog.Logger
|
||||
logWriter *lumberjack.Logger
|
||||
config *types.AccessLogConfig
|
||||
}
|
||||
|
||||
// Close 关闭日志文件
|
||||
func (l *accessLogger) Close() error {
|
||||
if l.logWriter != nil {
|
||||
return l.logWriter.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newAccessLogger 创建新的访问日志记录器
|
||||
func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
|
||||
var consoleLogger, fileLogger *zerolog.Logger
|
||||
var logWriter *lumberjack.Logger
|
||||
|
||||
// 设置日志级别
|
||||
level, err := zerolog.ParseLevel(config.Level)
|
||||
if err != nil {
|
||||
level = zerolog.InfoLevel
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
|
||||
// 配置控制台日志
|
||||
if config.EnableConsole {
|
||||
logger := zerolog.New(os.Stdout).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
if config.Format == "text" {
|
||||
logger = logger.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
|
||||
}
|
||||
consoleLogger = &logger
|
||||
}
|
||||
|
||||
// 配置文件日志
|
||||
if config.EnableFile {
|
||||
// 确保日志目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
// 配置日志轮转
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: config.FilePath,
|
||||
MaxSize: config.Rotation.MaxSize, // MB
|
||||
MaxAge: config.Rotation.MaxAge, // days
|
||||
MaxBackups: config.Rotation.MaxBackups, // files
|
||||
Compress: config.Rotation.Compress, // 是否压缩
|
||||
LocalTime: config.Rotation.LocalTime, // 使用本地时间
|
||||
}
|
||||
|
||||
logger := zerolog.New(logWriter).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
fileLogger = &logger
|
||||
}
|
||||
|
||||
return &accessLogger{
|
||||
consoleLogger: consoleLogger,
|
||||
fileLogger: fileLogger,
|
||||
logWriter: logWriter,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// logEvent 记录日志事件
|
||||
func (l *accessLogger) logEvent(fields map[string]interface{}, msg string) {
|
||||
if l.consoleLogger != nil {
|
||||
event := l.consoleLogger.Info()
|
||||
for k, v := range fields {
|
||||
event = event.Interface(k, v)
|
||||
}
|
||||
event.Msg(msg)
|
||||
}
|
||||
if l.fileLogger != nil {
|
||||
event := l.fileLogger.Info()
|
||||
for k, v := range fields {
|
||||
event = event.Interface(k, v)
|
||||
}
|
||||
event.Msg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// AccessLog 创建访问日志中间件
|
||||
func AccessLog(config *types.AccessLogConfig) (gin.HandlerFunc, error) {
|
||||
logger, err := newAccessLogger(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 用于测试时关闭日志文件
|
||||
if c == nil {
|
||||
if err := logger.Close(); err != nil {
|
||||
fmt.Printf("Error closing log file: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
requestID := uuid.New().String()
|
||||
path := c.Request.URL.Path
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
// 设置请求ID到上下文
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算处理时间
|
||||
latency := time.Since(start)
|
||||
|
||||
// 获取用户ID(如果已认证)
|
||||
var userID interface{}
|
||||
if id, exists := c.Get("user_id"); exists {
|
||||
userID = id
|
||||
}
|
||||
|
||||
// 准备日志字段
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"query": query,
|
||||
"ip": c.ClientIP(),
|
||||
"user_agent": c.Request.UserAgent(),
|
||||
"status": c.Writer.Status(),
|
||||
"size": c.Writer.Size(),
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"component": "access_log",
|
||||
}
|
||||
|
||||
if userID != nil {
|
||||
fields["user_id"] = userID
|
||||
}
|
||||
|
||||
// 如果有错误,添加到日志中
|
||||
if len(c.Errors) > 0 {
|
||||
fields["error"] = c.Errors.String()
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
logger.logEvent(fields, fmt.Sprintf("%s %s", c.Request.Method, path))
|
||||
}, nil
|
||||
}
|
238
backend/internal/middleware/accesslog_test.go
Normal file
238
backend/internal/middleware/accesslog_test.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
// 设置测试临时目录
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
setupRequest func(*http.Request)
|
||||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||||
}{
|
||||
{
|
||||
name: "Console logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "File logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: false,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both console and file logging",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With authenticated user",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
assert.Contains(t, logOutput, "test-user")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 捕获标准输出
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建访问日志中间件
|
||||
middleware, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 添加测试路由
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 如果是测试认证用户的情况,设置用户ID
|
||||
if tc.name == "With authenticated user" {
|
||||
c.Set("user_id", "test-user")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tc.setupRequest != nil {
|
||||
tc.setupRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 恢复标准输出并获取输出内容
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
os.Stdout = oldStdout
|
||||
|
||||
// 验证输出
|
||||
if tc.validateOutput != nil {
|
||||
tc.validateOutput(t, rec, buf.String())
|
||||
}
|
||||
|
||||
// 关闭日志文件
|
||||
if tc.config.EnableFile {
|
||||
// 调用中间件函数来关闭日志文件
|
||||
middleware(nil)
|
||||
// 等待一小段时间确保文件完全关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid log level",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Level: "invalid_level",
|
||||
},
|
||||
expectedError: false, // 应该使用默认的 info 级别
|
||||
},
|
||||
{
|
||||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
82
backend/internal/middleware/auth.go
Normal file
82
backend/internal/middleware/auth.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthMiddleware creates a middleware for JWT authentication
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to parse token")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
c.Set("user_id", claims["sub"])
|
||||
c.Set("user_role", claims["role"])
|
||||
c.Next()
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RoleMiddleware creates a middleware for role-based authorization
|
||||
func RoleMiddleware(roles ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userRole, exists := c.Get("user_role")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := userRole.(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
if role == roleStr {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
217
backend/internal/middleware/auth_test.go
Normal file
217
backend/internal/middleware/auth_test.go
Normal file
|
@ -0,0 +1,217 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestToken(secret string, claims jwt.MapClaims) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, _ := token.SignedString([]byte(secret))
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
jwtSecret := "test-secret"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(*http.Request)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
checkUserData bool
|
||||
expectedUserID string
|
||||
expectedRole string
|
||||
}{
|
||||
{
|
||||
name: "No Authorization header",
|
||||
setupAuth: func(req *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header is required"},
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization format",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "InvalidFormat")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
{
|
||||
name: "Valid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user123",
|
||||
"role": "user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
checkUserData: true,
|
||||
expectedUserID: "user123",
|
||||
expectedRole: "user",
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user123",
|
||||
"role": "user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(AuthMiddleware(jwtSecret))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if tc.checkUserData {
|
||||
userID, exists := c.Get("user_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tc.expectedUserID, userID)
|
||||
|
||||
role, exists := c.Get("user_role")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tc.expectedRole, role)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tc.setupAuth(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
allowedRoles []string
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "No user role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置用户角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "User role not found"},
|
||||
},
|
||||
{
|
||||
name: "Invalid role type",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", 123) // 设置错误类型的角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: map[string]string{"error": "Invalid user role type"},
|
||||
},
|
||||
{
|
||||
name: "Insufficient permissions",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "user")
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
||||
},
|
||||
{
|
||||
name: "Allowed role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "admin")
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "One of multiple allowed roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "editor")
|
||||
},
|
||||
allowedRoles: []string{"admin", "editor", "moderator"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加角色中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
tc.setupContext(c)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(RoleMiddleware(tc.allowedRoles...))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
22
backend/internal/middleware/cors.go
Normal file
22
backend/internal/middleware/cors.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS middleware
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
76
backend/internal/middleware/cors_test.go
Normal file
76
backend/internal/middleware/cors_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
checkHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "Normal GET request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加 CORS 中间件
|
||||
router.Use(CORS())
|
||||
|
||||
// 添加测试路由
|
||||
router.Any("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(tc.method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证状态码
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.checkHeaders {
|
||||
// 验证 CORS 头部
|
||||
headers := rec.Header()
|
||||
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
107
backend/internal/middleware/ratelimit.go
Normal file
107
backend/internal/middleware/ratelimit.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// ipLimiter IP限流器
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// rateLimiter 限流器管理器
|
||||
type rateLimiter struct {
|
||||
ips map[string]*ipLimiter
|
||||
mu sync.RWMutex
|
||||
config *types.RateLimitConfig
|
||||
routes map[string]*rate.Limiter
|
||||
}
|
||||
|
||||
// newRateLimiter 创建新的限流器
|
||||
func newRateLimiter(config *types.RateLimitConfig) *rateLimiter {
|
||||
// 初始化路由限流器
|
||||
routes := make(map[string]*rate.Limiter)
|
||||
for path, cfg := range config.RouteRates {
|
||||
routes[path] = rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Burst)
|
||||
}
|
||||
|
||||
rl := &rateLimiter{
|
||||
ips: make(map[string]*ipLimiter),
|
||||
config: config,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
// 启动清理过期IP限流器的goroutine
|
||||
go rl.cleanupIPLimiters()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanupIPLimiters 清理过期的IP限流器
|
||||
func (rl *rateLimiter) cleanupIPLimiters() {
|
||||
for {
|
||||
time.Sleep(time.Hour) // 每小时清理一次
|
||||
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter 获取IP限流器
|
||||
func (rl *rateLimiter) getLimiter(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, exists := rl.ips[ip]
|
||||
if !exists {
|
||||
limiter := rate.NewLimiter(rate.Limit(rl.config.IPRate), rl.config.IPBurst)
|
||||
rl.ips[ip] = &ipLimiter{limiter: limiter, lastSeen: time.Now()}
|
||||
return limiter
|
||||
}
|
||||
|
||||
v.lastSeen = time.Now()
|
||||
return v.limiter
|
||||
}
|
||||
|
||||
// RateLimit 创建限流中间件
|
||||
func RateLimit(config *types.RateLimitConfig) gin.HandlerFunc {
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 检查路由限流
|
||||
path := c.Request.URL.Path
|
||||
if limiter, ok := rl.routes[path]; ok {
|
||||
if !limiter.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "too many requests for this route",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IP限流
|
||||
limiter := rl.getLimiter(c.ClientIP())
|
||||
if !limiter.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "too many requests from this IP",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
207
backend/internal/middleware/ratelimit_test.go
Normal file
207
backend/internal/middleware/ratelimit_test.go
Normal file
|
@ -0,0 +1,207 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.RateLimitConfig
|
||||
setupTest func(*gin.Engine)
|
||||
runTest func(*testing.T, *gin.Engine)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IP rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1, // 每秒1个请求
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 第一个请求应该成功
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 第二个请求应该被限制
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests from this IP", response["error"])
|
||||
|
||||
// 等待限流器重置
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// 第三个请求应该成功
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Route rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
||||
IPBurst: 10,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/limited": {
|
||||
Rate: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/limited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
router.GET("/unlimited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 测试限流路由
|
||||
req := httptest.NewRequest("GET", "/limited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests for this route", response["error"])
|
||||
|
||||
// 测试未限流路由
|
||||
req = httptest.NewRequest("GET", "/unlimited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// IP1 的请求
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.3:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 的请求应该不受 IP1 的限制影响
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.4:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req2)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加限流中间件
|
||||
router.Use(RateLimit(tc.config))
|
||||
|
||||
// 设置测试路由
|
||||
tc.setupTest(router)
|
||||
|
||||
// 运行测试
|
||||
tc.runTest(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
config := &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
}
|
||||
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
// 添加一些IP限流器
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
rl.getLimiter(ip)
|
||||
}
|
||||
|
||||
// 验证IP限流器已创建
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, len(ips), len(rl.ips))
|
||||
rl.mu.RUnlock()
|
||||
|
||||
// 修改一些IP的最后访问时间为1小时前
|
||||
rl.mu.Lock()
|
||||
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 手动触发清理
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 验证过期的IP限流器已被删除
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, 1, len(rl.ips))
|
||||
_, exists := rl.ips["192.168.1.3"]
|
||||
assert.True(t, exists)
|
||||
rl.mu.RUnlock()
|
||||
}
|
110
backend/internal/middleware/rbac.go
Normal file
110
backend/internal/middleware/rbac.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/auth"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequirePermission creates a middleware that checks if the user has the required permission
|
||||
func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user from context
|
||||
userID, exists := c.Get(auth.UserIDKey)
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user with roles
|
||||
user, err := client.User.Query().
|
||||
Where(user.ID(userID.(int))).
|
||||
WithRoles(func(q *ent.RoleQuery) {
|
||||
q.WithPermissions()
|
||||
}).
|
||||
Only(context.Background())
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "User not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has the required permission through any of their roles
|
||||
hasPermission := false
|
||||
for _, r := range user.Edges.Roles {
|
||||
for _, p := range r.Edges.Permissions {
|
||||
if p.Resource == resource && p.Action == action {
|
||||
hasPermission = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasPermission {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasPermission {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("Missing required permission: %s:%s", resource, action),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole creates a middleware that checks if the user has the required role
|
||||
func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user from context
|
||||
userID, exists := c.Get(auth.UserIDKey)
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user with roles
|
||||
user, err := client.User.Query().
|
||||
Where(user.ID(userID.(int))).
|
||||
WithRoles().
|
||||
Only(context.Background())
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "User not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has the required role
|
||||
hasRole := false
|
||||
for _, r := range user.Edges.Roles {
|
||||
if r.Name == roleName {
|
||||
hasRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRole {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("Required role: %s", roleName),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
159
backend/internal/middleware/upload.go
Normal file
159
backend/internal/middleware/upload.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
|
||||
)
|
||||
|
||||
// ValidateUpload 创建文件上传验证中间件
|
||||
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否是multipart/form-data请求
|
||||
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Content-Type must be multipart/form-data",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析multipart表单
|
||||
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Failed to parse form: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil || form.File == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "No file uploaded",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有上传的文件
|
||||
for _, files := range form.File {
|
||||
for _, file := range files {
|
||||
// 检查文件大小
|
||||
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if !contains(cfg.AllowedExtensions, ext) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File extension %s is not allowed", ext),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to open file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 读取文件头部用于MIME类型检测
|
||||
header := make([]byte, maxHeaderBytes)
|
||||
n, err := src.Read(header)
|
||||
if err != nil && err != io.EOF {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// 检测MIME类型
|
||||
contentType := http.DetectContentType(header)
|
||||
if !contains(cfg.AllowedTypes, contentType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File type %s is not allowed", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件指针重置到开始位置
|
||||
_, err = src.Seek(0, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件内容读入缓冲区
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = io.Copy(buf, src)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将验证过的文件内容和类型保存到上下文中
|
||||
c.Set("validated_file_"+file.Filename, buf)
|
||||
c.Set("validated_content_type_"+file.Filename, contentType)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含指定的字符串
|
||||
func contains(slice []string, str string) bool {
|
||||
for _, s := range slice {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetValidatedFile 从上下文中获取验证过的文件内容
|
||||
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
|
||||
file, exists := c.Get("validated_file_" + filename)
|
||||
if !exists {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
contentType, exists := c.Get("validated_content_type_" + filename)
|
||||
if !exists {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
return file.(*bytes.Buffer), contentType.(string), true
|
||||
}
|
262
backend/internal/middleware/upload_test.go
Normal file
262
backend/internal/middleware/upload_test.go
Normal file
|
@ -0,0 +1,262 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, bytes.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/upload", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.UploadConfig
|
||||
filename string
|
||||
content []byte
|
||||
setupRequest func(*testing.T) *http.Request
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Valid image upload",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5, // 5MB
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.jpg",
|
||||
content: []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
|
||||
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invalid file extension",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.txt",
|
||||
content: []byte("test content"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File extension .txt is not allowed",
|
||||
},
|
||||
{
|
||||
name: "File too large",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 1, // 1MB
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "large.jpg",
|
||||
content: make([]byte, 2<<20), // 2MB
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File large.jpg exceeds maximum size of 1 MB",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "fake.jpg",
|
||||
content: []byte("not a real image"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File type text/plain; charset=utf-8 is not allowed",
|
||||
},
|
||||
{
|
||||
name: "Missing file",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Failed to parse form",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type header",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest("POST", "/upload", nil)
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Content-Type must be multipart/form-data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if tt.setupRequest != nil {
|
||||
req = tt.setupRequest(t)
|
||||
} else {
|
||||
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
middleware := ValidateUpload(tt.config)
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
if tt.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(w.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response["error"], tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidatedFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
filename string
|
||||
expectedFound bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Get existing file",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 创建测试文件内容
|
||||
content := []byte("test content")
|
||||
buf := bytes.NewBuffer(content)
|
||||
|
||||
// 设置验证过的文件和内容类型
|
||||
c.Set("validated_file_test.txt", buf)
|
||||
c.Set("validated_content_type_test.txt", "text/plain")
|
||||
},
|
||||
filename: "test.txt",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "File not found",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置任何文件
|
||||
},
|
||||
filename: "nonexistent.txt",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setupContext != nil {
|
||||
tt.setupContext(c)
|
||||
}
|
||||
|
||||
buffer, contentType, found := GetValidatedFile(c, tt.filename)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
if tt.expectedFound {
|
||||
assert.NotNil(t, buffer)
|
||||
assert.NotEmpty(t, contentType)
|
||||
} else {
|
||||
assert.Nil(t, buffer)
|
||||
assert.Empty(t, contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "String found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "String not found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
str: "a",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.str)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue