package middleware import ( "encoding/json" "fmt" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" "tss-rocks-be/internal/service" ) // AuthMiddleware creates a middleware for JWT authentication func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) c.Abort() return } parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) c.Abort() return } tokenStr := parts[1] // 检查 token 是否在黑名单中 if tokenBlacklist.IsBlacklisted(tokenStr) { c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"}) c.Abort() return } token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { // 添加调试输出 log.Debug().Str("token", tokenStr).Msg("Parsing token") if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { log.Error().Str("method", token.Method.Alg()).Msg("Invalid signing method") return nil, jwt.ErrSignatureInvalid } return []byte(jwtSecret), nil }) if err != nil { log.Error().Err(err).Msg("Failed to parse token") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { // 添加调试信息 log.Debug().Interface("claims", claims).Msg("Token claims") // 获取用户ID sub, exists := claims["sub"] if !exists { log.Error().Msg("Token does not contain sub claim") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) c.Abort() return } // 打印类型信息 log.Debug(). Str("type", fmt.Sprintf("%T", sub)). Interface("value", sub). Msg("User ID from token") // 将用户 ID 转换为字符串 var userIDStr string switch v := sub.(type) { case string: userIDStr = v case float64: userIDStr = strconv.FormatFloat(v, 'f', 0, 64) case json.Number: userIDStr = v.String() default: userIDStr = fmt.Sprintf("%v", v) } // 验证用户 ID 是否为有效的数字字符串 _, err := strconv.Atoi(userIDStr) if err != nil { log.Error(). Err(err). Str("user_id", userIDStr). Msg("Invalid user ID format") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"}) c.Abort() return } // 获取用户角色 roles, exists := claims["roles"] if !exists { log.Error().Msg("Token does not contain roles claim") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format: missing roles"}) c.Abort() return } // 将角色转换为字符串数组 var roleNames []string if rolesArray, ok := roles.([]interface{}); ok { for _, r := range rolesArray { if roleStr, ok := r.(string); ok { roleNames = append(roleNames, roleStr) } } } // 设置上下文 c.Set("user_id", userIDStr) c.Set("user_roles", roleNames) // 存储角色数组 c.Next() return } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } } } // RoleMiddleware creates a middleware for role-based authorization func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRoles, exists := c.Get("user_roles") if !exists { log.Error().Msg("User roles not found in context") c.JSON(http.StatusUnauthorized, gin.H{"error": "User roles not found"}) c.Abort() return } roleNames, ok := userRoles.([]string) if !ok { log.Error().Msg("Invalid user roles type") c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user roles type"}) c.Abort() return } // 检查用户是否拥有任一所需角色 for _, requiredRole := range roles { for _, userRole := range roleNames { if requiredRole == userRole { c.Next() return } } } log.Warn(). Strs("required_roles", roles). Strs("user_roles", roleNames). Msg("Insufficient permissions") c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.Abort() } }