[feature/backend] implement /auth/logout handling + overall enhancement

This commit is contained in:
CDN 2025-02-21 05:44:18 +08:00
parent d8d8e4b0d7
commit e5fc8691bf
Signed by: CDN
GPG key ID: 0C656827F9F80080
12 changed files with 420 additions and 64 deletions

View file

@ -2,6 +2,7 @@ package handler
import ( import (
"net/http" "net/http"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -28,8 +29,8 @@ type AuthResponse struct {
func (h *Handler) Register(c *gin.Context) { func (h *Handler) Register(c *gin.Context) {
// 检查是否启用注册功能 // 检查是否启用注册功能
if !h.config.Auth.Registration.Enabled { if !h.cfg.Auth.Registration.Enabled {
message := h.config.Auth.Registration.Message message := h.cfg.Auth.Registration.Message
if message == "" { if message == "" {
message = "Registration is currently disabled" message = "Registration is currently disabled"
} }
@ -181,3 +182,72 @@ func (h *Handler) Login(c *gin.Context) {
c.JSON(http.StatusOK, AuthResponse{Token: tokenString}) c.JSON(http.StatusOK, AuthResponse{Token: tokenString})
} }
func (h *Handler) Logout(c *gin.Context) {
// 获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "User not authenticated",
},
})
return
}
// 获取 token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "Authorization header is required",
},
})
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "Authorization header format must be Bearer {token}",
},
})
return
}
// 解析 token 以获取过期时间
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(h.cfg.JWT.Secret), nil
})
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "INVALID_TOKEN",
"message": "Invalid token",
},
})
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// 将 token 添加到黑名单
h.service.GetTokenBlacklist().AddToBlacklist(parts[1], claims)
}
// 记录日志
log.Info().
Interface("user_id", userID).
Msg("User logged out")
c.JSON(http.StatusOK, gin.H{
"message": "Successfully logged out",
})
}

View file

@ -150,7 +150,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
// 设置注册功能状态 // 设置注册功能状态
s.handler.config.Auth.Registration.Enabled = tc.registration s.handler.cfg.Auth.Registration.Enabled = tc.registration
// 设置 mock // 设置 mock
tc.setupMock() tc.setupMock()

View file

@ -9,6 +9,7 @@ import (
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/ent/categorycontent" "tss-rocks-be/ent/categorycontent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/types" "tss-rocks-be/internal/types"
@ -59,12 +60,24 @@ type CategoryHandlerTestSuite struct {
func (s *CategoryHandlerTestSuite) SetupTest() { func (s *CategoryHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type ContributorHandlerTestSuite struct {
func (s *ContributorHandlerTestSuite) SetupTest() { func (s *ContributorHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type DailyHandlerTestSuite struct {
func (s *DailyHandlerTestSuite) SetupTest() { func (s *DailyHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -14,14 +14,12 @@ import (
type Handler struct { type Handler struct {
cfg *config.Config cfg *config.Config
config *config.Config
service service.Service service service.Service
} }
func NewHandler(cfg *config.Config, service service.Service) *Handler { func NewHandler(cfg *config.Config, service service.Service) *Handler {
return &Handler{ return &Handler{
cfg: cfg, cfg: cfg,
config: cfg,
service: service, service: service,
} }
} }
@ -35,10 +33,11 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
{ {
auth.POST("/register", h.Register) auth.POST("/register", h.Register)
auth.POST("/login", h.Login) auth.POST("/login", h.Login)
auth.POST("/logout", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()), h.Logout)
} }
// User routes // User routes
users := api.Group("/users", middleware.AuthMiddleware(h.config.JWT.Secret)) users := api.Group("/users", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
{ {
users.GET("", h.ListUsers) users.GET("", h.ListUsers)
users.POST("", h.CreateUser) users.POST("", h.CreateUser)
@ -86,7 +85,7 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
} }
// Media routes // Media routes
media := api.Group("/media") media := api.Group("/media", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
{ {
media.GET("", h.ListMedia) media.GET("", h.ListMedia)
media.POST("", h.UploadMedia) media.POST("", h.UploadMedia)

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type PostHandlerTestSuite struct {
func (s *PostHandlerTestSuite) SetupTest() { func (s *PostHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -10,10 +10,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tss-rocks-be/internal/service"
) )
// AuthMiddleware creates a middleware for JWT authentication // AuthMiddleware creates a middleware for JWT authentication
func AuthMiddleware(jwtSecret string) gin.HandlerFunc { func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
@ -29,8 +30,20 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return return
} }
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) { tokenStr := parts[1]
// 检查 token 是否在黑名单中
if tokenBlacklist.IsBlacklisted(tokenStr) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"})
c.Abort()
return
}
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
// 添加调试输出
log.Debug().Str("token", tokenStr).Msg("Parsing token")
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
log.Error().Str("method", token.Method.Alg()).Msg("Invalid signing method")
return nil, jwt.ErrSignatureInvalid return nil, jwt.ErrSignatureInvalid
} }
return []byte(jwtSecret), nil return []byte(jwtSecret), nil
@ -62,42 +75,43 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
Interface("value", sub). Interface("value", sub).
Msg("User ID from token") Msg("User ID from token")
var userID int // 将用户 ID 转换为字符串
var userIDStr string
switch v := sub.(type) { switch v := sub.(type) {
case string: case string:
var err error userIDStr = v
userID, err = strconv.Atoi(v)
if err != nil {
log.Error().Err(err).Str("sub", v).Msg("Failed to convert string user ID to int")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"})
c.Abort()
return
}
case float64: case float64:
userID = int(v) userIDStr = strconv.FormatFloat(v, 'f', 0, 64)
case json.Number: case json.Number:
var err error userIDStr = v.String()
userID, err = strconv.Atoi(v.String())
if err != nil {
log.Error().Err(err).Str("sub", v.String()).Msg("Failed to convert json.Number user ID to int")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"})
c.Abort()
return
}
default: default:
userIDStr = fmt.Sprintf("%v", v)
}
// 验证用户 ID 是否为有效的数字字符串
_, err := strconv.Atoi(userIDStr)
if err != nil {
log.Error(). log.Error().
Str("type", fmt.Sprintf("%T", sub)). Err(err).
Interface("value", sub). Str("user_id", userIDStr).
Msg("Unexpected user ID type") Msg("Invalid user ID format")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID type"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"})
c.Abort() c.Abort()
return return
} }
// 将 userID 转换为 int64 以确保类型一致性 // 设置用户 ID
c.Set("user_id", int64(userID)) c.Set("user_id", userIDStr)
if roles, ok := claims["roles"].([]interface{}); ok {
c.Set("user_roles", roles) // 设置用户角色
if role, ok := claims["role"].(string); ok {
log.Debug().Str("role", role).Msg("Found user role")
c.Set("user_role", role)
} else {
log.Error().Interface("role", claims["role"]).Msg("Invalid or missing role claim")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid role format"})
c.Abort()
return
} }
c.Next() c.Next()
} else { } else {

View file

@ -2,6 +2,7 @@ package middleware
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -9,16 +10,22 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"tss-rocks-be/internal/service"
) )
func createTestToken(secret string, claims jwt.MapClaims) string { func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, _ := token.SignedString([]byte(secret)) signedToken, err := token.SignedString([]byte(secret))
if err != nil {
panic(fmt.Sprintf("Failed to sign token: %v", err))
}
return signedToken return signedToken
} }
func TestAuthMiddleware(t *testing.T) { func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret" jwtSecret := "test-secret"
tokenBlacklist := service.NewTokenBlacklist()
testCases := []struct { testCases := []struct {
name string name string
@ -55,7 +62,7 @@ func TestAuthMiddleware(t *testing.T) {
name: "Valid token", name: "Valid token",
setupAuth: func(req *http.Request) { setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"sub": "user123", "sub": "123",
"role": "user", "role": "user",
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),
} }
@ -64,14 +71,14 @@ func TestAuthMiddleware(t *testing.T) {
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
checkUserData: true, checkUserData: true,
expectedUserID: "user123", expectedUserID: "123",
expectedRole: "user", expectedRole: "user",
}, },
{ {
name: "Expired token", name: "Expired token",
setupAuth: func(req *http.Request) { setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"sub": "user123", "sub": "123",
"role": "user", "role": "user",
"exp": time.Now().Add(-time.Hour).Unix(), "exp": time.Now().Add(-time.Hour).Unix(),
} }
@ -89,18 +96,22 @@ func TestAuthMiddleware(t *testing.T) {
router := gin.New() router := gin.New()
// 添加认证中间件 // 添加认证中间件
router.Use(AuthMiddleware(jwtSecret)) router.Use(func(c *gin.Context) {
// 设置日志级别为 debug
gin.SetMode(gin.DebugMode)
c.Next()
}, AuthMiddleware(jwtSecret, tokenBlacklist))
// 测试路由 // 测试路由
router.GET("/test", func(c *gin.Context) { router.GET("/test", func(c *gin.Context) {
if tc.checkUserData { if tc.checkUserData {
userID, exists := c.Get("user_id") userID, exists := c.Get("user_id")
assert.True(t, exists) assert.True(t, exists, "user_id should exist in context")
assert.Equal(t, tc.expectedUserID, userID) assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
role, exists := c.Get("user_role") role, exists := c.Get("user_role")
assert.True(t, exists) assert.True(t, exists, "user_role should exist in context")
assert.Equal(t, tc.expectedRole, role) assert.Equal(t, tc.expectedRole, role, "user_role should match")
} }
c.Status(http.StatusOK) c.Status(http.StatusOK)
}) })
@ -114,13 +125,13 @@ func TestAuthMiddleware(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
// 验证响应 // 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code) assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil { if tc.expectedBody != nil {
var response map[string]string var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response) err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err) assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response) assert.Equal(t, tc.expectedBody, response, "Response body should match")
} }
}) })
} }

