diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 1132377..3816d1b 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -100,20 +100,30 @@ func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gi return } - // 设置用户 ID - c.Set("user_id", userIDStr) - - // 设置用户角色 - if role, ok := claims["role"].(string); ok { - log.Debug().Str("role", role).Msg("Found user role") - c.Set("user_role", role) - } else { - log.Error().Interface("role", claims["role"]).Msg("Invalid or missing role claim") - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid role format"}) + // 获取用户角色 + roles, exists := claims["roles"] + if !exists { + log.Error().Msg("Token does not contain roles claim") + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format: missing roles"}) c.Abort() return } + + // 将角色转换为字符串数组 + var roleNames []string + if rolesArray, ok := roles.([]interface{}); ok { + for _, r := range rolesArray { + if roleStr, ok := r.(string); ok { + roleNames = append(roleNames, roleStr) + } + } + } + + // 设置上下文 + c.Set("user_id", userIDStr) + c.Set("user_roles", roleNames) // 存储角色数组 c.Next() + return } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() @@ -125,27 +135,37 @@ func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gi // RoleMiddleware creates a middleware for role-based authorization func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { - userRole, exists := c.Get("user_role") + userRoles, exists := c.Get("user_roles") if !exists { - c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"}) + log.Error().Msg("User roles not found in context") + c.JSON(http.StatusUnauthorized, gin.H{"error": "User roles not found"}) c.Abort() return } - roleStr, ok := userRole.(string) + roleNames, ok := userRoles.([]string) if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"}) + log.Error().Msg("Invalid user roles type") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user roles type"}) c.Abort() return } - for _, role := range roles { - if role == roleStr { - c.Next() - return + // 检查用户是否拥有任一所需角色 + for _, requiredRole := range roles { + for _, userRole := range roleNames { + if requiredRole == userRole { + c.Next() + return + } } } + log.Warn(). + Strs("required_roles", roles). + Strs("user_roles", roleNames). + Msg("Insufficient permissions") + c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.Abort() } diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go index 6c8b249..78e2f03 100644 --- a/backend/internal/middleware/auth_test.go +++ b/backend/internal/middleware/auth_test.go @@ -34,7 +34,7 @@ func TestAuthMiddleware(t *testing.T) { expectedBody map[string]string checkUserData bool expectedUserID string - expectedRole string + expectedRoles []string }{ { name: "No Authorization header", @@ -62,9 +62,9 @@ func TestAuthMiddleware(t *testing.T) { name: "Valid token", setupAuth: func(req *http.Request) { claims := jwt.MapClaims{ - "sub": "123", - "role": "user", - "exp": time.Now().Add(time.Hour).Unix(), + "sub": "123", + "roles": []string{"admin", "editor"}, + "exp": time.Now().Add(time.Hour).Unix(), } token := createTestToken(jwtSecret, claims) req.Header.Set("Authorization", "Bearer "+token) @@ -72,15 +72,15 @@ func TestAuthMiddleware(t *testing.T) { expectedStatus: http.StatusOK, checkUserData: true, expectedUserID: "123", - expectedRole: "user", + expectedRoles: []string{"admin", "editor"}, }, { name: "Expired token", setupAuth: func(req *http.Request) { claims := jwt.MapClaims{ - "sub": "123", - "role": "user", - "exp": time.Now().Add(-time.Hour).Unix(), + "sub": "123", + "roles": []string{"user"}, + "exp": time.Now().Add(-time.Hour).Unix(), } token := createTestToken(jwtSecret, claims) req.Header.Set("Authorization", "Bearer "+token) @@ -109,9 +109,9 @@ func TestAuthMiddleware(t *testing.T) { assert.True(t, exists, "user_id should exist in context") assert.Equal(t, tc.expectedUserID, userID, "user_id should match") - role, exists := c.Get("user_role") - assert.True(t, exists, "user_role should exist in context") - assert.Equal(t, tc.expectedRole, role, "user_role should match") + roles, exists := c.Get("user_roles") + assert.True(t, exists, "user_roles should exist in context") + assert.Equal(t, tc.expectedRoles, roles, "user_roles should match") } c.Status(http.StatusOK) }) @@ -146,27 +146,27 @@ func TestRoleMiddleware(t *testing.T) { expectedBody map[string]string }{ { - name: "No user role", + name: "No user roles", setupContext: func(c *gin.Context) { // 不设置用户角色 }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "User role not found"}, + expectedBody: map[string]string{"error": "User roles not found"}, }, { - name: "Invalid role type", + name: "Invalid roles type", setupContext: func(c *gin.Context) { - c.Set("user_role", 123) // 设置错误类型的角色 + c.Set("user_roles", 123) // 设置错误类型的角色 }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusInternalServerError, - expectedBody: map[string]string{"error": "Invalid user role type"}, + expectedBody: map[string]string{"error": "Invalid user roles type"}, }, { name: "Insufficient permissions", setupContext: func(c *gin.Context) { - c.Set("user_role", "user") + c.Set("user_roles", []string{"user"}) }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusForbidden, @@ -175,7 +175,7 @@ func TestRoleMiddleware(t *testing.T) { { name: "Allowed role", setupContext: func(c *gin.Context) { - c.Set("user_role", "admin") + c.Set("user_roles", []string{"admin"}) }, allowedRoles: []string{"admin"}, expectedStatus: http.StatusOK, @@ -183,7 +183,7 @@ func TestRoleMiddleware(t *testing.T) { { name: "One of multiple allowed roles", setupContext: func(c *gin.Context) { - c.Set("user_role", "editor") + c.Set("user_roles", []string{"user", "editor"}) }, allowedRoles: []string{"admin", "editor", "moderator"}, expectedStatus: http.StatusOK, @@ -199,8 +199,7 @@ func TestRoleMiddleware(t *testing.T) { router.Use(func(c *gin.Context) { tc.setupContext(c) c.Next() - }) - router.Use(RoleMiddleware(tc.allowedRoles...)) + }, RoleMiddleware(tc.allowedRoles...)) // 测试路由 router.GET("/test", func(c *gin.Context) { @@ -215,13 +214,13 @@ func TestRoleMiddleware(t *testing.T) { router.ServeHTTP(rec, req) // 验证响应 - assert.Equal(t, tc.expectedStatus, rec.Code) + assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match") if tc.expectedBody != nil { var response map[string]string err := json.NewDecoder(rec.Body).Decode(&response) - assert.NoError(t, err) - assert.Equal(t, tc.expectedBody, response) + assert.NoError(t, err, "Response body should be valid JSON") + assert.Equal(t, tc.expectedBody, response, "Response body should match") } }) }