[feature] migrate to monorepo
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s

This commit is contained in:
CDN 2025-02-21 00:49:20 +08:00
commit 05ddc1f783
Signed by: CDN
GPG key ID: 0C656827F9F80080
267 changed files with 75165 additions and 0 deletions

View file

@ -0,0 +1,192 @@
package middleware
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2"
"tss-rocks-be/internal/types"
)
// AccessLogConfig 访问日志配置
type AccessLogConfig struct {
// 是否启用控制台输出
EnableConsole bool `yaml:"enable_console"`
// 是否启用文件日志
EnableFile bool `yaml:"enable_file"`
// 日志文件路径
FilePath string `yaml:"file_path"`
// 日志格式 (json 或 text)
Format string `yaml:"format"`
// 日志级别
Level string `yaml:"level"`
// 日志轮转配置
Rotation struct {
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小MB
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
} `yaml:"rotation"`
}
// accessLogger 访问日志记录器
type accessLogger struct {
consoleLogger *zerolog.Logger
fileLogger *zerolog.Logger
logWriter *lumberjack.Logger
config *types.AccessLogConfig
}
// Close 关闭日志文件
func (l *accessLogger) Close() error {
if l.logWriter != nil {
return l.logWriter.Close()
}
return nil
}
// newAccessLogger 创建新的访问日志记录器
func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
var consoleLogger, fileLogger *zerolog.Logger
var logWriter *lumberjack.Logger
// 设置日志级别
level, err := zerolog.ParseLevel(config.Level)
if err != nil {
level = zerolog.InfoLevel
}
zerolog.SetGlobalLevel(level)
// 配置控制台日志
if config.EnableConsole {
logger := zerolog.New(os.Stdout).
With().
Timestamp().
Logger()
if config.Format == "text" {
logger = logger.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
}
consoleLogger = &logger
}
// 配置文件日志
if config.EnableFile {
// 确保日志目录存在
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
// 配置日志轮转
logWriter = &lumberjack.Logger{
Filename: config.FilePath,
MaxSize: config.Rotation.MaxSize, // MB
MaxAge: config.Rotation.MaxAge, // days
MaxBackups: config.Rotation.MaxBackups, // files
Compress: config.Rotation.Compress, // 是否压缩
LocalTime: config.Rotation.LocalTime, // 使用本地时间
}
logger := zerolog.New(logWriter).
With().
Timestamp().
Logger()
fileLogger = &logger
}
return &accessLogger{
consoleLogger: consoleLogger,
fileLogger: fileLogger,
logWriter: logWriter,
config: config,
}, nil
}
// logEvent 记录日志事件
func (l *accessLogger) logEvent(fields map[string]interface{}, msg string) {
if l.consoleLogger != nil {
event := l.consoleLogger.Info()
for k, v := range fields {
event = event.Interface(k, v)
}
event.Msg(msg)
}
if l.fileLogger != nil {
event := l.fileLogger.Info()
for k, v := range fields {
event = event.Interface(k, v)
}
event.Msg(msg)
}
}
// AccessLog 创建访问日志中间件
func AccessLog(config *types.AccessLogConfig) (gin.HandlerFunc, error) {
logger, err := newAccessLogger(config)
if err != nil {
return nil, err
}
return func(c *gin.Context) {
// 用于测试时关闭日志文件
if c == nil {
if err := logger.Close(); err != nil {
fmt.Printf("Error closing log file: %v\n", err)
}
return
}
start := time.Now()
requestID := uuid.New().String()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
// 设置请求ID到上下文
c.Set("request_id", requestID)
// 处理请求
c.Next()
// 计算处理时间
latency := time.Since(start)
// 获取用户ID如果已认证
var userID interface{}
if id, exists := c.Get("user_id"); exists {
userID = id
}
// 准备日志字段
fields := map[string]interface{}{
"request_id": requestID,
"method": c.Request.Method,
"path": path,
"query": query,
"ip": c.ClientIP(),
"user_agent": c.Request.UserAgent(),
"status": c.Writer.Status(),
"size": c.Writer.Size(),
"latency_ms": latency.Milliseconds(),
"component": "access_log",
}
if userID != nil {
fields["user_id"] = userID
}
// 如果有错误,添加到日志中
if len(c.Errors) > 0 {
fields["error"] = c.Errors.String()
}
// 记录日志
logger.logEvent(fields, fmt.Sprintf("%s %s", c.Request.Method, path))
}, nil
}

View file

@ -0,0 +1,238 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View file

