[feature] migrate to monorepo
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s

This commit is contained in:
CDN 2025-02-21 00:49:20 +08:00
commit 05ddc1f783
Signed by: CDN
GPG key ID: 0C656827F9F80080
267 changed files with 75165 additions and 0 deletions

View file

@ -0,0 +1,892 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"mime/multipart"
"sort"
"strconv"
"strings"
"regexp"
"tss-rocks-be/ent"
"tss-rocks-be/ent/category"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/ent/contributor"
"tss-rocks-be/ent/contributorsociallink"
"tss-rocks-be/ent/daily"
"tss-rocks-be/ent/dailycontent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/post"
"tss-rocks-be/ent/postcontent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
// Error definitions
var (
ErrUnauthorized = errors.New("unauthorized")
)
// openFile is a variable that holds the Open method of multipart.FileHeader
// This allows us to mock it in tests
var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) {
return fh.Open()
}
type serviceImpl struct {
client *ent.Client
storage storage.Storage
}
// NewService creates a new Service instance
func NewService(client *ent.Client, storage storage.Storage) Service {
return &serviceImpl{
client: client,
storage: storage,
}
}
// User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
// Hash the password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Add the user role by default
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
AddRoles(userRole)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query().
Where(user.EmailEQ(email)).
Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email)
}
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// Category operations
func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) {
return s.client.Category.Create().Save(ctx)
}
func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.CategoryContent.Create().
SetCategoryID(categoryID).
SetLanguageCode(languageCode).
SetName(name).
SetDescription(description).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.Category.Query().
Where(
category.HasContentsWith(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugEQ(slug),
),
),
).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(categorycontent.LanguageCodeEQ(languageCode))
}).
Only(ctx)
}
func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
// Daily operations
func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) {
_, err := s.client.Daily.Create().
SetID(id).
SetCategoryID(categoryID).
SetImageURL(imageURL).
Save(ctx)
if err != nil {
return nil, err
}
// 加载 Category Edge
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
Only(ctx)
}
func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.DailyContent.Create().
SetDailyID(dailyID).
SetLanguageCode(languageCode).
SetQuote(quote).
Save(ctx)
}
func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) {
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
WithContents().
Only(ctx)
}
func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
query := s.client.Daily.Query().
WithContents(func(q *ent.DailyContentQuery) {
if langCode != "" {
q.Where(dailycontent.LanguageCodeEQ(languageCode))
}
}).
WithCategory()
if categoryID != nil {
query.Where(daily.HasCategoryWith(category.ID(*categoryID)))
}
query.Order(ent.Desc(daily.FieldCreatedAt))
if limit > 0 {
query.Limit(limit)
}
if offset > 0 {
query.Offset(offset)
}
return query.All(ctx)
}
// Media operations
func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
return s.client.Media.Query().
Order(ent.Desc("created_at")).
Limit(limit).
Offset(offset).
All(ctx)
}
func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
// Open the uploaded file
src, err := openFile(file)
if err != nil {
return nil, err
}
defer src.Close()
// Save the file to storage
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
if err != nil {
return nil, err
}
// Create media record
return s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(file.Filename).
SetMimeType(fileInfo.ContentType).
SetSize(fileInfo.Size).
SetURL(fileInfo.URL).
SetCreatedBy(strconv.Itoa(userID)).
Save(ctx)
}
func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) {
return s.client.Media.Get(ctx, id)
}
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.GetMedia(ctx, id)
if err != nil {
return nil, nil, err
}
return s.storage.Get(ctx, media.StorageID)
}
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
media, err := s.GetMedia(ctx, id)
if err != nil {
return err
}
// Check ownership
if media.CreatedBy != strconv.Itoa(userID) {
return ErrUnauthorized
}
// Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
return err
}
// Delete from database
return s.client.Media.DeleteOne(media).Exec(ctx)
}
// Post operations
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
var postStatus post.Status
switch status {
case "draft":
postStatus = post.StatusDraft
case "published":
postStatus = post.StatusPublished
case "archived":
postStatus = post.StatusArchived
default:
return nil, fmt.Errorf("invalid status: %s", status)
}
// Generate a random slug
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
return s.client.Post.Create().
SetStatus(postStatus).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Get the post first to check if it exists
post, err := s.client.Post.Get(ctx, postID)
if err != nil {
return nil, fmt.Errorf("failed to get post: %w", err)
}
// Generate slug from title
var slug string
if langCode == "en" {
// For English titles, convert to lowercase and replace spaces with dashes
slug = strings.ToLower(strings.ReplaceAll(title, " ", "-"))
// Remove all non-alphanumeric characters except dashes
slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "")
// Ensure slug is not empty and has minimum length
if slug == "" || len(slug) < 4 {
slug = fmt.Sprintf("post-%s", uuid.NewString()[:8])
}
} else {
// For Chinese titles, use the title as is
slug = title
}
return s.client.PostContent.Create().
SetPost(post).
SetLanguageCode(languageCode).
SetTitle(title).
SetContentMarkdown(content).
SetSummary(summary).
SetMetaKeywords(metaKeywords).
SetMetaDescription(metaDescription).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Find posts that have content with the given slug and language code
posts, err := s.client.Post.Query().
Where(
post.And(
post.StatusEQ(post.StatusPublished),
post.HasContentsWith(
postcontent.And(
postcontent.LanguageCodeEQ(languageCode),
postcontent.SlugEQ(slug),
),
),
),
).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
if len(posts) == 0 {
return nil, fmt.Errorf("post not found")
}
if len(posts) > 1 {
return nil, fmt.Errorf("multiple posts found with the same slug")
}
return posts[0], nil
}
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// First find all post IDs that have content in the given language
query := s.client.PostContent.Query().
Where(postcontent.LanguageCodeEQ(languageCode)).
QueryPost().
Where(post.StatusEQ(post.StatusPublished))
// Add category filter if provided
if categoryID != nil {
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
}
// Get unique post IDs
postIDs, err := query.
Order(ent.Desc(post.FieldCreatedAt)).
IDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get post IDs: %w", err)
}
// Remove duplicates while preserving order
seen := make(map[int]bool)
uniqueIDs := make([]int, 0, len(postIDs))
for _, id := range postIDs {
if !seen[id] {
seen[id] = true
uniqueIDs = append(uniqueIDs, id)
}
}
postIDs = uniqueIDs
if len(postIDs) == 0 {
return []*ent.Post{}, nil
}
// If no category filter is applied, only take the latest 5 posts
if categoryID == nil && len(postIDs) > 5 {
postIDs = postIDs[:5]
}
// Apply pagination
if offset >= len(postIDs) {
return []*ent.Post{}, nil
}
// If limit is 0, set it to the length of postIDs
if limit == 0 {
limit = len(postIDs)
}
// Adjust limit if it would exceed total
if offset+limit > len(postIDs) {
limit = len(postIDs) - offset
}
// Get the paginated post IDs
paginatedIDs := postIDs[offset : offset+limit]
// Get the posts with their contents
posts, err := s.client.Post.Query().
Where(post.IDIn(paginatedIDs...)).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
Order(ent.Desc(post.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
// Sort posts by ID to match the order of postIDs
sort.Slice(posts, func(i, j int) bool {
// Find index of each post ID in postIDs
var iIndex, jIndex int
for idx, id := range paginatedIDs {
if posts[i].ID == id {
iIndex = idx
}
if posts[j].ID == id {
jIndex = idx
}
}
return iIndex < jIndex
})
return posts, nil
}
// Contributor operations
func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) {
builder := s.client.Contributor.Create().
SetName(name)
if avatarURL != nil {
builder.SetAvatarURL(*avatarURL)
}
if bio != nil {
builder.SetBio(*bio)
}
contributor, err := builder.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) {
// 验证贡献者是否存在
contributor, err := s.client.Contributor.Get(ctx, contributorID)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
// 创建社交链接
link, err := s.client.ContributorSocialLink.Create().
SetContributor(contributor).
SetType(contributorsociallink.Type(linkType)).
SetName(name).
SetValue(value).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create social link: %w", err)
}
return link, nil
}
func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) {
contributor, err := s.client.Contributor.Query().
Where(contributor.ID(id)).
WithSocialLinks().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) {
contributors, err := s.client.Contributor.Query().
WithSocialLinks().
Order(ent.Asc(contributor.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list contributors: %w", err)
}
return contributors, nil
}
// RBAC operations
func (s *serviceImpl) InitializeRBAC(ctx context.Context) error {
// Create roles if they don't exist
adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx)
if ent.IsConstraintError(err) {
adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create admin role: %w", err)
}
editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx)
if ent.IsConstraintError(err) {
editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create editor role: %w", err)
}
userRole, err := s.client.Role.Create().SetName("user").Save(ctx)
if ent.IsConstraintError(err) {
userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create user role: %w", err)
}
// Define permissions
permissions := []struct {
role *ent.Role
resource string
actions []string
}{
// Admin permissions (full access)
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
{adminRole, "media", []string{"create", "read", "update", "delete"}},
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
// Editor permissions (can create and manage content)
{editorRole, "media", []string{"create", "read", "update", "delete"}},
{editorRole, "posts", []string{"create", "read", "update", "delete"}},
{editorRole, "categories", []string{"read"}},
{editorRole, "contributors", []string{"read"}},
{editorRole, "dailies", []string{"create", "read", "update", "delete"}},
// User permissions (read-only access)
{userRole, "media", []string{"read"}},
{userRole, "posts", []string{"read"}},
{userRole, "categories", []string{"read"}},
{userRole, "contributors", []string{"read"}},
{userRole, "dailies", []string{"read"}},
}
// Create permissions for each role
for _, p := range permissions {
for _, action := range p.actions {
perm, err := s.client.Permission.Create().
SetResource(p.resource).
SetAction(action).
Save(ctx)
if ent.IsConstraintError(err) {
perm, err = s.client.Permission.Query().
Where(
permission.ResourceEQ(p.resource),
permission.ActionEQ(action),
).
Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err)
}
// Add permission to role
err = s.client.Role.UpdateOne(p.role).
AddPermissions(perm).
Exec(ctx)
if err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err)
}
}
}
return nil
}
func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error {
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx)
}
func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error {
// Don't allow removing the user role
if roleName == "user" {
return errors.New("cannot remove user role")
}
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx)
}
func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user.Edges.Roles, nil
}
func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles(func(q *ent.RoleQuery) {
q.WithPermissions()
}).
Only(ctx)
if err != nil {
return false, fmt.Errorf("failed to get user: %w", err)
}
parts := strings.Split(permission, ":")
if len(parts) != 2 {
return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission)
}
resource, action := parts[0], parts[1]
for _, r := range user.Edges.Roles {
for _, p := range r.Edges.Permissions {
if p.Resource == resource && p.Action == action {
return true, nil
}
}
}
return false, nil
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,179 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"strings"
"path/filepath"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/pkg/imageutil"
)
type MediaService interface {
// Upload uploads a new file and creates a media record
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
// Get retrieves a media file by ID
Get(ctx context.Context, id int) (*ent.Media, error)
// Delete deletes a media file
Delete(ctx context.Context, id int, userID int) error
// List lists media files with pagination
List(ctx context.Context, limit, offset int) ([]*ent.Media, error)
// GetFile gets the file content and info
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
}
type mediaService struct {
client *ent.Client
storage storage.Storage
}
func NewMediaService(client *ent.Client, storage storage.Storage) MediaService {
return &mediaService{
client: client,
storage: storage,
}
}
// isValidFilename checks if a filename is valid
func isValidFilename(filename string) bool {
// Check for illegal characters
if strings.Contains(filename, "../") ||
strings.Contains(filename, "./") ||
strings.Contains(filename, "\\") {
return false
}
return true
}
func (s *mediaService) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
// Validate filename
if !isValidFilename(file.Filename) {
return nil, fmt.Errorf("invalid filename: %s", file.Filename)
}
// Open the file
src, err := file.Open()
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer src.Close()
// Read file content for processing
fileBytes, err := io.ReadAll(src)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
contentType := file.Header.Get("Content-Type")
filename := file.Filename
var processedBytes []byte
// Process image if it's an image file
if imageutil.IsImageFormat(contentType) {
opts := imageutil.DefaultOptions()
processedBytes, err = imageutil.ProcessImage(bytes.NewReader(fileBytes), opts)
if err != nil {
return nil, fmt.Errorf("failed to process image: %w", err)
}
// Update content type and filename for WebP
contentType = "image/webp"
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + ".webp"
} else {
processedBytes = fileBytes
}
// Save the processed file
fileInfo, err := s.storage.Save(ctx, filename, contentType, bytes.NewReader(processedBytes))
if err != nil {
return nil, fmt.Errorf("failed to save file: %w", err)
}
// Create media record in database
media, err := s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(filename).
SetMimeType(contentType).
SetSize(int64(len(processedBytes))).
SetURL(fmt.Sprintf("/api/media/%s", fileInfo.ID)).
SetCreatedBy(fmt.Sprint(userID)).
Save(ctx)
if err != nil {
// Try to cleanup the stored file if database operation fails
_ = s.storage.Delete(ctx, fileInfo.ID)
return nil, fmt.Errorf("failed to create media record: %w", err)
}
return media, nil
}
func (s *mediaService) Get(ctx context.Context, id int) (*ent.Media, error) {
media, err := s.client.Media.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("media not found: %d", id)
}
return nil, fmt.Errorf("failed to get media: %w", err)
}
return media, nil
}
func (s *mediaService) Delete(ctx context.Context, id int, userID int) error {
media, err := s.Get(ctx, id)
if err != nil {
return err
}
// Check ownership
if media.CreatedBy != fmt.Sprintf("%d", userID) {
return fmt.Errorf("unauthorized to delete media")
}
// Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
return fmt.Errorf("failed to delete file from storage: %w", err)
}
// Delete from database
if err := s.client.Media.DeleteOne(media).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete media record: %w", err)
}
return nil
}
func (s *mediaService) List(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
media, err := s.client.Media.Query().
Order(ent.Desc("created_at")).
Limit(limit).
Offset(offset).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list media: %w", err)
}
return media, nil
}
func (s *mediaService) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.Get(ctx, id)
if err != nil {
return nil, nil, err
}
reader, info, err := s.storage.Get(ctx, media.StorageID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get file from storage: %w", err)
}
return reader, info, nil
}

