110 lines
2.3 KiB
Go
110 lines
2.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"tss-rocks-be/ent"
|
|
"tss-rocks-be/ent/user"
|
|
"tss-rocks-be/internal/auth"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// RequirePermission creates a middleware that checks if the user has the required permission
|
|
func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get user from context
|
|
userID, exists := c.Get(auth.UserIDKey)
|
|
if !exists {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "Unauthorized",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get user with roles
|
|
user, err := client.User.Query().
|
|
Where(user.ID(userID.(int))).
|
|
WithRoles(func(q *ent.RoleQuery) {
|
|
q.WithPermissions()
|
|
}).
|
|
Only(context.Background())
|
|
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "User not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if user has the required permission through any of their roles
|
|
hasPermission := false
|
|
for _, r := range user.Edges.Roles {
|
|
for _, p := range r.Edges.Permissions {
|
|
if p.Resource == resource && p.Action == action {
|
|
hasPermission = true
|
|
break
|
|
}
|
|
}
|
|
if hasPermission {
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hasPermission {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": fmt.Sprintf("Missing required permission: %s:%s", resource, action),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireRole creates a middleware that checks if the user has the required role
|
|
func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get user from context
|
|
userID, exists := c.Get(auth.UserIDKey)
|
|
if !exists {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "Unauthorized",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get user with roles
|
|
user, err := client.User.Query().
|
|
Where(user.ID(userID.(int))).
|
|
WithRoles().
|
|
Only(context.Background())
|
|
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
|
"error": "User not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if user has the required role
|
|
hasRole := false
|
|
for _, r := range user.Edges.Roles {
|
|
if r.Name == roleName {
|
|
hasRole = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hasRole {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
|
"error": fmt.Sprintf("Required role: %s", roleName),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|