207 lines
5.3 KiB
Go
207 lines
5.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
"tss-rocks-be/internal/types"
|
|
)
|
|
|
|
func TestRateLimit(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
config *types.RateLimitConfig
|
|
setupTest func(*gin.Engine)
|
|
runTest func(*testing.T, *gin.Engine)
|
|
expectedStatus int
|
|
expectedBody map[string]string
|
|
}{
|
|
{
|
|
name: "IP rate limit",
|
|
config: &types.RateLimitConfig{
|
|
IPRate: 1, // 每秒1个请求
|
|
IPBurst: 1,
|
|
},
|
|
setupTest: func(router *gin.Engine) {
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
},
|
|
runTest: func(t *testing.T, router *gin.Engine) {
|
|
// 第一个请求应该成功
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
// 第二个请求应该被限制
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
var response map[string]string
|
|
err := json.NewDecoder(rec.Body).Decode(&response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "too many requests from this IP", response["error"])
|
|
|
|
// 等待限流器重置
|
|
time.Sleep(time.Second)
|
|
|
|
// 第三个请求应该成功
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
},
|
|
},
|
|
{
|
|
name: "Route rate limit",
|
|
config: &types.RateLimitConfig{
|
|
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
|
IPBurst: 10,
|
|
RouteRates: map[string]struct {
|
|
Rate int `yaml:"rate"`
|
|
Burst int `yaml:"burst"`
|
|
}{
|
|
"/limited": {
|
|
Rate: 1,
|
|
Burst: 1,
|
|
},
|
|
},
|
|
},
|
|
setupTest: func(router *gin.Engine) {
|
|
router.GET("/limited", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
router.GET("/unlimited", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
},
|
|
runTest: func(t *testing.T, router *gin.Engine) {
|
|
// 测试限流路由
|
|
req := httptest.NewRequest("GET", "/limited", nil)
|
|
req.RemoteAddr = "192.168.1.2:1234"
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
// 等待一小段时间确保限流器生效
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
var response map[string]string
|
|
err := json.NewDecoder(rec.Body).Decode(&response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "too many requests for this route", response["error"])
|
|
|
|
// 测试未限流路由
|
|
req = httptest.NewRequest("GET", "/unlimited", nil)
|
|
req.RemoteAddr = "192.168.1.2:1234"
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
// 等待一小段时间确保限流器生效
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
},
|
|
},
|
|
{
|
|
name: "Multiple IPs",
|
|
config: &types.RateLimitConfig{
|
|
IPRate: 1,
|
|
IPBurst: 1,
|
|
},
|
|
setupTest: func(router *gin.Engine) {
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
},
|
|
runTest: func(t *testing.T, router *gin.Engine) {
|
|
// IP1 的请求
|
|
req1 := httptest.NewRequest("GET", "/test", nil)
|
|
req1.RemoteAddr = "192.168.1.3:1234"
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req1)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req1)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
|
|
// IP2 的请求应该不受 IP1 的限制影响
|
|
req2 := httptest.NewRequest("GET", "/test", nil)
|
|
req2.RemoteAddr = "192.168.1.4:1234"
|
|
rec = httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req2)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
|
|
// 添加限流中间件
|
|
router.Use(RateLimit(tc.config))
|
|
|
|
// 设置测试路由
|
|
tc.setupTest(router)
|
|
|
|
// 运行测试
|
|
tc.runTest(t, router)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterCleanup(t *testing.T) {
|
|
config := &types.RateLimitConfig{
|
|
IPRate: 1,
|
|
IPBurst: 1,
|
|
}
|
|
|
|
rl := newRateLimiter(config)
|
|
|
|
// 添加一些IP限流器
|
|
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
|
for _, ip := range ips {
|
|
rl.getLimiter(ip)
|
|
}
|
|
|
|
// 验证IP限流器已创建
|
|
rl.mu.RLock()
|
|
assert.Equal(t, len(ips), len(rl.ips))
|
|
rl.mu.RUnlock()
|
|
|
|
// 修改一些IP的最后访问时间为1小时前
|
|
rl.mu.Lock()
|
|
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
|
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
|
rl.mu.Unlock()
|
|
|
|
// 手动触发清理
|
|
rl.mu.Lock()
|
|
for ip, limiter := range rl.ips {
|
|
if time.Since(limiter.lastSeen) > time.Hour {
|
|
delete(rl.ips, ip)
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
|
|
// 验证过期的IP限流器已被删除
|
|
rl.mu.RLock()
|
|
assert.Equal(t, 1, len(rl.ips))
|
|
_, exists := rl.ips["192.168.1.3"]
|
|
assert.True(t, exists)
|
|
rl.mu.RUnlock()
|
|
}
|