package middleware import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "testing" "time" "tss-rocks-be/internal/service" ) func createTestToken(secret string, claims jwt.MapClaims) string { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString([]byte(secret)) if err != nil { panic(fmt.Sprintf("Failed to sign token: %v", err)) } return signedToken } func TestAuthMiddleware(t *testing.T) { jwtSecret := "test-secret" tokenBlacklist := service.NewTokenBlacklist() testCases := []struct { name string setupAuth func(*http.Request) expectedStatus int expectedBody map[string]string checkUserData bool expectedUserID string expectedRole string }{ { name: "No Authorization header", setupAuth: func(req *http.Request) {}, expectedStatus: http.StatusUnauthorized, expectedBody: map[string]string{"error": "Authorization header is required"}, }, { name: "Invalid Authorization format", setupAuth: func(req *http.Request) { req.Header.Set("Authorization", "InvalidFormat") }, expectedStatus: http.StatusUnauthorized, expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"}, }, { name: "Invalid token", setupAuth: func(req *http.Request) { req.Header.Set("Authorization", "Bearer invalid.token.here") }, expectedStatus: http.StatusUnauthorized, expectedBody: map[string]string{"error": "Invalid token"}, }, { name: "Valid token", setupAuth: func(req *http.Request) { claims := jwt.MapClaims{ "sub": "123", "role": "user", "exp": time.Now().Add(time.Hour).Unix(), } token := createTestToken(jwtSecret, claims) req.Header.Set("Authorization", "Bearer "+token) }, expectedStatus: http.StatusOK, checkUserData: true, expectedUserID: "123", expectedRole: "user", }, { name: "Expired token", setupAuth: func(req *http.Request) { claims := jwt.MapClaims{ "sub": "123", "role": "user", "exp": time.Now().Add(-time.Hour).Unix(), } token := createTestToken(jwtSecret, claims) req.Header.Set("Authorization", "Bearer "+token) }, expectedStatus: http.StatusUnauthorized, expectedBody: map[string]string{"error": "Invalid token"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() // 添加认证中间件 router.Use(func(c *gin.Context) { // 设置日志级别为 debug gin.SetMode(gin.DebugMode) c.Next() }, AuthMiddleware(jwtSecret, tokenBlacklist)) // 测试路由 router.GET("/test", func(c *gin.Context) { if tc.checkUserData { userID, exists := c.Get("user_id") assert.True(t, exists, "user_id should exist in context") assert.Equal(t, tc.expectedUserID, userID, "user_id should match") role, exists := c.Get("user_role") assert.True(t, exists, "user_role should exist in context") assert.Equal(t, tc.expectedRole, role, "user_role should match") } c.Status(http.StatusOK) }) // 创建请求 req := httptest.NewRequest("GET", "/test", nil) tc.setupAuth(req) rec := httptest.NewRecorder() // 执行请求 router.ServeHTTP(rec, req) // 验证响应 assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match") if tc.expectedBody != nil { var response map[string]string err := json.NewDecoder(rec.Body).Decode(&response) assert.NoError(t, err, "Response body should be valid JSON") assert.Equal(t, tc.expectedBody, response, "Response body should match") } }) } } func TestRoleMiddleware(t *testing.T) { testCases := []struct { name string setupContext func(*gin.Context) allowedRoles []string expectedStatus int expectedBody map[string]string }{ { name: "No user role", setupContext: func(c *gin.Context) { // 不设置用户角色 }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusUnauthorized, expectedBody: map[string]string{"error": "User role not found"}, }, { name: "Invalid role type", setupContext: func(c *gin.Context) { c.Set("user_role", 123) // 设置错误类型的角色 }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusInternalServerError, expectedBody: map[string]string{"error": "Invalid user role type"}, }, { name: "Insufficient permissions", setupContext: func(c *gin.Context) { c.Set("user_role", "user") }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusForbidden, expectedBody: map[string]string{"error": "Insufficient permissions"}, }, { name: "Allowed role", setupContext: func(c *gin.Context) { c.Set("user_role", "admin") }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusOK, }, { name: "One of multiple allowed roles", setupContext: func(c *gin.Context) { c.Set("user_role", "editor") }, allowedRoles: []string{"admin", "editor", "moderator"}, expectedStatus: http.StatusOK, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() // 添加角色中间件 router.Use(func(c *gin.Context) { tc.setupContext(c) c.Next() }) router.Use(RoleMiddleware(tc.allowedRoles...)) // 测试路由 router.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) // 创建请求 req := httptest.NewRequest("GET", "/test", nil) rec := httptest.NewRecorder() // 执行请求 router.ServeHTTP(rec, req) // 验证响应 assert.Equal(t, tc.expectedStatus, rec.Code) if tc.expectedBody != nil { var response map[string]string err := json.NewDecoder(rec.Body).Decode(&response) assert.NoError(t, err) assert.Equal(t, tc.expectedBody, response) } }) } }