1092 lines
35 KiB
Go
1092 lines
35 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/textproto"
|
||
"strconv"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/stretchr/testify/suite"
|
||
"go.uber.org/mock/gomock"
|
||
|
||
"tss-rocks-be/ent"
|
||
"tss-rocks-be/ent/categorycontent"
|
||
"tss-rocks-be/ent/dailycontent"
|
||
"tss-rocks-be/internal/storage"
|
||
"tss-rocks-be/internal/storage/mock"
|
||
"tss-rocks-be/internal/testutil"
|
||
)
|
||
|
||
type ServiceImplTestSuite struct {
|
||
suite.Suite
|
||
ctx context.Context
|
||
client *ent.Client
|
||
storage *mock.MockStorage
|
||
ctrl *gomock.Controller
|
||
svc Service
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) SetupTest() {
|
||
s.ctx = context.Background()
|
||
s.client = testutil.NewTestClient()
|
||
require.NotNil(s.T(), s.client)
|
||
|
||
s.ctrl = gomock.NewController(s.T())
|
||
s.storage = mock.NewMockStorage(s.ctrl)
|
||
s.svc = NewService(s.client, s.storage)
|
||
|
||
// 清理数据库
|
||
_, err := s.client.Category.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.CategoryContent.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.User.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.Role.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.Permission.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.Daily.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
_, err = s.client.DailyContent.Delete().Exec(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
|
||
// 初始化 RBAC 系统
|
||
err = s.svc.InitializeRBAC(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
|
||
// Set default openFile function
|
||
openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
|
||
return fh.Open()
|
||
}
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TearDownTest() {
|
||
s.ctrl.Finish()
|
||
s.client.Close()
|
||
}
|
||
|
||
func TestServiceImplSuite(t *testing.T) {
|
||
suite.Run(t, new(ServiceImplTestSuite))
|
||
}
|
||
|
||
// mockMultipartFile implements multipart.File interface
|
||
type mockMultipartFile struct {
|
||
*bytes.Reader
|
||
}
|
||
|
||
func (m *mockMultipartFile) Close() error {
|
||
return nil
|
||
}
|
||
|
||
func (m *mockMultipartFile) ReadAt(p []byte, off int64) (n int, err error) {
|
||
return m.Reader.ReadAt(p, off)
|
||
}
|
||
|
||
func (m *mockMultipartFile) Seek(offset int64, whence int) (int64, error) {
|
||
return m.Reader.Seek(offset, whence)
|
||
}
|
||
|
||
func newMockMultipartFile(data []byte) *mockMultipartFile {
|
||
return &mockMultipartFile{
|
||
Reader: bytes.NewReader(data),
|
||
}
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestCreateUser() {
|
||
testCases := []struct {
|
||
name string
|
||
email string
|
||
password string
|
||
role string
|
||
wantError bool
|
||
}{
|
||
{
|
||
name: "Valid user creation",
|
||
email: "test@example.com",
|
||
password: "password123",
|
||
role: "admin",
|
||
wantError: false,
|
||
},
|
||
{
|
||
name: "Empty email",
|
||
email: "",
|
||
password: "password123",
|
||
role: "user",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "Empty password",
|
||
email: "test@example.com",
|
||
password: "",
|
||
role: "user",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "Invalid role",
|
||
email: "test@example.com",
|
||
password: "password123",
|
||
role: "invalid_role",
|
||
wantError: true,
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
s.Run(tc.name, func() {
|
||
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role)
|
||
if tc.wantError {
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), user)
|
||
} else {
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), user)
|
||
assert.Equal(s.T(), tc.email, user.Email)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestGetUserByEmail() {
|
||
// Create a test user first
|
||
email := "test@example.com"
|
||
password := "password123"
|
||
role := "user"
|
||
|
||
user, err := s.svc.CreateUser(s.ctx, email, password, role)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), user)
|
||
|
||
s.Run("Existing user", func() {
|
||
found, err := s.svc.GetUserByEmail(s.ctx, email)
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), found)
|
||
assert.Equal(s.T(), email, found.Email)
|
||
})
|
||
|
||
s.Run("Non-existing user", func() {
|
||
found, err := s.svc.GetUserByEmail(s.ctx, "nonexistent@example.com")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), found)
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestValidatePassword() {
|
||
// Create a test user first
|
||
email := "test@example.com"
|
||
password := "password123"
|
||
role := "user"
|
||
|
||
user, err := s.svc.CreateUser(s.ctx, email, password, role)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), user)
|
||
|
||
s.Run("Valid password", func() {
|
||
valid := s.svc.ValidatePassword(s.ctx, user, password)
|
||
assert.True(s.T(), valid)
|
||
})
|
||
|
||
s.Run("Invalid password", func() {
|
||
valid := s.svc.ValidatePassword(s.ctx, user, "wrongpassword")
|
||
assert.False(s.T(), valid)
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestRBAC() {
|
||
s.Run("AssignRole", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin")
|
||
require.NoError(s.T(), err)
|
||
|
||
err = s.svc.AssignRole(s.ctx, user.ID, "user")
|
||
assert.NoError(s.T(), err)
|
||
})
|
||
|
||
s.Run("RemoveRole", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin")
|
||
require.NoError(s.T(), err)
|
||
|
||
err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
|
||
assert.NoError(s.T(), err)
|
||
})
|
||
|
||
s.Run("HasPermission", func() {
|
||
s.Run("Admin can create users", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||
require.NoError(s.T(), err)
|
||
assert.True(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("Editor cannot create users", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||
require.NoError(s.T(), err)
|
||
assert.False(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("User cannot create users", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||
require.NoError(s.T(), err)
|
||
assert.False(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("Editor can create posts", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
|
||
require.NoError(s.T(), err)
|
||
assert.True(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("User can read posts", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
|
||
require.NoError(s.T(), err)
|
||
assert.True(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("User cannot create posts", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user")
|
||
require.NoError(s.T(), err)
|
||
|
||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
|
||
require.NoError(s.T(), err)
|
||
assert.False(s.T(), hasPermission)
|
||
})
|
||
|
||
s.Run("Invalid permission format", func() {
|
||
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user")
|
||
require.NoError(s.T(), err)
|
||
|
||
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
|
||
require.Error(s.T(), err)
|
||
assert.Contains(s.T(), err.Error(), "invalid permission format")
|
||
})
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestCategory() {
|
||
// Create a test user with admin role for testing
|
||
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), adminUser)
|
||
|
||
s.Run("CreateCategory", func() {
|
||
// Test category creation
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), category)
|
||
assert.NotZero(s.T(), category.ID)
|
||
})
|
||
|
||
s.Run("AddCategoryContent", func() {
|
||
// Create a category first
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), category)
|
||
|
||
testCases := []struct {
|
||
name string
|
||
langCode string
|
||
catName string
|
||
desc string
|
||
slug string
|
||
wantError bool
|
||
}{
|
||
{
|
||
name: "Valid category content",
|
||
langCode: "en",
|
||
catName: "Test Category",
|
||
desc: "Test Description",
|
||
slug: "test-category",
|
||
wantError: false,
|
||
},
|
||
{
|
||
name: "Empty language code",
|
||
langCode: "",
|
||
catName: "Test Category",
|
||
desc: "Test Description",
|
||
slug: "test-category-2",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "Empty name",
|
||
langCode: "en",
|
||
catName: "",
|
||
desc: "Test Description",
|
||
slug: "test-category-3",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "Empty slug",
|
||
langCode: "en",
|
||
catName: "Test Category",
|
||
desc: "Test Description",
|
||
slug: "",
|
||
wantError: true,
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
s.Run(tc.name, func() {
|
||
content, err := s.svc.AddCategoryContent(s.ctx, category.ID, tc.langCode, tc.catName, tc.desc, tc.slug)
|
||
if tc.wantError {
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), content)
|
||
} else {
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), content)
|
||
assert.Equal(s.T(), categorycontent.LanguageCode(tc.langCode), content.LanguageCode)
|
||
assert.Equal(s.T(), tc.catName, content.Name)
|
||
assert.Equal(s.T(), tc.desc, content.Description)
|
||
assert.Equal(s.T(), tc.slug, content.Slug)
|
||
}
|
||
})
|
||
}
|
||
})
|
||
|
||
s.Run("GetCategoryBySlug", func() {
|
||
// Create a category with content first
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), category)
|
||
|
||
content, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category-get")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), content)
|
||
|
||
s.Run("Existing category", func() {
|
||
found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "test-category-get")
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), found)
|
||
assert.Equal(s.T(), category.ID, found.ID)
|
||
|
||
// Check if content is loaded
|
||
require.NotEmpty(s.T(), found.Edges.Contents)
|
||
assert.Equal(s.T(), "Test Category", found.Edges.Contents[0].Name)
|
||
})
|
||
|
||
s.Run("Non-existing category", func() {
|
||
found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "non-existent")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), found)
|
||
})
|
||
|
||
s.Run("Wrong language code", func() {
|
||
found, err := s.svc.GetCategoryBySlug(s.ctx, "fr", "test-category-get")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), found)
|
||
})
|
||
})
|
||
|
||
s.Run("ListCategories", func() {
|
||
s.Run("List English categories", func() {
|
||
// 创建多个分类,但只有 3 个有英文内容
|
||
var createdCategories []*ent.Category
|
||
for i := 0; i < 5; i++ {
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), category)
|
||
createdCategories = append(createdCategories, category)
|
||
|
||
// 只给前 3 个分类添加英文内容
|
||
if i < 3 {
|
||
_, err = s.svc.AddCategoryContent(s.ctx, category.ID, "en",
|
||
fmt.Sprintf("Category %d", i),
|
||
fmt.Sprintf("Description %d", i),
|
||
fmt.Sprintf("category-list-%d", i))
|
||
require.NoError(s.T(), err)
|
||
}
|
||
}
|
||
|
||
categories, err := s.svc.ListCategories(s.ctx, "en")
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), categories)
|
||
assert.Len(s.T(), categories, 3)
|
||
|
||
// 检查所有返回的分类都有英文内容
|
||
for _, cat := range categories {
|
||
assert.NotEmpty(s.T(), cat.Edges.Contents)
|
||
for _, content := range cat.Edges.Contents {
|
||
assert.Equal(s.T(), categorycontent.LanguageCodeEN, content.LanguageCode)
|
||
}
|
||
}
|
||
})
|
||
|
||
s.Run("List Chinese categories", func() {
|
||
// 创建多个分类,但只有 2 个有中文内容
|
||
for i := 0; i < 4; i++ {
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), category)
|
||
|
||
// 只给前 2 个分类添加中文内容
|
||
if i < 2 {
|
||
_, err = s.svc.AddCategoryContent(s.ctx, category.ID, "zh-Hans",
|
||
fmt.Sprintf("分类 %d", i),
|
||
fmt.Sprintf("描述 %d", i),
|
||
fmt.Sprintf("category-list-%d", i))
|
||
require.NoError(s.T(), err)
|
||
}
|
||
}
|
||
|
||
categories, err := s.svc.ListCategories(s.ctx, "zh-Hans")
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), categories)
|
||
assert.Len(s.T(), categories, 2)
|
||
|
||
// 检查所有返回的分类都有中文内容
|
||
for _, cat := range categories {
|
||
assert.NotEmpty(s.T(), cat.Edges.Contents)
|
||
for _, content := range cat.Edges.Contents {
|
||
assert.Equal(s.T(), categorycontent.LanguageCodeZH_HANS, content.LanguageCode)
|
||
}
|
||
}
|
||
})
|
||
|
||
s.Run("List non-existing language", func() {
|
||
categories, err := s.svc.ListCategories(s.ctx, "fr")
|
||
assert.NoError(s.T(), err)
|
||
assert.Empty(s.T(), categories)
|
||
})
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestGetCategories() {
|
||
ctx := context.Background()
|
||
|
||
// 测试不支持的语言代码
|
||
categories, err := s.svc.GetCategories(ctx, "invalid")
|
||
s.Require().NoError(err)
|
||
s.Empty(categories)
|
||
|
||
// 创建测试数据
|
||
cat1 := s.createTestCategory(ctx, "test-cat-1")
|
||
cat2 := s.createTestCategory(ctx, "test-cat-2")
|
||
|
||
// 为分类添加不同语言的内容
|
||
_, err = s.svc.AddCategoryContent(ctx, cat1.ID, "en", "Test Category 1", "Test Description 1", "category-list-test-1")
|
||
s.Require().NoError(err)
|
||
|
||
_, err = s.svc.AddCategoryContent(ctx, cat2.ID, "zh-Hans", "测试分类2", "测试描述2", "category-list-test-2")
|
||
s.Require().NoError(err)
|
||
|
||
// 测试获取英文分类
|
||
enCategories, err := s.svc.GetCategories(ctx, "en")
|
||
s.Require().NoError(err)
|
||
s.Len(enCategories, 1)
|
||
s.Equal(cat1.ID, enCategories[0].ID)
|
||
|
||
// 测试获取简体中文分类
|
||
zhCategories, err := s.svc.GetCategories(ctx, "zh-Hans")
|
||
s.Require().NoError(err)
|
||
s.Len(zhCategories, 1)
|
||
s.Equal(cat2.ID, zhCategories[0].ID)
|
||
|
||
// 测试获取繁体中文分类(应该为空)
|
||
zhHantCategories, err := s.svc.GetCategories(ctx, "zh-Hant")
|
||
s.Require().NoError(err)
|
||
s.Empty(zhHantCategories)
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestGetUserRoles() {
|
||
ctx := context.Background()
|
||
|
||
// 创建测试用户,默认会有 "user" 角色
|
||
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user")
|
||
s.Require().NoError(err)
|
||
|
||
// 测试新用户有默认的 "user" 角色
|
||
roles, err := s.svc.GetUserRoles(ctx, user.ID)
|
||
s.Require().NoError(err)
|
||
s.Len(roles, 1)
|
||
s.Equal("user", roles[0].Name)
|
||
|
||
// 分配角色给用户
|
||
err = s.svc.AssignRole(ctx, user.ID, "admin")
|
||
s.Require().NoError(err)
|
||
|
||
// 测试用户现在有两个角色
|
||
roles, err = s.svc.GetUserRoles(ctx, user.ID)
|
||
s.Require().NoError(err)
|
||
s.Len(roles, 2)
|
||
roleNames := []string{roles[0].Name, roles[1].Name}
|
||
s.Contains(roleNames, "user")
|
||
s.Contains(roleNames, "admin")
|
||
|
||
// 测试不存在的用户
|
||
_, err = s.svc.GetUserRoles(ctx, -1)
|
||
s.Require().Error(err)
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestDaily() {
|
||
// 创建一个测试分类
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), category)
|
||
|
||
// 添加分类内容
|
||
categoryContent, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), categoryContent)
|
||
|
||
dailyID := "250212" // 使用符合验证规则的 ID 格式:YYMMDD
|
||
|
||
// 测试创建 Daily
|
||
s.Run("Create Daily", func() {
|
||
daily, err := s.svc.CreateDaily(s.ctx, dailyID, category.ID, "http://example.com/image.jpg")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), daily)
|
||
assert.Equal(s.T(), dailyID, daily.ID)
|
||
assert.Equal(s.T(), category.ID, daily.Edges.Category.ID)
|
||
assert.Equal(s.T(), "http://example.com/image.jpg", daily.ImageURL)
|
||
})
|
||
|
||
// 测试添加 Daily 内容
|
||
s.Run("Add Daily Content", func() {
|
||
content, err := s.svc.AddDailyContent(s.ctx, dailyID, "en", "Test quote for the day")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), content)
|
||
assert.Equal(s.T(), dailycontent.LanguageCodeEN, content.LanguageCode)
|
||
assert.Equal(s.T(), "Test quote for the day", content.Quote)
|
||
})
|
||
|
||
// 测试获取 Daily
|
||
s.Run("Get Daily By ID", func() {
|
||
daily, err := s.svc.GetDailyByID(s.ctx, dailyID)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), daily)
|
||
assert.Equal(s.T(), dailyID, daily.ID)
|
||
assert.Equal(s.T(), category.ID, daily.Edges.Category.ID)
|
||
})
|
||
|
||
// 测试列出 Daily
|
||
s.Run("List Dailies", func() {
|
||
// 创建另一个 Daily 用于测试列表
|
||
anotherDailyID := "250213"
|
||
_, err := s.svc.CreateDaily(s.ctx, anotherDailyID, category.ID, "http://example.com/image2.jpg")
|
||
assert.NoError(s.T(), err)
|
||
_, err = s.svc.AddDailyContent(s.ctx, anotherDailyID, "en", "Another test quote")
|
||
assert.NoError(s.T(), err)
|
||
|
||
// 测试列表功能
|
||
dailies, err := s.svc.ListDailies(s.ctx, "en", &category.ID, 10, 0)
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), dailies)
|
||
assert.Len(s.T(), dailies, 2)
|
||
|
||
// 测试分页
|
||
dailies, err = s.svc.ListDailies(s.ctx, "en", &category.ID, 1, 0)
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), dailies)
|
||
assert.Len(s.T(), dailies, 1)
|
||
|
||
// 测试无分类过滤
|
||
dailies, err = s.svc.ListDailies(s.ctx, "en", nil, 10, 0)
|
||
assert.NoError(s.T(), err)
|
||
assert.NotNil(s.T(), dailies)
|
||
assert.Len(s.T(), dailies, 2)
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestPost() {
|
||
s.Run("Create Post", func() {
|
||
s.Run("Draft", func() {
|
||
post, err := s.svc.CreatePost(s.ctx, "draft")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), post)
|
||
assert.Equal(s.T(), "draft", post.Status.String())
|
||
})
|
||
|
||
s.Run("Published", func() {
|
||
post, err := s.svc.CreatePost(s.ctx, "published")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), post)
|
||
assert.Equal(s.T(), "published", post.Status.String())
|
||
})
|
||
|
||
s.Run("Archived", func() {
|
||
post, err := s.svc.CreatePost(s.ctx, "archived")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), post)
|
||
assert.Equal(s.T(), "archived", post.Status.String())
|
||
})
|
||
|
||
s.Run("Invalid Status", func() {
|
||
post, err := s.svc.CreatePost(s.ctx, "invalid")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), post)
|
||
})
|
||
})
|
||
|
||
s.Run("Add Post Content", func() {
|
||
// Create a post first
|
||
post, err := s.svc.CreatePost(s.ctx, "draft")
|
||
require.NoError(s.T(), err)
|
||
|
||
s.Run("English Content", func() {
|
||
content, err := s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), content)
|
||
assert.Equal(s.T(), "en", content.LanguageCode.String())
|
||
assert.Equal(s.T(), "Test Post", content.Title)
|
||
assert.Equal(s.T(), "# Test Content", content.ContentMarkdown)
|
||
assert.Equal(s.T(), "Test Summary", content.Summary)
|
||
assert.Equal(s.T(), "test,post", content.MetaKeywords)
|
||
assert.Equal(s.T(), "Test Description", content.MetaDescription)
|
||
assert.Equal(s.T(), "test-post", content.Slug)
|
||
})
|
||
|
||
s.Run("Simplified Chinese Content", func() {
|
||
content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), content)
|
||
assert.Equal(s.T(), "zh-Hans", content.LanguageCode.String())
|
||
assert.Equal(s.T(), "测试帖子", content.Title)
|
||
assert.Equal(s.T(), "# 测试内容", content.ContentMarkdown)
|
||
assert.Equal(s.T(), "测试摘要", content.Summary)
|
||
assert.Equal(s.T(), "测试,帖子", content.MetaKeywords)
|
||
assert.Equal(s.T(), "测试描述", content.MetaDescription)
|
||
assert.Equal(s.T(), "测试帖子", content.Slug)
|
||
})
|
||
|
||
s.Run("Traditional Chinese Content", func() {
|
||
content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hant", "測試貼文", "# 測試內容", "測試摘要", "測試,貼文", "測試描述")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), content)
|
||
assert.Equal(s.T(), "zh-Hant", content.LanguageCode.String())
|
||
assert.Equal(s.T(), "測試貼文", content.Title)
|
||
assert.Equal(s.T(), "# 測試內容", content.ContentMarkdown)
|
||
assert.Equal(s.T(), "測試摘要", content.Summary)
|
||
assert.Equal(s.T(), "測試,貼文", content.MetaKeywords)
|
||
assert.Equal(s.T(), "測試描述", content.MetaDescription)
|
||
assert.Equal(s.T(), "測試貼文", content.Slug)
|
||
})
|
||
|
||
s.Run("Invalid Language Code", func() {
|
||
content, err := s.svc.AddPostContent(s.ctx, post.ID, "fr", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), content)
|
||
})
|
||
|
||
s.Run("Non-existent Post", func() {
|
||
content, err := s.svc.AddPostContent(s.ctx, 999999, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), content)
|
||
})
|
||
})
|
||
|
||
s.Run("Get Post By Slug", func() {
|
||
// Create a post first
|
||
post, err := s.svc.CreatePost(s.ctx, "published")
|
||
require.NoError(s.T(), err)
|
||
|
||
// Add content in different languages
|
||
_, err = s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
|
||
require.NoError(s.T(), err)
|
||
_, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述")
|
||
require.NoError(s.T(), err)
|
||
|
||
s.Run("Get Post By Slug - English", func() {
|
||
result, err := s.svc.GetPostBySlug(s.ctx, "en", "test-post")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), result)
|
||
assert.Equal(s.T(), post.ID, result.ID)
|
||
assert.Equal(s.T(), "published", result.Status.String())
|
||
|
||
contents := result.Edges.Contents
|
||
require.Len(s.T(), contents, 1)
|
||
assert.Equal(s.T(), "en", contents[0].LanguageCode.String())
|
||
assert.Equal(s.T(), "Test Post", contents[0].Title)
|
||
})
|
||
|
||
s.Run("Get Post By Slug - Chinese", func() {
|
||
result, err := s.svc.GetPostBySlug(s.ctx, "zh-Hans", "测试帖子")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), result)
|
||
assert.Equal(s.T(), post.ID, result.ID)
|
||
assert.Equal(s.T(), "published", result.Status.String())
|
||
|
||
contents := result.Edges.Contents
|
||
require.Len(s.T(), contents, 1)
|
||
assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String())
|
||
assert.Equal(s.T(), "测试帖子", contents[0].Title)
|
||
})
|
||
|
||
s.Run("Non-existent Post", func() {
|
||
result, err := s.svc.GetPostBySlug(s.ctx, "en", "non-existent")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), result)
|
||
})
|
||
|
||
s.Run("Invalid Language Code", func() {
|
||
result, err := s.svc.GetPostBySlug(s.ctx, "fr", "test-post")
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), result)
|
||
})
|
||
})
|
||
|
||
s.Run("List Posts", func() {
|
||
// Create some posts with content
|
||
for i := 0; i < 5; i++ {
|
||
post, err := s.svc.CreatePost(s.ctx, "published")
|
||
require.NoError(s.T(), err)
|
||
|
||
// Add content in different languages
|
||
_, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Post %d", i), "# Content", "Summary", "test", "Description")
|
||
require.NoError(s.T(), err)
|
||
_, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", fmt.Sprintf("帖子 %d", i), "# 内容", "摘要", "测试", "描述")
|
||
require.NoError(s.T(), err)
|
||
}
|
||
|
||
s.Run("List All Posts - English", func() {
|
||
posts, err := s.svc.ListPosts(s.ctx, "en", nil, 10, 0)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 5)
|
||
|
||
// Check that all posts have English content
|
||
for _, post := range posts {
|
||
contents := post.Edges.Contents
|
||
require.Len(s.T(), contents, 1)
|
||
assert.Equal(s.T(), "en", contents[0].LanguageCode.String())
|
||
}
|
||
})
|
||
|
||
s.Run("List All Posts - Chinese", func() {
|
||
posts, err := s.svc.ListPosts(s.ctx, "zh-Hans", nil, 10, 0)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 5)
|
||
|
||
// Check that all posts have Chinese content
|
||
for _, post := range posts {
|
||
contents := post.Edges.Contents
|
||
require.Len(s.T(), contents, 1)
|
||
assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String())
|
||
}
|
||
})
|
||
|
||
s.Run("List Posts with Pagination", func() {
|
||
// Get first page
|
||
posts, err := s.svc.ListPosts(s.ctx, "en", nil, 2, 0)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 2)
|
||
|
||
// Get second page
|
||
posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 2)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 2)
|
||
|
||
// Get last page
|
||
posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 4)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 1)
|
||
})
|
||
|
||
s.Run("List Posts by Category", func() {
|
||
// Create a category
|
||
category, err := s.svc.CreateCategory(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
|
||
// Create posts in this category
|
||
for i := 0; i < 3; i++ {
|
||
post, err := s.svc.CreatePost(s.ctx, "published")
|
||
require.NoError(s.T(), err)
|
||
|
||
// Set category
|
||
_, err = s.client.Post.UpdateOne(post).SetCategoryID(category.ID).Save(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
|
||
// Add content
|
||
_, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Category Post %d", i), "# Content", "Summary", "test", "Description")
|
||
require.NoError(s.T(), err)
|
||
}
|
||
|
||
// List posts in this category
|
||
posts, err := s.svc.ListPosts(s.ctx, "en", &category.ID, 10, 0)
|
||
require.NoError(s.T(), err)
|
||
require.Len(s.T(), posts, 3)
|
||
|
||
// Check that all posts belong to the category
|
||
for _, post := range posts {
|
||
assert.Equal(s.T(), category.ID, post.Edges.Category.ID)
|
||
}
|
||
})
|
||
|
||
s.Run("Invalid Language Code", func() {
|
||
posts, err := s.svc.ListPosts(s.ctx, "fr", nil, 10, 0)
|
||
assert.Error(s.T(), err)
|
||
assert.Nil(s.T(), posts)
|
||
})
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestMedia() {
|
||
s.Run("Upload Media", func() {
|
||
// Create a user first
|
||
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), user)
|
||
|
||
// Mock file content
|
||
fileContent := []byte("test file content")
|
||
|
||
// Mock the file header
|
||
fileHeader := &multipart.FileHeader{
|
||
Filename: "test.jpg",
|
||
Size: int64(len(fileContent)),
|
||
Header: textproto.MIMEHeader{
|
||
"Content-Type": []string{"image/jpeg"},
|
||
},
|
||
}
|
||
|
||
// Mock the storage behavior
|
||
s.storage.EXPECT().
|
||
Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()).
|
||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||
// Verify the reader content
|
||
data, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if !bytes.Equal(data, fileContent) {
|
||
return nil, fmt.Errorf("unexpected file content")
|
||
}
|
||
return &storage.FileInfo{
|
||
ID: "test123",
|
||
Name: name,
|
||
Size: int64(len(fileContent)),
|
||
ContentType: contentType,
|
||
URL: "http://example.com/test.jpg",
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}, nil
|
||
}).Times(1)
|
||
|
||
// Replace the Open method
|
||
openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
|
||
return &mockMultipartFile{bytes.NewReader(fileContent)}, nil
|
||
}
|
||
|
||
// Test upload
|
||
media, err := s.svc.Upload(s.ctx, fileHeader, user.ID)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), media)
|
||
assert.Equal(s.T(), "test123", media.StorageID)
|
||
assert.Equal(s.T(), "test.jpg", media.OriginalName)
|
||
assert.Equal(s.T(), int64(len(fileContent)), media.Size)
|
||
assert.Equal(s.T(), "image/jpeg", media.MimeType)
|
||
assert.Equal(s.T(), "http://example.com/test.jpg", media.URL)
|
||
assert.Equal(s.T(), strconv.Itoa(user.ID), media.CreatedBy)
|
||
|
||
// Now we can test other operations since we have a media record
|
||
s.Run("Get Media", func() {
|
||
result, err := s.svc.GetMedia(s.ctx, media.ID)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), result)
|
||
assert.Equal(s.T(), media.ID, result.ID)
|
||
assert.Equal(s.T(), media.StorageID, result.StorageID)
|
||
assert.Equal(s.T(), media.URL, result.URL)
|
||
})
|
||
|
||
s.Run("Get File", func() {
|
||
// Mock the storage behavior
|
||
mockReader := io.NopCloser(strings.NewReader("test content"))
|
||
mockFileInfo := &storage.FileInfo{
|
||
ID: media.StorageID,
|
||
Name: media.OriginalName,
|
||
Size: media.Size,
|
||
ContentType: media.MimeType,
|
||
URL: media.URL,
|
||
CreatedAt: media.CreatedAt,
|
||
UpdatedAt: media.UpdatedAt,
|
||
}
|
||
s.storage.EXPECT().
|
||
Get(gomock.Any(), media.StorageID).
|
||
Return(mockReader, mockFileInfo, nil)
|
||
|
||
// Test get file
|
||
reader, fileInfo, err := s.svc.GetFile(s.ctx, media.ID)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), reader)
|
||
require.NotNil(s.T(), fileInfo)
|
||
assert.Equal(s.T(), media.OriginalName, fileInfo.Name)
|
||
assert.Equal(s.T(), media.Size, fileInfo.Size)
|
||
assert.Equal(s.T(), media.MimeType, fileInfo.ContentType)
|
||
assert.Equal(s.T(), media.URL, fileInfo.URL)
|
||
|
||
// Clean up
|
||
reader.Close()
|
||
})
|
||
|
||
s.Run("List Media", func() {
|
||
// Test list media
|
||
list, err := s.svc.ListMedia(s.ctx, 10, 0)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), list)
|
||
require.Len(s.T(), list, 1)
|
||
assert.Equal(s.T(), "test.jpg", list[0].OriginalName)
|
||
})
|
||
|
||
s.Run("Delete Media", func() {
|
||
// Mock the storage behavior
|
||
s.storage.EXPECT().
|
||
Delete(gomock.Any(), media.StorageID).
|
||
Return(nil)
|
||
|
||
// Test delete media
|
||
err = s.svc.DeleteMedia(s.ctx, media.ID, user.ID)
|
||
require.NoError(s.T(), err)
|
||
|
||
// Verify media is deleted
|
||
count, err := s.client.Media.Query().Count(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
assert.Equal(s.T(), 0, count)
|
||
})
|
||
})
|
||
|
||
s.Run("Delete Media - Unauthorized", func() {
|
||
// Create a user
|
||
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "")
|
||
require.NoError(s.T(), err)
|
||
|
||
// Mock file content
|
||
fileContent := []byte("test file content")
|
||
|
||
// Mock the file header
|
||
fileHeader := &multipart.FileHeader{
|
||
Filename: "test2.jpg",
|
||
Size: int64(len(fileContent)),
|
||
Header: textproto.MIMEHeader{
|
||
"Content-Type": []string{"image/jpeg"},
|
||
},
|
||
}
|
||
|
||
// Mock the storage behavior
|
||
s.storage.EXPECT().
|
||
Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()).
|
||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||
// Verify the reader content
|
||
data, err := io.ReadAll(reader)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if !bytes.Equal(data, fileContent) {
|
||
return nil, fmt.Errorf("unexpected file content")
|
||
}
|
||
return &storage.FileInfo{
|
||
ID: "test456",
|
||
Name: name,
|
||
Size: int64(len(fileContent)),
|
||
ContentType: contentType,
|
||
URL: "http://example.com/test2.jpg",
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}, nil
|
||
}).Times(1)
|
||
|
||
// Replace the Open method
|
||
openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
|
||
return &mockMultipartFile{bytes.NewReader(fileContent)}, nil
|
||
}
|
||
|
||
media, err := s.svc.Upload(s.ctx, fileHeader, user.ID)
|
||
require.NoError(s.T(), err)
|
||
|
||
// Try to delete with different user
|
||
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "")
|
||
require.NoError(s.T(), err)
|
||
|
||
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)
|
||
assert.Equal(s.T(), ErrUnauthorized, err)
|
||
|
||
// Verify media is not deleted
|
||
count, err := s.client.Media.Query().Count(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
assert.Equal(s.T(), 1, count)
|
||
})
|
||
}
|
||
|
||
func (s *ServiceImplTestSuite) TestContributor() {
|
||
// 测试创建贡献者
|
||
avatarURL := "https://example.com/avatar.jpg"
|
||
bio := "Test bio"
|
||
contributor, err := s.svc.CreateContributor(s.ctx, "Test Contributor", &avatarURL, &bio)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), contributor)
|
||
assert.Equal(s.T(), "Test Contributor", contributor.Name)
|
||
assert.Equal(s.T(), avatarURL, contributor.AvatarURL)
|
||
assert.Equal(s.T(), bio, contributor.Bio)
|
||
|
||
// 测试添加社交链接
|
||
link, err := s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "github", "GitHub", "https://github.com/test")
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), link)
|
||
assert.Equal(s.T(), "github", link.Type.String())
|
||
assert.Equal(s.T(), "GitHub", link.Name)
|
||
assert.Equal(s.T(), "https://github.com/test", link.Value)
|
||
|
||
// 测试获取贡献者
|
||
fetchedContributor, err := s.svc.GetContributorByID(s.ctx, contributor.ID)
|
||
require.NoError(s.T(), err)
|
||
require.NotNil(s.T(), fetchedContributor)
|
||
assert.Equal(s.T(), contributor.ID, fetchedContributor.ID)
|
||
assert.Equal(s.T(), contributor.Name, fetchedContributor.Name)
|
||
assert.Equal(s.T(), contributor.AvatarURL, fetchedContributor.AvatarURL)
|
||
assert.Equal(s.T(), contributor.Bio, fetchedContributor.Bio)
|
||
require.Len(s.T(), fetchedContributor.Edges.SocialLinks, 1)
|
||
assert.Equal(s.T(), link.ID, fetchedContributor.Edges.SocialLinks[0].ID)
|
||
|
||
// 测试列出贡献者
|
||
contributors, err := s.svc.ListContributors(s.ctx)
|
||
require.NoError(s.T(), err)
|
||
require.NotEmpty(s.T(), contributors)
|
||
assert.Equal(s.T(), contributor.ID, contributors[0].ID)
|
||
require.Len(s.T(), contributors[0].Edges.SocialLinks, 1)
|
||
|
||
// 测试错误情况
|
||
_, err = s.svc.GetContributorByID(s.ctx, -1)
|
||
assert.Error(s.T(), err)
|
||
|
||
_, err = s.svc.AddContributorSocialLink(s.ctx, -1, "github", "GitHub", "https://github.com/test")
|
||
assert.Error(s.T(), err)
|
||
|
||
// 测试无效的社交链接类型
|
||
_, err = s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "invalid_type", "Invalid", "https://example.com")
|
||
assert.Error(s.T(), err)
|
||
}
|
||
|
||
func TestServiceSuite(t *testing.T) {
|
||
suite.Run(t, new(ServiceSuite))
|
||
}
|
||
|
||
type ServiceSuite struct {
|
||
suite.Suite
|
||
}
|
||
|
||
func TestServiceInterface(t *testing.T) {
|
||
var _ Service = (*serviceImpl)(nil)
|
||
}
|
||
|
||
// 创建测试分类的辅助函数
|
||
func (s *ServiceImplTestSuite) createTestCategory(ctx context.Context, slug string) *ent.Category {
|
||
category, err := s.svc.CreateCategory(ctx)
|
||
s.Require().NoError(err)
|
||
return category
|
||
}
|