package middleware import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "testing" "time" "tss-rocks-be/internal/types" ) func TestRateLimit(t *testing.T) { testCases := []struct { name string config *types.RateLimitConfig setupTest func(*gin.Engine) runTest func(*testing.T, *gin.Engine) expectedStatus int expectedBody map[string]string }{ { name: "IP rate limit", config: &types.RateLimitConfig{ IPRate: 1, // 每秒1个请求 IPBurst: 1, }, setupTest: func(router *gin.Engine) { router.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) }, runTest: func(t *testing.T, router *gin.Engine) { // 第一个请求应该成功 req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.1:1234" rec := httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) // 第二个请求应该被限制 rec = httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) var response map[string]string err := json.NewDecoder(rec.Body).Decode(&response) assert.NoError(t, err) assert.Equal(t, "too many requests from this IP", response["error"]) // 等待限流器重置 time.Sleep(time.Second) // 第三个请求应该成功 rec = httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) }, }, { name: "Route rate limit", config: &types.RateLimitConfig{ IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流 IPBurst: 10, RouteRates: map[string]struct { Rate int `yaml:"rate"` Burst int `yaml:"burst"` }{ "/limited": { Rate: 1, Burst: 1, }, }, }, setupTest: func(router *gin.Engine) { router.GET("/limited", func(c *gin.Context) { c.Status(http.StatusOK) }) router.GET("/unlimited", func(c *gin.Context) { c.Status(http.StatusOK) }) }, runTest: func(t *testing.T, router *gin.Engine) { // 测试限流路由 req := httptest.NewRequest("GET", "/limited", nil) req.RemoteAddr = "192.168.1.2:1234" rec := httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) // 等待一小段时间确保限流器生效 time.Sleep(10 * time.Millisecond) rec = httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code) var response map[string]string err := json.NewDecoder(rec.Body).Decode(&response) assert.NoError(t, err) assert.Equal(t, "too many requests for this route", response["error"]) // 测试未限流路由 req = httptest.NewRequest("GET", "/unlimited", nil) req.RemoteAddr = "192.168.1.2:1234" rec = httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) // 等待一小段时间确保限流器生效 time.Sleep(10 * time.Millisecond) rec = httptest.NewRecorder() router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) }, }, { name: "Multiple IPs", config: &types.RateLimitConfig{ IPRate: 1, IPBurst: 1, }, setupTest: func(router *gin.Engine) { router.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) }, runTest: func(t *testing.T, router *gin.Engine) { // IP1 的请求 req1 := httptest.NewRequest("GET", "/test", nil) req1.RemoteAddr = "192.168.1.3:1234" rec := httptest.NewRecorder() router.ServeHTTP(rec, req1) assert.Equal(t, http.StatusOK, rec.Code) rec = httptest.NewRecorder() router.ServeHTTP(rec, req1) assert.Equal(t, http.StatusTooManyRequests, rec.Code) // IP2 的请求应该不受 IP1 的限制影响 req2 := httptest.NewRequest("GET", "/test", nil) req2.RemoteAddr = "192.168.1.4:1234" rec = httptest.NewRecorder() router.ServeHTTP(rec, req2) assert.Equal(t, http.StatusOK, rec.Code) }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() // 添加限流中间件 router.Use(RateLimit(tc.config)) // 设置测试路由 tc.setupTest(router) // 运行测试 tc.runTest(t, router) }) } } func TestRateLimiterCleanup(t *testing.T) { config := &types.RateLimitConfig{ IPRate: 1, IPBurst: 1, } rl := newRateLimiter(config) // 添加一些IP限流器 ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"} for _, ip := range ips { rl.getLimiter(ip) } // 验证IP限流器已创建 rl.mu.RLock() assert.Equal(t, len(ips), len(rl.ips)) rl.mu.RUnlock() // 修改一些IP的最后访问时间为1小时前 rl.mu.Lock() rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour) rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour) rl.mu.Unlock() // 手动触发清理 rl.mu.Lock() for ip, limiter := range rl.ips { if time.Since(limiter.lastSeen) > time.Hour { delete(rl.ips, ip) } } rl.mu.Unlock() // 验证过期的IP限流器已被删除 rl.mu.RLock() assert.Equal(t, 1, len(rl.ips)) _, exists := rl.ips["192.168.1.3"] assert.True(t, exists) rl.mu.RUnlock() }