@ -0,0 +1,82 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)
// AuthMiddleware creates a middleware for JWT authentication
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
c.Abort()
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
c.Abort()
return
}
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Error().Err(err).Msg("Failed to parse token")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
c.Set("user_id", claims["sub"])
c.Set("user_role", claims["role"])
c.Next()
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
}
}
// RoleMiddleware creates a middleware for role-based authorization
func RoleMiddleware(roles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
userRole, exists := c.Get("user_role")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"})
c.Abort()
return
}
roleStr, ok := userRole.(string)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"})
c.Abort()
return
}
for _, role := range roles {
if role == roleStr {
c.Next()
return
}
}
c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
c.Abort()
}
}

View file

@ -0,0 +1,217 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, _ := token.SignedString([]byte(secret))
return signedToken
}
func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret"
testCases := []struct {
name string
setupAuth func(*http.Request)
expectedStatus int
expectedBody map[string]string
checkUserData bool
expectedUserID string
expectedRole string
}{
{
name: "No Authorization header",
setupAuth: func(req *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header is required"},
},
{
name: "Invalid Authorization format",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidFormat")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
},
{
name: "Invalid token",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer invalid.token.here")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
{
name: "Valid token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "user123",
"role": "user",
"exp": time.Now().Add(time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusOK,
checkUserData: true,
expectedUserID: "user123",
expectedRole: "user",
},
{
name: "Expired token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "user123",
"role": "user",
"exp": time.Now().Add(-time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加认证中间件
router.Use(AuthMiddleware(jwtSecret))
// 测试路由
router.GET("/test", func(c *gin.Context) {
if tc.checkUserData {
userID, exists := c.Get("user_id")
assert.True(t, exists)
assert.Equal(t, tc.expectedUserID, userID)
role, exists := c.Get("user_role")
assert.True(t, exists)
assert.Equal(t, tc.expectedRole, role)
}
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
tc.setupAuth(req)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, tc.expectedBody, response)
}
})
}
}
func TestRoleMiddleware(t *testing.T) {
testCases := []struct {
name string
setupContext func(*gin.Context)
allowedRoles []string
expectedStatus int
expectedBody map[string]string
}{
{
name: "No user role",
setupContext: func(c *gin.Context) {
// 不设置用户角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User role not found"},
},
{
name: "Invalid role type",
setupContext: func(c *gin.Context) {
c.Set("user_role", 123) // 设置错误类型的角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user role type"},
},
{
name: "Insufficient permissions",
setupContext: func(c *gin.Context) {
c.Set("user_role", "user")
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden,
expectedBody: map[string]string{"error": "Insufficient permissions"},
},
{
name: "Allowed role",
setupContext: func(c *gin.Context) {
c.Set("user_role", "admin")
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK,
},
{
name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) {
c.Set("user_role", "editor")
},
allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加角色中间件
router.Use(func(c *gin.Context) {
tc.setupContext(c)
c.Next()
})
router.Use(RoleMiddleware(tc.allowedRoles...))
// 测试路由
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, tc.expectedBody, response)
}
})
}
}

View file

@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View file

@ -0,0 +1,76 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCORS(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
checkHeaders bool
}{
{
name: "Normal GET request",
method: "GET",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
{
name: "OPTIONS request",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
checkHeaders: true,
},
{
name: "POST request",
method: "POST",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加 CORS 中间件
router.Use(CORS())
// 添加测试路由
router.Any("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest(tc.method, "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证状态码
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.checkHeaders {
// 验证 CORS 头部
headers := rec.Header()
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
}
})
}
}

View file

@ -0,0 +1,107 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"tss-rocks-be/internal/types"
)
// ipLimiter IP限流器
type ipLimiter struct {
limiter *rate.Limiter
lastSeen time.Time
}
// rateLimiter 限流器管理器
type rateLimiter struct {
ips map[string]*ipLimiter
mu sync.RWMutex
config *types.RateLimitConfig
routes map[string]*rate.Limiter
}
// newRateLimiter 创建新的限流器
func newRateLimiter(config *types.RateLimitConfig) *rateLimiter {
// 初始化路由限流器
routes := make(map[string]*rate.Limiter)
for path, cfg := range config.RouteRates {
routes[path] = rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Burst)
}
rl := &rateLimiter{
ips: make(map[string]*ipLimiter),
config: config,
routes: routes,
}
// 启动清理过期IP限流器的goroutine
go rl.cleanupIPLimiters()
return rl
}
// cleanupIPLimiters 清理过期的IP限流器
func (rl *rateLimiter) cleanupIPLimiters() {
for {
time.Sleep(time.Hour) // 每小时清理一次
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
}
}
// getLimiter 获取IP限流器
func (rl *rateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.ips[ip]
if !exists {
limiter := rate.NewLimiter(rate.Limit(rl.config.IPRate), rl.config.IPBurst)
rl.ips[ip] = &ipLimiter{limiter: limiter, lastSeen: time.Now()}
return limiter
}
v.lastSeen = time.Now()
return v.limiter
}
// RateLimit 创建限流中间件
func RateLimit(config *types.RateLimitConfig) gin.HandlerFunc {
rl := newRateLimiter(config)
return func(c *gin.Context) {
// 检查路由限流
path := c.Request.URL.Path
if limiter, ok := rl.routes[path]; ok {
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "too many requests for this route",
})
c.Abort()
return
}
}
// 检查IP限流
limiter := rl.getLimiter(c.ClientIP())
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "too many requests from this IP",
})
c.Abort()
return
}
c.Next()
}
}

