[feature] migrate to monorepo
This commit is contained in:
commit
05ddc1f783
267 changed files with 75165 additions and 0 deletions
892
backend/internal/service/impl.go
Normal file
892
backend/internal/service/impl.go
Normal file
|
@ -0,0 +1,892 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"regexp"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/category"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/ent/contributor"
|
||||
"tss-rocks-be/ent/contributorsociallink"
|
||||
"tss-rocks-be/ent/daily"
|
||||
"tss-rocks-be/ent/dailycontent"
|
||||
"tss-rocks-be/ent/permission"
|
||||
"tss-rocks-be/ent/post"
|
||||
"tss-rocks-be/ent/postcontent"
|
||||
"tss-rocks-be/ent/role"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
)
|
||||
|
||||
// openFile is a variable that holds the Open method of multipart.FileHeader
|
||||
// This allows us to mock it in tests
|
||||
var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) {
|
||||
return fh.Open()
|
||||
}
|
||||
|
||||
type serviceImpl struct {
|
||||
client *ent.Client
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
// NewService creates a new Service instance
|
||||
func NewService(client *ent.Client, storage storage.Storage) Service {
|
||||
return &serviceImpl{
|
||||
client: client,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// User operations
|
||||
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
|
||||
// Hash the password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Add the user role by default
|
||||
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user role: %w", err)
|
||||
}
|
||||
|
||||
// If a specific role is requested and it's not "user", get that role too
|
||||
var additionalRole *ent.Role
|
||||
if roleStr != "" && roleStr != "user" {
|
||||
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create user with password and user role
|
||||
userCreate := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash(string(hashedPassword)).
|
||||
AddRoles(userRole)
|
||||
|
||||
// Add the additional role if specified
|
||||
if additionalRole != nil {
|
||||
userCreate.AddRoles(additionalRole)
|
||||
}
|
||||
|
||||
// Save the user
|
||||
user, err := userCreate.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.EmailEQ(email)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %s", email)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Category operations
|
||||
func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) {
|
||||
return s.client.Category.Create().Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) {
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.CategoryContent.Create().
|
||||
SetCategoryID(categoryID).
|
||||
SetLanguageCode(languageCode).
|
||||
SetName(name).
|
||||
SetDescription(description).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) {
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.Category.Query().
|
||||
Where(
|
||||
category.HasContentsWith(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugEQ(slug),
|
||||
),
|
||||
),
|
||||
).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(categorycontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
|
||||
// 转换语言代码
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
// 不支持的语言代码返回空列表而不是错误
|
||||
return []*ent.Category{}, nil
|
||||
}
|
||||
|
||||
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
|
||||
contents, err := s.client.CategoryContent.Query().
|
||||
Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
|
||||
categoryMap := make(map[int]*ent.Category)
|
||||
for _, content := range contents {
|
||||
if content.Edges.Category != nil {
|
||||
categoryMap[content.Edges.Category.ID] = content.Edges.Category
|
||||
}
|
||||
}
|
||||
|
||||
// 将 map 转换为有序的切片
|
||||
var categories []*ent.Category
|
||||
for _, cat := range categoryMap {
|
||||
// 重新查询分类以获取完整的关联数据
|
||||
c, err := s.client.Category.Query().
|
||||
Where(category.ID(cat.ID)).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
categories = append(categories, c)
|
||||
}
|
||||
|
||||
// 按 ID 排序以保持结果稳定
|
||||
sort.Slice(categories, func(i, j int) bool {
|
||||
return categories[i].ID < categories[j].ID
|
||||
})
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
|
||||
// 转换语言代码
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
// 不支持的语言代码返回空列表而不是错误
|
||||
return []*ent.Category{}, nil
|
||||
}
|
||||
|
||||
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
|
||||
contents, err := s.client.CategoryContent.Query().
|
||||
Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
|
||||
categoryMap := make(map[int]*ent.Category)
|
||||
for _, content := range contents {
|
||||
if content.Edges.Category != nil {
|
||||
categoryMap[content.Edges.Category.ID] = content.Edges.Category
|
||||
}
|
||||
}
|
||||
|
||||
// 将 map 转换为有序的切片
|
||||
var categories []*ent.Category
|
||||
for _, cat := range categoryMap {
|
||||
// 重新查询分类以获取完整的关联数据
|
||||
c, err := s.client.Category.Query().
|
||||
Where(category.ID(cat.ID)).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
categories = append(categories, c)
|
||||
}
|
||||
|
||||
// 按 ID 排序以保持结果稳定
|
||||
sort.Slice(categories, func(i, j int) bool {
|
||||
return categories[i].ID < categories[j].ID
|
||||
})
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// Daily operations
|
||||
func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) {
|
||||
_, err := s.client.Daily.Create().
|
||||
SetID(id).
|
||||
SetCategoryID(categoryID).
|
||||
SetImageURL(imageURL).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 加载 Category Edge
|
||||
return s.client.Daily.Query().
|
||||
Where(daily.IDEQ(id)).
|
||||
WithCategory().
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) {
|
||||
var languageCode dailycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = dailycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.DailyContent.Create().
|
||||
SetDailyID(dailyID).
|
||||
SetLanguageCode(languageCode).
|
||||
SetQuote(quote).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) {
|
||||
return s.client.Daily.Query().
|
||||
Where(daily.IDEQ(id)).
|
||||
WithCategory().
|
||||
WithContents().
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) {
|
||||
var languageCode dailycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = dailycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
query := s.client.Daily.Query().
|
||||
WithContents(func(q *ent.DailyContentQuery) {
|
||||
if langCode != "" {
|
||||
q.Where(dailycontent.LanguageCodeEQ(languageCode))
|
||||
}
|
||||
}).
|
||||
WithCategory()
|
||||
|
||||
if categoryID != nil {
|
||||
query.Where(daily.HasCategoryWith(category.ID(*categoryID)))
|
||||
}
|
||||
|
||||
query.Order(ent.Desc(daily.FieldCreatedAt))
|
||||
|
||||
if limit > 0 {
|
||||
query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query.Offset(offset)
|
||||
}
|
||||
|
||||
return query.All(ctx)
|
||||
}
|
||||
|
||||
// Media operations
|
||||
func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
|
||||
return s.client.Media.Query().
|
||||
Order(ent.Desc("created_at")).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
All(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
|
||||
// Open the uploaded file
|
||||
src, err := openFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Save the file to storage
|
||||
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create media record
|
||||
return s.client.Media.Create().
|
||||
SetStorageID(fileInfo.ID).
|
||||
SetOriginalName(file.Filename).
|
||||
SetMimeType(fileInfo.ContentType).
|
||||
SetSize(fileInfo.Size).
|
||||
SetURL(fileInfo.URL).
|
||||
SetCreatedBy(strconv.Itoa(userID)).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) {
|
||||
return s.client.Media.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
media, err := s.GetMedia(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return s.storage.Get(ctx, media.StorageID)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
|
||||
media, err := s.GetMedia(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
if media.CreatedBy != strconv.Itoa(userID) {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
return s.client.Media.DeleteOne(media).Exec(ctx)
|
||||
}
|
||||
|
||||
// Post operations
|
||||
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
|
||||
var postStatus post.Status
|
||||
switch status {
|
||||
case "draft":
|
||||
postStatus = post.StatusDraft
|
||||
case "published":
|
||||
postStatus = post.StatusPublished
|
||||
case "archived":
|
||||
postStatus = post.StatusArchived
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid status: %s", status)
|
||||
}
|
||||
|
||||
// Generate a random slug
|
||||
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
|
||||
|
||||
return s.client.Post.Create().
|
||||
SetStatus(postStatus).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// Get the post first to check if it exists
|
||||
post, err := s.client.Post.Get(ctx, postID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
// Generate slug from title
|
||||
var slug string
|
||||
if langCode == "en" {
|
||||
// For English titles, convert to lowercase and replace spaces with dashes
|
||||
slug = strings.ToLower(strings.ReplaceAll(title, " ", "-"))
|
||||
// Remove all non-alphanumeric characters except dashes
|
||||
slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "")
|
||||
// Ensure slug is not empty and has minimum length
|
||||
if slug == "" || len(slug) < 4 {
|
||||
slug = fmt.Sprintf("post-%s", uuid.NewString()[:8])
|
||||
}
|
||||
} else {
|
||||
// For Chinese titles, use the title as is
|
||||
slug = title
|
||||
}
|
||||
|
||||
return s.client.PostContent.Create().
|
||||
SetPost(post).
|
||||
SetLanguageCode(languageCode).
|
||||
SetTitle(title).
|
||||
SetContentMarkdown(content).
|
||||
SetSummary(summary).
|
||||
SetMetaKeywords(metaKeywords).
|
||||
SetMetaDescription(metaDescription).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// Find posts that have content with the given slug and language code
|
||||
posts, err := s.client.Post.Query().
|
||||
Where(
|
||||
post.And(
|
||||
post.StatusEQ(post.StatusPublished),
|
||||
post.HasContentsWith(
|
||||
postcontent.And(
|
||||
postcontent.LanguageCodeEQ(languageCode),
|
||||
postcontent.SlugEQ(slug),
|
||||
),
|
||||
),
|
||||
),
|
||||
).
|
||||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get posts: %w", err)
|
||||
}
|
||||
|
||||
if len(posts) == 0 {
|
||||
return nil, fmt.Errorf("post not found")
|
||||
}
|
||||
if len(posts) > 1 {
|
||||
return nil, fmt.Errorf("multiple posts found with the same slug")
|
||||
}
|
||||
|
||||
return posts[0], nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// First find all post IDs that have content in the given language
|
||||
query := s.client.PostContent.Query().
|
||||
Where(postcontent.LanguageCodeEQ(languageCode)).
|
||||
QueryPost().
|
||||
Where(post.StatusEQ(post.StatusPublished))
|
||||
|
||||
// Add category filter if provided
|
||||
if categoryID != nil {
|
||||
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
|
||||
}
|
||||
|
||||
// Get unique post IDs
|
||||
postIDs, err := query.
|
||||
Order(ent.Desc(post.FieldCreatedAt)).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get post IDs: %w", err)
|
||||
}
|
||||
|
||||
// Remove duplicates while preserving order
|
||||
seen := make(map[int]bool)
|
||||
uniqueIDs := make([]int, 0, len(postIDs))
|
||||
for _, id := range postIDs {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
}
|
||||
postIDs = uniqueIDs
|
||||
|
||||
if len(postIDs) == 0 {
|
||||
return []*ent.Post{}, nil
|
||||
}
|
||||
|
||||
// If no category filter is applied, only take the latest 5 posts
|
||||
if categoryID == nil && len(postIDs) > 5 {
|
||||
postIDs = postIDs[:5]
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if offset >= len(postIDs) {
|
||||
return []*ent.Post{}, nil
|
||||
}
|
||||
|
||||
// If limit is 0, set it to the length of postIDs
|
||||
if limit == 0 {
|
||||
limit = len(postIDs)
|
||||
}
|
||||
|
||||
// Adjust limit if it would exceed total
|
||||
if offset+limit > len(postIDs) {
|
||||
limit = len(postIDs) - offset
|
||||
}
|
||||
|
||||
// Get the paginated post IDs
|
||||
paginatedIDs := postIDs[offset : offset+limit]
|
||||
|
||||
// Get the posts with their contents
|
||||
posts, err := s.client.Post.Query().
|
||||
Where(post.IDIn(paginatedIDs...)).
|
||||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
Order(ent.Desc(post.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get posts: %w", err)
|
||||
}
|
||||
|
||||
// Sort posts by ID to match the order of postIDs
|
||||
sort.Slice(posts, func(i, j int) bool {
|
||||
// Find index of each post ID in postIDs
|
||||
var iIndex, jIndex int
|
||||
for idx, id := range paginatedIDs {
|
||||
if posts[i].ID == id {
|
||||
iIndex = idx
|
||||
}
|
||||
if posts[j].ID == id {
|
||||
jIndex = idx
|
||||
}
|
||||
}
|
||||
return iIndex < jIndex
|
||||
})
|
||||
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
// Contributor operations
|
||||
func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) {
|
||||
builder := s.client.Contributor.Create().
|
||||
SetName(name)
|
||||
|
||||
if avatarURL != nil {
|
||||
builder.SetAvatarURL(*avatarURL)
|
||||
}
|
||||
if bio != nil {
|
||||
builder.SetBio(*bio)
|
||||
}
|
||||
|
||||
contributor, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create contributor: %w", err)
|
||||
}
|
||||
|
||||
return contributor, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) {
|
||||
// 验证贡献者是否存在
|
||||
contributor, err := s.client.Contributor.Get(ctx, contributorID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get contributor: %w", err)
|
||||
}
|
||||
|
||||
// 创建社交链接
|
||||
link, err := s.client.ContributorSocialLink.Create().
|
||||
SetContributor(contributor).
|
||||
SetType(contributorsociallink.Type(linkType)).
|
||||
SetName(name).
|
||||
SetValue(value).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create social link: %w", err)
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) {
|
||||
contributor, err := s.client.Contributor.Query().
|
||||
Where(contributor.ID(id)).
|
||||
WithSocialLinks().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get contributor: %w", err)
|
||||
}
|
||||
return contributor, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) {
|
||||
contributors, err := s.client.Contributor.Query().
|
||||
WithSocialLinks().
|
||||
Order(ent.Asc(contributor.FieldName)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list contributors: %w", err)
|
||||
}
|
||||
return contributors, nil
|
||||
}
|
||||
|
||||
// RBAC operations
|
||||
func (s *serviceImpl) InitializeRBAC(ctx context.Context) error {
|
||||
// Create roles if they don't exist
|
||||
adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create admin role: %w", err)
|
||||
}
|
||||
|
||||
editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create editor role: %w", err)
|
||||
}
|
||||
|
||||
userRole, err := s.client.Role.Create().SetName("user").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user role: %w", err)
|
||||
}
|
||||
|
||||
// Define permissions
|
||||
permissions := []struct {
|
||||
role *ent.Role
|
||||
resource string
|
||||
actions []string
|
||||
}{
|
||||
// Admin permissions (full access)
|
||||
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
|
||||
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
// Editor permissions (can create and manage content)
|
||||
{editorRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{editorRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{editorRole, "categories", []string{"read"}},
|
||||
{editorRole, "contributors", []string{"read"}},
|
||||
{editorRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
// User permissions (read-only access)
|
||||
{userRole, "media", []string{"read"}},
|
||||
{userRole, "posts", []string{"read"}},
|
||||
{userRole, "categories", []string{"read"}},
|
||||
{userRole, "contributors", []string{"read"}},
|
||||
{userRole, "dailies", []string{"read"}},
|
||||
}
|
||||
|
||||
// Create permissions for each role
|
||||
for _, p := range permissions {
|
||||
for _, action := range p.actions {
|
||||
perm, err := s.client.Permission.Create().
|
||||
SetResource(p.resource).
|
||||
SetAction(action).
|
||||
Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
perm, err = s.client.Permission.Query().
|
||||
Where(
|
||||
permission.ResourceEQ(p.resource),
|
||||
permission.ActionEQ(action),
|
||||
).
|
||||
Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err)
|
||||
}
|
||||
|
||||
// Add permission to role
|
||||
err = s.client.Role.UpdateOne(p.role).
|
||||
AddPermissions(perm).
|
||||
Exec(ctx)
|
||||
if err != nil && !ent.IsConstraintError(err) {
|
||||
return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error {
|
||||
user, err := s.client.User.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error {
|
||||
// Don't allow removing the user role
|
||||
if roleName == "user" {
|
||||
return errors.New("cannot remove user role")
|
||||
}
|
||||
|
||||
user, err := s.client.User.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.ID(userID)).
|
||||
WithRoles().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return user.Edges.Roles, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.ID(userID)).
|
||||
WithRoles(func(q *ent.RoleQuery) {
|
||||
q.WithPermissions()
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(permission, ":")
|
||||
if len(parts) != 2 {
|
||||
return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission)
|
||||
}
|
||||
resource, action := parts[0], parts[1]
|
||||
|
||||
for _, r := range user.Edges.Roles {
|
||||
for _, p := range r.Edges.Permissions {
|
||||
if p.Resource == resource && p.Action == action {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
1092
backend/internal/service/impl_test.go
Normal file
1092
backend/internal/service/impl_test.go
Normal file
File diff suppressed because it is too large
Load diff
179
backend/internal/service/media.go
Normal file
179
backend/internal/service/media.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"path/filepath"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/pkg/imageutil"
|
||||
)
|
||||
|
||||
type MediaService interface {
|
||||
// Upload uploads a new file and creates a media record
|
||||
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
|
||||
|
||||
// Get retrieves a media file by ID
|
||||
Get(ctx context.Context, id int) (*ent.Media, error)
|
||||
|
||||
// Delete deletes a media file
|
||||
Delete(ctx context.Context, id int, userID int) error
|
||||
|
||||
// List lists media files with pagination
|
||||
List(ctx context.Context, limit, offset int) ([]*ent.Media, error)
|
||||
|
||||
// GetFile gets the file content and info
|
||||
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
|
||||
}
|
||||
|
||||
type mediaService struct {
|
||||
client *ent.Client
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
func NewMediaService(client *ent.Client, storage storage.Storage) MediaService {
|
||||
return &mediaService{
|
||||
client: client,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// isValidFilename checks if a filename is valid
|
||||
func isValidFilename(filename string) bool {
|
||||
// Check for illegal characters
|
||||
if strings.Contains(filename, "../") ||
|
||||
strings.Contains(filename, "./") ||
|
||||
strings.Contains(filename, "\\") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *mediaService) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
|
||||
// Validate filename
|
||||
if !isValidFilename(file.Filename) {
|
||||
return nil, fmt.Errorf("invalid filename: %s", file.Filename)
|
||||
}
|
||||
|
||||
// Open the file
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Read file content for processing
|
||||
fileBytes, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
filename := file.Filename
|
||||
var processedBytes []byte
|
||||
|
||||
// Process image if it's an image file
|
||||
if imageutil.IsImageFormat(contentType) {
|
||||
opts := imageutil.DefaultOptions()
|
||||
processedBytes, err = imageutil.ProcessImage(bytes.NewReader(fileBytes), opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process image: %w", err)
|
||||
}
|
||||
|
||||
// Update content type and filename for WebP
|
||||
contentType = "image/webp"
|
||||
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + ".webp"
|
||||
} else {
|
||||
processedBytes = fileBytes
|
||||
}
|
||||
|
||||
// Save the processed file
|
||||
fileInfo, err := s.storage.Save(ctx, filename, contentType, bytes.NewReader(processedBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
|
||||
// Create media record in database
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID(fileInfo.ID).
|
||||
SetOriginalName(filename).
|
||||
SetMimeType(contentType).
|
||||
SetSize(int64(len(processedBytes))).
|
||||
SetURL(fmt.Sprintf("/api/media/%s", fileInfo.ID)).
|
||||
SetCreatedBy(fmt.Sprint(userID)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
// Try to cleanup the stored file if database operation fails
|
||||
_ = s.storage.Delete(ctx, fileInfo.ID)
|
||||
return nil, fmt.Errorf("failed to create media record: %w", err)
|
||||
}
|
||||
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) Get(ctx context.Context, id int) (*ent.Media, error) {
|
||||
media, err := s.client.Media.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("media not found: %d", id)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get media: %w", err)
|
||||
}
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) Delete(ctx context.Context, id int, userID int) error {
|
||||
media, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
if media.CreatedBy != fmt.Sprintf("%d", userID) {
|
||||
return fmt.Errorf("unauthorized to delete media")
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
|
||||
return fmt.Errorf("failed to delete file from storage: %w", err)
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
if err := s.client.Media.DeleteOne(media).Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete media record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mediaService) List(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
|
||||
media, err := s.client.Media.Query().
|
||||
Order(ent.Desc("created_at")).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list media: %w", err)
|
||||
}
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
media, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
reader, info, err := s.storage.Get(ctx, media.StorageID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get file from storage: %w", err)
|
||||
}
|
||||
|
||||
return reader, info, nil
|
||||
}
|
332
backend/internal/service/media_test.go
Normal file
332
backend/internal/service/media_test.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/internal/storage/mock"
|
||||
"tss-rocks-be/internal/testutil"
|
||||
|
||||
"bou.ke/monkey"
|
||||
)
|
||||
|
||||
type MediaServiceTestSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *ent.Client
|
||||
storage *mock.MockStorage
|
||||
ctrl *gomock.Controller
|
||||
svc MediaService
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testutil.NewTestClient()
|
||||
require.NotNil(s.T(), s.client)
|
||||
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storage = mock.NewMockStorage(s.ctrl)
|
||||
s.svc = NewMediaService(s.client, s.storage)
|
||||
|
||||
// 清理数据库
|
||||
_, err := s.client.Media.Delete().Exec(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
s.client.Close()
|
||||
}
|
||||
|
||||
func TestMediaServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaServiceTestSuite))
|
||||
}
|
||||
|
||||
type mockFileHeader struct {
|
||||
filename string
|
||||
contentType string
|
||||
size int64
|
||||
content []byte
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Open() (multipart.File, error) {
|
||||
return newMockMultipartFile(h.content), nil
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Filename() string {
|
||||
return h.filename
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Size() int64 {
|
||||
return h.size
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Header() textproto.MIMEHeader {
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Type", h.contentType)
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
|
||||
header := &multipart.FileHeader{
|
||||
Filename: filename,
|
||||
Header: make(textproto.MIMEHeader),
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
header.Header.Set("Content-Type", contentType)
|
||||
|
||||
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
|
||||
return newMockMultipartFile(content), nil
|
||||
})
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestUpload() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
contentType string
|
||||
content []byte
|
||||
setupMock func()
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Upload text file",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||||
content, err := io.ReadAll(reader)
|
||||
s.Require().NoError(err)
|
||||
s.Equal([]byte("test content"), content)
|
||||
return &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: int64(len(content)),
|
||||
}, nil
|
||||
})
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename",
|
||||
filename: "../test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {},
|
||||
wantErr: true,
|
||||
errMsg: "invalid filename",
|
||||
},
|
||||
{
|
||||
name: "Storage error",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
Return(nil, fmt.Errorf("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "storage error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create test file
|
||||
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
|
||||
|
||||
// Add debug output
|
||||
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
|
||||
|
||||
// Test upload
|
||||
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
|
||||
|
||||
// Add debug output
|
||||
if err != nil {
|
||||
s.T().Logf("Upload error: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), tc.errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(media)
|
||||
s.Equal(tc.filename, media.OriginalName)
|
||||
s.Equal(tc.contentType, media.MimeType)
|
||||
s.Equal(int64(len(tc.content)), media.Size)
|
||||
s.Equal("1", media.CreatedBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGet() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test get existing media
|
||||
result, err := s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(media.ID, result.ID)
|
||||
s.Equal(media.OriginalName, result.OriginalName)
|
||||
|
||||
// Test get non-existing media
|
||||
_, err = s.svc.Get(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "media not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestDelete() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test delete by unauthorized user
|
||||
err = s.svc.Delete(s.ctx, media.ID, 2)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "unauthorized")
|
||||
|
||||
// Test delete by owner
|
||||
s.storage.EXPECT().
|
||||
Delete(gomock.Any(), "test-id").
|
||||
Return(nil)
|
||||
err = s.svc.Delete(s.ctx, media.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify media is deleted
|
||||
_, err = s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestList() {
|
||||
// Create test media
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := s.client.Media.Create().
|
||||
SetStorageID(fmt.Sprintf("test-id-%d", i)).
|
||||
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test list with limit and offset
|
||||
media, err := s.svc.List(s.ctx, 3, 1)
|
||||
s.Require().NoError(err)
|
||||
s.Len(media, 3)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGetFile() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Mock storage.Get
|
||||
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
|
||||
mockFileInfo := &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: 12,
|
||||
}
|
||||
s.storage.EXPECT().
|
||||
Get(gomock.Any(), "test-id").
|
||||
Return(mockReader, mockFileInfo, nil)
|
||||
|
||||
// Test get file
|
||||
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(reader)
|
||||
s.Equal(mockFileInfo, info)
|
||||
|
||||
// Test get non-existing file
|
||||
_, _, err = s.svc.GetFile(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestIsValidFilename() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Valid filename",
|
||||
filename: "test.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ../",
|
||||
filename: "../test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ./",
|
||||
filename: "./test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with backslash",
|
||||
filename: "test\\file.txt",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
got := isValidFilename(tc.filename)
|
||||
s.Equal(tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
3
backend/internal/service/mock/mock.go
Normal file
3
backend/internal/service/mock/mock.go
Normal file
|
@ -0,0 +1,3 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock
|
105
backend/internal/service/rbac_service.go
Normal file
105
backend/internal/service/rbac_service.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/permission"
|
||||
"tss-rocks-be/ent/role"
|
||||
)
|
||||
|
||||
type RBACService struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewRBACService(client *ent.Client) *RBACService {
|
||||
return &RBACService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeRBAC sets up the initial RBAC configuration
|
||||
func (s *RBACService) InitializeRBAC(ctx context.Context) error {
|
||||
// Create admin role if it doesn't exist
|
||||
adminRole, err := s.client.Role.Query().
|
||||
Where(role.Name("admin")).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
adminRole, err = s.client.Role.Create().
|
||||
SetName("admin").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create admin role: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query admin role: %w", err)
|
||||
}
|
||||
|
||||
// Create editor role if it doesn't exist
|
||||
editorRole, err := s.client.Role.Query().
|
||||
Where(role.Name("editor")).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
editorRole, err = s.client.Role.Create().
|
||||
SetName("editor").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create editor role: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query editor role: %w", err)
|
||||
}
|
||||
|
||||
// Define permissions
|
||||
permissions := []struct {
|
||||
role *ent.Role
|
||||
resource string
|
||||
actions []string
|
||||
}{
|
||||
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
|
||||
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
{editorRole, "media", []string{"create", "read", "update"}},
|
||||
{editorRole, "posts", []string{"create", "read", "update"}},
|
||||
{editorRole, "categories", []string{"read"}},
|
||||
{editorRole, "contributors", []string{"read"}},
|
||||
{editorRole, "dailies", []string{"create", "read", "update"}},
|
||||
}
|
||||
|
||||
// Create permissions for each role
|
||||
for _, p := range permissions {
|
||||
for _, action := range p.actions {
|
||||
// Check if permission already exists
|
||||
exists, err := s.client.Permission.Query().
|
||||
Where(
|
||||
permission.Resource(p.resource),
|
||||
permission.Action(action),
|
||||
permission.HasRolesWith(role.ID(p.role.ID)),
|
||||
).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query permission: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Create permission and associate it with the role
|
||||
_, err = s.client.Permission.Create().
|
||||
SetResource(p.resource).
|
||||
SetAction(action).
|
||||
AddRoles(p.role).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create permission: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
59
backend/internal/service/service.go
Normal file
59
backend/internal/service/service.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package service
|
||||
|
||||
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
)
|
||||
|
||||
// Service interface defines all business logic operations
|
||||
type Service interface {
|
||||
// User operations
|
||||
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
|
||||
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
|
||||
|
||||
// Category operations
|
||||
CreateCategory(ctx context.Context) (*ent.Category, error)
|
||||
AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error)
|
||||
GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error)
|
||||
ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
|
||||
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
|
||||
|
||||
// Post operations
|
||||
CreatePost(ctx context.Context, status string) (*ent.Post, error)
|
||||
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
|
||||
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
|
||||
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
|
||||
|
||||
// Contributor operations
|
||||
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
|
||||
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
|
||||
GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error)
|
||||
ListContributors(ctx context.Context) ([]*ent.Contributor, error)
|
||||
|
||||
// Daily operations
|
||||
CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error)
|
||||
AddDailyContent(ctx context.Context, dailyID string, langCode, quote string) (*ent.DailyContent, error)
|
||||
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
|
||||
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
|
||||
|
||||
// Media operations
|
||||
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
|
||||
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
|
||||
GetMedia(ctx context.Context, id int) (*ent.Media, error)
|
||||
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
|
||||
DeleteMedia(ctx context.Context, id int, userID int) error
|
||||
|
||||
// RBAC operations
|
||||
AssignRole(ctx context.Context, userID int, role string) error
|
||||
RemoveRole(ctx context.Context, userID int, role string) error
|
||||
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
|
||||
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
|
||||
InitializeRBAC(ctx context.Context) error
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue