tss-rocks/backend/internal/service/impl.go

1052 lines
29 KiB
Go

package service
import (
"context"
"errors"
"fmt"
"io"
"mime/multipart"
"sort"
"strconv"
"strings"
"regexp"
"tss-rocks-be/ent"
"tss-rocks-be/ent/category"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/ent/contributor"
"tss-rocks-be/ent/contributorsociallink"
"tss-rocks-be/ent/daily"
"tss-rocks-be/ent/dailycontent"
"tss-rocks-be/ent/media"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/post"
"tss-rocks-be/ent/postcontent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
// Error definitions
var (
ErrUnauthorized = errors.New("unauthorized")
)
// openFile is a variable that holds the Open method of multipart.FileHeader
// This allows us to mock it in tests
var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) {
return fh.Open()
}
type serviceImpl struct {
client *ent.Client
storage storage.Storage
tokenBlacklist *TokenBlacklist
}
// NewService creates a new Service instance
func NewService(client *ent.Client, storage storage.Storage) Service {
return &serviceImpl{
client: client,
storage: storage,
tokenBlacklist: NewTokenBlacklist(),
}
}
// GetTokenBlacklist returns the token blacklist
func (s *serviceImpl) GetTokenBlacklist() *TokenBlacklist {
return s.tokenBlacklist
}
// User operations
func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(email) {
return nil, fmt.Errorf("invalid email format")
}
// 验证密码长度
if len(password) < 8 {
return nil, fmt.Errorf("password must be at least 8 characters")
}
// 检查用户名是否已存在
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("error checking username: %v", err)
}
if exists {
return nil, fmt.Errorf("username '%s' already exists", username)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %v", err)
}
// 创建用户
u, err := s.client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
SetStatus("active").
Save(ctx)
if err != nil {
return nil, fmt.Errorf("error creating user: %v", err)
}
// 分配角色
err = s.AssignRole(ctx, u.ID, roleStr)
if err != nil {
return nil, fmt.Errorf("error assigning role: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with username '%s' not found", username)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with email '%s' not found", email)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
}
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// Category operations
func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) {
return s.client.Category.Create().Save(ctx)
}
func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.CategoryContent.Create().
SetCategoryID(categoryID).
SetLanguageCode(languageCode).
SetName(name).
SetDescription(description).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.Category.Query().
Where(
category.HasContentsWith(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugEQ(slug),
),
),
).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(categorycontent.LanguageCodeEQ(languageCode))
}).
Only(ctx)
}
func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
// Daily operations
func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) {
_, err := s.client.Daily.Create().
SetID(id).
SetCategoryID(categoryID).
SetImageURL(imageURL).
Save(ctx)
if err != nil {
return nil, err
}
// 加载 Category Edge
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
Only(ctx)
}
func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.DailyContent.Create().
SetDailyID(dailyID).
SetLanguageCode(languageCode).
SetQuote(quote).
Save(ctx)
}
func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) {
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
WithContents().
Only(ctx)
}
func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
query := s.client.Daily.Query().
WithContents(func(q *ent.DailyContentQuery) {
if langCode != "" {
q.Where(dailycontent.LanguageCodeEQ(languageCode))
}
}).
WithCategory()
if categoryID != nil {
query.Where(daily.HasCategoryWith(category.ID(*categoryID)))
}
query.Order(ent.Desc(daily.FieldCreatedAt))
if limit > 0 {
query.Limit(limit)
}
if offset > 0 {
query.Offset(offset)
}
return query.All(ctx)
}
// Media operations
func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
return s.client.Media.Query().
Order(ent.Desc("created_at")).
Limit(limit).
Offset(offset).
All(ctx)
}
func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
// Open the uploaded file
src, err := openFile(file)
if err != nil {
return nil, err
}
defer src.Close()
// Save the file to storage
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
if err != nil {
return nil, err
}
// Create media record
return s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(file.Filename).
SetMimeType(fileInfo.ContentType).
SetSize(fileInfo.Size).
SetURL(fileInfo.URL).
SetCreatedBy(strconv.Itoa(userID)).
Save(ctx)
}
func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) {
return s.client.Media.Get(ctx, id)
}
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.GetMedia(ctx, id)
if err != nil {
return nil, nil, err
}
return s.storage.Get(ctx, media.StorageID)
}
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
media, err := s.GetMedia(ctx, id)
if err != nil {
return err
}
// Check ownership
isOwner := media.CreatedBy == strconv.Itoa(userID)
if !isOwner {
return ErrUnauthorized
}
// Delete from storage
err = s.storage.Delete(ctx, media.StorageID)
if err != nil {
return err
}
// Delete from database
return s.client.Media.DeleteOne(media).Exec(ctx)
}
// Post operations
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
var postStatus post.Status
switch status {
case "draft":
postStatus = post.StatusDraft
case "published":
postStatus = post.StatusPublished
case "archived":
postStatus = post.StatusArchived
default:
return nil, fmt.Errorf("invalid status: %s", status)
}
// Generate a random slug
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
return s.client.Post.Create().
SetStatus(postStatus).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Get the post first to check if it exists
post, err := s.client.Post.Get(ctx, postID)
if err != nil {
return nil, fmt.Errorf("failed to get post: %w", err)
}
// Generate slug from title
var slug string
if langCode == "en" {
// For English titles, convert to lowercase and replace spaces with dashes
slug = strings.ToLower(strings.ReplaceAll(title, " ", "-"))
// Remove all non-alphanumeric characters except dashes
slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "")
// Ensure slug is not empty and has minimum length
if slug == "" || len(slug) < 4 {
slug = fmt.Sprintf("post-%s", uuid.NewString()[:8])
}
} else {
// For Chinese titles, use the title as is
slug = title
}
return s.client.PostContent.Create().
SetPost(post).
SetLanguageCode(languageCode).
SetTitle(title).
SetContentMarkdown(content).
SetSummary(summary).
SetMetaKeywords(metaKeywords).
SetMetaDescription(metaDescription).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Find posts that have content with the given slug and language code
posts, err := s.client.Post.Query().
Where(
post.And(
post.StatusEQ(post.StatusPublished),
post.HasContentsWith(
postcontent.And(
postcontent.LanguageCodeEQ(languageCode),
postcontent.SlugEQ(slug),
),
),
),
).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
if len(posts) == 0 {
return nil, fmt.Errorf("post not found")
}
if len(posts) > 1 {
return nil, fmt.Errorf("multiple posts found with the same slug")
}
return posts[0], nil
}
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// First find all post IDs that have content in the given language
query := s.client.PostContent.Query().
Where(postcontent.LanguageCodeEQ(languageCode)).
QueryPost().
Where(post.StatusEQ(post.StatusPublished))
// Add category filter if provided
if categoryID != nil {
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
}
// Get unique post IDs
postIDs, err := query.
Order(ent.Desc(post.FieldCreatedAt)).
IDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get post IDs: %w", err)
}
// Remove duplicates while preserving order
seen := make(map[int]bool)
uniqueIDs := make([]int, 0, len(postIDs))
for _, id := range postIDs {
if !seen[id] {
seen[id] = true
uniqueIDs = append(uniqueIDs, id)
}
}
postIDs = uniqueIDs
if len(postIDs) == 0 {
return []*ent.Post{}, nil
}
// If no category filter is applied, only take the latest 5 posts
if categoryID == nil && len(postIDs) > 5 {
postIDs = postIDs[:5]
}
// Apply pagination
if offset >= len(postIDs) {
return []*ent.Post{}, nil
}
// If limit is 0, set it to the length of postIDs
if limit == 0 {
limit = len(postIDs)
}
// Adjust limit if it would exceed total
if offset+limit > len(postIDs) {
limit = len(postIDs) - offset
}
// Get the paginated post IDs
paginatedIDs := postIDs[offset : offset+limit]
// Get the posts with their contents
posts, err := s.client.Post.Query().
Where(post.IDIn(paginatedIDs...)).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
Order(ent.Desc(post.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
// Sort posts by ID to match the order of postIDs
sort.Slice(posts, func(i, j int) bool {
// Find index of each post ID in postIDs
var iIndex, jIndex int
for idx, id := range paginatedIDs {
if posts[i].ID == id {
iIndex = idx
}
if posts[j].ID == id {
jIndex = idx
}
}
return iIndex < jIndex
})
return posts, nil
}
// Contributor operations
func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) {
builder := s.client.Contributor.Create().
SetName(name)
if avatarURL != nil {
builder.SetAvatarURL(*avatarURL)
}
if bio != nil {
builder.SetBio(*bio)
}
contributor, err := builder.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) {
// 验证贡献者是否存在
contributor, err := s.client.Contributor.Get(ctx, contributorID)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
// 创建社交链接
link, err := s.client.ContributorSocialLink.Create().
SetContributor(contributor).
SetType(contributorsociallink.Type(linkType)).
SetName(name).
SetValue(value).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create social link: %w", err)
}
return link, nil
}
func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) {
contributor, err := s.client.Contributor.Query().
Where(contributor.ID(id)).
WithSocialLinks().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) {
contributors, err := s.client.Contributor.Query().
WithSocialLinks().
Order(ent.Asc(contributor.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list contributors: %w", err)
}
return contributors, nil
}
// RBAC operations
func (s *serviceImpl) InitializeRBAC(ctx context.Context) error {
// Create roles if they don't exist
adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx)
if ent.IsConstraintError(err) {
adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create admin role: %w", err)
}
editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx)
if ent.IsConstraintError(err) {
editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create editor role: %w", err)
}
userRole, err := s.client.Role.Create().SetName("user").Save(ctx)
if ent.IsConstraintError(err) {
userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create user role: %w", err)
}
// Define permissions
permissions := []struct {
role *ent.Role
resource string
actions []string
}{
// Admin permissions (full access)
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
{adminRole, "media", []string{"create", "read", "update", "delete"}},
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
// Editor permissions (can create and manage content)
{editorRole, "media", []string{"create", "read", "update", "delete"}},
{editorRole, "posts", []string{"create", "read", "update", "delete"}},
{editorRole, "categories", []string{"read"}},
{editorRole, "contributors", []string{"read"}},
{editorRole, "dailies", []string{"create", "read", "update", "delete"}},
// User permissions (read-only access)
{userRole, "media", []string{"read"}},
{userRole, "posts", []string{"read"}},
{userRole, "categories", []string{"read"}},
{userRole, "contributors", []string{"read"}},
{userRole, "dailies", []string{"read"}},
}
// Create permissions for each role
for _, p := range permissions {
for _, action := range p.actions {
perm, err := s.client.Permission.Create().
SetResource(p.resource).
SetAction(action).
Save(ctx)
if ent.IsConstraintError(err) {
perm, err = s.client.Permission.Query().
Where(
permission.ResourceEQ(p.resource),
permission.ActionEQ(action),
).
Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err)
}
// Add permission to role
err = s.client.Role.UpdateOne(p.role).
AddPermissions(perm).
Exec(ctx)
if err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err)
}
}
}
return nil
}
func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error {
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx)
}
func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error {
// Don't allow removing the user role
if roleName == "user" {
return errors.New("cannot remove user role")
}
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx)
}
func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user.Edges.Roles, nil
}
func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles(func(q *ent.RoleQuery) {
q.WithPermissions()
}).
Only(ctx)
if err != nil {
return false, fmt.Errorf("failed to get user: %w", err)
}
parts := strings.Split(permission, ":")
if len(parts) != 2 {
return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission)
}
resource, action := parts[0], parts[1]
for _, r := range user.Edges.Roles {
for _, p := range r.Edges.Permissions {
if p.Resource == resource && p.Action == action {
return true, nil
}
}
}
return false, nil
}
func (s *serviceImpl) Delete(ctx context.Context, id int, currentUserID int) error {
// Check if the entity exists and get its type
var entityExists bool
var err error
// Try to find the entity in different tables
if entityExists, err = s.client.User.Query().Where(user.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete users
hasPermission, err := s.HasPermission(ctx, currentUserID, "users:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
// Cannot delete yourself
if id == currentUserID {
return fmt.Errorf("cannot delete your own account")
}
return s.client.User.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Post.Query().Where(post.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete posts
hasPermission, err := s.HasPermission(ctx, currentUserID, "posts:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the author of the post
isAuthor, err := s.client.Post.Query().
Where(post.ID(id)).
QueryContributors().
QueryContributor().
QueryUser().
Where(user.ID(currentUserID)).
Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check post author: %v", err)
}
if !isAuthor {
return ErrUnauthorized
}
}
return s.client.Post.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Category.Query().Where(category.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete categories
hasPermission, err := s.HasPermission(ctx, currentUserID, "categories:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Category.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Contributor.Query().Where(contributor.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete contributors
hasPermission, err := s.HasPermission(ctx, currentUserID, "contributors:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Contributor.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Media.Query().Where(media.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete media
hasPermission, err := s.HasPermission(ctx, currentUserID, "media:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the uploader of the media
mediaItem, err := s.client.Media.Query().
Where(media.ID(id)).
Only(ctx)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
isOwner := mediaItem.CreatedBy == strconv.Itoa(currentUserID)
if !isOwner {
return ErrUnauthorized
}
}
// Get media item for path
mediaItem, err := s.client.Media.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
// Delete from storage first
if err := s.storage.Delete(ctx, mediaItem.StorageID); err != nil {
return fmt.Errorf("failed to delete media file: %v", err)
}
// Then delete from database
return s.client.Media.DeleteOneID(id).Exec(ctx)
}
return fmt.Errorf("entity with id %d not found or delete operation not supported for this entity type", id)
}
func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID int) error {
// Check if user has permission to delete daily content
hasPermission, err := s.HasPermission(ctx, currentUserID, "daily:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
exists, err := s.client.Daily.Query().Where(daily.ID(id)).Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check daily existence: %v", err)
}
if !exists {
return fmt.Errorf("daily with id %s not found", id)
}
return s.client.Daily.DeleteOneID(id).Exec(ctx)
}