76 lines
1.9 KiB
Go
76 lines
1.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestCORS(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
method string
|
|
expectedStatus int
|
|
checkHeaders bool
|
|
}{
|
|
{
|
|
name: "Normal GET request",
|
|
method: "GET",
|
|
expectedStatus: http.StatusOK,
|
|
checkHeaders: true,
|
|
},
|
|
{
|
|
name: "OPTIONS request",
|
|
method: "OPTIONS",
|
|
expectedStatus: http.StatusNoContent,
|
|
checkHeaders: true,
|
|
},
|
|
{
|
|
name: "POST request",
|
|
method: "POST",
|
|
expectedStatus: http.StatusOK,
|
|
checkHeaders: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// 创建一个新的 gin 引擎
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
|
|
// 添加 CORS 中间件
|
|
router.Use(CORS())
|
|
|
|
// 添加测试路由
|
|
router.Any("/test", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
// 创建测试请求
|
|
req := httptest.NewRequest(tc.method, "/test", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
// 执行请求
|
|
router.ServeHTTP(rec, req)
|
|
|
|
// 验证状态码
|
|
assert.Equal(t, tc.expectedStatus, rec.Code)
|
|
|
|
if tc.checkHeaders {
|
|
// 验证 CORS 头部
|
|
headers := rec.Header()
|
|
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
|
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
|
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
|
}
|
|
})
|
|
}
|
|
}
|