View file

@ -0,0 +1,332 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/textproto"
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/storage/mock"
"tss-rocks-be/internal/testutil"
"bou.ke/monkey"
)
type MediaServiceTestSuite struct {
suite.Suite
ctx context.Context
client *ent.Client
storage *mock.MockStorage
ctrl *gomock.Controller
svc MediaService
}
func (s *MediaServiceTestSuite) SetupTest() {
s.ctx = context.Background()
s.client = testutil.NewTestClient()
require.NotNil(s.T(), s.client)
s.ctrl = gomock.NewController(s.T())
s.storage = mock.NewMockStorage(s.ctrl)
s.svc = NewMediaService(s.client, s.storage)
// 清理数据库
_, err := s.client.Media.Delete().Exec(s.ctx)
require.NoError(s.T(), err)
}
func (s *MediaServiceTestSuite) TearDownTest() {
s.ctrl.Finish()
s.client.Close()
}
func TestMediaServiceSuite(t *testing.T) {
suite.Run(t, new(MediaServiceTestSuite))
}
type mockFileHeader struct {
filename string
contentType string
size int64
content []byte
}
func (h *mockFileHeader) Open() (multipart.File, error) {
return newMockMultipartFile(h.content), nil
}
func (h *mockFileHeader) Filename() string {
return h.filename
}
func (h *mockFileHeader) Size() int64 {
return h.size
}
func (h *mockFileHeader) Header() textproto.MIMEHeader {
header := make(textproto.MIMEHeader)
header.Set("Content-Type", h.contentType)
return header
}
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
header := &multipart.FileHeader{
Filename: filename,
Header: make(textproto.MIMEHeader),
Size: int64(len(content)),
}
header.Header.Set("Content-Type", contentType)
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
return newMockMultipartFile(content), nil
})
return header
}
func (s *MediaServiceTestSuite) TestUpload() {
testCases := []struct {
name string
filename string
contentType string
content []byte
setupMock func()
wantErr bool
errMsg string
}{
{
name: "Upload text file",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
content, err := io.ReadAll(reader)
s.Require().NoError(err)
s.Equal([]byte("test content"), content)
return &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: int64(len(content)),
}, nil
})
},
wantErr: false,
},
{
name: "Invalid filename",
filename: "../test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {},
wantErr: true,
errMsg: "invalid filename",
},
{
name: "Storage error",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
Return(nil, fmt.Errorf("storage error"))
},
wantErr: true,
errMsg: "storage error",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create test file
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
// Add debug output
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
// Test upload
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
// Add debug output
if err != nil {
s.T().Logf("Upload error: %v", err)
}
if tc.wantErr {
s.Require().Error(err)
s.Contains(err.Error(), tc.errMsg)
return
}
s.Require().NoError(err)
s.NotNil(media)
s.Equal(tc.filename, media.OriginalName)
s.Equal(tc.contentType, media.MimeType)
s.Equal(int64(len(tc.content)), media.Size)
s.Equal("1", media.CreatedBy)
})
}
}
func (s *MediaServiceTestSuite) TestGet() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test get existing media
result, err := s.svc.Get(s.ctx, media.ID)
s.Require().NoError(err)
s.Equal(media.ID, result.ID)
s.Equal(media.OriginalName, result.OriginalName)
// Test get non-existing media
_, err = s.svc.Get(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "media not found")
}
func (s *MediaServiceTestSuite) TestDelete() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test delete by unauthorized user
err = s.svc.Delete(s.ctx, media.ID, 2)
s.Require().Error(err)
s.Contains(err.Error(), "unauthorized")
// Test delete by owner
s.storage.EXPECT().
Delete(gomock.Any(), "test-id").
Return(nil)
err = s.svc.Delete(s.ctx, media.ID, 1)
s.Require().NoError(err)
// Verify media is deleted
_, err = s.svc.Get(s.ctx, media.ID)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestList() {
// Create test media
for i := 0; i < 5; i++ {
_, err := s.client.Media.Create().
SetStorageID(fmt.Sprintf("test-id-%d", i)).
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
SetMimeType("text/plain").
SetSize(12).
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
}
// Test list with limit and offset
media, err := s.svc.List(s.ctx, 3, 1)
s.Require().NoError(err)
s.Len(media, 3)
}
func (s *MediaServiceTestSuite) TestGetFile() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Mock storage.Get
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
mockFileInfo := &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: 12,
}
s.storage.EXPECT().
Get(gomock.Any(), "test-id").
Return(mockReader, mockFileInfo, nil)
// Test get file
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
s.Require().NoError(err)
s.NotNil(reader)
s.Equal(mockFileInfo, info)
// Test get non-existing file
_, _, err = s.svc.GetFile(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestIsValidFilename() {
testCases := []struct {
name string
filename string
want bool
}{
{
name: "Valid filename",
filename: "test.txt",
want: true,
},
{
name: "Invalid filename with ../",
filename: "../test.txt",
want: false,
},
{
name: "Invalid filename with ./",
filename: "./test.txt",
want: false,
},
{
name: "Invalid filename with backslash",
filename: "test\\file.txt",
want: false,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
got := isValidFilename(tc.filename)
s.Equal(tc.want, got)
})
}
}

