package middleware import ( "encoding/json" "fmt" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" "tss-rocks-be/internal/service" ) // AuthMiddleware creates a middleware for JWT authentication func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) c.Abort() return } parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) c.Abort() return } tokenStr := parts[1] // 检查 token 是否在黑名单中 if tokenBlacklist.IsBlacklisted(tokenStr) { c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"}) c.Abort() return } token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { // 添加调试输出 log.Debug().Str("token", tokenStr).Msg("Parsing token") if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { log.Error().Str("method", token.Method.Alg()).Msg("Invalid signing method") return nil, jwt.ErrSignatureInvalid } return []byte(jwtSecret), nil }) if err != nil { log.Error().Err(err).Msg("Failed to parse token") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { // 添加调试信息 log.Debug().Interface("claims", claims).Msg("Token claims") // 获取用户ID sub, exists := claims["sub"] if !exists { log.Error().Msg("Token does not contain sub claim") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) c.Abort() return } // 打印类型信息 log.Debug(). Str("type", fmt.Sprintf("%T", sub)). Interface("value", sub). Msg("User ID from token") // 将用户 ID 转换为字符串 var userIDStr string switch v := sub.(type) { case string: userIDStr = v case float64: userIDStr = strconv.FormatFloat(v, 'f', 0, 64) case json.Number: userIDStr = v.String() default: userIDStr = fmt.Sprintf("%v", v) } // 验证用户 ID 是否为有效的数字字符串 _, err := strconv.Atoi(userIDStr) if err != nil { log.Error(). Err(err). Str("user_id", userIDStr). Msg("Invalid user ID format") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"}) c.Abort() return } // 设置用户 ID c.Set("user_id", userIDStr) // 设置用户角色 if role, ok := claims["role"].(string); ok { log.Debug().Str("role", role).Msg("Found user role") c.Set("user_role", role) } else { log.Error().Interface("role", claims["role"]).Msg("Invalid or missing role claim") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid role format"}) c.Abort() return } c.Next() } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } } } // RoleMiddleware creates a middleware for role-based authorization func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("user_role") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"}) c.Abort() return } roleStr, ok := userRole.(string) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"}) c.Abort() return } for _, role := range roles { if role == roleStr { c.Next() return } } c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.Abort() } }