View file

@ -0,0 +1,57 @@
package service
import (
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)
// TokenBlacklist 用于存储已失效的 token
type TokenBlacklist struct {
tokens sync.Map
}
// NewTokenBlacklist 创建一个新的 token 黑名单
func NewTokenBlacklist() *TokenBlacklist {
bl := &TokenBlacklist{}
// 启动定期清理过期 token 的 goroutine
go bl.cleanupExpiredTokens()
return bl
}
// AddToBlacklist 将 token 添加到黑名单
func (bl *TokenBlacklist) AddToBlacklist(tokenStr string, claims jwt.MapClaims) {
// 获取 token 的过期时间
exp, ok := claims["exp"].(float64)
if !ok {
log.Error().Msg("Failed to get token expiration time")
return
}
// 存储 token 和其过期时间
bl.tokens.Store(tokenStr, time.Unix(int64(exp), 0))
}
// IsBlacklisted 检查 token 是否在黑名单中
func (bl *TokenBlacklist) IsBlacklisted(tokenStr string) bool {
_, exists := bl.tokens.Load(tokenStr)
return exists
}
// cleanupExpiredTokens 定期清理过期的 token
func (bl *TokenBlacklist) cleanupExpiredTokens() {
ticker := time.NewTicker(1 * time.Hour)
for range ticker.C {
now := time.Now()
bl.tokens.Range(func(key, value interface{}) bool {
if expTime, ok := value.(time.Time); ok {
if now.After(expTime) {
bl.tokens.Delete(key)
}
}
return true
})
}
}

View file

