tss-rocks/backend/internal/middleware/ratelimit_test.go
CDN 05ddc1f783
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s
[feature] migrate to monorepo
2025-02-21 00:49:20 +08:00

207 lines
5.3 KiB
Go

package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}