View file

@ -0,0 +1,207 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}

View file

@ -0,0 +1,110 @@
package middleware
import (
"context"
"fmt"
"net/http"
"tss-rocks-be/ent"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/auth"
"github.com/gin-gonic/gin"
)
// RequirePermission creates a middleware that checks if the user has the required permission
func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc {
return func(c *gin.Context) {
// Get user from context
userID, exists := c.Get(auth.UserIDKey)
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
})
return
}
// Get user with roles
user, err := client.User.Query().
Where(user.ID(userID.(int))).
WithRoles(func(q *ent.RoleQuery) {
q.WithPermissions()
}).
Only(context.Background())
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "User not found",
})
return
}
// Check if user has the required permission through any of their roles
hasPermission := false
for _, r := range user.Edges.Roles {
for _, p := range r.Edges.Permissions {
if p.Resource == resource && p.Action == action {
hasPermission = true
break
}
}
if hasPermission {
break
}
}
if !hasPermission {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("Missing required permission: %s:%s", resource, action),
})
return
}
c.Next()
}
}
// RequireRole creates a middleware that checks if the user has the required role
func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc {
return func(c *gin.Context) {
// Get user from context
userID, exists := c.Get(auth.UserIDKey)
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
})
return
}
// Get user with roles
user, err := client.User.Query().
Where(user.ID(userID.(int))).
WithRoles().
Only(context.Background())
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "User not found",
})
return
}
// Check if user has the required role
hasRole := false
for _, r := range user.Edges.Roles {
if r.Name == roleName {
hasRole = true
break
}
}
if !hasRole {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("Required role: %s", roleName),
})
return
}
c.Next()
}
}

View file

@ -0,0 +1,159 @@
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
"tss-rocks-be/internal/types"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
)
// ValidateUpload 创建文件上传验证中间件
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查是否是multipart/form-data请求
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Content-Type must be multipart/form-data",
})
c.Abort()
return
}
// 解析multipart表单
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Failed to parse form: %v", err),
})
c.Abort()
return
}
form := c.Request.MultipartForm
if form == nil || form.File == nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "No file uploaded",
})
c.Abort()
return
}
// 遍历所有上传的文件
for _, files := range form.File {
for _, file := range files {
// 检查文件大小
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
})
c.Abort()
return
}
// 检查文件扩展名
ext := strings.ToLower(filepath.Ext(file.Filename))
if !contains(cfg.AllowedExtensions, ext) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File extension %s is not allowed", ext),
})
c.Abort()
return
}
// 打开文件
src, err := file.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to open file: %v", err),
})
c.Abort()
return
}
defer src.Close()
// 读取文件头部用于MIME类型检测
header := make([]byte, maxHeaderBytes)
n, err := src.Read(header)
if err != nil && err != io.EOF {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
header = header[:n]
// 检测MIME类型
contentType := http.DetectContentType(header)
if !contains(cfg.AllowedTypes, contentType) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File type %s is not allowed", contentType),
})
c.Abort()
return
}
// 将文件指针重置到开始位置
_, err = src.Seek(0, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将文件内容读入缓冲区
buf := &bytes.Buffer{}
_, err = io.Copy(buf, src)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将验证过的文件内容和类型保存到上下文中
c.Set("validated_file_"+file.Filename, buf)
c.Set("validated_content_type_"+file.Filename, contentType)
}
}
c.Next()
}
}
// contains 检查切片中是否包含指定的字符串
func contains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
// GetValidatedFile 从上下文中获取验证过的文件内容
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
file, exists := c.Get("validated_file_" + filename)
if !exists {
return nil, "", false
}
contentType, exists := c.Get("validated_content_type_" + filename)
if !exists {
return nil, "", false
}
return file.(*bytes.Buffer), contentType.(string), true
}

View file

@ -0,0 +1,262 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}