package service import ( "bytes" "context" "errors" "fmt" "io" "mime/multipart" "os" "path/filepath" "sort" "strconv" "strings" "regexp" "tss-rocks-be/ent" "tss-rocks-be/ent/category" "tss-rocks-be/ent/categorycontent" "tss-rocks-be/ent/contributor" "tss-rocks-be/ent/contributorsociallink" "tss-rocks-be/ent/daily" "tss-rocks-be/ent/dailycontent" "tss-rocks-be/ent/media" "tss-rocks-be/ent/permission" "tss-rocks-be/ent/post" "tss-rocks-be/ent/postcontent" "tss-rocks-be/ent/role" "tss-rocks-be/ent/user" "tss-rocks-be/internal/storage" "github.com/chai2010/webp" "github.com/disintegration/imaging" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) // Error definitions var ( ErrUnauthorized = errors.New("unauthorized") ) // openFile is a variable that holds the Open method of multipart.FileHeader // This allows us to mock it in tests var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) { return fh.Open() } type serviceImpl struct { client *ent.Client storage storage.Storage tokenBlacklist *TokenBlacklist } // NewService creates a new Service instance func NewService(client *ent.Client, storage storage.Storage) Service { return &serviceImpl{ client: client, storage: storage, tokenBlacklist: NewTokenBlacklist(), } } // GetTokenBlacklist returns the token blacklist func (s *serviceImpl) GetTokenBlacklist() *TokenBlacklist { return s.tokenBlacklist } // User operations func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) { // 验证邮箱格式 emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) if !emailRegex.MatchString(email) { return nil, fmt.Errorf("invalid email format") } // 验证密码长度 if len(password) < 8 { return nil, fmt.Errorf("password must be at least 8 characters") } // 检查用户名是否已存在 exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx) if err != nil { return nil, fmt.Errorf("error checking username: %v", err) } if exists { return nil, fmt.Errorf("username '%s' already exists", username) } // 生成密码哈希 hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("error hashing password: %v", err) } // 创建用户 u, err := s.client.User.Create(). SetUsername(username). SetEmail(email). SetPasswordHash(string(hashedPassword)). SetStatus("active"). Save(ctx) if err != nil { return nil, fmt.Errorf("error creating user: %v", err) } // 分配角色 err = s.AssignRole(ctx, u.ID, roleStr) if err != nil { return nil, fmt.Errorf("error assigning role: %v", err) } return u, nil } func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) { u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx) if err != nil { if ent.IsNotFound(err) { return nil, fmt.Errorf("user with username '%s' not found", username) } return nil, fmt.Errorf("error getting user: %v", err) } return u, nil } func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) { u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx) if err != nil { if ent.IsNotFound(err) { return nil, fmt.Errorf("user with email '%s' not found", email) } return nil, fmt.Errorf("error getting user: %v", err) } return u, nil } func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool { err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) return err == nil } // Category operations func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) { return s.client.Category.Create().Save(ctx) } func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) { var languageCode categorycontent.LanguageCode switch langCode { case "en": languageCode = categorycontent.LanguageCodeEN case "zh-Hans": languageCode = categorycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = categorycontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } return s.client.CategoryContent.Create(). SetCategoryID(categoryID). SetLanguageCode(languageCode). SetName(name). SetDescription(description). SetSlug(slug). Save(ctx) } func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) { var languageCode categorycontent.LanguageCode switch langCode { case "en": languageCode = categorycontent.LanguageCodeEN case "zh-Hans": languageCode = categorycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = categorycontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } return s.client.Category.Query(). Where( category.HasContentsWith( categorycontent.And( categorycontent.LanguageCodeEQ(languageCode), categorycontent.SlugEQ(slug), ), ), ). WithContents(func(q *ent.CategoryContentQuery) { q.Where(categorycontent.LanguageCodeEQ(languageCode)) }). Only(ctx) } func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) { // 转换语言代码 var languageCode categorycontent.LanguageCode switch langCode { case "en": languageCode = categorycontent.LanguageCodeEN case "zh-Hans": languageCode = categorycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = categorycontent.LanguageCodeZH_HANT default: // 不支持的语言代码返回空列表而不是错误 return []*ent.Category{}, nil } // 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类 contents, err := s.client.CategoryContent.Query(). Where( categorycontent.And( categorycontent.LanguageCodeEQ(languageCode), categorycontent.SlugHasPrefix("category-list-"), ), ). WithCategory(). All(ctx) if err != nil { return nil, err } // 使用 map 去重,因为一个分类可能有多个相同语言的内容 categoryMap := make(map[int]*ent.Category) for _, content := range contents { if content.Edges.Category != nil { categoryMap[content.Edges.Category.ID] = content.Edges.Category } } // 将 map 转换为有序的切片 var categories []*ent.Category for _, cat := range categoryMap { // 重新查询分类以获取完整的关联数据 c, err := s.client.Category.Query(). Where(category.ID(cat.ID)). WithContents(func(q *ent.CategoryContentQuery) { q.Where( categorycontent.And( categorycontent.LanguageCodeEQ(languageCode), categorycontent.SlugHasPrefix("category-list-"), ), ) }). Only(ctx) if err != nil { return nil, err } categories = append(categories, c) } // 按 ID 排序以保持结果稳定 sort.Slice(categories, func(i, j int) bool { return categories[i].ID < categories[j].ID }) return categories, nil } func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) { // 转换语言代码 var languageCode categorycontent.LanguageCode switch langCode { case "en": languageCode = categorycontent.LanguageCodeEN case "zh-Hans": languageCode = categorycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = categorycontent.LanguageCodeZH_HANT default: // 不支持的语言代码返回空列表而不是错误 return []*ent.Category{}, nil } // 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类 contents, err := s.client.CategoryContent.Query(). Where( categorycontent.And( categorycontent.LanguageCodeEQ(languageCode), categorycontent.SlugHasPrefix("category-list-"), ), ). WithCategory(). All(ctx) if err != nil { return nil, err } // 使用 map 去重,因为一个分类可能有多个相同语言的内容 categoryMap := make(map[int]*ent.Category) for _, content := range contents { if content.Edges.Category != nil { categoryMap[content.Edges.Category.ID] = content.Edges.Category } } // 将 map 转换为有序的切片 var categories []*ent.Category for _, cat := range categoryMap { // 重新查询分类以获取完整的关联数据 c, err := s.client.Category.Query(). Where(category.ID(cat.ID)). WithContents(func(q *ent.CategoryContentQuery) { q.Where( categorycontent.And( categorycontent.LanguageCodeEQ(languageCode), categorycontent.SlugHasPrefix("category-list-"), ), ) }). Only(ctx) if err != nil { return nil, err } categories = append(categories, c) } // 按 ID 排序以保持结果稳定 sort.Slice(categories, func(i, j int) bool { return categories[i].ID < categories[j].ID }) return categories, nil } // Daily operations func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) { _, err := s.client.Daily.Create(). SetID(id). SetCategoryID(categoryID). SetImageURL(imageURL). Save(ctx) if err != nil { return nil, err } // 加载 Category Edge return s.client.Daily.Query(). Where(daily.IDEQ(id)). WithCategory(). Only(ctx) } func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) { var languageCode dailycontent.LanguageCode switch langCode { case "en": languageCode = dailycontent.LanguageCodeEN case "zh-Hans": languageCode = dailycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = dailycontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } return s.client.DailyContent.Create(). SetDailyID(dailyID). SetLanguageCode(languageCode). SetQuote(quote). Save(ctx) } func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) { return s.client.Daily.Query(). Where(daily.IDEQ(id)). WithCategory(). WithContents(). Only(ctx) } func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) { var languageCode dailycontent.LanguageCode switch langCode { case "en": languageCode = dailycontent.LanguageCodeEN case "zh-Hans": languageCode = dailycontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = dailycontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } query := s.client.Daily.Query(). WithContents(func(q *ent.DailyContentQuery) { if langCode != "" { q.Where(dailycontent.LanguageCodeEQ(languageCode)) } }). WithCategory() if categoryID != nil { query.Where(daily.HasCategoryWith(category.ID(*categoryID))) } query.Order(ent.Desc(daily.FieldCreatedAt)) if limit > 0 { query.Limit(limit) } if offset > 0 { query.Offset(offset) } return query.All(ctx) } // Media operations func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) { return s.client.Media.Query(). Order(ent.Desc("created_at")). Limit(limit). Offset(offset). All(ctx) } func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) { // Open the uploaded file src, err := openFile(file) if err != nil { return nil, fmt.Errorf("failed to open file: %v", err) } defer src.Close() // 获取文件类型和扩展名 contentType := file.Header.Get("Content-Type") ext := strings.ToLower(filepath.Ext(file.Filename)) if contentType == "" { // 如果 Content-Type 为空,尝试从文件扩展名判断 switch ext { case ".jpg", ".jpeg": contentType = "image/jpeg" case ".png": contentType = "image/png" case ".gif": contentType = "image/gif" case ".webp": contentType = "image/webp" case ".mp4": contentType = "video/mp4" case ".webm": contentType = "video/webm" case ".mp3": contentType = "audio/mpeg" case ".ogg": contentType = "audio/ogg" case ".wav": contentType = "audio/wav" case ".pdf": contentType = "application/pdf" case ".doc": contentType = "application/msword" case ".docx": contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" } } // 如果是图片,检查是否需要转换为 WebP var fileToSave multipart.File = src var finalContentType = contentType if strings.HasPrefix(contentType, "image/") && contentType != "image/webp" { // 转换为 WebP webpFile, err := convertToWebP(src) if err != nil { return nil, fmt.Errorf("failed to convert image to WebP: %v", err) } fileToSave = webpFile finalContentType = "image/webp" ext = ".webp" } // 生成带扩展名的存储文件名 storageFilename := uuid.New().String() + ext // Save the file to storage fileInfo, err := s.storage.Save(ctx, storageFilename, finalContentType, fileToSave) if err != nil { return nil, fmt.Errorf("failed to save file: %v", err) } // Create media record return s.client.Media.Create(). SetStorageID(fileInfo.ID). SetOriginalName(file.Filename). SetMimeType(finalContentType). SetSize(fileInfo.Size). SetURL(fileInfo.URL). SetCreatedBy(strconv.Itoa(userID)). Save(ctx) } func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) { return s.client.Media.Get(ctx, id) } func (s *serviceImpl) GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) { return s.storage.Get(ctx, storageID) } func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error { media, err := s.GetMedia(ctx, id) if err != nil { return err } // Check ownership isOwner := media.CreatedBy == strconv.Itoa(userID) if !isOwner { return ErrUnauthorized } // Delete from storage err = s.storage.Delete(ctx, media.StorageID) if err != nil { return err } // Delete from database return s.client.Media.DeleteOne(media).Exec(ctx) } // Post operations func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) { var postStatus post.Status switch status { case "draft": postStatus = post.StatusDraft case "published": postStatus = post.StatusPublished case "archived": postStatus = post.StatusArchived default: return nil, fmt.Errorf("invalid status: %s", status) } // Generate a random slug slug := fmt.Sprintf("post-%s", uuid.New().String()[:8]) // Create post with categories postCreate := s.client.Post.Create(). SetStatus(postStatus). SetSlug(slug) // Add categories if provided if len(categoryIDs) > 0 { categories := make([]*ent.Category, 0, len(categoryIDs)) for _, id := range categoryIDs { category, err := s.client.Category.Get(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get category %d: %w", id, err) } categories = append(categories, category) } postCreate.AddCategories(categories...) } return postCreate.Save(ctx) } func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) { var languageCode postcontent.LanguageCode switch langCode { case "en": languageCode = postcontent.LanguageCodeEN case "zh-Hans": languageCode = postcontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = postcontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } // Get the post first to check if it exists post, err := s.client.Post.Get(ctx, postID) if err != nil { return nil, fmt.Errorf("failed to get post: %w", err) } // Generate slug from title var slug string if langCode == "en" { // For English titles, convert to lowercase and replace spaces with dashes slug = strings.ToLower(strings.ReplaceAll(title, " ", "-")) // Remove all non-alphanumeric characters except dashes slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "") // Ensure slug is not empty and has minimum length if slug == "" || len(slug) < 4 { slug = fmt.Sprintf("post-%s", uuid.NewString()[:8]) } } else { // For Chinese titles, use the title as is slug = title } return s.client.PostContent.Create(). SetPost(post). SetLanguageCode(languageCode). SetTitle(title). SetContentMarkdown(content). SetSummary(summary). SetMetaKeywords(metaKeywords). SetMetaDescription(metaDescription). SetSlug(slug). Save(ctx) } func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) { var languageCode postcontent.LanguageCode switch langCode { case "en": languageCode = postcontent.LanguageCodeEN case "zh-Hans": languageCode = postcontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = postcontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } // Find posts that have content with the given slug and language code posts, err := s.client.Post.Query(). Where( post.And( post.StatusEQ(post.StatusPublished), post.HasContentsWith( postcontent.And( postcontent.LanguageCodeEQ(languageCode), postcontent.SlugEQ(slug), ), ), ), ). WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). WithCategories(). All(ctx) if err != nil { return nil, fmt.Errorf("failed to get posts: %w", err) } if len(posts) == 0 { return nil, fmt.Errorf("post not found") } if len(posts) > 1 { return nil, fmt.Errorf("multiple posts found with the same slug") } return posts[0], nil } func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) { var languageCode postcontent.LanguageCode switch langCode { case "en": languageCode = postcontent.LanguageCodeEN case "zh-Hans": languageCode = postcontent.LanguageCodeZH_HANS case "zh-Hant": languageCode = postcontent.LanguageCodeZH_HANT default: return nil, fmt.Errorf("unsupported language code: %s", langCode) } // First find all post IDs that have content in the given language query := s.client.PostContent.Query(). Where(postcontent.LanguageCodeEQ(languageCode)). QueryPost(). Where(post.StatusEQ(post.StatusPublished)) // Add category filter if provided if len(categoryIDs) > 0 { query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...))) } // Get unique post IDs postIDs, err := query. Order(ent.Desc(post.FieldCreatedAt)). IDs(ctx) if err != nil { return nil, fmt.Errorf("failed to get post IDs: %w", err) } // Remove duplicates while preserving order seen := make(map[int]bool) uniqueIDs := make([]int, 0, len(postIDs)) for _, id := range postIDs { if !seen[id] { seen[id] = true uniqueIDs = append(uniqueIDs, id) } } postIDs = uniqueIDs if len(postIDs) == 0 { return []*ent.Post{}, nil } // If no category filter is applied, only take the latest 5 posts if len(categoryIDs) == 0 && len(postIDs) > 5 { postIDs = postIDs[:5] } // Apply pagination if offset >= len(postIDs) { return []*ent.Post{}, nil } // If limit is 0, set it to the length of postIDs if limit == 0 { limit = len(postIDs) } // Adjust limit if it would exceed total if offset+limit > len(postIDs) { limit = len(postIDs) - offset } // Get the paginated post IDs paginatedIDs := postIDs[offset : offset+limit] // Get the posts with their contents posts, err := s.client.Post.Query(). Where(post.IDIn(paginatedIDs...)). WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). WithCategories(). Order(ent.Desc(post.FieldCreatedAt)). All(ctx) if err != nil { return nil, fmt.Errorf("failed to get posts: %w", err) } // Sort posts by ID to match the order of postIDs sort.Slice(posts, func(i, j int) bool { // Find index of each post ID in postIDs var iIndex, jIndex int for idx, id := range paginatedIDs { if posts[i].ID == id { iIndex = idx } if posts[j].ID == id { jIndex = idx } } return iIndex < jIndex }) return posts, nil } // Contributor operations func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) { builder := s.client.Contributor.Create(). SetName(name) if avatarURL != nil { builder.SetAvatarURL(*avatarURL) } if bio != nil { builder.SetBio(*bio) } contributor, err := builder.Save(ctx) if err != nil { return nil, fmt.Errorf("failed to create contributor: %w", err) } return contributor, nil } func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) { // 验证贡献者是否存在 contributor, err := s.client.Contributor.Get(ctx, contributorID) if err != nil { return nil, fmt.Errorf("failed to get contributor: %w", err) } // 创建社交链接 link, err := s.client.ContributorSocialLink.Create(). SetContributor(contributor). SetType(contributorsociallink.Type(linkType)). SetName(name). SetValue(value). Save(ctx) if err != nil { return nil, fmt.Errorf("failed to create social link: %w", err) } return link, nil } func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) { contributor, err := s.client.Contributor.Query(). Where(contributor.ID(id)). WithSocialLinks(). Only(ctx) if err != nil { return nil, fmt.Errorf("failed to get contributor: %w", err) } return contributor, nil } func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) { contributors, err := s.client.Contributor.Query(). WithSocialLinks(). Order(ent.Asc(contributor.FieldName)). All(ctx) if err != nil { return nil, fmt.Errorf("failed to list contributors: %w", err) } return contributors, nil } // RBAC operations func (s *serviceImpl) InitializeRBAC(ctx context.Context) error { // Create roles if they don't exist adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx) if ent.IsConstraintError(err) { adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx) } if err != nil { return fmt.Errorf("failed to create admin role: %w", err) } editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx) if ent.IsConstraintError(err) { editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx) } if err != nil { return fmt.Errorf("failed to create editor role: %w", err) } userRole, err := s.client.Role.Create().SetName("user").Save(ctx) if ent.IsConstraintError(err) { userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx) } if err != nil { return fmt.Errorf("failed to create user role: %w", err) } // Define permissions permissions := []struct { role *ent.Role resource string actions []string }{ // Admin permissions (full access) {adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}}, {adminRole, "roles", []string{"create", "read", "update", "delete"}}, {adminRole, "media", []string{"create", "read", "update", "delete"}}, {adminRole, "posts", []string{"create", "read", "update", "delete"}}, {adminRole, "categories", []string{"create", "read", "update", "delete"}}, {adminRole, "contributors", []string{"create", "read", "update", "delete"}}, {adminRole, "dailies", []string{"create", "read", "update", "delete"}}, // Editor permissions (can create and manage content) {editorRole, "media", []string{"create", "read", "update", "delete"}}, {editorRole, "posts", []string{"create", "read", "update", "delete"}}, {editorRole, "categories", []string{"read"}}, {editorRole, "contributors", []string{"read"}}, {editorRole, "dailies", []string{"create", "read", "update", "delete"}}, // User permissions (read-only access) {userRole, "media", []string{"read"}}, {userRole, "posts", []string{"read"}}, {userRole, "categories", []string{"read"}}, {userRole, "contributors", []string{"read"}}, {userRole, "dailies", []string{"read"}}, } // Create permissions for each role for _, p := range permissions { for _, action := range p.actions { perm, err := s.client.Permission.Create(). SetResource(p.resource). SetAction(action). Save(ctx) if ent.IsConstraintError(err) { perm, err = s.client.Permission.Query(). Where( permission.ResourceEQ(p.resource), permission.ActionEQ(action), ). Only(ctx) } if err != nil { return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err) } // Add permission to role err = s.client.Role.UpdateOne(p.role). AddPermissions(perm). Exec(ctx) if err != nil && !ent.IsConstraintError(err) { return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err) } } } return nil } func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error { user, err := s.client.User.Get(ctx, userID) if err != nil { return fmt.Errorf("failed to get user: %w", err) } role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx) if err != nil { return fmt.Errorf("failed to get role: %w", err) } return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx) } func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error { // Don't allow removing the user role if roleName == "user" { return errors.New("cannot remove user role") } user, err := s.client.User.Get(ctx, userID) if err != nil { return fmt.Errorf("failed to get user: %w", err) } role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx) if err != nil { return fmt.Errorf("failed to get role: %w", err) } return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx) } func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) { user, err := s.client.User.Query(). Where(user.ID(userID)). WithRoles(). Only(ctx) if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } return user.Edges.Roles, nil } func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) { user, err := s.client.User.Query(). Where(user.ID(userID)). WithRoles(func(q *ent.RoleQuery) { q.WithPermissions() }). Only(ctx) if err != nil { return false, fmt.Errorf("failed to get user: %w", err) } parts := strings.Split(permission, ":") if len(parts) != 2 { return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission) } resource, action := parts[0], parts[1] for _, r := range user.Edges.Roles { for _, p := range r.Edges.Permissions { if p.Resource == resource && p.Action == action { return true, nil } } } return false, nil } func (s *serviceImpl) Delete(ctx context.Context, id int, currentUserID int) error { // Check if the entity exists and get its type var entityExists bool var err error // Try to find the entity in different tables if entityExists, err = s.client.User.Query().Where(user.ID(id)).Exist(ctx); err == nil && entityExists { // Check if user has permission to delete users hasPermission, err := s.HasPermission(ctx, currentUserID, "users:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { return ErrUnauthorized } // Cannot delete yourself if id == currentUserID { return fmt.Errorf("cannot delete your own account") } return s.client.User.DeleteOneID(id).Exec(ctx) } if entityExists, err = s.client.Post.Query().Where(post.ID(id)).Exist(ctx); err == nil && entityExists { // Check if user has permission to delete posts hasPermission, err := s.HasPermission(ctx, currentUserID, "posts:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { // Check if the user is the author of the post isAuthor, err := s.client.Post.Query(). Where(post.ID(id)). QueryContributors(). QueryContributor(). QueryUser(). Where(user.ID(currentUserID)). Exist(ctx) if err != nil { return fmt.Errorf("failed to check post author: %v", err) } if !isAuthor { return ErrUnauthorized } } return s.client.Post.DeleteOneID(id).Exec(ctx) } if entityExists, err = s.client.Category.Query().Where(category.ID(id)).Exist(ctx); err == nil && entityExists { // Check if user has permission to delete categories hasPermission, err := s.HasPermission(ctx, currentUserID, "categories:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { return ErrUnauthorized } return s.client.Category.DeleteOneID(id).Exec(ctx) } if entityExists, err = s.client.Contributor.Query().Where(contributor.ID(id)).Exist(ctx); err == nil && entityExists { // Check if user has permission to delete contributors hasPermission, err := s.HasPermission(ctx, currentUserID, "contributors:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { return ErrUnauthorized } return s.client.Contributor.DeleteOneID(id).Exec(ctx) } if entityExists, err = s.client.Media.Query().Where(media.ID(id)).Exist(ctx); err == nil && entityExists { // Check if user has permission to delete media hasPermission, err := s.HasPermission(ctx, currentUserID, "media:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { // Check if the user is the uploader of the media mediaItem, err := s.client.Media.Query(). Where(media.ID(id)). Only(ctx) if err != nil { return fmt.Errorf("failed to get media: %v", err) } isOwner := mediaItem.CreatedBy == strconv.Itoa(currentUserID) if !isOwner { return ErrUnauthorized } } // Get media item for path mediaItem, err := s.client.Media.Get(ctx, id) if err != nil { return fmt.Errorf("failed to get media: %v", err) } // Delete from storage first if err := s.storage.Delete(ctx, mediaItem.StorageID); err != nil { return fmt.Errorf("failed to delete media file: %v", err) } // Then delete from database return s.client.Media.DeleteOneID(id).Exec(ctx) } return fmt.Errorf("entity with id %d not found or delete operation not supported for this entity type", id) } func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID int) error { // Check if user has permission to delete daily content hasPermission, err := s.HasPermission(ctx, currentUserID, "daily:delete") if err != nil { return fmt.Errorf("failed to check permission: %v", err) } if !hasPermission { return ErrUnauthorized } exists, err := s.client.Daily.Query().Where(daily.ID(id)).Exist(ctx) if err != nil { return fmt.Errorf("failed to check daily existence: %v", err) } if !exists { return fmt.Errorf("daily with id %s not found", id) } return s.client.Daily.DeleteOneID(id).Exec(ctx) } // convertToWebP 将图片转换为 WebP 格式 func convertToWebP(src multipart.File) (multipart.File, error) { // 读取原始图片 img, err := imaging.Decode(src) if err != nil { return nil, fmt.Errorf("failed to decode image: %v", err) } // 创建一个新的缓冲区来存储 WebP 图片 buf := new(bytes.Buffer) // 将图片编码为 WebP 格式 // 设置较高的质量以保持图片质量 err = webp.Encode(buf, img, &webp.Options{ Lossless: false, Quality: 90, }) if err != nil { return nil, fmt.Errorf("failed to encode image to WebP: %v", err) } // 创建一个新的临时文件来存储转换后的图片 tmpFile, err := os.CreateTemp("", "webp-*.webp") if err != nil { return nil, fmt.Errorf("failed to create temp file: %v", err) } // 写入转换后的数据 if _, err := io.Copy(tmpFile, buf); err != nil { tmpFile.Close() os.Remove(tmpFile.Name()) return nil, fmt.Errorf("failed to write WebP data: %v", err) } // 将文件指针移回开始位置 if _, err := tmpFile.Seek(0, 0); err != nil { tmpFile.Close() os.Remove(tmpFile.Name()) return nil, fmt.Errorf("failed to seek file: %v", err) } return tmpFile, nil }