View file

@ -0,0 +1,3 @@
package mock
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock

View file

@ -0,0 +1,105 @@
package service
import (
"context"
"fmt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/role"
)
type RBACService struct {
client *ent.Client
}
func NewRBACService(client *ent.Client) *RBACService {
return &RBACService{
client: client,
}
}
// InitializeRBAC sets up the initial RBAC configuration
func (s *RBACService) InitializeRBAC(ctx context.Context) error {
// Create admin role if it doesn't exist
adminRole, err := s.client.Role.Query().
Where(role.Name("admin")).
Only(ctx)
if ent.IsNotFound(err) {
adminRole, err = s.client.Role.Create().
SetName("admin").
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create admin role: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to query admin role: %w", err)
}
// Create editor role if it doesn't exist
editorRole, err := s.client.Role.Query().
Where(role.Name("editor")).
Only(ctx)
if ent.IsNotFound(err) {
editorRole, err = s.client.Role.Create().
SetName("editor").
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create editor role: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to query editor role: %w", err)
}
// Define permissions
permissions := []struct {
role *ent.Role
resource string
actions []string
}{
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
{adminRole, "media", []string{"create", "read", "update", "delete"}},
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
{editorRole, "media", []string{"create", "read", "update"}},
{editorRole, "posts", []string{"create", "read", "update"}},
{editorRole, "categories", []string{"read"}},
{editorRole, "contributors", []string{"read"}},
{editorRole, "dailies", []string{"create", "read", "update"}},
}
// Create permissions for each role
for _, p := range permissions {
for _, action := range p.actions {
// Check if permission already exists
exists, err := s.client.Permission.Query().
Where(
permission.Resource(p.resource),
permission.Action(action),
permission.HasRolesWith(role.ID(p.role.ID)),
).
Exist(ctx)
if err != nil {
return fmt.Errorf("failed to query permission: %w", err)
}
if !exists {
// Create permission and associate it with the role
_, err = s.client.Permission.Create().
SetResource(p.resource).
SetAction(action).
AddRoles(p.role).
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create permission: %w", err)
}
}
}
}
return nil
}

View file

@ -0,0 +1,59 @@
package service
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
import (
"context"
"io"
"mime/multipart"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
)
// Service interface defines all business logic operations
type Service interface {
// User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
// Category operations
CreateCategory(ctx context.Context) (*ent.Category, error)
AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error)
GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error)
ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
// Post operations
CreatePost(ctx context.Context, status string) (*ent.Post, error)
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
// Contributor operations
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error)
ListContributors(ctx context.Context) ([]*ent.Contributor, error)
// Daily operations
CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error)
AddDailyContent(ctx context.Context, dailyID string, langCode, quote string) (*ent.DailyContent, error)
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations
AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
}