@ -18,12 +18,14 @@ import (
"tss-rocks-be/ent/contributorsociallink" "tss-rocks-be/ent/contributorsociallink"
"tss-rocks-be/ent/daily" "tss-rocks-be/ent/daily"
"tss-rocks-be/ent/dailycontent" "tss-rocks-be/ent/dailycontent"
"tss-rocks-be/ent/media"
"tss-rocks-be/ent/permission" "tss-rocks-be/ent/permission"
"tss-rocks-be/ent/post" "tss-rocks-be/ent/post"
"tss-rocks-be/ent/postcontent" "tss-rocks-be/ent/postcontent"
"tss-rocks-be/ent/role" "tss-rocks-be/ent/role"
"tss-rocks-be/ent/user" "tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage" "tss-rocks-be/internal/storage"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -42,6 +44,7 @@ var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *m
type serviceImpl struct { type serviceImpl struct {
client *ent.Client client *ent.Client
storage storage.Storage storage storage.Storage
tokenBlacklist *TokenBlacklist
} }
// NewService creates a new Service instance // NewService creates a new Service instance
@ -49,9 +52,15 @@ func NewService(client *ent.Client, storage storage.Storage) Service {
return &serviceImpl{ return &serviceImpl{
client: client, client: client,
storage: storage, storage: storage,
tokenBlacklist: NewTokenBlacklist(),
} }
} }
// GetTokenBlacklist returns the token blacklist
func (s *serviceImpl) GetTokenBlacklist() *TokenBlacklist {
return s.tokenBlacklist
}
// User operations // User operations
func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) { func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// 验证邮箱格式 // 验证邮箱格式
@ -451,12 +460,14 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error
} }
// Check ownership // Check ownership
if media.CreatedBy != strconv.Itoa(userID) { isOwner := media.CreatedBy == strconv.Itoa(userID)
if !isOwner {
return ErrUnauthorized return ErrUnauthorized
} }
// Delete from storage // Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil { err = s.storage.Delete(ctx, media.StorageID)
if err != nil {
return err return err
} }
@ -904,3 +915,138 @@ func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission
return false, nil return false, nil
} }
func (s *serviceImpl) Delete(ctx context.Context, id int, currentUserID int) error {
// Check if the entity exists and get its type
var entityExists bool
var err error
// Try to find the entity in different tables
if entityExists, err = s.client.User.Query().Where(user.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete users
hasPermission, err := s.HasPermission(ctx, currentUserID, "users:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
// Cannot delete yourself
if id == currentUserID {
return fmt.Errorf("cannot delete your own account")
}
return s.client.User.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Post.Query().Where(post.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete posts
hasPermission, err := s.HasPermission(ctx, currentUserID, "posts:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the author of the post
isAuthor, err := s.client.Post.Query().
Where(post.ID(id)).
QueryContributors().
QueryContributor().
QueryUser().
Where(user.ID(currentUserID)).
Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check post author: %v", err)
}
if !isAuthor {
return ErrUnauthorized
}
}
return s.client.Post.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Category.Query().Where(category.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete categories
hasPermission, err := s.HasPermission(ctx, currentUserID, "categories:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Category.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Contributor.Query().Where(contributor.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete contributors
hasPermission, err := s.HasPermission(ctx, currentUserID, "contributors:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Contributor.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Media.Query().Where(media.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete media
hasPermission, err := s.HasPermission(ctx, currentUserID, "media:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the uploader of the media
mediaItem, err := s.client.Media.Query().
Where(media.ID(id)).
Only(ctx)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
isOwner := mediaItem.CreatedBy == strconv.Itoa(currentUserID)
if !isOwner {
return ErrUnauthorized
}
}
// Get media item for path
mediaItem, err := s.client.Media.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
// Delete from storage first
if err := s.storage.Delete(ctx, mediaItem.StorageID); err != nil {
return fmt.Errorf("failed to delete media file: %v", err)
}
// Then delete from database
return s.client.Media.DeleteOneID(id).Exec(ctx)
}
return fmt.Errorf("entity with id %d not found or delete operation not supported for this entity type", id)
}
func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID int) error {
// Check if user has permission to delete daily content
hasPermission, err := s.HasPermission(ctx, currentUserID, "daily:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
exists, err := s.client.Daily.Query().Where(daily.ID(id)).Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check daily existence: %v", err)
}
if !exists {
return fmt.Errorf("daily with id %s not found", id)
}
return s.client.Daily.DeleteOneID(id).Exec(ctx)
}

View file

@ -39,6 +39,13 @@ type Service interface {
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// Contributor operations // Contributor operations
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
@ -51,16 +58,16 @@ type Service interface {
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error) ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations // RBAC operations
InitializeRBAC(ctx context.Context) error InitializeRBAC(ctx context.Context) error
AssignRole(ctx context.Context, userID int, role string) error AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error RemoveRole(ctx context.Context, userID int, role string) error
HasPermission(ctx context.Context, userID int, permission string) (bool, error) HasPermission(ctx context.Context, userID int, permission string) (bool, error)
// Token blacklist
GetTokenBlacklist() *TokenBlacklist
// Generic operations
Delete(ctx context.Context, id int, currentUserID int) error
DeleteDaily(ctx context.Context, id string, currentUserID int) error
} }