package middleware import ( "net/http" "sync" "time" "github.com/gin-gonic/gin" "golang.org/x/time/rate" "tss-rocks-be/internal/types" ) // ipLimiter IP限流器 type ipLimiter struct { limiter *rate.Limiter lastSeen time.Time } // rateLimiter 限流器管理器 type rateLimiter struct { ips map[string]*ipLimiter mu sync.RWMutex config *types.RateLimitConfig routes map[string]*rate.Limiter } // newRateLimiter 创建新的限流器 func newRateLimiter(config *types.RateLimitConfig) *rateLimiter { // 初始化路由限流器 routes := make(map[string]*rate.Limiter) for path, cfg := range config.RouteRates { routes[path] = rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Burst) } rl := &rateLimiter{ ips: make(map[string]*ipLimiter), config: config, routes: routes, } // 启动清理过期IP限流器的goroutine go rl.cleanupIPLimiters() return rl } // cleanupIPLimiters 清理过期的IP限流器 func (rl *rateLimiter) cleanupIPLimiters() { for { time.Sleep(time.Hour) // 每小时清理一次 rl.mu.Lock() for ip, limiter := range rl.ips { if time.Since(limiter.lastSeen) > time.Hour { delete(rl.ips, ip) } } rl.mu.Unlock() } } // getLimiter 获取IP限流器 func (rl *rateLimiter) getLimiter(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() v, exists := rl.ips[ip] if !exists { limiter := rate.NewLimiter(rate.Limit(rl.config.IPRate), rl.config.IPBurst) rl.ips[ip] = &ipLimiter{limiter: limiter, lastSeen: time.Now()} return limiter } v.lastSeen = time.Now() return v.limiter } // RateLimit 创建限流中间件 func RateLimit(config *types.RateLimitConfig) gin.HandlerFunc { rl := newRateLimiter(config) return func(c *gin.Context) { // 检查路由限流 path := c.Request.URL.Path if limiter, ok := rl.routes[path]; ok { if !limiter.Allow() { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "too many requests for this route", }) c.Abort() return } } // 检查IP限流 limiter := rl.getLimiter(c.ClientIP()) if !limiter.Allow() { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "too many requests from this IP", }) c.Abort() return } c.Next() } }