228 lines
6 KiB
Go
228 lines
6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"tss-rocks-be/internal/service"
|
|
)
|
|
|
|
func createTestToken(secret string, claims jwt.MapClaims) string {
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signedToken, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Failed to sign token: %v", err))
|
|
}
|
|
return signedToken
|
|
}
|
|
|
|
func TestAuthMiddleware(t *testing.T) {
|
|
jwtSecret := "test-secret"
|
|
tokenBlacklist := service.NewTokenBlacklist()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
setupAuth func(*http.Request)
|
|
expectedStatus int
|
|
expectedBody map[string]string
|
|
checkUserData bool
|
|
expectedUserID string
|
|
expectedRole string
|
|
}{
|
|
{
|
|
name: "No Authorization header",
|
|
setupAuth: func(req *http.Request) {},
|
|
expectedStatus: http.StatusUnauthorized,
|
|
expectedBody: map[string]string{"error": "Authorization header is required"},
|
|
},
|
|
{
|
|
name: "Invalid Authorization format",
|
|
setupAuth: func(req *http.Request) {
|
|
req.Header.Set("Authorization", "InvalidFormat")
|
|
},
|
|
expectedStatus: http.StatusUnauthorized,
|
|
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
|
},
|
|
{
|
|
name: "Invalid token",
|
|
setupAuth: func(req *http.Request) {
|
|
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
|
},
|
|
expectedStatus: http.StatusUnauthorized,
|
|
expectedBody: map[string]string{"error": "Invalid token"},
|
|
},
|
|
{
|
|
name: "Valid token",
|
|
setupAuth: func(req *http.Request) {
|
|
claims := jwt.MapClaims{
|
|
"sub": "123",
|
|
"role": "user",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
}
|
|
token := createTestToken(jwtSecret, claims)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkUserData: true,
|
|
expectedUserID: "123",
|
|
expectedRole: "user",
|
|
},
|
|
{
|
|
name: "Expired token",
|
|
setupAuth: func(req *http.Request) {
|
|
claims := jwt.MapClaims{
|
|
"sub": "123",
|
|
"role": "user",
|
|
"exp": time.Now().Add(-time.Hour).Unix(),
|
|
}
|
|
token := createTestToken(jwtSecret, claims)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
},
|
|
expectedStatus: http.StatusUnauthorized,
|
|
expectedBody: map[string]string{"error": "Invalid token"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
|
|
// 添加认证中间件
|
|
router.Use(func(c *gin.Context) {
|
|
// 设置日志级别为 debug
|
|
gin.SetMode(gin.DebugMode)
|
|
c.Next()
|
|
}, AuthMiddleware(jwtSecret, tokenBlacklist))
|
|
|
|
// 测试路由
|
|
router.GET("/test", func(c *gin.Context) {
|
|
if tc.checkUserData {
|
|
userID, exists := c.Get("user_id")
|
|
assert.True(t, exists, "user_id should exist in context")
|
|
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
|
|
|
|
role, exists := c.Get("user_role")
|
|
assert.True(t, exists, "user_role should exist in context")
|
|
assert.Equal(t, tc.expectedRole, role, "user_role should match")
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
// 创建请求
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
tc.setupAuth(req)
|
|
rec := httptest.NewRecorder()
|
|
|
|
// 执行请求
|
|
router.ServeHTTP(rec, req)
|
|
|
|
// 验证响应
|
|
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
|
|
|
if tc.expectedBody != nil {
|
|
var response map[string]string
|
|
err := json.NewDecoder(rec.Body).Decode(&response)
|
|
assert.NoError(t, err, "Response body should be valid JSON")
|
|
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRoleMiddleware(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
setupContext func(*gin.Context)
|
|
allowedRoles []string
|
|
expectedStatus int
|
|
expectedBody map[string]string
|
|
}{
|
|
{
|
|
name: "No user role",
|
|
setupContext: func(c *gin.Context) {
|
|
// 不设置用户角色
|
|
},
|
|
allowedRoles: []string{"admin"},
|
|
expectedStatus: http.StatusUnauthorized,
|
|
expectedBody: map[string]string{"error": "User role not found"},
|
|
},
|
|
{
|
|
name: "Invalid role type",
|
|
setupContext: func(c *gin.Context) {
|
|
c.Set("user_role", 123) // 设置错误类型的角色
|
|
},
|
|
allowedRoles: []string{"admin"},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
expectedBody: map[string]string{"error": "Invalid user role type"},
|
|
},
|
|
{
|
|
name: "Insufficient permissions",
|
|
setupContext: func(c *gin.Context) {
|
|
c.Set("user_role", "user")
|
|
},
|
|
allowedRoles: []string{"admin"},
|
|
expectedStatus: http.StatusForbidden,
|
|
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
|
},
|
|
{
|
|
name: "Allowed role",
|
|
setupContext: func(c *gin.Context) {
|
|
c.Set("user_role", "admin")
|
|
},
|
|
allowedRoles: []string{"admin"},
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "One of multiple allowed roles",
|
|
setupContext: func(c *gin.Context) {
|
|
c.Set("user_role", "editor")
|
|
},
|
|
allowedRoles: []string{"admin", "editor", "moderator"},
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
|
|
// 添加角色中间件
|
|
router.Use(func(c *gin.Context) {
|
|
tc.setupContext(c)
|
|
c.Next()
|
|
})
|
|
router.Use(RoleMiddleware(tc.allowedRoles...))
|
|
|
|
// 测试路由
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
// 创建请求
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
// 执行请求
|
|
router.ServeHTTP(rec, req)
|
|
|
|
// 验证响应
|
|
assert.Equal(t, tc.expectedStatus, rec.Code)
|
|
|
|
if tc.expectedBody != nil {
|
|
var response map[string]string
|
|
err := json.NewDecoder(rec.Body).Decode(&response)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tc.expectedBody, response)
|
|
}
|
|
})
|
|
}
|
|
}
|