diff --git a/api/schemas/components/schemas.yaml b/api/schemas/components/schemas.yaml
index 056ef91..3812e98 100644
--- a/api/schemas/components/schemas.yaml
+++ b/api/schemas/components/schemas.yaml
@@ -103,6 +103,7 @@ Post:
- slug
- status
- contents
+ - categories
properties:
id:
type: integer
@@ -117,6 +118,11 @@ Post:
type: array
items:
$ref: '#/PostContent'
+ categories:
+ type: array
+ items:
+ $ref: '#/Category'
+ description: 文章所属分类
created_at:
type: string
format: date-time
diff --git a/backend/Dockerfile b/backend/Dockerfile
index c89062d..3a1eed0 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -14,9 +14,7 @@ RUN adduser -u 1000 -D tss-rocks
USER tss-rocks
WORKDIR /app
-# 复制二进制文件和配置
COPY --from=builder /app/tss-rocks-be .
-COPY --from=builder /app/config/config.yaml ./config/
EXPOSE 8080
ENV GIN_MODE=release
diff --git a/backend/config/config.yaml.example b/backend/config/config.yaml.example
index 220ace8..894b281 100644
--- a/backend/config/config.yaml.example
+++ b/backend/config/config.yaml.example
@@ -21,10 +21,50 @@ auth:
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
storage:
- driver: local
+ type: local # local or s3
local:
- root: storage
- base_url: http://localhost:8080/storage
+ root_dir: "./storage/media"
+ s3:
+ region: "us-east-1"
+ bucket: "your-bucket-name"
+ access_key_id: "your-access-key-id"
+ secret_access_key: "your-secret-access-key"
+ endpoint: "" # Optional, for MinIO or other S3-compatible services
+ custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
+ proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
+ upload:
+ limits:
+ image:
+ max_size: 10 # MB
+ allowed_types:
+ - image/jpeg
+ - image/png
+ - image/gif
+ - image/webp
+ - image/svg+xml
+ video:
+ max_size: 500 # MB
+ allowed_types:
+ - video/mp4
+ - video/webm
+ audio:
+ max_size: 50 # MB
+ allowed_types:
+ - audio/mpeg
+ - audio/ogg
+ - audio/wav
+ document:
+ max_size: 20 # MB
+ allowed_types:
+ - application/pdf
+ - application/msword
+ - application/vnd.openxmlformats-officedocument.wordprocessingml.document
+ - application/vnd.ms-excel
+ - application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
+ - application/zip
+ - application/x-rar-compressed
+ - text/plain
+ - text/csv
logging:
level: debug
diff --git a/backend/ent/category/category.go b/backend/ent/category/category.go
index 80073b1..529079d 100644
--- a/backend/ent/category/category.go
+++ b/backend/ent/category/category.go
@@ -33,13 +33,11 @@ const (
ContentsInverseTable = "category_contents"
// ContentsColumn is the table column denoting the contents relation/edge.
ContentsColumn = "category_contents"
- // PostsTable is the table that holds the posts relation/edge.
- PostsTable = "posts"
+ // PostsTable is the table that holds the posts relation/edge. The primary key declared below.
+ PostsTable = "category_posts"
// PostsInverseTable is the table name for the Post entity.
// It exists in this package in order to avoid circular dependency with the "post" package.
PostsInverseTable = "posts"
- // PostsColumn is the table column denoting the posts relation/edge.
- PostsColumn = "category_posts"
// DailyItemsTable is the table that holds the daily_items relation/edge.
DailyItemsTable = "dailies"
// DailyItemsInverseTable is the table name for the Daily entity.
@@ -56,6 +54,12 @@ var Columns = []string{
FieldUpdatedAt,
}
+var (
+ // PostsPrimaryKey and PostsColumn2 are the table columns denoting the
+ // primary key for the posts relation (M2M).
+ PostsPrimaryKey = []string{"category_id", "post_id"}
+)
+
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
@@ -145,7 +149,7 @@ func newPostsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PostsInverseTable, FieldID),
- sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
+ sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
)
}
func newDailyItemsStep() *sqlgraph.Step {
diff --git a/backend/ent/category/where.go b/backend/ent/category/where.go
index ea42ba5..223d752 100644
--- a/backend/ent/category/where.go
+++ b/backend/ent/category/where.go
@@ -173,7 +173,7 @@ func HasPosts() predicate.Category {
return predicate.Category(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
- sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
+ sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
)
sqlgraph.HasNeighbors(s, step)
})
diff --git a/backend/ent/category_create.go b/backend/ent/category_create.go
index f479883..b78ea04 100644
--- a/backend/ent/category_create.go
+++ b/backend/ent/category_create.go
@@ -201,10 +201,10 @@ func (cc *CategoryCreate) createSpec() (*Category, *sqlgraph.CreateSpec) {
}
if nodes := cc.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
diff --git a/backend/ent/category_query.go b/backend/ent/category_query.go
index 26c7b61..83dadcc 100644
--- a/backend/ent/category_query.go
+++ b/backend/ent/category_query.go
@@ -101,7 +101,7 @@ func (cq *CategoryQuery) QueryPosts() *PostQuery {
step := sqlgraph.NewStep(
sqlgraph.From(category.Table, category.FieldID, selector),
sqlgraph.To(post.Table, post.FieldID),
- sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
+ sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
)
fromU = sqlgraph.SetNeighbors(cq.driver.Dialect(), step)
return fromU, nil
@@ -523,33 +523,63 @@ func (cq *CategoryQuery) loadContents(ctx context.Context, query *CategoryConten
return nil
}
func (cq *CategoryQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*Category, init func(*Category), assign func(*Category, *Post)) error {
- fks := make([]driver.Value, 0, len(nodes))
- nodeids := make(map[int]*Category)
- for i := range nodes {
- fks = append(fks, nodes[i].ID)
- nodeids[nodes[i].ID] = nodes[i]
+ edgeIDs := make([]driver.Value, len(nodes))
+ byID := make(map[int]*Category)
+ nids := make(map[int]map[*Category]struct{})
+ for i, node := range nodes {
+ edgeIDs[i] = node.ID
+ byID[node.ID] = node
if init != nil {
- init(nodes[i])
+ init(node)
}
}
- query.withFKs = true
- query.Where(predicate.Post(func(s *sql.Selector) {
- s.Where(sql.InValues(s.C(category.PostsColumn), fks...))
- }))
- neighbors, err := query.All(ctx)
+ query.Where(func(s *sql.Selector) {
+ joinT := sql.Table(category.PostsTable)
+ s.Join(joinT).On(s.C(post.FieldID), joinT.C(category.PostsPrimaryKey[1]))
+ s.Where(sql.InValues(joinT.C(category.PostsPrimaryKey[0]), edgeIDs...))
+ columns := s.SelectedColumns()
+ s.Select(joinT.C(category.PostsPrimaryKey[0]))
+ s.AppendSelect(columns...)
+ s.SetDistinct(false)
+ })
+ if err := query.prepareQuery(ctx); err != nil {
+ return err
+ }
+ qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
+ return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
+ assign := spec.Assign
+ values := spec.ScanValues
+ spec.ScanValues = func(columns []string) ([]any, error) {
+ values, err := values(columns[1:])
+ if err != nil {
+ return nil, err
+ }
+ return append([]any{new(sql.NullInt64)}, values...), nil
+ }
+ spec.Assign = func(columns []string, values []any) error {
+ outValue := int(values[0].(*sql.NullInt64).Int64)
+ inValue := int(values[1].(*sql.NullInt64).Int64)
+ if nids[inValue] == nil {
+ nids[inValue] = map[*Category]struct{}{byID[outValue]: {}}
+ return assign(columns[1:], values[1:])
+ }
+ nids[inValue][byID[outValue]] = struct{}{}
+ return nil
+ }
+ })
+ })
+ neighbors, err := withInterceptors[[]*Post](ctx, query, qr, query.inters)
if err != nil {
return err
}
for _, n := range neighbors {
- fk := n.category_posts
- if fk == nil {
- return fmt.Errorf(`foreign-key "category_posts" is nil for node %v`, n.ID)
- }
- node, ok := nodeids[*fk]
+ nodes, ok := nids[n.ID]
if !ok {
- return fmt.Errorf(`unexpected referenced foreign-key "category_posts" returned %v for node %v`, *fk, n.ID)
+ return fmt.Errorf(`unexpected "posts" node returned %v`, n.ID)
+ }
+ for kn := range nodes {
+ assign(kn, n)
}
- assign(node, n)
}
return nil
}
diff --git a/backend/ent/category_update.go b/backend/ent/category_update.go
index 628a7f2..57c3012 100644
--- a/backend/ent/category_update.go
+++ b/backend/ent/category_update.go
@@ -262,10 +262,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if cu.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@@ -275,10 +275,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if nodes := cu.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cu.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@@ -291,10 +291,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if nodes := cu.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@@ -631,10 +631,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if cuo.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@@ -644,10 +644,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if nodes := cuo.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cuo.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@@ -660,10 +660,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if nodes := cuo.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.O2M,
+ Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
- Columns: []string{category.PostsColumn},
+ Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
diff --git a/backend/ent/client.go b/backend/ent/client.go
index bbcf6f3..2e7b243 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -464,7 +464,7 @@ func (c *CategoryClient) QueryPosts(ca *Category) *PostQuery {
step := sqlgraph.NewStep(
sqlgraph.From(category.Table, category.FieldID, id),
sqlgraph.To(post.Table, post.FieldID),
- sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
+ sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
)
fromV = sqlgraph.Neighbors(ca.driver.Dialect(), step)
return fromV, nil
@@ -2207,15 +2207,15 @@ func (c *PostClient) QueryContributors(po *Post) *PostContributorQuery {
return query
}
-// QueryCategory queries the category edge of a Post.
-func (c *PostClient) QueryCategory(po *Post) *CategoryQuery {
+// QueryCategories queries the categories edge of a Post.
+func (c *PostClient) QueryCategories(po *Post) *CategoryQuery {
query := (&CategoryClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := po.ID
step := sqlgraph.NewStep(
sqlgraph.From(post.Table, post.FieldID, id),
sqlgraph.To(category.Table, category.FieldID),
- sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
+ sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
)
fromV = sqlgraph.Neighbors(po.driver.Dialect(), step)
return fromV, nil
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 562eed3..614e3af 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -265,21 +265,12 @@ var (
{Name: "slug", Type: field.TypeString, Unique: true},
{Name: "created_at", Type: field.TypeTime},
{Name: "updated_at", Type: field.TypeTime},
- {Name: "category_posts", Type: field.TypeInt, Nullable: true},
}
// PostsTable holds the schema information for the "posts" table.
PostsTable = &schema.Table{
Name: "posts",
Columns: PostsColumns,
PrimaryKey: []*schema.Column{PostsColumns[0]},
- ForeignKeys: []*schema.ForeignKey{
- {
- Symbol: "posts_categories_posts",
- Columns: []*schema.Column{PostsColumns[5]},
- RefColumns: []*schema.Column{CategoriesColumns[0]},
- OnDelete: schema.SetNull,
- },
- },
}
// PostContentsColumns holds the columns for the "post_contents" table.
PostContentsColumns = []*schema.Column{
@@ -387,6 +378,31 @@ var (
Columns: UsersColumns,
PrimaryKey: []*schema.Column{UsersColumns[0]},
}
+ // CategoryPostsColumns holds the columns for the "category_posts" table.
+ CategoryPostsColumns = []*schema.Column{
+ {Name: "category_id", Type: field.TypeInt},
+ {Name: "post_id", Type: field.TypeInt},
+ }
+ // CategoryPostsTable holds the schema information for the "category_posts" table.
+ CategoryPostsTable = &schema.Table{
+ Name: "category_posts",
+ Columns: CategoryPostsColumns,
+ PrimaryKey: []*schema.Column{CategoryPostsColumns[0], CategoryPostsColumns[1]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "category_posts_category_id",
+ Columns: []*schema.Column{CategoryPostsColumns[0]},
+ RefColumns: []*schema.Column{CategoriesColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ {
+ Symbol: "category_posts_post_id",
+ Columns: []*schema.Column{CategoryPostsColumns[1]},
+ RefColumns: []*schema.Column{PostsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ }
// RolePermissionsColumns holds the columns for the "role_permissions" table.
RolePermissionsColumns = []*schema.Column{
{Name: "role_id", Type: field.TypeInt},
@@ -455,6 +471,7 @@ var (
PostContributorsTable,
RolesTable,
UsersTable,
+ CategoryPostsTable,
RolePermissionsTable,
UserRolesTable,
}
@@ -469,11 +486,12 @@ func init() {
DailyCategoryContentsTable.ForeignKeys[0].RefTable = DailyCategoriesTable
DailyContentsTable.ForeignKeys[0].RefTable = DailiesTable
MediaTable.ForeignKeys[0].RefTable = UsersTable
- PostsTable.ForeignKeys[0].RefTable = CategoriesTable
PostContentsTable.ForeignKeys[0].RefTable = PostsTable
PostContributorsTable.ForeignKeys[0].RefTable = ContributorsTable
PostContributorsTable.ForeignKeys[1].RefTable = ContributorRolesTable
PostContributorsTable.ForeignKeys[2].RefTable = PostsTable
+ CategoryPostsTable.ForeignKeys[0].RefTable = CategoriesTable
+ CategoryPostsTable.ForeignKeys[1].RefTable = PostsTable
RolePermissionsTable.ForeignKeys[0].RefTable = RolesTable
RolePermissionsTable.ForeignKeys[1].RefTable = PermissionsTable
UserRolesTable.ForeignKeys[0].RefTable = UsersTable
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index b8ccb5b..f15bc73 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -6578,8 +6578,9 @@ type PostMutation struct {
contributors map[int]struct{}
removedcontributors map[int]struct{}
clearedcontributors bool
- category *int
- clearedcategory bool
+ categories map[int]struct{}
+ removedcategories map[int]struct{}
+ clearedcategories bool
done bool
oldValue func(context.Context) (*Post, error)
predicates []predicate.Post
@@ -6935,43 +6936,58 @@ func (m *PostMutation) ResetContributors() {
m.removedcontributors = nil
}
-// SetCategoryID sets the "category" edge to the Category entity by id.
-func (m *PostMutation) SetCategoryID(id int) {
- m.category = &id
+// AddCategoryIDs adds the "categories" edge to the Category entity by ids.
+func (m *PostMutation) AddCategoryIDs(ids ...int) {
+ if m.categories == nil {
+ m.categories = make(map[int]struct{})
+ }
+ for i := range ids {
+ m.categories[ids[i]] = struct{}{}
+ }
}
-// ClearCategory clears the "category" edge to the Category entity.
-func (m *PostMutation) ClearCategory() {
- m.clearedcategory = true
+// ClearCategories clears the "categories" edge to the Category entity.
+func (m *PostMutation) ClearCategories() {
+ m.clearedcategories = true
}
-// CategoryCleared reports if the "category" edge to the Category entity was cleared.
-func (m *PostMutation) CategoryCleared() bool {
- return m.clearedcategory
+// CategoriesCleared reports if the "categories" edge to the Category entity was cleared.
+func (m *PostMutation) CategoriesCleared() bool {
+ return m.clearedcategories
}
-// CategoryID returns the "category" edge ID in the mutation.
-func (m *PostMutation) CategoryID() (id int, exists bool) {
- if m.category != nil {
- return *m.category, true
+// RemoveCategoryIDs removes the "categories" edge to the Category entity by IDs.
+func (m *PostMutation) RemoveCategoryIDs(ids ...int) {
+ if m.removedcategories == nil {
+ m.removedcategories = make(map[int]struct{})
+ }
+ for i := range ids {
+ delete(m.categories, ids[i])
+ m.removedcategories[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedCategories returns the removed IDs of the "categories" edge to the Category entity.
+func (m *PostMutation) RemovedCategoriesIDs() (ids []int) {
+ for id := range m.removedcategories {
+ ids = append(ids, id)
}
return
}
-// CategoryIDs returns the "category" edge IDs in the mutation.
-// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
-// CategoryID instead. It exists only for internal usage by the builders.
-func (m *PostMutation) CategoryIDs() (ids []int) {
- if id := m.category; id != nil {
- ids = append(ids, *id)
+// CategoriesIDs returns the "categories" edge IDs in the mutation.
+func (m *PostMutation) CategoriesIDs() (ids []int) {
+ for id := range m.categories {
+ ids = append(ids, id)
}
return
}
-// ResetCategory resets all changes to the "category" edge.
-func (m *PostMutation) ResetCategory() {
- m.category = nil
- m.clearedcategory = false
+// ResetCategories resets all changes to the "categories" edge.
+func (m *PostMutation) ResetCategories() {
+ m.categories = nil
+ m.clearedcategories = false
+ m.removedcategories = nil
}
// Where appends a list predicates to the PostMutation builder.
@@ -7165,8 +7181,8 @@ func (m *PostMutation) AddedEdges() []string {
if m.contributors != nil {
edges = append(edges, post.EdgeContributors)
}
- if m.category != nil {
- edges = append(edges, post.EdgeCategory)
+ if m.categories != nil {
+ edges = append(edges, post.EdgeCategories)
}
return edges
}
@@ -7187,10 +7203,12 @@ func (m *PostMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
- case post.EdgeCategory:
- if id := m.category; id != nil {
- return []ent.Value{*id}
+ case post.EdgeCategories:
+ ids := make([]ent.Value, 0, len(m.categories))
+ for id := range m.categories {
+ ids = append(ids, id)
}
+ return ids
}
return nil
}
@@ -7204,6 +7222,9 @@ func (m *PostMutation) RemovedEdges() []string {
if m.removedcontributors != nil {
edges = append(edges, post.EdgeContributors)
}
+ if m.removedcategories != nil {
+ edges = append(edges, post.EdgeCategories)
+ }
return edges
}
@@ -7223,6 +7244,12 @@ func (m *PostMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case post.EdgeCategories:
+ ids := make([]ent.Value, 0, len(m.removedcategories))
+ for id := range m.removedcategories {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
@@ -7236,8 +7263,8 @@ func (m *PostMutation) ClearedEdges() []string {
if m.clearedcontributors {
edges = append(edges, post.EdgeContributors)
}
- if m.clearedcategory {
- edges = append(edges, post.EdgeCategory)
+ if m.clearedcategories {
+ edges = append(edges, post.EdgeCategories)
}
return edges
}
@@ -7250,8 +7277,8 @@ func (m *PostMutation) EdgeCleared(name string) bool {
return m.clearedcontents
case post.EdgeContributors:
return m.clearedcontributors
- case post.EdgeCategory:
- return m.clearedcategory
+ case post.EdgeCategories:
+ return m.clearedcategories
}
return false
}
@@ -7260,9 +7287,6 @@ func (m *PostMutation) EdgeCleared(name string) bool {
// if that edge is not defined in the schema.
func (m *PostMutation) ClearEdge(name string) error {
switch name {
- case post.EdgeCategory:
- m.ClearCategory()
- return nil
}
return fmt.Errorf("unknown Post unique edge %s", name)
}
@@ -7277,8 +7301,8 @@ func (m *PostMutation) ResetEdge(name string) error {
case post.EdgeContributors:
m.ResetContributors()
return nil
- case post.EdgeCategory:
- m.ResetCategory()
+ case post.EdgeCategories:
+ m.ResetCategories()
return nil
}
return fmt.Errorf("unknown Post edge %s", name)
diff --git a/backend/ent/post.go b/backend/ent/post.go
index 0be8691..3e4529c 100644
--- a/backend/ent/post.go
+++ b/backend/ent/post.go
@@ -6,7 +6,6 @@ import (
"fmt"
"strings"
"time"
- "tss-rocks-be/ent/category"
"tss-rocks-be/ent/post"
"entgo.io/ent"
@@ -28,9 +27,8 @@ type Post struct {
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the PostQuery when eager-loading is set.
- Edges PostEdges `json:"edges"`
- category_posts *int
- selectValues sql.SelectValues
+ Edges PostEdges `json:"edges"`
+ selectValues sql.SelectValues
}
// PostEdges holds the relations/edges for other nodes in the graph.
@@ -39,8 +37,8 @@ type PostEdges struct {
Contents []*PostContent `json:"contents,omitempty"`
// Contributors holds the value of the contributors edge.
Contributors []*PostContributor `json:"contributors,omitempty"`
- // Category holds the value of the category edge.
- Category *Category `json:"category,omitempty"`
+ // Categories holds the value of the categories edge.
+ Categories []*Category `json:"categories,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [3]bool
@@ -64,15 +62,13 @@ func (e PostEdges) ContributorsOrErr() ([]*PostContributor, error) {
return nil, &NotLoadedError{edge: "contributors"}
}
-// CategoryOrErr returns the Category value or an error if the edge
-// was not loaded in eager-loading, or loaded but was not found.
-func (e PostEdges) CategoryOrErr() (*Category, error) {
- if e.Category != nil {
- return e.Category, nil
- } else if e.loadedTypes[2] {
- return nil, &NotFoundError{label: category.Label}
+// CategoriesOrErr returns the Categories value or an error if the edge
+// was not loaded in eager-loading.
+func (e PostEdges) CategoriesOrErr() ([]*Category, error) {
+ if e.loadedTypes[2] {
+ return e.Categories, nil
}
- return nil, &NotLoadedError{edge: "category"}
+ return nil, &NotLoadedError{edge: "categories"}
}
// scanValues returns the types for scanning values from sql.Rows.
@@ -86,8 +82,6 @@ func (*Post) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullString)
case post.FieldCreatedAt, post.FieldUpdatedAt:
values[i] = new(sql.NullTime)
- case post.ForeignKeys[0]: // category_posts
- values[i] = new(sql.NullInt64)
default:
values[i] = new(sql.UnknownType)
}
@@ -133,13 +127,6 @@ func (po *Post) assignValues(columns []string, values []any) error {
} else if value.Valid {
po.UpdatedAt = value.Time
}
- case post.ForeignKeys[0]:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for edge-field category_posts", value)
- } else if value.Valid {
- po.category_posts = new(int)
- *po.category_posts = int(value.Int64)
- }
default:
po.selectValues.Set(columns[i], values[i])
}
@@ -163,9 +150,9 @@ func (po *Post) QueryContributors() *PostContributorQuery {
return NewPostClient(po.config).QueryContributors(po)
}
-// QueryCategory queries the "category" edge of the Post entity.
-func (po *Post) QueryCategory() *CategoryQuery {
- return NewPostClient(po.config).QueryCategory(po)
+// QueryCategories queries the "categories" edge of the Post entity.
+func (po *Post) QueryCategories() *CategoryQuery {
+ return NewPostClient(po.config).QueryCategories(po)
}
// Update returns a builder for updating this Post.
diff --git a/backend/ent/post/post.go b/backend/ent/post/post.go
index 6dc1e83..b72d3e0 100644
--- a/backend/ent/post/post.go
+++ b/backend/ent/post/post.go
@@ -27,8 +27,8 @@ const (
EdgeContents = "contents"
// EdgeContributors holds the string denoting the contributors edge name in mutations.
EdgeContributors = "contributors"
- // EdgeCategory holds the string denoting the category edge name in mutations.
- EdgeCategory = "category"
+ // EdgeCategories holds the string denoting the categories edge name in mutations.
+ EdgeCategories = "categories"
// Table holds the table name of the post in the database.
Table = "posts"
// ContentsTable is the table that holds the contents relation/edge.
@@ -45,13 +45,11 @@ const (
ContributorsInverseTable = "post_contributors"
// ContributorsColumn is the table column denoting the contributors relation/edge.
ContributorsColumn = "post_contributors"
- // CategoryTable is the table that holds the category relation/edge.
- CategoryTable = "posts"
- // CategoryInverseTable is the table name for the Category entity.
+ // CategoriesTable is the table that holds the categories relation/edge. The primary key declared below.
+ CategoriesTable = "category_posts"
+ // CategoriesInverseTable is the table name for the Category entity.
// It exists in this package in order to avoid circular dependency with the "category" package.
- CategoryInverseTable = "categories"
- // CategoryColumn is the table column denoting the category relation/edge.
- CategoryColumn = "category_posts"
+ CategoriesInverseTable = "categories"
)
// Columns holds all SQL columns for post fields.
@@ -63,11 +61,11 @@ var Columns = []string{
FieldUpdatedAt,
}
-// ForeignKeys holds the SQL foreign-keys that are owned by the "posts"
-// table and are not defined as standalone fields in the schema.
-var ForeignKeys = []string{
- "category_posts",
-}
+var (
+ // CategoriesPrimaryKey and CategoriesColumn2 are the table columns denoting the
+ // primary key for the categories relation (M2M).
+ CategoriesPrimaryKey = []string{"category_id", "post_id"}
+)
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
@@ -76,11 +74,6 @@ func ValidColumn(column string) bool {
return true
}
}
- for i := range ForeignKeys {
- if column == ForeignKeys[i] {
- return true
- }
- }
return false
}
@@ -178,10 +171,17 @@ func ByContributors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
-// ByCategoryField orders the results by category field.
-func ByCategoryField(field string, opts ...sql.OrderTermOption) OrderOption {
+// ByCategoriesCount orders the results by categories count.
+func ByCategoriesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
- sqlgraph.OrderByNeighborTerms(s, newCategoryStep(), sql.OrderByField(field, opts...))
+ sqlgraph.OrderByNeighborsCount(s, newCategoriesStep(), opts...)
+ }
+}
+
+// ByCategories orders the results by categories terms.
+func ByCategories(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newCategoriesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newContentsStep() *sqlgraph.Step {
@@ -198,10 +198,10 @@ func newContributorsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, ContributorsTable, ContributorsColumn),
)
}
-func newCategoryStep() *sqlgraph.Step {
+func newCategoriesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
- sqlgraph.To(CategoryInverseTable, FieldID),
- sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
+ sqlgraph.To(CategoriesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
)
}
diff --git a/backend/ent/post/where.go b/backend/ent/post/where.go
index 7b09ff9..4d9dfda 100644
--- a/backend/ent/post/where.go
+++ b/backend/ent/post/where.go
@@ -281,21 +281,21 @@ func HasContributorsWith(preds ...predicate.PostContributor) predicate.Post {
})
}
-// HasCategory applies the HasEdge predicate on the "category" edge.
-func HasCategory() predicate.Post {
+// HasCategories applies the HasEdge predicate on the "categories" edge.
+func HasCategories() predicate.Post {
return predicate.Post(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
- sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
+ sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
)
sqlgraph.HasNeighbors(s, step)
})
}
-// HasCategoryWith applies the HasEdge predicate on the "category" edge with a given conditions (other predicates).
-func HasCategoryWith(preds ...predicate.Category) predicate.Post {
+// HasCategoriesWith applies the HasEdge predicate on the "categories" edge with a given conditions (other predicates).
+func HasCategoriesWith(preds ...predicate.Category) predicate.Post {
return predicate.Post(func(s *sql.Selector) {
- step := newCategoryStep()
+ step := newCategoriesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
diff --git a/backend/ent/post_create.go b/backend/ent/post_create.go
index c001fe2..b9fb3c7 100644
--- a/backend/ent/post_create.go
+++ b/backend/ent/post_create.go
@@ -101,23 +101,19 @@ func (pc *PostCreate) AddContributors(p ...*PostContributor) *PostCreate {
return pc.AddContributorIDs(ids...)
}
-// SetCategoryID sets the "category" edge to the Category entity by ID.
-func (pc *PostCreate) SetCategoryID(id int) *PostCreate {
- pc.mutation.SetCategoryID(id)
+// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
+func (pc *PostCreate) AddCategoryIDs(ids ...int) *PostCreate {
+ pc.mutation.AddCategoryIDs(ids...)
return pc
}
-// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
-func (pc *PostCreate) SetNillableCategoryID(id *int) *PostCreate {
- if id != nil {
- pc = pc.SetCategoryID(*id)
+// AddCategories adds the "categories" edges to the Category entity.
+func (pc *PostCreate) AddCategories(c ...*Category) *PostCreate {
+ ids := make([]int, len(c))
+ for i := range c {
+ ids[i] = c[i].ID
}
- return pc
-}
-
-// SetCategory sets the "category" edge to the Category entity.
-func (pc *PostCreate) SetCategory(c *Category) *PostCreate {
- return pc.SetCategoryID(c.ID)
+ return pc.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@@ -267,12 +263,12 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
- if nodes := pc.mutation.CategoryIDs(); len(nodes) > 0 {
+ if nodes := pc.mutation.CategoriesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.M2O,
+ Rel: sqlgraph.M2M,
Inverse: true,
- Table: post.CategoryTable,
- Columns: []string{post.CategoryColumn},
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@@ -281,7 +277,6 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
- _node.category_posts = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
diff --git a/backend/ent/post_query.go b/backend/ent/post_query.go
index f90b5ff..c3fbb7a 100644
--- a/backend/ent/post_query.go
+++ b/backend/ent/post_query.go
@@ -28,8 +28,7 @@ type PostQuery struct {
predicates []predicate.Post
withContents *PostContentQuery
withContributors *PostContributorQuery
- withCategory *CategoryQuery
- withFKs bool
+ withCategories *CategoryQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -110,8 +109,8 @@ func (pq *PostQuery) QueryContributors() *PostContributorQuery {
return query
}
-// QueryCategory chains the current query on the "category" edge.
-func (pq *PostQuery) QueryCategory() *CategoryQuery {
+// QueryCategories chains the current query on the "categories" edge.
+func (pq *PostQuery) QueryCategories() *CategoryQuery {
query := (&CategoryClient{config: pq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := pq.prepareQuery(ctx); err != nil {
@@ -124,7 +123,7 @@ func (pq *PostQuery) QueryCategory() *CategoryQuery {
step := sqlgraph.NewStep(
sqlgraph.From(post.Table, post.FieldID, selector),
sqlgraph.To(category.Table, category.FieldID),
- sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
+ sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
)
fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step)
return fromU, nil
@@ -326,7 +325,7 @@ func (pq *PostQuery) Clone() *PostQuery {
predicates: append([]predicate.Post{}, pq.predicates...),
withContents: pq.withContents.Clone(),
withContributors: pq.withContributors.Clone(),
- withCategory: pq.withCategory.Clone(),
+ withCategories: pq.withCategories.Clone(),
// clone intermediate query.
sql: pq.sql.Clone(),
path: pq.path,
@@ -355,14 +354,14 @@ func (pq *PostQuery) WithContributors(opts ...func(*PostContributorQuery)) *Post
return pq
}
-// WithCategory tells the query-builder to eager-load the nodes that are connected to
-// the "category" edge. The optional arguments are used to configure the query builder of the edge.
-func (pq *PostQuery) WithCategory(opts ...func(*CategoryQuery)) *PostQuery {
+// WithCategories tells the query-builder to eager-load the nodes that are connected to
+// the "categories" edge. The optional arguments are used to configure the query builder of the edge.
+func (pq *PostQuery) WithCategories(opts ...func(*CategoryQuery)) *PostQuery {
query := (&CategoryClient{config: pq.config}).Query()
for _, opt := range opts {
opt(query)
}
- pq.withCategory = query
+ pq.withCategories = query
return pq
}
@@ -443,20 +442,13 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error {
func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) {
var (
nodes = []*Post{}
- withFKs = pq.withFKs
_spec = pq.querySpec()
loadedTypes = [3]bool{
pq.withContents != nil,
pq.withContributors != nil,
- pq.withCategory != nil,
+ pq.withCategories != nil,
}
)
- if pq.withCategory != nil {
- withFKs = true
- }
- if withFKs {
- _spec.Node.Columns = append(_spec.Node.Columns, post.ForeignKeys...)
- }
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Post).scanValues(nil, columns)
}
@@ -489,9 +481,10 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e
return nil, err
}
}
- if query := pq.withCategory; query != nil {
- if err := pq.loadCategory(ctx, query, nodes, nil,
- func(n *Post, e *Category) { n.Edges.Category = e }); err != nil {
+ if query := pq.withCategories; query != nil {
+ if err := pq.loadCategories(ctx, query, nodes,
+ func(n *Post) { n.Edges.Categories = []*Category{} },
+ func(n *Post, e *Category) { n.Edges.Categories = append(n.Edges.Categories, e) }); err != nil {
return nil, err
}
}
@@ -560,34 +553,63 @@ func (pq *PostQuery) loadContributors(ctx context.Context, query *PostContributo
}
return nil
}
-func (pq *PostQuery) loadCategory(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
- ids := make([]int, 0, len(nodes))
- nodeids := make(map[int][]*Post)
- for i := range nodes {
- if nodes[i].category_posts == nil {
- continue
+func (pq *PostQuery) loadCategories(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
+ edgeIDs := make([]driver.Value, len(nodes))
+ byID := make(map[int]*Post)
+ nids := make(map[int]map[*Post]struct{})
+ for i, node := range nodes {
+ edgeIDs[i] = node.ID
+ byID[node.ID] = node
+ if init != nil {
+ init(node)
}
- fk := *nodes[i].category_posts
- if _, ok := nodeids[fk]; !ok {
- ids = append(ids, fk)
- }
- nodeids[fk] = append(nodeids[fk], nodes[i])
}
- if len(ids) == 0 {
- return nil
+ query.Where(func(s *sql.Selector) {
+ joinT := sql.Table(post.CategoriesTable)
+ s.Join(joinT).On(s.C(category.FieldID), joinT.C(post.CategoriesPrimaryKey[0]))
+ s.Where(sql.InValues(joinT.C(post.CategoriesPrimaryKey[1]), edgeIDs...))
+ columns := s.SelectedColumns()
+ s.Select(joinT.C(post.CategoriesPrimaryKey[1]))
+ s.AppendSelect(columns...)
+ s.SetDistinct(false)
+ })
+ if err := query.prepareQuery(ctx); err != nil {
+ return err
}
- query.Where(category.IDIn(ids...))
- neighbors, err := query.All(ctx)
+ qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
+ return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
+ assign := spec.Assign
+ values := spec.ScanValues
+ spec.ScanValues = func(columns []string) ([]any, error) {
+ values, err := values(columns[1:])
+ if err != nil {
+ return nil, err
+ }
+ return append([]any{new(sql.NullInt64)}, values...), nil
+ }
+ spec.Assign = func(columns []string, values []any) error {
+ outValue := int(values[0].(*sql.NullInt64).Int64)
+ inValue := int(values[1].(*sql.NullInt64).Int64)
+ if nids[inValue] == nil {
+ nids[inValue] = map[*Post]struct{}{byID[outValue]: {}}
+ return assign(columns[1:], values[1:])
+ }
+ nids[inValue][byID[outValue]] = struct{}{}
+ return nil
+ }
+ })
+ })
+ neighbors, err := withInterceptors[[]*Category](ctx, query, qr, query.inters)
if err != nil {
return err
}
for _, n := range neighbors {
- nodes, ok := nodeids[n.ID]
+ nodes, ok := nids[n.ID]
if !ok {
- return fmt.Errorf(`unexpected foreign-key "category_posts" returned %v`, n.ID)
+ return fmt.Errorf(`unexpected "categories" node returned %v`, n.ID)
}
- for i := range nodes {
- assign(nodes[i], n)
+ for kn := range nodes {
+ assign(kn, n)
}
}
return nil
diff --git a/backend/ent/post_update.go b/backend/ent/post_update.go
index d936a6a..7eb44b3 100644
--- a/backend/ent/post_update.go
+++ b/backend/ent/post_update.go
@@ -109,23 +109,19 @@ func (pu *PostUpdate) AddContributors(p ...*PostContributor) *PostUpdate {
return pu.AddContributorIDs(ids...)
}
-// SetCategoryID sets the "category" edge to the Category entity by ID.
-func (pu *PostUpdate) SetCategoryID(id int) *PostUpdate {
- pu.mutation.SetCategoryID(id)
+// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
+func (pu *PostUpdate) AddCategoryIDs(ids ...int) *PostUpdate {
+ pu.mutation.AddCategoryIDs(ids...)
return pu
}
-// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
-func (pu *PostUpdate) SetNillableCategoryID(id *int) *PostUpdate {
- if id != nil {
- pu = pu.SetCategoryID(*id)
+// AddCategories adds the "categories" edges to the Category entity.
+func (pu *PostUpdate) AddCategories(c ...*Category) *PostUpdate {
+ ids := make([]int, len(c))
+ for i := range c {
+ ids[i] = c[i].ID
}
- return pu
-}
-
-// SetCategory sets the "category" edge to the Category entity.
-func (pu *PostUpdate) SetCategory(c *Category) *PostUpdate {
- return pu.SetCategoryID(c.ID)
+ return pu.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@@ -175,12 +171,27 @@ func (pu *PostUpdate) RemoveContributors(p ...*PostContributor) *PostUpdate {
return pu.RemoveContributorIDs(ids...)
}
-// ClearCategory clears the "category" edge to the Category entity.
-func (pu *PostUpdate) ClearCategory() *PostUpdate {
- pu.mutation.ClearCategory()
+// ClearCategories clears all "categories" edges to the Category entity.
+func (pu *PostUpdate) ClearCategories() *PostUpdate {
+ pu.mutation.ClearCategories()
return pu
}
+// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
+func (pu *PostUpdate) RemoveCategoryIDs(ids ...int) *PostUpdate {
+ pu.mutation.RemoveCategoryIDs(ids...)
+ return pu
+}
+
+// RemoveCategories removes "categories" edges to Category entities.
+func (pu *PostUpdate) RemoveCategories(c ...*Category) *PostUpdate {
+ ids := make([]int, len(c))
+ for i := range c {
+ ids[i] = c[i].ID
+ }
+ return pu.RemoveCategoryIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (pu *PostUpdate) Save(ctx context.Context) (int, error) {
pu.defaults()
@@ -346,12 +357,12 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
- if pu.mutation.CategoryCleared() {
+ if pu.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.M2O,
+ Rel: sqlgraph.M2M,
Inverse: true,
- Table: post.CategoryTable,
- Columns: []string{post.CategoryColumn},
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@@ -359,12 +370,28 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
- if nodes := pu.mutation.CategoryIDs(); len(nodes) > 0 {
+ if nodes := pu.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !pu.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.M2O,
+ Rel: sqlgraph.M2M,
Inverse: true,
- Table: post.CategoryTable,
- Columns: []string{post.CategoryColumn},
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := pu.mutation.CategoriesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@@ -473,23 +500,19 @@ func (puo *PostUpdateOne) AddContributors(p ...*PostContributor) *PostUpdateOne
return puo.AddContributorIDs(ids...)
}
-// SetCategoryID sets the "category" edge to the Category entity by ID.
-func (puo *PostUpdateOne) SetCategoryID(id int) *PostUpdateOne {
- puo.mutation.SetCategoryID(id)
+// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
+func (puo *PostUpdateOne) AddCategoryIDs(ids ...int) *PostUpdateOne {
+ puo.mutation.AddCategoryIDs(ids...)
return puo
}
-// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
-func (puo *PostUpdateOne) SetNillableCategoryID(id *int) *PostUpdateOne {
- if id != nil {
- puo = puo.SetCategoryID(*id)
+// AddCategories adds the "categories" edges to the Category entity.
+func (puo *PostUpdateOne) AddCategories(c ...*Category) *PostUpdateOne {
+ ids := make([]int, len(c))
+ for i := range c {
+ ids[i] = c[i].ID
}
- return puo
-}
-
-// SetCategory sets the "category" edge to the Category entity.
-func (puo *PostUpdateOne) SetCategory(c *Category) *PostUpdateOne {
- return puo.SetCategoryID(c.ID)
+ return puo.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@@ -539,12 +562,27 @@ func (puo *PostUpdateOne) RemoveContributors(p ...*PostContributor) *PostUpdateO
return puo.RemoveContributorIDs(ids...)
}
-// ClearCategory clears the "category" edge to the Category entity.
-func (puo *PostUpdateOne) ClearCategory() *PostUpdateOne {
- puo.mutation.ClearCategory()
+// ClearCategories clears all "categories" edges to the Category entity.
+func (puo *PostUpdateOne) ClearCategories() *PostUpdateOne {
+ puo.mutation.ClearCategories()
return puo
}
+// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
+func (puo *PostUpdateOne) RemoveCategoryIDs(ids ...int) *PostUpdateOne {
+ puo.mutation.RemoveCategoryIDs(ids...)
+ return puo
+}
+
+// RemoveCategories removes "categories" edges to Category entities.
+func (puo *PostUpdateOne) RemoveCategories(c ...*Category) *PostUpdateOne {
+ ids := make([]int, len(c))
+ for i := range c {
+ ids[i] = c[i].ID
+ }
+ return puo.RemoveCategoryIDs(ids...)
+}
+
// Where appends a list predicates to the PostUpdate builder.
func (puo *PostUpdateOne) Where(ps ...predicate.Post) *PostUpdateOne {
puo.mutation.Where(ps...)
@@ -740,12 +778,12 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
- if puo.mutation.CategoryCleared() {
+ if puo.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.M2O,
+ Rel: sqlgraph.M2M,
Inverse: true,
- Table: post.CategoryTable,
- Columns: []string{post.CategoryColumn},
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@@ -753,12 +791,28 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
- if nodes := puo.mutation.CategoryIDs(); len(nodes) > 0 {
+ if nodes := puo.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !puo.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
- Rel: sqlgraph.M2O,
+ Rel: sqlgraph.M2M,
Inverse: true,
- Table: post.CategoryTable,
- Columns: []string{post.CategoryColumn},
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := puo.mutation.CategoriesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: post.CategoriesTable,
+ Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
diff --git a/backend/ent/schema/media.go b/backend/ent/schema/media.go
index 5c41d70..3ad19d5 100644
--- a/backend/ent/schema/media.go
+++ b/backend/ent/schema/media.go
@@ -16,11 +16,14 @@ type Media struct {
func (Media) Fields() []ent.Field {
return []ent.Field{
field.String("storage_id").
+ StorageKey("storage_id").
NotEmpty().
Unique(),
field.String("original_name").
+ StorageKey("original_name").
NotEmpty(),
field.String("mime_type").
+ StorageKey("mime_type").
NotEmpty(),
field.Int64("size").
Positive(),
diff --git a/backend/ent/schema/post.go b/backend/ent/schema/post.go
index 05362e4..ff97e93 100644
--- a/backend/ent/schema/post.go
+++ b/backend/ent/schema/post.go
@@ -34,8 +34,7 @@ func (Post) Edges() []ent.Edge {
return []ent.Edge{
edge.To("contents", PostContent.Type),
edge.To("contributors", PostContributor.Type),
- edge.From("category", Category.Type).
- Ref("posts").
- Unique(),
+ edge.From("categories", Category.Type).
+ Ref("posts"),
}
}
diff --git a/backend/go.mod b/backend/go.mod
index cc83baa..e2141f5 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -3,13 +3,13 @@ module tss-rocks-be
go 1.23.6
require (
- bou.ke/monkey v1.0.2
entgo.io/ent v0.14.1
github.com/aws/aws-sdk-go-v2 v1.36.1
github.com/aws/aws-sdk-go-v2/config v1.29.6
github.com/aws/aws-sdk-go-v2/credentials v1.17.59
github.com/aws/aws-sdk-go-v2/service/s3 v1.76.1
github.com/chai2010/webp v1.1.1
+ github.com/disintegration/imaging v1.6.2
github.com/gin-gonic/gin v1.10.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
@@ -17,7 +17,6 @@ require (
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
- go.uber.org/mock v0.5.0
golang.org/x/crypto v0.33.0
golang.org/x/time v0.10.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
@@ -68,12 +67,12 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
- github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/zclconf/go-cty v1.16.2 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
golang.org/x/arch v0.14.0 // indirect
+ golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sync v0.11.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index f3dba0a..37a92d6 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -1,7 +1,5 @@
ariga.io/atlas v0.31.0 h1:Nw6/Jdc7OpZfiy6oh/dJAYPp5XxGYvMTWLOUutwWjeY=
ariga.io/atlas v0.31.0/go.mod h1:J3chwsQAgjDF6Ostz7JmJJRTCbtqIupUbVR/gqZrMiA=
-bou.ke/monkey v1.0.2 h1:kWcnsrCNUatbxncxR/ThdYqbytgOIArtYWqcQLQzKLI=
-bou.ke/monkey v1.0.2/go.mod h1:OqickVX3tNx6t33n1xvtTtu85YN5s6cKwVug+oHMaIA=
entgo.io/ent v0.14.1 h1:fUERL506Pqr92EPHJqr8EYxbPioflJo6PudkrEA8a/s=
entgo.io/ent v0.14.1/go.mod h1:MH6XLG0KXpkcDQhKiHfANZSzR55TJyPL5IGNpI8wpco=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
@@ -63,6 +61,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
+github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
@@ -110,8 +110,6 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
-github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
@@ -121,8 +119,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
-github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
-github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -139,7 +135,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -159,12 +154,13 @@ github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0=
github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs=
-go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
-go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4=
golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
+golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
+golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
@@ -176,10 +172,13 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
diff --git a/backend/internal/auth/auth_test.go b/backend/internal/auth/auth_test.go
deleted file mode 100644
index 109a7fe..0000000
--- a/backend/internal/auth/auth_test.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package auth
-
-import (
- "context"
- "testing"
-)
-
-func TestUserIDKey(t *testing.T) {
- // Test that the UserIDKey constant is defined correctly
- if UserIDKey != "user_id" {
- t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
- }
-
- // Test context with user ID
- ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
- value := ctx.Value(UserIDKey)
- if value != "test-user-123" {
- t.Errorf("Context value = %v, want %v", value, "test-user-123")
- }
-
- // Test context without user ID
- emptyCtx := context.Background()
- emptyValue := emptyCtx.Value(UserIDKey)
- if emptyValue != nil {
- t.Errorf("Empty context value = %v, want nil", emptyValue)
- }
-}
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index faa6bdd..f7eedd8 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -49,7 +49,7 @@ type StorageConfig struct {
Type string `yaml:"type"`
Local LocalStorage `yaml:"local"`
S3 S3Storage `yaml:"s3"`
- Upload types.UploadConfig `yaml:"upload"`
+ Upload UploadConfig `yaml:"upload"`
}
type LocalStorage struct {
@@ -66,6 +66,27 @@ type S3Storage struct {
ProxyS3 bool `yaml:"proxy_s3"`
}
+type UploadConfig struct {
+ Limits struct {
+ Image struct {
+ MaxSize int `yaml:"max_size"`
+ AllowedTypes []string `yaml:"allowed_types"`
+ } `yaml:"image"`
+ Video struct {
+ MaxSize int `yaml:"max_size"`
+ AllowedTypes []string `yaml:"allowed_types"`
+ } `yaml:"video"`
+ Audio struct {
+ MaxSize int `yaml:"max_size"`
+ AllowedTypes []string `yaml:"allowed_types"`
+ } `yaml:"audio"`
+ Document struct {
+ MaxSize int `yaml:"max_size"`
+ AllowedTypes []string `yaml:"allowed_types"`
+ } `yaml:"document"`
+ } `yaml:"limits"`
+}
+
// Load loads configuration from a YAML file
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
deleted file mode 100644
index 800f5b0..0000000
--- a/backend/internal/config/config_test.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package config
-
-import (
- "os"
- "path/filepath"
- "testing"
-)
-
-func TestLoad(t *testing.T) {
- // Create a temporary test config file
- content := []byte(`
-database:
- driver: postgres
- dsn: postgres://user:pass@localhost:5432/dbname
-server:
- port: 8080
- host: localhost
-jwt:
- secret: test-secret
- expiration: 24h
-storage:
- type: local
- local:
- root_dir: /tmp/storage
- upload:
- max_size: 10485760
-logging:
- level: info
- format: json
-`)
-
- tmpDir := t.TempDir()
- configPath := filepath.Join(tmpDir, "config.yaml")
- if err := os.WriteFile(configPath, content, 0644); err != nil {
- t.Fatalf("Failed to write test config: %v", err)
- }
-
- // Test loading config
- cfg, err := Load(configPath)
- if err != nil {
- t.Fatalf("Load() error = %v", err)
- }
-
- // Verify loaded values
- tests := []struct {
- name string
- got interface{}
- want interface{}
- errorMsg string
- }{
- {"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
- {"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
- {"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
- {"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
- {"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if tt.got != tt.want {
- t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
- }
- })
- }
-}
-
-func TestLoadError(t *testing.T) {
- // Test loading non-existent file
- _, err := Load("non-existent-file.yaml")
- if err == nil {
- t.Error("Load() error = nil, want error for non-existent file")
- }
-
- // Test loading invalid YAML
- tmpDir := t.TempDir()
- invalidPath := filepath.Join(tmpDir, "invalid.yaml")
- if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
- t.Fatalf("Failed to write invalid config: %v", err)
- }
-
- _, err = Load(invalidPath)
- if err == nil {
- t.Error("Load() error = nil, want error for invalid YAML")
- }
-}
diff --git a/backend/internal/handler/auth_handler_test.go b/backend/internal/handler/auth_handler_test.go
deleted file mode 100644
index cae7360..0000000
--- a/backend/internal/handler/auth_handler_test.go
+++ /dev/null
@@ -1,312 +0,0 @@
-package handler
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service/mock"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
- "golang.org/x/crypto/bcrypt"
-)
-
-type AuthHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *AuthHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- s.handler = NewHandler(&config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- },
- Auth: config.AuthConfig{
- Registration: struct {
- Enabled bool `yaml:"enabled"`
- Message string `yaml:"message"`
- }{
- Enabled: true,
- Message: "Registration is disabled",
- },
- },
- }, s.service)
- s.router = gin.New()
-}
-
-func (s *AuthHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestAuthHandlerSuite(t *testing.T) {
- suite.Run(t, new(AuthHandlerTestSuite))
-}
-
-type ErrorResponse struct {
- Error struct {
- Code string `json:"code"`
- Message string `json:"message"`
- } `json:"error"`
-}
-
-func (s *AuthHandlerTestSuite) TestRegister() {
- testCases := []struct {
- name string
- request RegisterRequest
- setupMock func()
- expectedStatus int
- expectedError string
- registration bool
- }{
- {
- name: "成功注册",
- request: RegisterRequest{
- Username: "testuser",
- Email: "test@example.com",
- Password: "password123",
- Role: "contributor",
- },
- setupMock: func() {
- s.service.EXPECT().
- CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
- Return(&ent.User{
- ID: 1,
- Username: "testuser",
- Email: "test@example.com",
- }, nil)
- s.service.EXPECT().
- GetUserRoles(gomock.Any(), 1).
- Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
- },
- expectedStatus: http.StatusCreated,
- registration: true,
- },
- {
- name: "注册功能已禁用",
- request: RegisterRequest{
- Username: "testuser",
- Email: "test@example.com",
- Password: "password123",
- Role: "contributor",
- },
- setupMock: func() {},
- expectedStatus: http.StatusForbidden,
- expectedError: "Registration is disabled",
- registration: false,
- },
- {
- name: "无效的邮箱格式",
- request: RegisterRequest{
- Username: "testuser",
- Email: "invalid-email",
- Password: "password123",
- Role: "contributor",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
- registration: true,
- },
- {
- name: "密码太短",
- request: RegisterRequest{
- Username: "testuser",
- Email: "test@example.com",
- Password: "short",
- Role: "contributor",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
- registration: true,
- },
- {
- name: "无效的角色",
- request: RegisterRequest{
- Username: "testuser",
- Email: "test@example.com",
- Password: "password123",
- Role: "invalid-role",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
- registration: true,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置注册功能状态
- s.handler.cfg.Auth.Registration.Enabled = tc.registration
-
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- reqBody, _ := json.Marshal(tc.request)
- req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // 执行请求
- s.handler.Register(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response ErrorResponse
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Contains(response.Error.Message, tc.expectedError)
- } else {
- var response AuthResponse
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.NotEmpty(response.Token)
- }
- })
- }
-}
-
-func (s *AuthHandlerTestSuite) TestLogin() {
- hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
-
- testCases := []struct {
- name string
- request LoginRequest
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功登录",
- request: LoginRequest{
- Username: "testuser",
- Password: "password123",
- },
- setupMock: func() {
- s.service.EXPECT().
- GetUserByUsername(gomock.Any(), "testuser").
- Return(&ent.User{
- ID: 1,
- Username: "testuser",
- PasswordHash: string(hashedPassword),
- }, nil)
- s.service.EXPECT().
- GetUserRoles(gomock.Any(), 1).
- Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
- },
- expectedStatus: http.StatusOK,
- },
- {
- name: "无效的用户名",
- request: LoginRequest{
- Username: "te",
- Password: "password123",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
- },
- {
- name: "用户不存在",
- request: LoginRequest{
- Username: "nonexistent",
- Password: "password123",
- },
- setupMock: func() {
- s.service.EXPECT().
- GetUserByUsername(gomock.Any(), "nonexistent").
- Return(nil, fmt.Errorf("user not found"))
- },
- expectedStatus: http.StatusUnauthorized,
- expectedError: "Invalid username or password",
- },
- {
- name: "密码错误",
- request: LoginRequest{
- Username: "testuser",
- Password: "wrongpassword",
- },
- setupMock: func() {
- s.service.EXPECT().
- GetUserByUsername(gomock.Any(), "testuser").
- Return(&ent.User{
- ID: 1,
- Username: "testuser",
- PasswordHash: string(hashedPassword),
- }, nil)
- },
- expectedStatus: http.StatusUnauthorized,
- expectedError: "Invalid username or password",
- },
- {
- name: "获取用户角色失败",
- request: LoginRequest{
- Username: "testuser",
- Password: "password123",
- },
- setupMock: func() {
- s.service.EXPECT().
- GetUserByUsername(gomock.Any(), "testuser").
- Return(&ent.User{
- ID: 1,
- Username: "testuser",
- PasswordHash: string(hashedPassword),
- }, nil)
- s.service.EXPECT().
- GetUserRoles(gomock.Any(), 1).
- Return(nil, fmt.Errorf("failed to get roles"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to get user roles",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- reqBody, _ := json.Marshal(tc.request)
- req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // 执行请求
- s.handler.Login(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response ErrorResponse
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Contains(response.Error.Message, tc.expectedError)
- } else {
- var response AuthResponse
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.NotEmpty(response.Token)
- }
- })
- }
-}
diff --git a/backend/internal/handler/category_handler_test.go b/backend/internal/handler/category_handler_test.go
deleted file mode 100644
index 95c4bd9..0000000
--- a/backend/internal/handler/category_handler_test.go
+++ /dev/null
@@ -1,481 +0,0 @@
-package handler
-
-import (
- "bytes"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "tss-rocks-be/ent"
- "tss-rocks-be/ent/categorycontent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service"
- "tss-rocks-be/internal/service/mock"
- "tss-rocks-be/internal/types"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "errors"
-)
-
-// Custom assertion function for comparing categories
-func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
- if expected == nil && actual == nil {
- return true
- }
- if expected == nil || actual == nil {
- return assert.Fail(t, "One category is nil while the other is not")
- }
-
- // Compare only relevant fields, ignoring time fields
- return assert.Equal(t, expected.ID, actual.ID) &&
- assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
-}
-
-// Custom assertion function for comparing category slices
-func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
- if len(expected) != len(actual) {
- return assert.Fail(t, "Category slice lengths do not match")
- }
-
- for i := range expected {
- if !assertCategoryEqual(t, expected[i], actual[i]) {
- return false
- }
- }
- return true
-}
-
-type CategoryHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *CategoryHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- },
- }
- s.handler = NewHandler(cfg, s.service)
-
- // Setup Gin router
- gin.SetMode(gin.TestMode)
- s.router = gin.New()
-
- // Setup mock for GetTokenBlacklist
- tokenBlacklist := &service.TokenBlacklist{}
- s.service.EXPECT().
- GetTokenBlacklist().
- Return(tokenBlacklist).
- AnyTimes()
-
- s.handler.RegisterRoutes(s.router)
-}
-
-func (s *CategoryHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestCategoryHandlerSuite(t *testing.T) {
- suite.Run(t, new(CategoryHandlerTestSuite))
-}
-
-// Test cases for ListCategories
-func (s *CategoryHandlerTestSuite) TestListCategories() {
- testCases := []struct {
- name string
- langCode string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success with default language",
- langCode: "",
- setupMock: func() {
- s.service.EXPECT().
- ListCategories(gomock.Any(), gomock.Eq("en")).
- Return([]*ent.Category{
- {
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: "Test Description",
- Slug: "test-category",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Category{
- {
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: "Test Description",
- Slug: "test-category",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with specific language",
- langCode: "zh",
- setupMock: func() {
- s.service.EXPECT().
- ListCategories(gomock.Any(), gomock.Eq("zh")).
- Return([]*ent.Category{
- {
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("zh"),
- Name: "测试分类",
- Description: "测试描述",
- Slug: "test-category",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Category{
- {
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("zh"),
- Name: "测试分类",
- Description: "测试描述",
- Slug: "test-category",
- },
- },
- },
- },
- },
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup mock
- tc.setupMock()
-
- // Create request
- url := "/api/v1/categories"
- if tc.langCode != "" {
- url += "?lang=" + tc.langCode
- }
- req := httptest.NewRequest(http.MethodGet, url, nil)
- w := httptest.NewRecorder()
-
- // Perform request
- s.router.ServeHTTP(w, req)
-
- // Assert response
- assert.Equal(s.T(), tc.expectedStatus, w.Code)
- if tc.expectedBody != nil {
- var response []*ent.Category
- err := json.Unmarshal(w.Body.Bytes(), &response)
- assert.NoError(s.T(), err)
- assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
- }
- })
- }
-}
-
-// Test cases for GetCategory
-func (s *CategoryHandlerTestSuite) TestGetCategory() {
- testCases := []struct {
- name string
- langCode string
- slug string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- langCode: "en",
- slug: "test-category",
- setupMock: func() {
- s.service.EXPECT().
- GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
- Return(&ent.Category{
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: "Test Description",
- Slug: "test-category",
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: &ent.Category{
- ID: 1,
- Edges: ent.CategoryEdges{
- Contents: []*ent.CategoryContent{
- {
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: "Test Description",
- Slug: "test-category",
- },
- },
- },
- },
- },
- {
- name: "Not Found",
- langCode: "en",
- slug: "non-existent",
- setupMock: func() {
- s.service.EXPECT().
- GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
- Return(nil, types.ErrNotFound)
- },
- expectedStatus: http.StatusNotFound,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup mock
- tc.setupMock()
-
- // Create request
- url := "/api/v1/categories/" + tc.slug
- if tc.langCode != "" {
- url += "?lang=" + tc.langCode
- }
- req := httptest.NewRequest(http.MethodGet, url, nil)
- w := httptest.NewRecorder()
-
- // Perform request
- s.router.ServeHTTP(w, req)
-
- // Assert response
- assert.Equal(s.T(), tc.expectedStatus, w.Code)
- if tc.expectedBody != nil {
- var response ent.Category
- err := json.Unmarshal(w.Body.Bytes(), &response)
- assert.NoError(s.T(), err)
- assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
- }
- })
- }
-}
-
-// Test cases for AddCategoryContent
-func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
- var description = "Test Description"
- testCases := []struct {
- name string
- categoryID string
- requestBody interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- categoryID: "1",
- requestBody: AddCategoryContentRequest{
- LanguageCode: "en",
- Name: "Test Category",
- Description: &description,
- Slug: "test-category",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddCategoryContent(
- gomock.Any(),
- 1,
- "en",
- "Test Category",
- description,
- "test-category",
- ).
- Return(&ent.CategoryContent{
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: description,
- Slug: "test-category",
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: &ent.CategoryContent{
- LanguageCode: categorycontent.LanguageCode("en"),
- Name: "Test Category",
- Description: description,
- Slug: "test-category",
- },
- },
- {
- name: "Invalid JSON",
- categoryID: "1",
- requestBody: "invalid json",
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- },
- {
- name: "Invalid Category ID",
- categoryID: "invalid",
- requestBody: AddCategoryContentRequest{
- LanguageCode: "en",
- Name: "Test Category",
- Description: &description,
- Slug: "test-category",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- },
- {
- name: "Service Error",
- categoryID: "1",
- requestBody: AddCategoryContentRequest{
- LanguageCode: "en",
- Name: "Test Category",
- Description: &description,
- Slug: "test-category",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddCategoryContent(
- gomock.Any(),
- 1,
- "en",
- "Test Category",
- description,
- "test-category",
- ).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup mock
- tc.setupMock()
-
- // Create request
- var body []byte
- var err error
- if str, ok := tc.requestBody.(string); ok {
- body = []byte(str)
- } else {
- body, err = json.Marshal(tc.requestBody)
- s.NoError(err)
- }
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
-
- // Perform request
- s.router.ServeHTTP(w, req)
-
- // Assert response
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedBody != nil {
- var response ent.CategoryContent
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedBody, &response)
- }
- })
- }
-}
-
-// Test cases for CreateCategory
-func (s *CategoryHandlerTestSuite) TestCreateCategory() {
- testCases := []struct {
- name string
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功创建分类",
- setupMock: func() {
- category := &ent.Category{
- ID: 1,
- }
- s.service.EXPECT().
- CreateCategory(gomock.Any()).
- Return(category, nil)
- },
- expectedStatus: http.StatusCreated,
- },
- {
- name: "创建分类失败",
- setupMock: func() {
- s.service.EXPECT().
- CreateCategory(gomock.Any()).
- Return(nil, errors.New("failed to create category"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to create category",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // 执行请求
- s.handler.CreateCategory(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response map[string]string
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedError, response["error"])
- } else {
- var response *ent.Category
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.NotNil(response)
- s.Equal(1, response.ID)
- }
- })
- }
-}
diff --git a/backend/internal/handler/contributor_handler_test.go b/backend/internal/handler/contributor_handler_test.go
deleted file mode 100644
index 50f36da..0000000
--- a/backend/internal/handler/contributor_handler_test.go
+++ /dev/null
@@ -1,456 +0,0 @@
-package handler
-
-import (
- "bytes"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service"
- "tss-rocks-be/internal/service/mock"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "errors"
-)
-
-type ContributorHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *ContributorHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- },
- }
- s.handler = NewHandler(cfg, s.service)
-
- // Setup Gin router
- gin.SetMode(gin.TestMode)
- s.router = gin.New()
-
- // Setup mock for GetTokenBlacklist
- tokenBlacklist := &service.TokenBlacklist{}
- s.service.EXPECT().
- GetTokenBlacklist().
- Return(tokenBlacklist).
- AnyTimes()
-
- s.handler.RegisterRoutes(s.router)
-}
-
-func (s *ContributorHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestContributorHandlerSuite(t *testing.T) {
- suite.Run(t, new(ContributorHandlerTestSuite))
-}
-
-func (s *ContributorHandlerTestSuite) TestListContributors() {
- testCases := []struct {
- name string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- setupMock: func() {
- s.service.EXPECT().
- ListContributors(gomock.Any()).
- Return([]*ent.Contributor{
- {
- ID: 1,
- Name: "John Doe",
- Edges: ent.ContributorEdges{
- SocialLinks: []*ent.ContributorSocialLink{
- {
- Type: "github",
- Value: "https://github.com/johndoe",
- Edges: ent.ContributorSocialLinkEdges{},
- },
- },
- },
- CreatedAt: time.Time{},
- UpdatedAt: time.Time{},
- },
- {
- ID: 2,
- Name: "Jane Smith",
- Edges: ent.ContributorEdges{
- SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
- },
- CreatedAt: time.Time{},
- UpdatedAt: time.Time{},
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []gin.H{
- {
- "id": 1,
- "name": "John Doe",
- "created_at": time.Time{},
- "updated_at": time.Time{},
- "edges": gin.H{
- "social_links": []gin.H{
- {
- "type": "github",
- "value": "https://github.com/johndoe",
- "edges": gin.H{},
- },
- },
- },
- },
- {
- "id": 2,
- "name": "Jane Smith",
- "created_at": time.Time{},
- "updated_at": time.Time{},
- "edges": gin.H{
- "social_links": []gin.H{}, // Ensure empty SocialLinks array is present
- },
- },
- },
- },
- {
- name: "Service error",
- setupMock: func() {
- s.service.EXPECT().
- ListContributors(gomock.Any()).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to list contributors"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *ContributorHandlerTestSuite) TestGetContributor() {
- testCases := []struct {
- name string
- id string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- id: "1",
- setupMock: func() {
- s.service.EXPECT().
- GetContributorByID(gomock.Any(), 1).
- Return(&ent.Contributor{
- ID: 1,
- Name: "John Doe",
- Edges: ent.ContributorEdges{
- SocialLinks: []*ent.ContributorSocialLink{
- {
- Type: "github",
- Value: "https://github.com/johndoe",
- Edges: ent.ContributorSocialLinkEdges{},
- },
- },
- },
- CreatedAt: time.Time{},
- UpdatedAt: time.Time{},
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: gin.H{
- "id": 1,
- "name": "John Doe",
- "created_at": time.Time{},
- "updated_at": time.Time{},
- "edges": gin.H{
- "social_links": []gin.H{
- {
- "type": "github",
- "value": "https://github.com/johndoe",
- "edges": gin.H{},
- },
- },
- },
- },
- },
- {
- name: "Invalid ID",
- id: "invalid",
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Invalid contributor ID"},
- },
- {
- name: "Service error",
- id: "1",
- setupMock: func() {
- s.service.EXPECT().
- GetContributorByID(gomock.Any(), 1).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to get contributor"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *ContributorHandlerTestSuite) TestCreateContributor() {
- testCases := []struct {
- name string
- body interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- body: CreateContributorRequest{
- Name: "John Doe",
- },
- setupMock: func() {
- name := "John Doe"
- s.service.EXPECT().
- CreateContributor(
- gomock.Any(),
- name,
- nil,
- nil,
- ).
- Return(&ent.Contributor{
- ID: 1,
- Name: name,
- CreatedAt: time.Time{},
- UpdatedAt: time.Time{},
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: gin.H{
- "id": 1,
- "name": "John Doe",
- "created_at": time.Time{},
- "updated_at": time.Time{},
- "edges": gin.H{},
- },
- },
- {
- name: "Invalid request body",
- body: map[string]interface{}{
- "name": "", // Empty name is not allowed
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
- },
- {
- name: "Service error",
- body: CreateContributorRequest{
- Name: "John Doe",
- },
- setupMock: func() {
- name := "John Doe"
- s.service.EXPECT().
- CreateContributor(
- gomock.Any(),
- name,
- nil,
- nil,
- ).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to create contributor"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- body, err := json.Marshal(tc.body)
- s.NoError(err, "Failed to marshal request body")
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
- testCases := []struct {
- name string
- id string
- body interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- id: "1",
- body: func() AddContributorSocialLinkRequest {
- name := "johndoe"
- return AddContributorSocialLinkRequest{
- Type: "github",
- Name: &name,
- Value: "https://github.com/johndoe",
- }
- }(),
- setupMock: func() {
- name := "johndoe"
- s.service.EXPECT().
- AddContributorSocialLink(
- gomock.Any(),
- 1,
- "github",
- name,
- "https://github.com/johndoe",
- ).
- Return(&ent.ContributorSocialLink{
- Type: "github",
- Name: name,
- Value: "https://github.com/johndoe",
- Edges: ent.ContributorSocialLinkEdges{},
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: gin.H{
- "type": "github",
- "name": "johndoe",
- "value": "https://github.com/johndoe",
- "edges": gin.H{},
- },
- },
- {
- name: "Invalid contributor ID",
- id: "invalid",
- body: func() AddContributorSocialLinkRequest {
- name := "johndoe"
- return AddContributorSocialLinkRequest{
- Type: "github",
- Name: &name,
- Value: "https://github.com/johndoe",
- }
- }(),
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Invalid contributor ID"},
- },
- {
- name: "Invalid request body",
- id: "1",
- body: map[string]interface{}{
- "type": "", // Empty type is not allowed
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
- },
- {
- name: "Service error",
- id: "1",
- body: func() AddContributorSocialLinkRequest {
- name := "johndoe"
- return AddContributorSocialLinkRequest{
- Type: "github",
- Name: &name,
- Value: "https://github.com/johndoe",
- }
- }(),
- setupMock: func() {
- name := "johndoe"
- s.service.EXPECT().
- AddContributorSocialLink(
- gomock.Any(),
- 1,
- "github",
- name,
- "https://github.com/johndoe",
- ).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to add contributor social link"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- body, err := json.Marshal(tc.body)
- s.NoError(err, "Failed to marshal request body")
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
diff --git a/backend/internal/handler/daily_handler_test.go b/backend/internal/handler/daily_handler_test.go
deleted file mode 100644
index 24c94cc..0000000
--- a/backend/internal/handler/daily_handler_test.go
+++ /dev/null
@@ -1,532 +0,0 @@
-package handler
-
-import (
- "bytes"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service"
- "tss-rocks-be/internal/service/mock"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "errors"
- "strings"
-)
-
-type DailyHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *DailyHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- },
- }
- s.handler = NewHandler(cfg, s.service)
-
- // Setup Gin router
- gin.SetMode(gin.TestMode)
- s.router = gin.New()
-
- // Setup mock for GetTokenBlacklist
- tokenBlacklist := &service.TokenBlacklist{}
- s.service.EXPECT().
- GetTokenBlacklist().
- Return(tokenBlacklist).
- AnyTimes()
-
- s.handler.RegisterRoutes(s.router)
-}
-
-func (s *DailyHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestDailyHandlerSuite(t *testing.T) {
- suite.Run(t, new(DailyHandlerTestSuite))
-}
-
-func (s *DailyHandlerTestSuite) TestListDailies() {
- testCases := []struct {
- name string
- langCode string
- categoryID string
- limit string
- offset string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success with default language",
- langCode: "",
- setupMock: func() {
- s.service.EXPECT().
- ListDailies(gomock.Any(), "en", nil, 10, 0).
- Return([]*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with specific language",
- langCode: "zh",
- setupMock: func() {
- s.service.EXPECT().
- ListDailies(gomock.Any(), "zh", nil, 10, 0).
- Return([]*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "zh",
- Quote: "测试语录1",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "zh",
- Quote: "测试语录1",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with category filter",
- categoryID: "1",
- setupMock: func() {
- categoryID := 1
- s.service.EXPECT().
- ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
- Return([]*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Daily{
- {
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with pagination",
- limit: "2",
- offset: "1",
- setupMock: func() {
- s.service.EXPECT().
- ListDailies(gomock.Any(), "en", nil, 2, 1).
- Return([]*ent.Daily{
- {
- ID: "daily2",
- ImageURL: "https://example.com/image2.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 2",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Daily{
- {
- ID: "daily2",
- ImageURL: "https://example.com/image2.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 2",
- },
- },
- },
- },
- },
- },
- {
- name: "Service Error",
- setupMock: func() {
- s.service.EXPECT().
- ListDailies(gomock.Any(), "en", nil, 10, 0).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to list dailies"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- url := "/api/v1/dailies"
- if tc.langCode != "" {
- url += "?lang=" + tc.langCode
- }
- if tc.categoryID != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "category_id=" + tc.categoryID
- }
- if tc.limit != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "limit=" + tc.limit
- }
- if tc.offset != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "offset=" + tc.offset
- }
-
- req := httptest.NewRequest(http.MethodGet, url, nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *DailyHandlerTestSuite) TestGetDaily() {
- testCases := []struct {
- name string
- id string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- id: "daily1",
- setupMock: func() {
- s.service.EXPECT().
- GetDailyByID(gomock.Any(), "daily1").
- Return(&ent.Daily{
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: &ent.Daily{
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{
- {
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- },
- },
- },
- {
- name: "Service error",
- id: "daily1",
- setupMock: func() {
- s.service.EXPECT().
- GetDailyByID(gomock.Any(), "daily1").
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to get daily"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *DailyHandlerTestSuite) TestCreateDaily() {
- testCases := []struct {
- name string
- body interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- body: CreateDailyRequest{
- ID: "daily1",
- CategoryID: 1,
- ImageURL: "https://example.com/image1.jpg",
- },
- setupMock: func() {
- s.service.EXPECT().
- CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
- Return(&ent.Daily{
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{},
- },
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: &ent.Daily{
- ID: "daily1",
- ImageURL: "https://example.com/image1.jpg",
- Edges: ent.DailyEdges{
- Category: &ent.Category{ID: 1},
- Contents: []*ent.DailyContent{},
- },
- },
- },
- {
- name: "Invalid request body",
- body: map[string]interface{}{
- "id": "daily1",
- // Missing required fields
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
- },
- {
- name: "Service error",
- body: CreateDailyRequest{
- ID: "daily1",
- CategoryID: 1,
- ImageURL: "https://example.com/image1.jpg",
- },
- setupMock: func() {
- s.service.EXPECT().
- CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to create daily"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- body, err := json.Marshal(tc.body)
- s.NoError(err, "Failed to marshal request body")
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-func (s *DailyHandlerTestSuite) TestAddDailyContent() {
- testCases := []struct {
- name string
- dailyID string
- body interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- dailyID: "daily1",
- body: AddDailyContentRequest{
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
- Return(&ent.DailyContent{
- LanguageCode: "en",
- Quote: "Test Quote 1",
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: &ent.DailyContent{
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- },
- {
- name: "Invalid request body",
- dailyID: "daily1",
- body: map[string]interface{}{
- "language_code": "en",
- // Missing required fields
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
- },
- {
- name: "Service error",
- dailyID: "daily1",
- body: AddDailyContentRequest{
- LanguageCode: "en",
- Quote: "Test Quote 1",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to add daily content"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- body, err := json.Marshal(tc.body)
- s.NoError(err, "Failed to marshal request body")
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 5afcaf0..11c5b3e 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -25,8 +25,9 @@ func NewHandler(cfg *config.Config, service service.Service) *Handler {
}
// RegisterRoutes registers all the routes
-func (h *Handler) RegisterRoutes(r *gin.Engine) {
- api := r.Group("/api/v1")
+func (h *Handler) RegisterRoutes(router *gin.Engine) {
+ // API routes
+ api := router.Group("/api/v1")
{
// Auth routes
auth := api.Group("/auth")
@@ -93,6 +94,9 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
media.DELETE("/:id", h.DeleteMedia)
}
}
+
+ // Public media files
+ router.GET("/media/:year/:month/:filename", h.GetMediaFile)
}
// Category handlers
@@ -186,10 +190,12 @@ func (h *Handler) ListPosts(c *gin.Context) {
langCode = "en" // Default to English
}
- var categoryID *int
- if catIDStr := c.Query("category_id"); catIDStr != "" {
- if id, err := strconv.Atoi(catIDStr); err == nil {
- categoryID = &id
+ var categoryIDs []int
+ if catIDsStr := c.QueryArray("category_ids"); len(catIDsStr) > 0 {
+ for _, idStr := range catIDsStr {
+ if id, err := strconv.Atoi(idStr); err == nil {
+ categoryIDs = append(categoryIDs, id)
+ }
}
}
@@ -207,7 +213,7 @@ func (h *Handler) ListPosts(c *gin.Context) {
}
}
- posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset)
+ posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryIDs, limit, offset)
if err != nil {
log.Error().Err(err).Msg("Failed to list posts")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"})
@@ -244,10 +250,10 @@ func (h *Handler) GetPost(c *gin.Context) {
contents := make([]gin.H, 0, len(post.Edges.Contents))
for _, content := range post.Edges.Contents {
contents = append(contents, gin.H{
- "language_code": content.LanguageCode,
- "title": content.Title,
+ "language_code": content.LanguageCode,
+ "title": content.Title,
"content_markdown": content.ContentMarkdown,
- "summary": content.Summary,
+ "summary": content.Summary,
})
}
response["edges"].(gin.H)["contents"] = contents
@@ -256,7 +262,22 @@ func (h *Handler) GetPost(c *gin.Context) {
}
func (h *Handler) CreatePost(c *gin.Context) {
- post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status
+ var req struct {
+ Status string `json:"status" binding:"omitempty,oneof=draft published archived"`
+ CategoryIDs []int `json:"category_ids" binding:"required,min=1"`
+ }
+
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+
+ status := req.Status
+ if status == "" {
+ status = "draft" // Default to draft status
+ }
+
+ post, err := h.service.CreatePost(c.Request.Context(), status, req.CategoryIDs)
if err != nil {
log.Error().Err(err).Msg("Failed to create post")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"})
@@ -268,7 +289,8 @@ func (h *Handler) CreatePost(c *gin.Context) {
"id": post.ID,
"status": post.Status,
"edges": gin.H{
- "contents": []interface{}{},
+ "contents": []interface{}{},
+ "categories": []interface{}{},
},
}
@@ -276,7 +298,7 @@ func (h *Handler) CreatePost(c *gin.Context) {
}
type AddPostContentRequest struct {
- LanguageCode string `json:"language_code" binding:"required"`
+ LanguageCode string `json:"language_code" binding:"required"`
Title string `json:"title" binding:"required"`
ContentMarkdown string `json:"content_markdown" binding:"required"`
Summary string `json:"summary" binding:"required"`
@@ -308,10 +330,10 @@ func (h *Handler) AddPostContent(c *gin.Context) {
"title": content.Title,
"content_markdown": content.ContentMarkdown,
"language_code": content.LanguageCode,
- "summary": content.Summary,
+ "summary": content.Summary,
"meta_keywords": content.MetaKeywords,
"meta_description": content.MetaDescription,
- "edges": gin.H{},
+ "edges": gin.H{},
})
}
@@ -517,11 +539,3 @@ func (h *Handler) AddDailyContent(c *gin.Context) {
c.JSON(http.StatusCreated, content)
}
-
-// Helper functions
-func stringPtr(s *string) string {
- if s == nil {
- return ""
- }
- return *s
-}
diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go
deleted file mode 100644
index a9d3b9f..0000000
--- a/backend/internal/handler/handler_test.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package handler
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestStringPtr(t *testing.T) {
- testCases := []struct {
- name string
- input *string
- expected string
- }{
- {
- name: "nil pointer",
- input: nil,
- expected: "",
- },
- {
- name: "empty string",
- input: strPtr(""),
- expected: "",
- },
- {
- name: "non-empty string",
- input: strPtr("test"),
- expected: "test",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := stringPtr(tc.input)
- assert.Equal(t, tc.expected, result)
- })
- }
-}
-
-// Helper function to create string pointer
-func strPtr(s string) *string {
- return &s
-}
diff --git a/backend/internal/handler/media.go b/backend/internal/handler/media.go
index 68d4a4b..1ad429b 100644
--- a/backend/internal/handler/media.go
+++ b/backend/internal/handler/media.go
@@ -5,9 +5,11 @@ import (
"io"
"net/http"
"strconv"
+ "strings"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
+ "path/filepath"
)
// Media handlers
@@ -33,17 +35,29 @@ func (h *Handler) ListMedia(c *gin.Context) {
return
}
- c.JSON(http.StatusOK, media)
+ c.JSON(http.StatusOK, gin.H{
+ "data": media,
+ })
}
func (h *Handler) UploadMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
- userID, exists := c.Get("user_id")
+ userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
+ // Convert user ID to int
+ userID, err := strconv.Atoi(userIDStr.(string))
+ if err != nil {
+ log.Error().Err(err).
+ Str("user_id", fmt.Sprintf("%v", userIDStr)).
+ Msg("Failed to convert user ID to int")
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
+ return
+ }
+
// Get file from form
file, err := c.FormFile("file")
if err != nil {
@@ -51,38 +65,107 @@ func (h *Handler) UploadMedia(c *gin.Context) {
return
}
- // 文件大小限制
- if file.Size > 10*1024*1024 { // 10MB
- c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"})
+ // 获取文件类型和扩展名
+ contentType := file.Header.Get("Content-Type")
+ ext := strings.ToLower(filepath.Ext(file.Filename))
+ if contentType == "" {
+ // 如果 Content-Type 为空,尝试从文件扩展名判断
+ switch ext {
+ case ".jpg", ".jpeg":
+ contentType = "image/jpeg"
+ case ".png":
+ contentType = "image/png"
+ case ".gif":
+ contentType = "image/gif"
+ case ".webp":
+ contentType = "image/webp"
+ case ".mp4":
+ contentType = "video/mp4"
+ case ".webm":
+ contentType = "video/webm"
+ case ".mp3":
+ contentType = "audio/mpeg"
+ case ".ogg":
+ contentType = "audio/ogg"
+ case ".wav":
+ contentType = "audio/wav"
+ case ".pdf":
+ contentType = "application/pdf"
+ case ".doc":
+ contentType = "application/msword"
+ case ".docx":
+ contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
+ }
+ }
+
+ // 根据 Content-Type 确定文件类型和限制
+ var maxSize int64
+ var allowedTypes []string
+ var fileType string
+
+ limits := h.cfg.Storage.Upload.Limits
+ switch {
+ case strings.HasPrefix(contentType, "image/"):
+ maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Image.AllowedTypes
+ fileType = "image"
+ case strings.HasPrefix(contentType, "video/"):
+ maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Video.AllowedTypes
+ fileType = "video"
+ case strings.HasPrefix(contentType, "audio/"):
+ maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Audio.AllowedTypes
+ fileType = "audio"
+ case strings.HasPrefix(contentType, "application/"):
+ maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Document.AllowedTypes
+ fileType = "document"
+ default:
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": "Unsupported file type",
+ })
return
}
- // 文件类型限制
- allowedTypes := map[string]bool{
- "image/jpeg": true,
- "image/png": true,
- "image/gif": true,
- "video/mp4": true,
- "video/webm": true,
- "audio/mpeg": true,
- "audio/ogg": true,
- "application/pdf": true,
+ // 检查文件类型是否允许
+ typeAllowed := false
+ for _, allowed := range allowedTypes {
+ if contentType == allowed {
+ typeAllowed = true
+ break
+ }
}
- contentType := file.Header.Get("Content-Type")
- if _, ok := allowedTypes[contentType]; !ok {
- c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"})
+ if !typeAllowed {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
+ })
+ return
+ }
+
+ // 检查文件大小
+ if file.Size > maxSize {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
+ })
return
}
// Upload file
- media, err := h.service.Upload(c.Request.Context(), file, userID.(int))
+ media, err := h.service.Upload(c.Request.Context(), file, userID)
if err != nil {
- log.Error().Err(err).Msg("Failed to upload media")
+ log.Error().Err(err).
+ Str("filename", file.Filename).
+ Str("content_type", contentType).
+ Int("user_id", userID).
+ Msg("Failed to upload media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"})
return
}
- c.JSON(http.StatusCreated, media)
+ c.JSON(http.StatusCreated, gin.H{
+ "data": media,
+ })
}
func (h *Handler) GetMedia(c *gin.Context) {
@@ -101,7 +184,7 @@ func (h *Handler) GetMedia(c *gin.Context) {
}
// Get file content
- reader, info, err := h.service.GetFile(c.Request.Context(), id)
+ reader, info, err := h.service.GetFile(c.Request.Context(), media.StorageID)
if err != nil {
log.Error().Err(err).Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
@@ -122,16 +205,18 @@ func (h *Handler) GetMedia(c *gin.Context) {
}
func (h *Handler) GetMediaFile(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
- return
- }
+ year := c.Param("year")
+ month := c.Param("month")
+ filename := c.Param("filename")
// Get file content
- reader, info, err := h.service.GetFile(c.Request.Context(), id)
+ reader, info, err := h.service.GetFile(c.Request.Context(), filename) // 直接使用完整的文件名
if err != nil {
- log.Error().Err(err).Msg("Failed to get media file")
+ log.Error().Err(err).
+ Str("year", year).
+ Str("month", month).
+ Str("filename", filename).
+ Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
return
}
@@ -151,20 +236,33 @@ func (h *Handler) GetMediaFile(c *gin.Context) {
func (h *Handler) DeleteMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
- userID, exists := c.Get("user_id")
+ userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
+ // Convert user ID to int
+ userID, err := strconv.Atoi(userIDStr.(string))
+ if err != nil {
+ log.Error().Err(err).
+ Str("user_id", fmt.Sprintf("%v", userIDStr)).
+ Msg("Failed to convert user ID to int")
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
+ return
+ }
+
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
- if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil {
- log.Error().Err(err).Msg("Failed to delete media")
+ if err := h.service.DeleteMedia(c.Request.Context(), id, userID); err != nil {
+ log.Error().Err(err).
+ Int("media_id", id).
+ Int("user_id", userID).
+ Msg("Failed to delete media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"})
return
}
diff --git a/backend/internal/handler/media_handler_test.go b/backend/internal/handler/media_handler_test.go
deleted file mode 100644
index 5452129..0000000
--- a/backend/internal/handler/media_handler_test.go
+++ /dev/null
@@ -1,524 +0,0 @@
-package handler
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service/mock"
- "tss-rocks-be/internal/storage"
-
- "net/textproto"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-)
-
-type MediaHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *MediaHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- s.handler = NewHandler(&config.Config{}, s.service)
- s.router = gin.New()
-}
-
-func (s *MediaHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestMediaHandlerSuite(t *testing.T) {
- suite.Run(t, new(MediaHandlerTestSuite))
-}
-
-func (s *MediaHandlerTestSuite) TestListMedia() {
- testCases := []struct {
- name string
- query string
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功列出媒体",
- query: "?limit=10&offset=0",
- setupMock: func() {
- s.service.EXPECT().
- ListMedia(gomock.Any(), 10, 0).
- Return([]*ent.Media{{ID: 1}}, nil)
- },
- expectedStatus: http.StatusOK,
- },
- {
- name: "使用默认限制和偏移",
- query: "",
- setupMock: func() {
- s.service.EXPECT().
- ListMedia(gomock.Any(), 10, 0).
- Return([]*ent.Media{{ID: 1}}, nil)
- },
- expectedStatus: http.StatusOK,
- },
- {
- name: "列出媒体失败",
- query: "",
- setupMock: func() {
- s.service.EXPECT().
- ListMedia(gomock.Any(), 10, 0).
- Return(nil, errors.New("failed to list media"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to list media",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // 执行请求
- s.handler.ListMedia(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response map[string]string
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedError, response["error"])
- } else {
- var response []*ent.Media
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.NotEmpty(response)
- }
- })
- }
-}
-
-func (s *MediaHandlerTestSuite) TestUploadMedia() {
- testCases := []struct {
- name string
- setupRequest func() (*http.Request, error)
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功上传媒体",
- setupRequest: func() (*http.Request, error) {
- body := &bytes.Buffer{}
- writer := multipart.NewWriter(body)
-
- // 创建文件部分
- fileHeader := make(textproto.MIMEHeader)
- fileHeader.Set("Content-Type", "image/jpeg")
- fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
- part, err := writer.CreatePart(fileHeader)
- if err != nil {
- return nil, err
- }
- testContent := "test content"
- _, err = io.Copy(part, strings.NewReader(testContent))
- if err != nil {
- return nil, err
- }
- writer.Close()
-
- req := httptest.NewRequest(http.MethodPost, "/media", body)
- req.Header.Set("Content-Type", writer.FormDataContentType())
- return req, nil
- },
- setupMock: func() {
- expectedFile := &multipart.FileHeader{
- Filename: "test.jpg",
- Size: int64(len("test content")),
- Header: textproto.MIMEHeader{
- "Content-Type": []string{"image/jpeg"},
- },
- }
- s.service.EXPECT().
- Upload(gomock.Any(), gomock.Any(), 1).
- DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
- s.Equal(expectedFile.Filename, f.Filename)
- s.Equal(expectedFile.Size, f.Size)
- s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
- return &ent.Media{ID: 1}, nil
- })
- },
- expectedStatus: http.StatusCreated,
- },
- {
- name: "未授权",
- setupRequest: func() (*http.Request, error) {
- req := httptest.NewRequest(http.MethodPost, "/media", nil)
- return req, nil
- },
- setupMock: func() {},
- expectedStatus: http.StatusUnauthorized,
- expectedError: "Unauthorized",
- },
- {
- name: "上传失败",
- setupRequest: func() (*http.Request, error) {
- body := &bytes.Buffer{}
- writer := multipart.NewWriter(body)
-
- // 创建文件部分
- fileHeader := make(textproto.MIMEHeader)
- fileHeader.Set("Content-Type", "image/jpeg")
- fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
- part, err := writer.CreatePart(fileHeader)
- if err != nil {
- return nil, err
- }
- testContent := "test content"
- _, err = io.Copy(part, strings.NewReader(testContent))
- if err != nil {
- return nil, err
- }
- writer.Close()
-
- req := httptest.NewRequest(http.MethodPost, "/media", body)
- req.Header.Set("Content-Type", writer.FormDataContentType())
- return req, nil
- },
- setupMock: func() {
- s.service.EXPECT().
- Upload(gomock.Any(), gomock.Any(), 1).
- Return(nil, errors.New("failed to upload"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to upload media",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- req, err := tc.setupRequest()
- s.NoError(err)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // 设置用户ID(除了未授权的测试用例)
- if tc.expectedError != "Unauthorized" {
- c.Set("user_id", 1)
- }
-
- // 执行请求
- s.handler.UploadMedia(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response map[string]string
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedError, response["error"])
- } else {
- var response *ent.Media
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.NotNil(response)
- }
- })
- }
-}
-
-func (s *MediaHandlerTestSuite) TestGetMedia() {
- testCases := []struct {
- name string
- mediaID string
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功获取媒体",
- mediaID: "1",
- setupMock: func() {
- media := &ent.Media{
- ID: 1,
- MimeType: "image/jpeg",
- OriginalName: "test.jpg",
- }
- s.service.EXPECT().
- GetMedia(gomock.Any(), 1).
- Return(media, nil)
- s.service.EXPECT().
- GetFile(gomock.Any(), 1).
- Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
- Size: 11,
- Name: "test.jpg",
- ContentType: "image/jpeg",
- }, nil)
- },
- expectedStatus: http.StatusOK,
- },
- {
- name: "无效的媒体ID",
- mediaID: "invalid",
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Invalid media ID",
- },
- {
- name: "获取媒体元数据失败",
- mediaID: "1",
- setupMock: func() {
- s.service.EXPECT().
- GetMedia(gomock.Any(), 1).
- Return(nil, errors.New("failed to get media"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to get media",
- },
- {
- name: "获取媒体文件失败",
- mediaID: "1",
- setupMock: func() {
- media := &ent.Media{
- ID: 1,
- MimeType: "image/jpeg",
- OriginalName: "test.jpg",
- }
- s.service.EXPECT().
- GetMedia(gomock.Any(), 1).
- Return(media, nil)
- s.service.EXPECT().
- GetFile(gomock.Any(), 1).
- Return(nil, nil, errors.New("failed to get file"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to get media file",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // Extract ID from URL path
- parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
- if len(parts) >= 2 {
- c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
- }
-
- // 执行请求
- s.handler.GetMedia(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response map[string]string
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedError, response["error"])
- } else {
- s.Equal("image/jpeg", w.Header().Get("Content-Type"))
- s.Equal("11", w.Header().Get("Content-Length"))
- s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
- s.Equal("test content", w.Body.String())
- }
- })
- }
-}
-
-func (s *MediaHandlerTestSuite) TestGetMediaFile() {
- testCases := []struct {
- name string
- setupRequest func() (*http.Request, error)
- setupMock func()
- expectedStatus int
- expectedBody []byte
- }{
- {
- name: "成功获取媒体文件",
- setupRequest: func() (*http.Request, error) {
- return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
- },
- setupMock: func() {
- fileContent := "test file content"
- s.service.EXPECT().
- GetFile(gomock.Any(), 1).
- Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
- Name: "test.jpg",
- Size: int64(len(fileContent)),
- ContentType: "image/jpeg",
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []byte("test file content"),
- },
- {
- name: "无效的媒体ID",
- setupRequest: func() (*http.Request, error) {
- return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- },
- {
- name: "获取媒体文件失败",
- setupRequest: func() (*http.Request, error) {
- return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
- },
- setupMock: func() {
- s.service.EXPECT().
- GetFile(gomock.Any(), 1).
- Return(nil, nil, errors.New("failed to get file"))
- },
- expectedStatus: http.StatusInternalServerError,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup
- req, err := tc.setupRequest()
- s.Require().NoError(err)
-
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // Extract ID from URL path
- parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
- if len(parts) >= 2 {
- c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
- }
-
- // Setup mock
- tc.setupMock()
-
- // Test
- s.handler.GetMediaFile(c)
-
- // Verify
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedBody != nil {
- s.Equal(tc.expectedBody, w.Body.Bytes())
- s.Equal("image/jpeg", w.Header().Get("Content-Type"))
- s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
- s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
- }
- })
- }
-}
-
-func (s *MediaHandlerTestSuite) TestDeleteMedia() {
- testCases := []struct {
- name string
- mediaID string
- setupMock func()
- expectedStatus int
- expectedError string
- }{
- {
- name: "成功删除媒体",
- mediaID: "1",
- setupMock: func() {
- s.service.EXPECT().
- DeleteMedia(gomock.Any(), 1, 1).
- Return(nil)
- },
- expectedStatus: http.StatusNoContent,
- },
- {
- name: "未授权",
- mediaID: "1",
- setupMock: func() {},
- expectedStatus: http.StatusUnauthorized,
- expectedError: "Unauthorized",
- },
- {
- name: "无效的媒体ID",
- mediaID: "invalid",
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedError: "Invalid media ID",
- },
- {
- name: "删除媒体失败",
- mediaID: "1",
- setupMock: func() {
- s.service.EXPECT().
- DeleteMedia(gomock.Any(), 1, 1).
- Return(errors.New("failed to delete"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedError: "Failed to delete media",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // 设置 mock
- tc.setupMock()
-
- // 创建请求
- req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = req
-
- // Extract ID from URL path
- parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
- if len(parts) >= 2 {
- c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
- }
-
- // 设置用户ID(除了未授权的测试用例)
- if tc.expectedError != "Unauthorized" {
- c.Set("user_id", 1)
- }
-
- // 执行请求
- s.handler.DeleteMedia(c)
-
- // 验证响应
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedError != "" {
- var response map[string]string
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedError, response["error"])
- }
- })
- }
-}
diff --git a/backend/internal/handler/post_handler_test.go b/backend/internal/handler/post_handler_test.go
deleted file mode 100644
index 76d58ed..0000000
--- a/backend/internal/handler/post_handler_test.go
+++ /dev/null
@@ -1,624 +0,0 @@
-package handler
-
-import (
- "bytes"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/service"
- "tss-rocks-be/internal/service/mock"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "errors"
- "strings"
-)
-
-type PostHandlerTestSuite struct {
- suite.Suite
- ctrl *gomock.Controller
- service *mock.MockService
- handler *Handler
- router *gin.Engine
-}
-
-func (s *PostHandlerTestSuite) SetupTest() {
- s.ctrl = gomock.NewController(s.T())
- s.service = mock.NewMockService(s.ctrl)
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- },
- }
- s.handler = NewHandler(cfg, s.service)
-
- // Setup Gin router
- gin.SetMode(gin.TestMode)
- s.router = gin.New()
-
- // Setup mock for GetTokenBlacklist
- tokenBlacklist := &service.TokenBlacklist{}
- s.service.EXPECT().
- GetTokenBlacklist().
- Return(tokenBlacklist).
- AnyTimes()
-
- s.handler.RegisterRoutes(s.router)
-}
-
-func (s *PostHandlerTestSuite) TearDownTest() {
- s.ctrl.Finish()
-}
-
-func TestPostHandlerSuite(t *testing.T) {
- suite.Run(t, new(PostHandlerTestSuite))
-}
-
-// Test cases for ListPosts
-func (s *PostHandlerTestSuite) TestListPosts() {
- categoryID := 1
- testCases := []struct {
- name string
- langCode string
- categoryID string
- limit string
- offset string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success with default language",
- langCode: "",
- setupMock: func() {
- s.service.EXPECT().
- ListPosts(gomock.Any(), "en", nil, 10, 0).
- Return([]*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with specific language",
- langCode: "zh",
- setupMock: func() {
- s.service.EXPECT().
- ListPosts(gomock.Any(), "zh", nil, 10, 0).
- Return([]*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "zh",
- Title: "测试帖子",
- ContentMarkdown: "测试内容",
- Summary: "测试摘要",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "zh",
- Title: "测试帖子",
- ContentMarkdown: "测试内容",
- Summary: "测试摘要",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with category filter",
- langCode: "en",
- categoryID: "1",
- setupMock: func() {
- s.service.EXPECT().
- ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
- Return([]*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Post{
- {
- ID: 1,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- },
- },
- },
- },
- },
- {
- name: "Success with pagination",
- langCode: "en",
- limit: "2",
- offset: "1",
- setupMock: func() {
- s.service.EXPECT().
- ListPosts(gomock.Any(), "en", nil, 2, 1).
- Return([]*ent.Post{
- {
- ID: 2,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post 2",
- ContentMarkdown: "Test Content 2",
- Summary: "Test Summary 2",
- },
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: []*ent.Post{
- {
- ID: 2,
- Status: "published",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post 2",
- ContentMarkdown: "Test Content 2",
- Summary: "Test Summary 2",
- },
- },
- },
- },
- },
- },
- {
- name: "Service Error",
- langCode: "en",
- setupMock: func() {
- s.service.EXPECT().
- ListPosts(gomock.Any(), "en", nil, 10, 0).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup mock
- tc.setupMock()
-
- // Create request
- url := "/api/v1/posts"
- if tc.langCode != "" {
- url += "?lang=" + tc.langCode
- }
- if tc.categoryID != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "category_id=" + tc.categoryID
- }
- if tc.limit != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "limit=" + tc.limit
- }
- if tc.offset != "" {
- if strings.Contains(url, "?") {
- url += "&"
- } else {
- url += "?"
- }
- url += "offset=" + tc.offset
- }
-
- req := httptest.NewRequest(http.MethodGet, url, nil)
- w := httptest.NewRecorder()
-
- // Perform request
- s.router.ServeHTTP(w, req)
-
- // Assert response
- s.Equal(tc.expectedStatus, w.Code)
- if tc.expectedBody != nil {
- var response []*ent.Post
- err := json.Unmarshal(w.Body.Bytes(), &response)
- s.NoError(err)
- s.Equal(tc.expectedBody, response)
- }
- })
- }
-}
-
-// Test cases for GetPost
-func (s *PostHandlerTestSuite) TestGetPost() {
- testCases := []struct {
- name string
- langCode string
- slug string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success with default language",
- langCode: "",
- slug: "test-post",
- setupMock: func() {
- s.service.EXPECT().
- GetPostBySlug(gomock.Any(), "en", "test-post").
- Return(&ent.Post{
- ID: 1,
- Status: "published",
- Slug: "test-post",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: gin.H{
- "id": 1,
- "status": "published",
- "slug": "test-post",
- "edges": gin.H{
- "contents": []gin.H{
- {
- "language_code": "en",
- "title": "Test Post",
- "content_markdown": "Test Content",
- "summary": "Test Summary",
- },
- },
- },
- },
- },
- {
- name: "Success with specific language",
- langCode: "zh",
- slug: "test-post",
- setupMock: func() {
- s.service.EXPECT().
- GetPostBySlug(gomock.Any(), "zh", "test-post").
- Return(&ent.Post{
- ID: 1,
- Status: "published",
- Slug: "test-post",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{
- {
- LanguageCode: "zh",
- Title: "测试帖子",
- ContentMarkdown: "测试内容",
- Summary: "测试摘要",
- },
- },
- },
- }, nil)
- },
- expectedStatus: http.StatusOK,
- expectedBody: gin.H{
- "id": 1,
- "status": "published",
- "slug": "test-post",
- "edges": gin.H{
- "contents": []gin.H{
- {
- "language_code": "zh",
- "title": "测试帖子",
- "content_markdown": "测试内容",
- "summary": "测试摘要",
- },
- },
- },
- },
- },
- {
- name: "Service error",
- slug: "test-post",
- setupMock: func() {
- s.service.EXPECT().
- GetPostBySlug(gomock.Any(), "en", "test-post").
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to get post"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- url := "/api/v1/posts/" + tc.slug
- if tc.langCode != "" {
- url += "?lang=" + tc.langCode
- }
-
- req := httptest.NewRequest(http.MethodGet, url, nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-// Test cases for CreatePost
-func (s *PostHandlerTestSuite) TestCreatePost() {
- testCases := []struct {
- name string
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- setupMock: func() {
- s.service.EXPECT().
- CreatePost(gomock.Any(), "draft").
- Return(&ent.Post{
- ID: 1,
- Status: "draft",
- Edges: ent.PostEdges{
- Contents: []*ent.PostContent{},
- },
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: gin.H{
- "id": 1,
- "status": "draft",
- "edges": gin.H{
- "contents": []gin.H{},
- },
- },
- },
- {
- name: "Service error",
- setupMock: func() {
- s.service.EXPECT().
- CreatePost(gomock.Any(), "draft").
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to create post"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
-
-// Test cases for AddPostContent
-func (s *PostHandlerTestSuite) TestAddPostContent() {
- testCases := []struct {
- name string
- postID string
- body interface{}
- setupMock func()
- expectedStatus int
- expectedBody interface{}
- }{
- {
- name: "Success",
- postID: "1",
- body: AddPostContentRequest{
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- MetaKeywords: "test,keywords",
- MetaDescription: "Test meta description",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddPostContent(
- gomock.Any(),
- 1,
- "en",
- "Test Post",
- "Test Content",
- "Test Summary",
- "test,keywords",
- "Test meta description",
- ).
- Return(&ent.PostContent{
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- MetaKeywords: "test,keywords",
- MetaDescription: "Test meta description",
- Edges: ent.PostContentEdges{},
- }, nil)
- },
- expectedStatus: http.StatusCreated,
- expectedBody: gin.H{
- "language_code": "en",
- "title": "Test Post",
- "content_markdown": "Test Content",
- "summary": "Test Summary",
- "meta_keywords": "test,keywords",
- "meta_description": "Test meta description",
- "edges": gin.H{},
- },
- },
- {
- name: "Invalid post ID",
- postID: "invalid",
- body: AddPostContentRequest{
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Invalid post ID"},
- },
- {
- name: "Invalid request body",
- postID: "1",
- body: map[string]interface{}{
- "language_code": "en",
- // Missing required fields
- },
- setupMock: func() {},
- expectedStatus: http.StatusBadRequest,
- expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
- },
- {
- name: "Service error",
- postID: "1",
- body: AddPostContentRequest{
- LanguageCode: "en",
- Title: "Test Post",
- ContentMarkdown: "Test Content",
- Summary: "Test Summary",
- MetaKeywords: "test,keywords",
- MetaDescription: "Test meta description",
- },
- setupMock: func() {
- s.service.EXPECT().
- AddPostContent(
- gomock.Any(),
- 1,
- "en",
- "Test Post",
- "Test Content",
- "Test Summary",
- "test,keywords",
- "Test meta description",
- ).
- Return(nil, errors.New("service error"))
- },
- expectedStatus: http.StatusInternalServerError,
- expectedBody: gin.H{"error": "Failed to add post content"},
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- tc.setupMock()
-
- body, err := json.Marshal(tc.body)
- s.NoError(err, "Failed to marshal request body")
-
- req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- w := httptest.NewRecorder()
- s.router.ServeHTTP(w, req)
-
- s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
-
- if tc.expectedBody != nil {
- expectedJSON, err := json.Marshal(tc.expectedBody)
- s.NoError(err, "Failed to marshal expected body")
- s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
- }
- })
- }
-}
diff --git a/backend/internal/middleware/accesslog_test.go b/backend/internal/middleware/accesslog_test.go
deleted file mode 100644
index 449513a..0000000
--- a/backend/internal/middleware/accesslog_test.go
+++ /dev/null
@@ -1,238 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "io"
- "net/http"
- "net/http/httptest"
- "os"
- "path/filepath"
- "testing"
- "time"
- "tss-rocks-be/internal/types"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-)
-
-func TestAccessLog(t *testing.T) {
- // 设置测试临时目录
- tmpDir := t.TempDir()
- logPath := filepath.Join(tmpDir, "test.log")
-
- testCases := []struct {
- name string
- config *types.AccessLogConfig
- expectedError bool
- setupRequest func(*http.Request)
- validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
- }{
- {
- name: "Console logging only",
- config: &types.AccessLogConfig{
- EnableConsole: true,
- EnableFile: false,
- Format: "json",
- Level: "info",
- },
- expectedError: false,
- setupRequest: func(req *http.Request) {
- req.Header.Set("User-Agent", "test-agent")
- },
- validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
- assert.Equal(t, http.StatusOK, w.Code)
- assert.Contains(t, logOutput, "GET /test")
- assert.Contains(t, logOutput, "test-agent")
- },
- },
- {
- name: "File logging only",
- config: &types.AccessLogConfig{
- EnableConsole: false,
- EnableFile: true,
- FilePath: logPath,
- Format: "json",
- Level: "info",
- Rotation: struct {
- MaxSize int `yaml:"max_size"`
- MaxAge int `yaml:"max_age"`
- MaxBackups int `yaml:"max_backups"`
- Compress bool `yaml:"compress"`
- LocalTime bool `yaml:"local_time"`
- }{
- MaxSize: 1,
- MaxAge: 1,
- MaxBackups: 1,
- Compress: false,
- LocalTime: true,
- },
- },
- expectedError: false,
- setupRequest: func(req *http.Request) {
- req.Header.Set("User-Agent", "test-agent")
- },
- validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
- assert.Equal(t, http.StatusOK, w.Code)
-
- // 读取日志文件内容
- content, err := os.ReadFile(logPath)
- assert.NoError(t, err)
- assert.Contains(t, string(content), "GET /test")
- assert.Contains(t, string(content), "test-agent")
- },
- },
- {
- name: "Both console and file logging",
- config: &types.AccessLogConfig{
- EnableConsole: true,
- EnableFile: true,
- FilePath: logPath,
- Format: "json",
- Level: "info",
- Rotation: struct {
- MaxSize int `yaml:"max_size"`
- MaxAge int `yaml:"max_age"`
- MaxBackups int `yaml:"max_backups"`
- Compress bool `yaml:"compress"`
- LocalTime bool `yaml:"local_time"`
- }{
- MaxSize: 1,
- MaxAge: 1,
- MaxBackups: 1,
- Compress: false,
- LocalTime: true,
- },
- },
- expectedError: false,
- setupRequest: func(req *http.Request) {
- req.Header.Set("User-Agent", "test-agent")
- },
- validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
- assert.Equal(t, http.StatusOK, w.Code)
- assert.Contains(t, logOutput, "GET /test")
- assert.Contains(t, logOutput, "test-agent")
-
- // 读取日志文件内容
- content, err := os.ReadFile(logPath)
- assert.NoError(t, err)
- assert.Contains(t, string(content), "GET /test")
- assert.Contains(t, string(content), "test-agent")
- },
- },
- {
- name: "With authenticated user",
- config: &types.AccessLogConfig{
- EnableConsole: true,
- EnableFile: false,
- Format: "json",
- Level: "info",
- },
- expectedError: false,
- setupRequest: func(req *http.Request) {
- req.Header.Set("User-Agent", "test-agent")
- },
- validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
- assert.Equal(t, http.StatusOK, w.Code)
- assert.Contains(t, logOutput, "GET /test")
- assert.Contains(t, logOutput, "test-agent")
- assert.Contains(t, logOutput, "test-user")
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // 捕获标准输出
- oldStdout := os.Stdout
- r, w, _ := os.Pipe()
- os.Stdout = w
-
- // 创建一个新的 gin 引擎
- gin.SetMode(gin.TestMode)
- router := gin.New()
-
- // 创建访问日志中间件
- middleware, err := AccessLog(tc.config)
- if tc.expectedError {
- assert.Error(t, err)
- return
- }
- assert.NoError(t, err)
-
- // 添加测试路由
- router.Use(middleware)
- router.GET("/test", func(c *gin.Context) {
- // 如果是测试认证用户的情况,设置用户ID
- if tc.name == "With authenticated user" {
- c.Set("user_id", "test-user")
- }
- c.Status(http.StatusOK)
- })
-
- // 创建测试请求
- req := httptest.NewRequest("GET", "/test", nil)
- if tc.setupRequest != nil {
- tc.setupRequest(req)
- }
- rec := httptest.NewRecorder()
-
- // 执行请求
- router.ServeHTTP(rec, req)
-
- // 恢复标准输出并获取输出内容
- w.Close()
- var buf bytes.Buffer
- io.Copy(&buf, r)
- os.Stdout = oldStdout
-
- // 验证输出
- if tc.validateOutput != nil {
- tc.validateOutput(t, rec, buf.String())
- }
-
- // 关闭日志文件
- if tc.config.EnableFile {
- // 调用中间件函数来关闭日志文件
- middleware(nil)
- // 等待一小段时间确保文件完全关闭
- time.Sleep(100 * time.Millisecond)
- }
- })
- }
-}
-
-func TestAccessLogInvalidConfig(t *testing.T) {
- testCases := []struct {
- name string
- config *types.AccessLogConfig
- expectedError bool
- }{
- {
- name: "Invalid log level",
- config: &types.AccessLogConfig{
- EnableConsole: true,
- Level: "invalid_level",
- },
- expectedError: false, // 应该使用默认的 info 级别
- },
- {
- name: "Invalid file path",
- config: &types.AccessLogConfig{
- EnableFile: true,
- FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
- },
- expectedError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- _, err := AccessLog(tc.config)
- if tc.expectedError {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- }
- })
- }
-}
diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go
deleted file mode 100644
index 78e2f03..0000000
--- a/backend/internal/middleware/auth_test.go
+++ /dev/null
@@ -1,227 +0,0 @@
-package middleware
-
-import (
- "encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/assert"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
-
- "tss-rocks-be/internal/service"
-)
-
-func createTestToken(secret string, claims jwt.MapClaims) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- signedToken, err := token.SignedString([]byte(secret))
- if err != nil {
- panic(fmt.Sprintf("Failed to sign token: %v", err))
- }
- return signedToken
-}
-
-func TestAuthMiddleware(t *testing.T) {
- jwtSecret := "test-secret"
- tokenBlacklist := service.NewTokenBlacklist()
-
- testCases := []struct {
- name string
- setupAuth func(*http.Request)
- expectedStatus int
- expectedBody map[string]string
- checkUserData bool
- expectedUserID string
- expectedRoles []string
- }{
- {
- name: "No Authorization header",
- setupAuth: func(req *http.Request) {},
- expectedStatus: http.StatusUnauthorized,
- expectedBody: map[string]string{"error": "Authorization header is required"},
- },
- {
- name: "Invalid Authorization format",
- setupAuth: func(req *http.Request) {
- req.Header.Set("Authorization", "InvalidFormat")
- },
- expectedStatus: http.StatusUnauthorized,
- expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
- },
- {
- name: "Invalid token",
- setupAuth: func(req *http.Request) {
- req.Header.Set("Authorization", "Bearer invalid.token.here")
- },
- expectedStatus: http.StatusUnauthorized,
- expectedBody: map[string]string{"error": "Invalid token"},
- },
- {
- name: "Valid token",
- setupAuth: func(req *http.Request) {
- claims := jwt.MapClaims{
- "sub": "123",
- "roles": []string{"admin", "editor"},
- "exp": time.Now().Add(time.Hour).Unix(),
- }
- token := createTestToken(jwtSecret, claims)
- req.Header.Set("Authorization", "Bearer "+token)
- },
- expectedStatus: http.StatusOK,
- checkUserData: true,
- expectedUserID: "123",
- expectedRoles: []string{"admin", "editor"},
- },
- {
- name: "Expired token",
- setupAuth: func(req *http.Request) {
- claims := jwt.MapClaims{
- "sub": "123",
- "roles": []string{"user"},
- "exp": time.Now().Add(-time.Hour).Unix(),
- }
- token := createTestToken(jwtSecret, claims)
- req.Header.Set("Authorization", "Bearer "+token)
- },
- expectedStatus: http.StatusUnauthorized,
- expectedBody: map[string]string{"error": "Invalid token"},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- gin.SetMode(gin.TestMode)
- router := gin.New()
-
- // 添加认证中间件
- router.Use(func(c *gin.Context) {
- // 设置日志级别为 debug
- gin.SetMode(gin.DebugMode)
- c.Next()
- }, AuthMiddleware(jwtSecret, tokenBlacklist))
-
- // 测试路由
- router.GET("/test", func(c *gin.Context) {
- if tc.checkUserData {
- userID, exists := c.Get("user_id")
- assert.True(t, exists, "user_id should exist in context")
- assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
-
- roles, exists := c.Get("user_roles")
- assert.True(t, exists, "user_roles should exist in context")
- assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
- }
- c.Status(http.StatusOK)
- })
-
- // 创建请求
- req := httptest.NewRequest("GET", "/test", nil)
- tc.setupAuth(req)
- rec := httptest.NewRecorder()
-
- // 执行请求
- router.ServeHTTP(rec, req)
-
- // 验证响应
- assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
-
- if tc.expectedBody != nil {
- var response map[string]string
- err := json.NewDecoder(rec.Body).Decode(&response)
- assert.NoError(t, err, "Response body should be valid JSON")
- assert.Equal(t, tc.expectedBody, response, "Response body should match")
- }
- })
- }
-}
-
-func TestRoleMiddleware(t *testing.T) {
- testCases := []struct {
- name string
- setupContext func(*gin.Context)
- allowedRoles []string
- expectedStatus int
- expectedBody map[string]string
- }{
- {
- name: "No user roles",
- setupContext: func(c *gin.Context) {
- // 不设置用户角色
- },
- allowedRoles: []string{"admin"},
- expectedStatus: http.StatusUnauthorized,
- expectedBody: map[string]string{"error": "User roles not found"},
- },
- {
- name: "Invalid roles type",
- setupContext: func(c *gin.Context) {
- c.Set("user_roles", 123) // 设置错误类型的角色
- },
- allowedRoles: []string{"admin"},
- expectedStatus: http.StatusInternalServerError,
- expectedBody: map[string]string{"error": "Invalid user roles type"},
- },
- {
- name: "Insufficient permissions",
- setupContext: func(c *gin.Context) {
- c.Set("user_roles", []string{"user"})
- },
- allowedRoles: []string{"admin"},
- expectedStatus: http.StatusForbidden,
- expectedBody: map[string]string{"error": "Insufficient permissions"},
- },
- {
- name: "Allowed role",
- setupContext: func(c *gin.Context) {
- c.Set("user_roles", []string{"admin"})
- },
- allowedRoles: []string{"admin"},
- expectedStatus: http.StatusOK,
- },
- {
- name: "One of multiple allowed roles",
- setupContext: func(c *gin.Context) {
- c.Set("user_roles", []string{"user", "editor"})
- },
- allowedRoles: []string{"admin", "editor", "moderator"},
- expectedStatus: http.StatusOK,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- gin.SetMode(gin.TestMode)
- router := gin.New()
-
- // 添加角色中间件
- router.Use(func(c *gin.Context) {
- tc.setupContext(c)
- c.Next()
- }, RoleMiddleware(tc.allowedRoles...))
-
- // 测试路由
- router.GET("/test", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
-
- // 创建请求
- req := httptest.NewRequest("GET", "/test", nil)
- rec := httptest.NewRecorder()
-
- // 执行请求
- router.ServeHTTP(rec, req)
-
- // 验证响应
- assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
-
- if tc.expectedBody != nil {
- var response map[string]string
- err := json.NewDecoder(rec.Body).Decode(&response)
- assert.NoError(t, err, "Response body should be valid JSON")
- assert.Equal(t, tc.expectedBody, response, "Response body should match")
- }
- })
- }
-}
diff --git a/backend/internal/middleware/cors_test.go b/backend/internal/middleware/cors_test.go
deleted file mode 100644
index bf187f9..0000000
--- a/backend/internal/middleware/cors_test.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package middleware
-
-import (
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-func TestCORS(t *testing.T) {
- testCases := []struct {
- name string
- method string
- expectedStatus int
- checkHeaders bool
- }{
- {
- name: "Normal GET request",
- method: "GET",
- expectedStatus: http.StatusOK,
- checkHeaders: true,
- },
- {
- name: "OPTIONS request",
- method: "OPTIONS",
- expectedStatus: http.StatusNoContent,
- checkHeaders: true,
- },
- {
- name: "POST request",
- method: "POST",
- expectedStatus: http.StatusOK,
- checkHeaders: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // 创建一个新的 gin 引擎
- gin.SetMode(gin.TestMode)
- router := gin.New()
-
- // 添加 CORS 中间件
- router.Use(CORS())
-
- // 添加测试路由
- router.Any("/test", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
-
- // 创建测试请求
- req := httptest.NewRequest(tc.method, "/test", nil)
- rec := httptest.NewRecorder()
-
- // 执行请求
- router.ServeHTTP(rec, req)
-
- // 验证状态码
- assert.Equal(t, tc.expectedStatus, rec.Code)
-
- if tc.checkHeaders {
- // 验证 CORS 头部
- headers := rec.Header()
- assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
- assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
- assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
- assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
- assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
- assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
- assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
- assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
- }
- })
- }
-}
diff --git a/backend/internal/middleware/ratelimit_test.go b/backend/internal/middleware/ratelimit_test.go
deleted file mode 100644
index 381f0af..0000000
--- a/backend/internal/middleware/ratelimit_test.go
+++ /dev/null
@@ -1,207 +0,0 @@
-package middleware
-
-import (
- "encoding/json"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "tss-rocks-be/internal/types"
-)
-
-func TestRateLimit(t *testing.T) {
- testCases := []struct {
- name string
- config *types.RateLimitConfig
- setupTest func(*gin.Engine)
- runTest func(*testing.T, *gin.Engine)
- expectedStatus int
- expectedBody map[string]string
- }{
- {
- name: "IP rate limit",
- config: &types.RateLimitConfig{
- IPRate: 1, // 每秒1个请求
- IPBurst: 1,
- },
- setupTest: func(router *gin.Engine) {
- router.GET("/test", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
- },
- runTest: func(t *testing.T, router *gin.Engine) {
- // 第一个请求应该成功
- req := httptest.NewRequest("GET", "/test", nil)
- req.RemoteAddr = "192.168.1.1:1234"
- rec := httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
-
- // 第二个请求应该被限制
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusTooManyRequests, rec.Code)
- var response map[string]string
- err := json.NewDecoder(rec.Body).Decode(&response)
- assert.NoError(t, err)
- assert.Equal(t, "too many requests from this IP", response["error"])
-
- // 等待限流器重置
- time.Sleep(time.Second)
-
- // 第三个请求应该成功
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
- },
- },
- {
- name: "Route rate limit",
- config: &types.RateLimitConfig{
- IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
- IPBurst: 10,
- RouteRates: map[string]struct {
- Rate int `yaml:"rate"`
- Burst int `yaml:"burst"`
- }{
- "/limited": {
- Rate: 1,
- Burst: 1,
- },
- },
- },
- setupTest: func(router *gin.Engine) {
- router.GET("/limited", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
- router.GET("/unlimited", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
- },
- runTest: func(t *testing.T, router *gin.Engine) {
- // 测试限流路由
- req := httptest.NewRequest("GET", "/limited", nil)
- req.RemoteAddr = "192.168.1.2:1234"
- rec := httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
-
- // 等待一小段时间确保限流器生效
- time.Sleep(10 * time.Millisecond)
-
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusTooManyRequests, rec.Code)
- var response map[string]string
- err := json.NewDecoder(rec.Body).Decode(&response)
- assert.NoError(t, err)
- assert.Equal(t, "too many requests for this route", response["error"])
-
- // 测试未限流路由
- req = httptest.NewRequest("GET", "/unlimited", nil)
- req.RemoteAddr = "192.168.1.2:1234"
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
-
- // 等待一小段时间确保限流器生效
- time.Sleep(10 * time.Millisecond)
-
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
- },
- },
- {
- name: "Multiple IPs",
- config: &types.RateLimitConfig{
- IPRate: 1,
- IPBurst: 1,
- },
- setupTest: func(router *gin.Engine) {
- router.GET("/test", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
- },
- runTest: func(t *testing.T, router *gin.Engine) {
- // IP1 的请求
- req1 := httptest.NewRequest("GET", "/test", nil)
- req1.RemoteAddr = "192.168.1.3:1234"
- rec := httptest.NewRecorder()
- router.ServeHTTP(rec, req1)
- assert.Equal(t, http.StatusOK, rec.Code)
-
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req1)
- assert.Equal(t, http.StatusTooManyRequests, rec.Code)
-
- // IP2 的请求应该不受 IP1 的限制影响
- req2 := httptest.NewRequest("GET", "/test", nil)
- req2.RemoteAddr = "192.168.1.4:1234"
- rec = httptest.NewRecorder()
- router.ServeHTTP(rec, req2)
- assert.Equal(t, http.StatusOK, rec.Code)
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- gin.SetMode(gin.TestMode)
- router := gin.New()
-
- // 添加限流中间件
- router.Use(RateLimit(tc.config))
-
- // 设置测试路由
- tc.setupTest(router)
-
- // 运行测试
- tc.runTest(t, router)
- })
- }
-}
-
-func TestRateLimiterCleanup(t *testing.T) {
- config := &types.RateLimitConfig{
- IPRate: 1,
- IPBurst: 1,
- }
-
- rl := newRateLimiter(config)
-
- // 添加一些IP限流器
- ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
- for _, ip := range ips {
- rl.getLimiter(ip)
- }
-
- // 验证IP限流器已创建
- rl.mu.RLock()
- assert.Equal(t, len(ips), len(rl.ips))
- rl.mu.RUnlock()
-
- // 修改一些IP的最后访问时间为1小时前
- rl.mu.Lock()
- rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
- rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
- rl.mu.Unlock()
-
- // 手动触发清理
- rl.mu.Lock()
- for ip, limiter := range rl.ips {
- if time.Since(limiter.lastSeen) > time.Hour {
- delete(rl.ips, ip)
- }
- }
- rl.mu.Unlock()
-
- // 验证过期的IP限流器已被删除
- rl.mu.RLock()
- assert.Equal(t, 1, len(rl.ips))
- _, exists := rl.ips["192.168.1.3"]
- assert.True(t, exists)
- rl.mu.RUnlock()
-}
diff --git a/backend/internal/middleware/upload.go b/backend/internal/middleware/upload.go
index 91f9e53..c5a3756 100644
--- a/backend/internal/middleware/upload.go
+++ b/backend/internal/middleware/upload.go
@@ -3,146 +3,120 @@ package middleware
import (
"bytes"
"fmt"
- "io"
"net/http"
"path/filepath"
"strings"
+ "tss-rocks-be/internal/config"
+
"github.com/gin-gonic/gin"
- "tss-rocks-be/internal/types"
)
-const (
- defaultMaxMemory = 32 << 20 // 32 MB
- maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
-)
-
-// ValidateUpload 创建文件上传验证中间件
-func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
+// ValidateUpload 验证上传的文件
+func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc {
return func(c *gin.Context) {
- // 检查是否是multipart/form-data请求
- if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": "Content-Type must be multipart/form-data",
- })
+ // Get file from form
+ file, err := c.FormFile("file")
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
c.Abort()
return
}
- // 解析multipart表单
- if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": fmt.Sprintf("Failed to parse form: %v", err),
- })
- c.Abort()
- return
- }
+ // 获取文件类型和扩展名
+ contentType := file.Header.Get("Content-Type")
+ ext := strings.ToLower(filepath.Ext(file.Filename))
- form := c.Request.MultipartForm
- if form == nil || form.File == nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": "No file uploaded",
- })
- c.Abort()
- return
- }
-
- // 遍历所有上传的文件
- for _, files := range form.File {
- for _, file := range files {
- // 检查文件大小
- if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
- c.JSON(http.StatusBadRequest, gin.H{
- "error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
- })
- c.Abort()
- return
- }
-
- // 检查文件扩展名
- ext := strings.ToLower(filepath.Ext(file.Filename))
- if !contains(cfg.AllowedExtensions, ext) {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": fmt.Sprintf("File extension %s is not allowed", ext),
- })
- c.Abort()
- return
- }
-
- // 打开文件
- src, err := file.Open()
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": fmt.Sprintf("Failed to open file: %v", err),
- })
- c.Abort()
- return
- }
- defer src.Close()
-
- // 读取文件头部用于MIME类型检测
- header := make([]byte, maxHeaderBytes)
- n, err := src.Read(header)
- if err != nil && err != io.EOF {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": fmt.Sprintf("Failed to read file: %v", err),
- })
- c.Abort()
- return
- }
- header = header[:n]
-
- // 检测MIME类型
- contentType := http.DetectContentType(header)
- if !contains(cfg.AllowedTypes, contentType) {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": fmt.Sprintf("File type %s is not allowed", contentType),
- })
- c.Abort()
- return
- }
-
- // 将文件指针重置到开始位置
- _, err = src.Seek(0, 0)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": fmt.Sprintf("Failed to read file: %v", err),
- })
- c.Abort()
- return
- }
-
- // 将文件内容读入缓冲区
- buf := &bytes.Buffer{}
- _, err = io.Copy(buf, src)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": fmt.Sprintf("Failed to read file: %v", err),
- })
- c.Abort()
- return
- }
-
- // 将验证过的文件内容和类型保存到上下文中
- c.Set("validated_file_"+file.Filename, buf)
- c.Set("validated_content_type_"+file.Filename, contentType)
+ // 如果 Content-Type 为空,尝试从文件扩展名判断
+ if contentType == "" {
+ switch ext {
+ case ".jpg", ".jpeg":
+ contentType = "image/jpeg"
+ case ".png":
+ contentType = "image/png"
+ case ".gif":
+ contentType = "image/gif"
+ case ".webp":
+ contentType = "image/webp"
+ case ".mp4":
+ contentType = "video/mp4"
+ case ".webm":
+ contentType = "video/webm"
+ case ".mp3":
+ contentType = "audio/mpeg"
+ case ".ogg":
+ contentType = "audio/ogg"
+ case ".wav":
+ contentType = "audio/wav"
+ case ".pdf":
+ contentType = "application/pdf"
+ case ".doc":
+ contentType = "application/msword"
+ case ".docx":
+ contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
}
+ // 根据 Content-Type 确定文件类型和限制
+ var maxSize int64
+ var allowedTypes []string
+ var fileType string
+
+ limits := cfg.Limits
+ switch {
+ case strings.HasPrefix(contentType, "image/"):
+ maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Image.AllowedTypes
+ fileType = "image"
+ case strings.HasPrefix(contentType, "video/"):
+ maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Video.AllowedTypes
+ fileType = "video"
+ case strings.HasPrefix(contentType, "audio/"):
+ maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Audio.AllowedTypes
+ fileType = "audio"
+ case strings.HasPrefix(contentType, "application/"):
+ maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
+ allowedTypes = limits.Document.AllowedTypes
+ fileType = "document"
+ default:
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": fmt.Sprintf("Unsupported file type: %s", contentType),
+ })
+ c.Abort()
+ return
+ }
+
+ // 检查文件类型是否允许
+ typeAllowed := false
+ for _, allowed := range allowedTypes {
+ if contentType == allowed {
+ typeAllowed = true
+ break
+ }
+ }
+ if !typeAllowed {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
+ })
+ c.Abort()
+ return
+ }
+
+ // 检查文件大小
+ if file.Size > maxSize {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
+ })
+ c.Abort()
+ return
+ }
+
c.Next()
}
}
-// contains 检查切片中是否包含指定的字符串
-func contains(slice []string, str string) bool {
- for _, s := range slice {
- if s == str {
- return true
- }
- }
- return false
-}
-
// GetValidatedFile 从上下文中获取验证过的文件内容
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
file, exists := c.Get("validated_file_" + filename)
diff --git a/backend/internal/middleware/upload_test.go b/backend/internal/middleware/upload_test.go
deleted file mode 100644
index 7434484..0000000
--- a/backend/internal/middleware/upload_test.go
+++ /dev/null
@@ -1,262 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "mime/multipart"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "tss-rocks-be/internal/types"
-
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-)
-
-func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
- body := &bytes.Buffer{}
- writer := multipart.NewWriter(body)
-
- part, err := writer.CreateFormFile("file", filename)
- if err != nil {
- return nil, err
- }
-
- _, err = io.Copy(part, bytes.NewReader(content))
- if err != nil {
- return nil, err
- }
-
- err = writer.Close()
- if err != nil {
- return nil, err
- }
-
- req := httptest.NewRequest("POST", "/upload", body)
- req.Header.Set("Content-Type", writer.FormDataContentType())
- return req, nil
-}
-
-func TestValidateUpload(t *testing.T) {
- tests := []struct {
- name string
- config *types.UploadConfig
- filename string
- content []byte
- setupRequest func(*testing.T) *http.Request
- expectedStatus int
- expectedError string
- }{
- {
- name: "Valid image upload",
- config: &types.UploadConfig{
- MaxSize: 5, // 5MB
- AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
- AllowedTypes: []string{"image/jpeg", "image/png"},
- },
- filename: "test.jpg",
- content: []byte{
- 0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
- 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
- },
- expectedStatus: http.StatusOK,
- },
- {
- name: "Invalid file extension",
- config: &types.UploadConfig{
- MaxSize: 5,
- AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
- AllowedTypes: []string{"image/jpeg", "image/png"},
- },
- filename: "test.txt",
- content: []byte("test content"),
- expectedStatus: http.StatusBadRequest,
- expectedError: "File extension .txt is not allowed",
- },
- {
- name: "File too large",
- config: &types.UploadConfig{
- MaxSize: 1, // 1MB
- AllowedExtensions: []string{".jpg"},
- AllowedTypes: []string{"image/jpeg"},
- },
- filename: "large.jpg",
- content: make([]byte, 2<<20), // 2MB
- expectedStatus: http.StatusBadRequest,
- expectedError: "File large.jpg exceeds maximum size of 1 MB",
- },
- {
- name: "Invalid content type",
- config: &types.UploadConfig{
- MaxSize: 5,
- AllowedExtensions: []string{".jpg"},
- AllowedTypes: []string{"image/jpeg"},
- },
- filename: "fake.jpg",
- content: []byte("not a real image"),
- expectedStatus: http.StatusBadRequest,
- expectedError: "File type text/plain; charset=utf-8 is not allowed",
- },
- {
- name: "Missing file",
- config: &types.UploadConfig{
- MaxSize: 5,
- AllowedExtensions: []string{".jpg"},
- AllowedTypes: []string{"image/jpeg"},
- },
- setupRequest: func(t *testing.T) *http.Request {
- req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
- req.Header.Set("Content-Type", "multipart/form-data")
- return req
- },
- expectedStatus: http.StatusBadRequest,
- expectedError: "Failed to parse form",
- },
- {
- name: "Invalid content type header",
- config: &types.UploadConfig{
- MaxSize: 5,
- AllowedExtensions: []string{".jpg"},
- AllowedTypes: []string{"image/jpeg"},
- },
- setupRequest: func(t *testing.T) *http.Request {
- return httptest.NewRequest("POST", "/upload", nil)
- },
- expectedStatus: http.StatusBadRequest,
- expectedError: "Content-Type must be multipart/form-data",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- gin.SetMode(gin.TestMode)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
-
- var req *http.Request
- var err error
-
- if tt.setupRequest != nil {
- req = tt.setupRequest(t)
- } else {
- req, err = createMultipartRequest(t, tt.filename, tt.content, "")
- if err != nil {
- t.Fatalf("Failed to create request: %v", err)
- }
- }
-
- c.Request = req
-
- middleware := ValidateUpload(tt.config)
- middleware(c)
-
- assert.Equal(t, tt.expectedStatus, w.Code)
- if tt.expectedError != "" {
- var response map[string]string
- err := json.NewDecoder(w.Body).Decode(&response)
- assert.NoError(t, err)
- assert.Contains(t, response["error"], tt.expectedError)
- }
- })
- }
-}
-
-func TestGetValidatedFile(t *testing.T) {
- tests := []struct {
- name string
- setupContext func(*gin.Context)
- filename string
- expectedFound bool
- expectedError string
- }{
- {
- name: "Get existing file",
- setupContext: func(c *gin.Context) {
- // 创建测试文件内容
- content := []byte("test content")
- buf := bytes.NewBuffer(content)
-
- // 设置验证过的文件和内容类型
- c.Set("validated_file_test.txt", buf)
- c.Set("validated_content_type_test.txt", "text/plain")
- },
- filename: "test.txt",
- expectedFound: true,
- },
- {
- name: "File not found",
- setupContext: func(c *gin.Context) {
- // 不设置任何文件
- },
- filename: "nonexistent.txt",
- expectedFound: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- gin.SetMode(gin.TestMode)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
-
- if tt.setupContext != nil {
- tt.setupContext(c)
- }
-
- buffer, contentType, found := GetValidatedFile(c, tt.filename)
-
- assert.Equal(t, tt.expectedFound, found)
- if tt.expectedFound {
- assert.NotNil(t, buffer)
- assert.NotEmpty(t, contentType)
- } else {
- assert.Nil(t, buffer)
- assert.Empty(t, contentType)
- }
- })
- }
-}
-
-func TestContains(t *testing.T) {
- tests := []struct {
- name string
- slice []string
- str string
- expected bool
- }{
- {
- name: "String found in slice",
- slice: []string{"a", "b", "c"},
- str: "b",
- expected: true,
- },
- {
- name: "String not found in slice",
- slice: []string{"a", "b", "c"},
- str: "d",
- expected: false,
- },
- {
- name: "Empty slice",
- slice: []string{},
- str: "a",
- expected: false,
- },
- {
- name: "Empty string",
- slice: []string{"a", "b", "c"},
- str: "",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := contains(tt.slice, tt.str)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
diff --git a/backend/internal/rbac/init_test.go b/backend/internal/rbac/init_test.go
deleted file mode 100644
index 2da0158..0000000
--- a/backend/internal/rbac/init_test.go
+++ /dev/null
@@ -1,99 +0,0 @@
-package rbac
-
-import (
- "context"
- "testing"
-
- "tss-rocks-be/ent/enttest"
- "tss-rocks-be/ent/role"
-
- _ "github.com/mattn/go-sqlite3"
-)
-
-func TestInitializeRBAC(t *testing.T) {
- // Create an in-memory SQLite client for testing
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- ctx := context.Background()
-
- // Test initialization
- err := InitializeRBAC(ctx, client)
- if err != nil {
- t.Fatalf("Failed to initialize RBAC: %v", err)
- }
-
- // Verify roles were created
- for roleName := range DefaultRoles {
- r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
- if err != nil {
- t.Errorf("Role %s was not created: %v", roleName, err)
- }
-
- // Verify permissions for each role
- perms, err := r.QueryPermissions().All(ctx)
- if err != nil {
- t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
- }
-
- expectedPerms := DefaultRoles[roleName]
- permCount := 0
- for _, actions := range expectedPerms {
- permCount += len(actions)
- }
-
- if len(perms) != permCount {
- t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
- }
- }
-}
-
-func TestAssignRoleToUser(t *testing.T) {
- // Create an in-memory SQLite client for testing
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- ctx := context.Background()
-
- // Initialize RBAC
- err := InitializeRBAC(ctx, client)
- if err != nil {
- t.Fatalf("Failed to initialize RBAC: %v", err)
- }
-
- // Create a test user
- user, err := client.User.Create().
- SetEmail("test@example.com").
- SetUsername("testuser").
- SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
- Save(ctx)
- if err != nil {
- t.Fatalf("Failed to create test user: %v", err)
- }
-
- // Test assigning role to user
- err = AssignRoleToUser(ctx, client, user.ID, "editor")
- if err != nil {
- t.Fatalf("Failed to assign role to user: %v", err)
- }
-
- // Verify role assignment
- assignedRoles, err := user.QueryRoles().All(ctx)
- if err != nil {
- t.Fatalf("Failed to query user roles: %v", err)
- }
-
- if len(assignedRoles) != 1 {
- t.Errorf("Expected 1 role, got %d", len(assignedRoles))
- }
-
- if assignedRoles[0].Name != "editor" {
- t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
- }
-
- // Test assigning non-existent role
- err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
- if err == nil {
- t.Error("Expected error when assigning non-existent role, got nil")
- }
-}
diff --git a/backend/internal/server/database_test.go b/backend/internal/server/database_test.go
deleted file mode 100644
index c8e35da..0000000
--- a/backend/internal/server/database_test.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package server
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestInitDatabase(t *testing.T) {
- tests := []struct {
- name string
- driver string
- dsn string
- wantErr bool
- errContains string
- }{
- {
- name: "success with sqlite3",
- driver: "sqlite3",
- dsn: "file:ent?mode=memory&cache=shared&_fk=1",
- },
- {
- name: "invalid driver",
- driver: "invalid_driver",
- dsn: "file:ent?mode=memory",
- wantErr: true,
- errContains: "unsupported driver",
- },
- {
- name: "invalid dsn",
- driver: "sqlite3",
- dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
- wantErr: true,
- errContains: "foreign_keys pragma is off",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- client, err := InitDatabase(ctx, tt.driver, tt.dsn)
-
- if tt.wantErr {
- assert.Error(t, err)
- if tt.errContains != "" {
- assert.Contains(t, err.Error(), tt.errContains)
- }
- assert.Nil(t, client)
- } else {
- require.NoError(t, err)
- assert.NotNil(t, client)
-
- // 测试数据库连接是否正常工作
- err = client.Schema.Create(ctx)
- assert.NoError(t, err)
-
- // 清理
- client.Close()
- }
- })
- }
-}
diff --git a/backend/internal/server/ent_test.go b/backend/internal/server/ent_test.go
deleted file mode 100644
index 1f71881..0000000
--- a/backend/internal/server/ent_test.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package server
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "tss-rocks-be/internal/config"
-)
-
-func TestNewEntClient(t *testing.T) {
- tests := []struct {
- name string
- cfg *config.Config
- }{
- {
- name: "default sqlite3 config",
- cfg: &config.Config{
- Database: config.DatabaseConfig{
- Driver: "sqlite3",
- DSN: "file:ent?mode=memory&cache=shared&_fk=1",
- },
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- client := NewEntClient(tt.cfg)
- assert.NotNil(t, client)
-
- // 验证客户端是否可以正常工作
- err := client.Schema.Create(context.Background())
- assert.NoError(t, err)
-
- // 清理
- client.Close()
- })
- }
-}
diff --git a/backend/internal/server/server_test.go b/backend/internal/server/server_test.go
deleted file mode 100644
index 41d552e..0000000
--- a/backend/internal/server/server_test.go
+++ /dev/null
@@ -1,220 +0,0 @@
-package server
-
-import (
- "context"
- "net/http"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "tss-rocks-be/internal/config"
- "tss-rocks-be/internal/types"
- "tss-rocks-be/ent/enttest"
-)
-
-func TestNew(t *testing.T) {
- // 创建测试配置
- cfg := &config.Config{
- Server: config.ServerConfig{
- Host: "localhost",
- Port: 8080,
- },
- Storage: config.StorageConfig{
- Type: "local",
- Local: config.LocalStorage{
- RootDir: "testdata",
- },
- Upload: types.UploadConfig{
- MaxSize: 10,
- AllowedTypes: []string{"image/jpeg", "image/png"},
- AllowedExtensions: []string{".jpg", ".png"},
- },
- },
- RateLimit: types.RateLimitConfig{
- IPRate: 100,
- IPBurst: 200,
- RouteRates: map[string]struct {
- Rate int `yaml:"rate"`
- Burst int `yaml:"burst"`
- }{
- "/api/v1/upload": {Rate: 10, Burst: 20},
- },
- },
- AccessLog: types.AccessLogConfig{
- EnableConsole: true,
- EnableFile: true,
- FilePath: "testdata/access.log",
- Format: "json",
- Level: "info",
- Rotation: struct {
- MaxSize int `yaml:"max_size"`
- MaxAge int `yaml:"max_age"`
- MaxBackups int `yaml:"max_backups"`
- Compress bool `yaml:"compress"`
- LocalTime bool `yaml:"local_time"`
- }{
- MaxSize: 100,
- MaxAge: 7,
- MaxBackups: 3,
- Compress: true,
- LocalTime: true,
- },
- },
- }
-
- // 创建测试数据库客户端
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- // 测试服务器初始化
- s, err := New(cfg, client)
- require.NoError(t, err)
- assert.NotNil(t, s)
- assert.NotNil(t, s.router)
- assert.NotNil(t, s.handler)
- assert.Equal(t, cfg, s.config)
-}
-
-func TestNew_StorageError(t *testing.T) {
- // 创建一个无效的存储配置
- cfg := &config.Config{
- Storage: config.StorageConfig{
- Type: "invalid_type", // 使用无效的存储类型
- },
- }
-
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- s, err := New(cfg, client)
- assert.Error(t, err)
- assert.Nil(t, s)
- assert.Contains(t, err.Error(), "failed to initialize storage")
-}
-
-func TestServer_StartAndShutdown(t *testing.T) {
- // 创建测试配置
- cfg := &config.Config{
- Server: config.ServerConfig{
- Host: "localhost",
- Port: 0, // 使用随机端口
- },
- Storage: config.StorageConfig{
- Type: "local",
- Local: config.LocalStorage{
- RootDir: "testdata",
- },
- },
- RateLimit: types.RateLimitConfig{
- IPRate: 100,
- IPBurst: 200,
- },
- AccessLog: types.AccessLogConfig{
- EnableConsole: true,
- Format: "json",
- Level: "info",
- },
- }
-
- // 创建测试数据库客户端
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- // 初始化服务器
- s, err := New(cfg, client)
- require.NoError(t, err)
-
- // 创建一个通道来接收服务器错误
- errChan := make(chan error, 1)
-
- // 在 goroutine 中启动服务器
- go func() {
- err := s.Start()
- if err != nil && err != http.ErrServerClosed {
- errChan <- err
- }
- close(errChan)
- }()
-
- // 给服务器一些时间启动
- time.Sleep(100 * time.Millisecond)
-
- // 测试关闭服务器
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- err = s.Shutdown(ctx)
- assert.NoError(t, err)
-
- // 检查服务器是否有错误发生
- err = <-errChan
- assert.NoError(t, err)
-}
-
-func TestServer_StartError(t *testing.T) {
- // 创建一个配置,使用已经被占用的端口来触发错误
- cfg := &config.Config{
- Server: config.ServerConfig{
- Host: "localhost",
- Port: 8899, // 使用固定端口以便测试
- },
- Storage: config.StorageConfig{
- Type: "local",
- Local: config.LocalStorage{
- RootDir: "testdata",
- },
- },
- }
-
- client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
- defer client.Close()
-
- // 创建第一个服务器实例
- s1, err := New(cfg, client)
- require.NoError(t, err)
-
- // 创建一个通道来接收服务器错误
- errChan := make(chan error, 1)
-
- // 启动第一个服务器
- go func() {
- err := s1.Start()
- if err != nil && err != http.ErrServerClosed {
- errChan <- err
- }
- close(errChan)
- }()
-
- // 给服务器一些时间启动
- time.Sleep(100 * time.Millisecond)
-
- // 尝试在同一端口启动第二个服务器,应该会失败
- s2, err := New(cfg, client)
- require.NoError(t, err)
-
- err = s2.Start()
- assert.Error(t, err)
-
- // 清理
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- // 关闭第一个服务器
- err = s1.Shutdown(ctx)
- assert.NoError(t, err)
-
- // 检查第一个服务器是否有错误发生
- err = <-errChan
- assert.NoError(t, err)
-
- // 关闭第二个服务器
- err = s2.Shutdown(ctx)
- assert.NoError(t, err)
-}
-
-func TestServer_ShutdownWithNilServer(t *testing.T) {
- s := &Server{}
- err := s.Shutdown(context.Background())
- assert.NoError(t, err)
-}
diff --git a/backend/internal/service/impl.go b/backend/internal/service/impl.go
index 216e348..98203ee 100644
--- a/backend/internal/service/impl.go
+++ b/backend/internal/service/impl.go
@@ -1,11 +1,14 @@
package service
import (
+ "bytes"
"context"
"errors"
"fmt"
"io"
"mime/multipart"
+ "os"
+ "path/filepath"
"sort"
"strconv"
"strings"
@@ -26,6 +29,8 @@ import (
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
+ "github.com/chai2010/webp"
+ "github.com/disintegration/imaging"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@@ -419,21 +424,71 @@ func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, us
// Open the uploaded file
src, err := openFile(file)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to open file: %v", err)
}
defer src.Close()
+ // 获取文件类型和扩展名
+ contentType := file.Header.Get("Content-Type")
+ ext := strings.ToLower(filepath.Ext(file.Filename))
+ if contentType == "" {
+ // 如果 Content-Type 为空,尝试从文件扩展名判断
+ switch ext {
+ case ".jpg", ".jpeg":
+ contentType = "image/jpeg"
+ case ".png":
+ contentType = "image/png"
+ case ".gif":
+ contentType = "image/gif"
+ case ".webp":
+ contentType = "image/webp"
+ case ".mp4":
+ contentType = "video/mp4"
+ case ".webm":
+ contentType = "video/webm"
+ case ".mp3":
+ contentType = "audio/mpeg"
+ case ".ogg":
+ contentType = "audio/ogg"
+ case ".wav":
+ contentType = "audio/wav"
+ case ".pdf":
+ contentType = "application/pdf"
+ case ".doc":
+ contentType = "application/msword"
+ case ".docx":
+ contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
+ }
+ }
+
+ // 如果是图片,检查是否需要转换为 WebP
+ var fileToSave multipart.File = src
+ var finalContentType = contentType
+ if strings.HasPrefix(contentType, "image/") && contentType != "image/webp" {
+ // 转换为 WebP
+ webpFile, err := convertToWebP(src)
+ if err != nil {
+ return nil, fmt.Errorf("failed to convert image to WebP: %v", err)
+ }
+ fileToSave = webpFile
+ finalContentType = "image/webp"
+ ext = ".webp"
+ }
+
+ // 生成带扩展名的存储文件名
+ storageFilename := uuid.New().String() + ext
+
// Save the file to storage
- fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
+ fileInfo, err := s.storage.Save(ctx, storageFilename, finalContentType, fileToSave)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to save file: %v", err)
}
// Create media record
return s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(file.Filename).
- SetMimeType(fileInfo.ContentType).
+ SetMimeType(finalContentType).
SetSize(fileInfo.Size).
SetURL(fileInfo.URL).
SetCreatedBy(strconv.Itoa(userID)).
@@ -444,13 +499,8 @@ func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error)
return s.client.Media.Get(ctx, id)
}
-func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
- media, err := s.GetMedia(ctx, id)
- if err != nil {
- return nil, nil, err
- }
-
- return s.storage.Get(ctx, media.StorageID)
+func (s *serviceImpl) GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) {
+ return s.storage.Get(ctx, storageID)
}
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
@@ -476,7 +526,7 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error
}
// Post operations
-func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
+func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) {
var postStatus post.Status
switch status {
case "draft":
@@ -492,10 +542,25 @@ func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post,
// Generate a random slug
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
- return s.client.Post.Create().
+ // Create post with categories
+ postCreate := s.client.Post.Create().
SetStatus(postStatus).
- SetSlug(slug).
- Save(ctx)
+ SetSlug(slug)
+
+ // Add categories if provided
+ if len(categoryIDs) > 0 {
+ categories := make([]*ent.Category, 0, len(categoryIDs))
+ for _, id := range categoryIDs {
+ category, err := s.client.Category.Get(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get category %d: %w", id, err)
+ }
+ categories = append(categories, category)
+ }
+ postCreate.AddCategories(categories...)
+ }
+
+ return postCreate.Save(ctx)
}
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
@@ -574,7 +639,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
- WithCategory().
+ WithCategories().
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
@@ -590,7 +655,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
return posts[0], nil
}
-func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
+func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
@@ -610,8 +675,8 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
Where(post.StatusEQ(post.StatusPublished))
// Add category filter if provided
- if categoryID != nil {
- query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
+ if len(categoryIDs) > 0 {
+ query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...)))
}
// Get unique post IDs
@@ -638,7 +703,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
}
// If no category filter is applied, only take the latest 5 posts
- if categoryID == nil && len(postIDs) > 5 {
+ if len(categoryIDs) == 0 && len(postIDs) > 5 {
postIDs = postIDs[:5]
}
@@ -666,7 +731,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
- WithCategory().
+ WithCategories().
Order(ent.Desc(post.FieldCreatedAt)).
All(ctx)
if err != nil {
@@ -1050,3 +1115,47 @@ func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID
return s.client.Daily.DeleteOneID(id).Exec(ctx)
}
+
+// convertToWebP 将图片转换为 WebP 格式
+func convertToWebP(src multipart.File) (multipart.File, error) {
+ // 读取原始图片
+ img, err := imaging.Decode(src)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode image: %v", err)
+ }
+
+ // 创建一个新的缓冲区来存储 WebP 图片
+ buf := new(bytes.Buffer)
+
+ // 将图片编码为 WebP 格式
+ // 设置较高的质量以保持图片质量
+ err = webp.Encode(buf, img, &webp.Options{
+ Lossless: false,
+ Quality: 90,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to encode image to WebP: %v", err)
+ }
+
+ // 创建一个新的临时文件来存储转换后的图片
+ tmpFile, err := os.CreateTemp("", "webp-*.webp")
+ if err != nil {
+ return nil, fmt.Errorf("failed to create temp file: %v", err)
+ }
+
+ // 写入转换后的数据
+ if _, err := io.Copy(tmpFile, buf); err != nil {
+ tmpFile.Close()
+ os.Remove(tmpFile.Name())
+ return nil, fmt.Errorf("failed to write WebP data: %v", err)
+ }
+
+ // 将文件指针移回开始位置
+ if _, err := tmpFile.Seek(0, 0); err != nil {
+ tmpFile.Close()
+ os.Remove(tmpFile.Name())
+ return nil, fmt.Errorf("failed to seek file: %v", err)
+ }
+
+ return tmpFile, nil
+}
diff --git a/backend/internal/service/impl_test.go b/backend/internal/service/impl_test.go
deleted file mode 100644
index abba3cb..0000000
--- a/backend/internal/service/impl_test.go
+++ /dev/null
@@ -1,1090 +0,0 @@
-package service
-
-import (
- "bytes"
- "context"
- "fmt"
- "io"
- "mime/multipart"
- "net/textproto"
- "strconv"
- "strings"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "tss-rocks-be/ent"
- "tss-rocks-be/ent/categorycontent"
- "tss-rocks-be/ent/dailycontent"
- "tss-rocks-be/internal/storage"
- "tss-rocks-be/internal/storage/mock"
- "tss-rocks-be/internal/testutil"
-)
-
-type ServiceImplTestSuite struct {
- suite.Suite
- ctx context.Context
- client *ent.Client
- storage *mock.MockStorage
- ctrl *gomock.Controller
- svc Service
-}
-
-func (s *ServiceImplTestSuite) SetupTest() {
- s.ctx = context.Background()
- s.client = testutil.NewTestClient()
- require.NotNil(s.T(), s.client)
-
- s.ctrl = gomock.NewController(s.T())
- s.storage = mock.NewMockStorage(s.ctrl)
- s.svc = NewService(s.client, s.storage)
-
- // 清理数据库
- _, err := s.client.Category.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.CategoryContent.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.User.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.Role.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.Permission.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.Daily.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
- _, err = s.client.DailyContent.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
-
- // 初始化 RBAC 系统
- err = s.svc.InitializeRBAC(s.ctx)
- require.NoError(s.T(), err)
-
- // Set default openFile function
- openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
- return fh.Open()
- }
-}
-
-func (s *ServiceImplTestSuite) TearDownTest() {
- s.ctrl.Finish()
- s.client.Close()
-}
-
-func TestServiceImplSuite(t *testing.T) {
- suite.Run(t, new(ServiceImplTestSuite))
-}
-
-// mockMultipartFile implements multipart.File interface
-type mockMultipartFile struct {
- *bytes.Reader
-}
-
-func (m *mockMultipartFile) Close() error {
- return nil
-}
-
-func (m *mockMultipartFile) ReadAt(p []byte, off int64) (n int, err error) {
- return m.Reader.ReadAt(p, off)
-}
-
-func (m *mockMultipartFile) Seek(offset int64, whence int) (int64, error) {
- return m.Reader.Seek(offset, whence)
-}
-
-func newMockMultipartFile(data []byte) *mockMultipartFile {
- return &mockMultipartFile{
- Reader: bytes.NewReader(data),
- }
-}
-
-func (s *ServiceImplTestSuite) TestCreateUser() {
- testCases := []struct {
- name string
- username string
- email string
- password string
- role string
- wantErr bool
- }{
- {
- name: "有效的用户",
- username: "testuser",
- email: "test@example.com",
- password: "password123",
- role: "user",
- wantErr: false,
- },
- {
- name: "无效的邮箱",
- username: "testuser2",
- email: "invalid-email",
- password: "password123",
- role: "user",
- wantErr: true,
- },
- {
- name: "空密码",
- username: "testuser3",
- email: "test3@example.com",
- password: "",
- role: "user",
- wantErr: true,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
- if tc.wantErr {
- s.Error(err)
- s.Nil(user)
- } else {
- s.NoError(err)
- s.NotNil(user)
- s.Equal(tc.email, user.Email)
- s.Equal(tc.username, user.Username)
- }
- })
- }
-}
-
-func (s *ServiceImplTestSuite) TestGetUserByEmail() {
- // Create a test user first
- email := "test@example.com"
- password := "password123"
- role := "user"
-
- user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), user)
-
- s.Run("Existing user", func() {
- found, err := s.svc.GetUserByEmail(s.ctx, email)
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), found)
- assert.Equal(s.T(), email, found.Email)
- })
-
- s.Run("Non-existing user", func() {
- found, err := s.svc.GetUserByEmail(s.ctx, "nonexistent@example.com")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), found)
- })
-}
-
-func (s *ServiceImplTestSuite) TestValidatePassword() {
- // Create a test user first
- email := "test@example.com"
- password := "password123"
- role := "user"
-
- user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), user)
-
- s.Run("Valid password", func() {
- valid := s.svc.ValidatePassword(s.ctx, user, password)
- assert.True(s.T(), valid)
- })
-
- s.Run("Invalid password", func() {
- valid := s.svc.ValidatePassword(s.ctx, user, "wrongpassword")
- assert.False(s.T(), valid)
- })
-}
-
-func (s *ServiceImplTestSuite) TestRBAC() {
- s.Run("AssignRole", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
- require.NoError(s.T(), err)
-
- err = s.svc.AssignRole(s.ctx, user.ID, "user")
- assert.NoError(s.T(), err)
- })
-
- s.Run("RemoveRole", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
- require.NoError(s.T(), err)
-
- err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
- assert.NoError(s.T(), err)
- })
-
- s.Run("HasPermission", func() {
- s.Run("Admin can create users", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
- require.NoError(s.T(), err)
- assert.True(s.T(), hasPermission)
- })
-
- s.Run("Editor cannot create users", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
- require.NoError(s.T(), err)
- assert.False(s.T(), hasPermission)
- })
-
- s.Run("User cannot create users", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
- require.NoError(s.T(), err)
- assert.False(s.T(), hasPermission)
- })
-
- s.Run("Editor can create posts", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
- require.NoError(s.T(), err)
- assert.True(s.T(), hasPermission)
- })
-
- s.Run("User can read posts", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
- require.NoError(s.T(), err)
- assert.True(s.T(), hasPermission)
- })
-
- s.Run("User cannot create posts", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
- require.NoError(s.T(), err)
-
- hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
- require.NoError(s.T(), err)
- assert.False(s.T(), hasPermission)
- })
-
- s.Run("Invalid permission format", func() {
- user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
- require.NoError(s.T(), err)
-
- _, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
- require.Error(s.T(), err)
- assert.Contains(s.T(), err.Error(), "invalid permission format")
- })
- })
-}
-
-func (s *ServiceImplTestSuite) TestCategory() {
- // Create a test user with admin role for testing
- adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), adminUser)
-
- s.Run("CreateCategory", func() {
- // Test category creation
- category, err := s.svc.CreateCategory(s.ctx)
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), category)
- assert.NotZero(s.T(), category.ID)
- })
-
- s.Run("AddCategoryContent", func() {
- // Create a category first
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), category)
-
- testCases := []struct {
- name string
- langCode string
- catName string
- desc string
- slug string
- wantError bool
- }{
- {
- name: "Valid category content",
- langCode: "en",
- catName: "Test Category",
- desc: "Test Description",
- slug: "test-category",
- wantError: false,
- },
- {
- name: "Empty language code",
- langCode: "",
- catName: "Test Category",
- desc: "Test Description",
- slug: "test-category-2",
- wantError: true,
- },
- {
- name: "Empty name",
- langCode: "en",
- catName: "",
- desc: "Test Description",
- slug: "test-category-3",
- wantError: true,
- },
- {
- name: "Empty slug",
- langCode: "en",
- catName: "Test Category",
- desc: "Test Description",
- slug: "",
- wantError: true,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- content, err := s.svc.AddCategoryContent(s.ctx, category.ID, tc.langCode, tc.catName, tc.desc, tc.slug)
- if tc.wantError {
- assert.Error(s.T(), err)
- assert.Nil(s.T(), content)
- } else {
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), content)
- assert.Equal(s.T(), categorycontent.LanguageCode(tc.langCode), content.LanguageCode)
- assert.Equal(s.T(), tc.catName, content.Name)
- assert.Equal(s.T(), tc.desc, content.Description)
- assert.Equal(s.T(), tc.slug, content.Slug)
- }
- })
- }
- })
-
- s.Run("GetCategoryBySlug", func() {
- // Create a category with content first
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), category)
-
- content, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category-get")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), content)
-
- s.Run("Existing category", func() {
- found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "test-category-get")
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), found)
- assert.Equal(s.T(), category.ID, found.ID)
-
- // Check if content is loaded
- require.NotEmpty(s.T(), found.Edges.Contents)
- assert.Equal(s.T(), "Test Category", found.Edges.Contents[0].Name)
- })
-
- s.Run("Non-existing category", func() {
- found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "non-existent")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), found)
- })
-
- s.Run("Wrong language code", func() {
- found, err := s.svc.GetCategoryBySlug(s.ctx, "fr", "test-category-get")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), found)
- })
- })
-
- s.Run("ListCategories", func() {
- s.Run("List English categories", func() {
- // 创建多个分类,但只有 3 个有英文内容
- var createdCategories []*ent.Category
- for i := 0; i < 5; i++ {
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), category)
- createdCategories = append(createdCategories, category)
-
- // 只给前 3 个分类添加英文内容
- if i < 3 {
- _, err = s.svc.AddCategoryContent(s.ctx, category.ID, "en",
- fmt.Sprintf("Category %d", i),
- fmt.Sprintf("Description %d", i),
- fmt.Sprintf("category-list-%d", i))
- require.NoError(s.T(), err)
- }
- }
-
- categories, err := s.svc.ListCategories(s.ctx, "en")
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), categories)
- assert.Len(s.T(), categories, 3)
-
- // 检查所有返回的分类都有英文内容
- for _, cat := range categories {
- assert.NotEmpty(s.T(), cat.Edges.Contents)
- for _, content := range cat.Edges.Contents {
- assert.Equal(s.T(), categorycontent.LanguageCodeEN, content.LanguageCode)
- }
- }
- })
-
- s.Run("List Chinese categories", func() {
- // 创建多个分类,但只有 2 个有中文内容
- for i := 0; i < 4; i++ {
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), category)
-
- // 只给前 2 个分类添加中文内容
- if i < 2 {
- _, err = s.svc.AddCategoryContent(s.ctx, category.ID, "zh-Hans",
- fmt.Sprintf("分类 %d", i),
- fmt.Sprintf("描述 %d", i),
- fmt.Sprintf("category-list-%d", i))
- require.NoError(s.T(), err)
- }
- }
-
- categories, err := s.svc.ListCategories(s.ctx, "zh-Hans")
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), categories)
- assert.Len(s.T(), categories, 2)
-
- // 检查所有返回的分类都有中文内容
- for _, cat := range categories {
- assert.NotEmpty(s.T(), cat.Edges.Contents)
- for _, content := range cat.Edges.Contents {
- assert.Equal(s.T(), categorycontent.LanguageCodeZH_HANS, content.LanguageCode)
- }
- }
- })
-
- s.Run("List non-existing language", func() {
- categories, err := s.svc.ListCategories(s.ctx, "fr")
- assert.NoError(s.T(), err)
- assert.Empty(s.T(), categories)
- })
- })
-}
-
-func (s *ServiceImplTestSuite) TestGetCategories() {
- ctx := context.Background()
-
- // 测试不支持的语言代码
- categories, err := s.svc.GetCategories(ctx, "invalid")
- s.Require().NoError(err)
- s.Empty(categories)
-
- // 创建测试数据
- cat1 := s.createTestCategory(ctx, "test-cat-1")
- cat2 := s.createTestCategory(ctx, "test-cat-2")
-
- // 为分类添加不同语言的内容
- _, err = s.svc.AddCategoryContent(ctx, cat1.ID, "en", "Test Category 1", "Test Description 1", "category-list-test-1")
- s.Require().NoError(err)
-
- _, err = s.svc.AddCategoryContent(ctx, cat2.ID, "zh-Hans", "测试分类2", "测试描述2", "category-list-test-2")
- s.Require().NoError(err)
-
- // 测试获取英文分类
- enCategories, err := s.svc.GetCategories(ctx, "en")
- s.Require().NoError(err)
- s.Len(enCategories, 1)
- s.Equal(cat1.ID, enCategories[0].ID)
-
- // 测试获取简体中文分类
- zhCategories, err := s.svc.GetCategories(ctx, "zh-Hans")
- s.Require().NoError(err)
- s.Len(zhCategories, 1)
- s.Equal(cat2.ID, zhCategories[0].ID)
-
- // 测试获取繁体中文分类(应该为空)
- zhHantCategories, err := s.svc.GetCategories(ctx, "zh-Hant")
- s.Require().NoError(err)
- s.Empty(zhHantCategories)
-}
-
-func (s *ServiceImplTestSuite) TestGetUserRoles() {
- ctx := context.Background()
-
- // 创建测试用户,默认会有 "user" 角色
- user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
- s.Require().NoError(err)
-
- // 测试新用户有默认的 "user" 角色
- roles, err := s.svc.GetUserRoles(ctx, user.ID)
- s.Require().NoError(err)
- s.Len(roles, 1)
- s.Equal("user", roles[0].Name)
-
- // 分配角色给用户
- err = s.svc.AssignRole(ctx, user.ID, "admin")
- s.Require().NoError(err)
-
- // 测试用户现在有两个角色
- roles, err = s.svc.GetUserRoles(ctx, user.ID)
- s.Require().NoError(err)
- s.Len(roles, 2)
- roleNames := []string{roles[0].Name, roles[1].Name}
- s.Contains(roleNames, "user")
- s.Contains(roleNames, "admin")
-
- // 测试不存在的用户
- _, err = s.svc.GetUserRoles(ctx, -1)
- s.Require().Error(err)
-}
-
-func (s *ServiceImplTestSuite) TestDaily() {
- // 创建一个测试分类
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), category)
-
- // 添加分类内容
- categoryContent, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), categoryContent)
-
- dailyID := "250212" // 使用符合验证规则的 ID 格式:YYMMDD
-
- // 测试创建 Daily
- s.Run("Create Daily", func() {
- daily, err := s.svc.CreateDaily(s.ctx, dailyID, category.ID, "http://example.com/image.jpg")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), daily)
- assert.Equal(s.T(), dailyID, daily.ID)
- assert.Equal(s.T(), category.ID, daily.Edges.Category.ID)
- assert.Equal(s.T(), "http://example.com/image.jpg", daily.ImageURL)
- })
-
- // 测试添加 Daily 内容
- s.Run("Add Daily Content", func() {
- content, err := s.svc.AddDailyContent(s.ctx, dailyID, "en", "Test quote for the day")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), content)
- assert.Equal(s.T(), dailycontent.LanguageCodeEN, content.LanguageCode)
- assert.Equal(s.T(), "Test quote for the day", content.Quote)
- })
-
- // 测试获取 Daily
- s.Run("Get Daily By ID", func() {
- daily, err := s.svc.GetDailyByID(s.ctx, dailyID)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), daily)
- assert.Equal(s.T(), dailyID, daily.ID)
- assert.Equal(s.T(), category.ID, daily.Edges.Category.ID)
- })
-
- // 测试列出 Daily
- s.Run("List Dailies", func() {
- // 创建另一个 Daily 用于测试列表
- anotherDailyID := "250213"
- _, err := s.svc.CreateDaily(s.ctx, anotherDailyID, category.ID, "http://example.com/image2.jpg")
- assert.NoError(s.T(), err)
- _, err = s.svc.AddDailyContent(s.ctx, anotherDailyID, "en", "Another test quote")
- assert.NoError(s.T(), err)
-
- // 测试列表功能
- dailies, err := s.svc.ListDailies(s.ctx, "en", &category.ID, 10, 0)
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), dailies)
- assert.Len(s.T(), dailies, 2)
-
- // 测试分页
- dailies, err = s.svc.ListDailies(s.ctx, "en", &category.ID, 1, 0)
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), dailies)
- assert.Len(s.T(), dailies, 1)
-
- // 测试无分类过滤
- dailies, err = s.svc.ListDailies(s.ctx, "en", nil, 10, 0)
- assert.NoError(s.T(), err)
- assert.NotNil(s.T(), dailies)
- assert.Len(s.T(), dailies, 2)
- })
-}
-
-func (s *ServiceImplTestSuite) TestPost() {
- s.Run("Create Post", func() {
- s.Run("Draft", func() {
- post, err := s.svc.CreatePost(s.ctx, "draft")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), post)
- assert.Equal(s.T(), "draft", post.Status.String())
- })
-
- s.Run("Published", func() {
- post, err := s.svc.CreatePost(s.ctx, "published")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), post)
- assert.Equal(s.T(), "published", post.Status.String())
- })
-
- s.Run("Archived", func() {
- post, err := s.svc.CreatePost(s.ctx, "archived")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), post)
- assert.Equal(s.T(), "archived", post.Status.String())
- })
-
- s.Run("Invalid Status", func() {
- post, err := s.svc.CreatePost(s.ctx, "invalid")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), post)
- })
- })
-
- s.Run("Add Post Content", func() {
- // Create a post first
- post, err := s.svc.CreatePost(s.ctx, "draft")
- require.NoError(s.T(), err)
-
- s.Run("English Content", func() {
- content, err := s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), content)
- assert.Equal(s.T(), "en", content.LanguageCode.String())
- assert.Equal(s.T(), "Test Post", content.Title)
- assert.Equal(s.T(), "# Test Content", content.ContentMarkdown)
- assert.Equal(s.T(), "Test Summary", content.Summary)
- assert.Equal(s.T(), "test,post", content.MetaKeywords)
- assert.Equal(s.T(), "Test Description", content.MetaDescription)
- assert.Equal(s.T(), "test-post", content.Slug)
- })
-
- s.Run("Simplified Chinese Content", func() {
- content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), content)
- assert.Equal(s.T(), "zh-Hans", content.LanguageCode.String())
- assert.Equal(s.T(), "测试帖子", content.Title)
- assert.Equal(s.T(), "# 测试内容", content.ContentMarkdown)
- assert.Equal(s.T(), "测试摘要", content.Summary)
- assert.Equal(s.T(), "测试,帖子", content.MetaKeywords)
- assert.Equal(s.T(), "测试描述", content.MetaDescription)
- assert.Equal(s.T(), "测试帖子", content.Slug)
- })
-
- s.Run("Traditional Chinese Content", func() {
- content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hant", "測試貼文", "# 測試內容", "測試摘要", "測試,貼文", "測試描述")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), content)
- assert.Equal(s.T(), "zh-Hant", content.LanguageCode.String())
- assert.Equal(s.T(), "測試貼文", content.Title)
- assert.Equal(s.T(), "# 測試內容", content.ContentMarkdown)
- assert.Equal(s.T(), "測試摘要", content.Summary)
- assert.Equal(s.T(), "測試,貼文", content.MetaKeywords)
- assert.Equal(s.T(), "測試描述", content.MetaDescription)
- assert.Equal(s.T(), "測試貼文", content.Slug)
- })
-
- s.Run("Invalid Language Code", func() {
- content, err := s.svc.AddPostContent(s.ctx, post.ID, "fr", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), content)
- })
-
- s.Run("Non-existent Post", func() {
- content, err := s.svc.AddPostContent(s.ctx, 999999, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), content)
- })
- })
-
- s.Run("Get Post By Slug", func() {
- // Create a post first
- post, err := s.svc.CreatePost(s.ctx, "published")
- require.NoError(s.T(), err)
-
- // Add content in different languages
- _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description")
- require.NoError(s.T(), err)
- _, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述")
- require.NoError(s.T(), err)
-
- s.Run("Get Post By Slug - English", func() {
- result, err := s.svc.GetPostBySlug(s.ctx, "en", "test-post")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), result)
- assert.Equal(s.T(), post.ID, result.ID)
- assert.Equal(s.T(), "published", result.Status.String())
-
- contents := result.Edges.Contents
- require.Len(s.T(), contents, 1)
- assert.Equal(s.T(), "en", contents[0].LanguageCode.String())
- assert.Equal(s.T(), "Test Post", contents[0].Title)
- })
-
- s.Run("Get Post By Slug - Chinese", func() {
- result, err := s.svc.GetPostBySlug(s.ctx, "zh-Hans", "测试帖子")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), result)
- assert.Equal(s.T(), post.ID, result.ID)
- assert.Equal(s.T(), "published", result.Status.String())
-
- contents := result.Edges.Contents
- require.Len(s.T(), contents, 1)
- assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String())
- assert.Equal(s.T(), "测试帖子", contents[0].Title)
- })
-
- s.Run("Non-existent Post", func() {
- result, err := s.svc.GetPostBySlug(s.ctx, "en", "non-existent")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), result)
- })
-
- s.Run("Invalid Language Code", func() {
- result, err := s.svc.GetPostBySlug(s.ctx, "fr", "test-post")
- assert.Error(s.T(), err)
- assert.Nil(s.T(), result)
- })
- })
-
- s.Run("List Posts", func() {
- // Create some posts with content
- for i := 0; i < 5; i++ {
- post, err := s.svc.CreatePost(s.ctx, "published")
- require.NoError(s.T(), err)
-
- // Add content in different languages
- _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Post %d", i), "# Content", "Summary", "test", "Description")
- require.NoError(s.T(), err)
- _, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", fmt.Sprintf("帖子 %d", i), "# 内容", "摘要", "测试", "描述")
- require.NoError(s.T(), err)
- }
-
- s.Run("List All Posts - English", func() {
- posts, err := s.svc.ListPosts(s.ctx, "en", nil, 10, 0)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 5)
-
- // Check that all posts have English content
- for _, post := range posts {
- contents := post.Edges.Contents
- require.Len(s.T(), contents, 1)
- assert.Equal(s.T(), "en", contents[0].LanguageCode.String())
- }
- })
-
- s.Run("List All Posts - Chinese", func() {
- posts, err := s.svc.ListPosts(s.ctx, "zh-Hans", nil, 10, 0)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 5)
-
- // Check that all posts have Chinese content
- for _, post := range posts {
- contents := post.Edges.Contents
- require.Len(s.T(), contents, 1)
- assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String())
- }
- })
-
- s.Run("List Posts with Pagination", func() {
- // Get first page
- posts, err := s.svc.ListPosts(s.ctx, "en", nil, 2, 0)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 2)
-
- // Get second page
- posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 2)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 2)
-
- // Get last page
- posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 4)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 1)
- })
-
- s.Run("List Posts by Category", func() {
- // Create a category
- category, err := s.svc.CreateCategory(s.ctx)
- require.NoError(s.T(), err)
-
- // Create posts in this category
- for i := 0; i < 3; i++ {
- post, err := s.svc.CreatePost(s.ctx, "published")
- require.NoError(s.T(), err)
-
- // Set category
- _, err = s.client.Post.UpdateOne(post).SetCategoryID(category.ID).Save(s.ctx)
- require.NoError(s.T(), err)
-
- // Add content
- _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Category Post %d", i), "# Content", "Summary", "test", "Description")
- require.NoError(s.T(), err)
- }
-
- // List posts in this category
- posts, err := s.svc.ListPosts(s.ctx, "en", &category.ID, 10, 0)
- require.NoError(s.T(), err)
- require.Len(s.T(), posts, 3)
-
- // Check that all posts belong to the category
- for _, post := range posts {
- assert.Equal(s.T(), category.ID, post.Edges.Category.ID)
- }
- })
-
- s.Run("Invalid Language Code", func() {
- posts, err := s.svc.ListPosts(s.ctx, "fr", nil, 10, 0)
- assert.Error(s.T(), err)
- assert.Nil(s.T(), posts)
- })
- })
-}
-
-func (s *ServiceImplTestSuite) TestMedia() {
- s.Run("Upload Media", func() {
- // Create a user first
- user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), user)
-
- // Mock file content
- fileContent := []byte("test file content")
-
- // Mock the file header
- fileHeader := &multipart.FileHeader{
- Filename: "test.jpg",
- Size: int64(len(fileContent)),
- Header: textproto.MIMEHeader{
- "Content-Type": []string{"image/jpeg"},
- },
- }
-
- // Mock the storage behavior
- s.storage.EXPECT().
- Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()).
- DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
- // Verify the reader content
- data, err := io.ReadAll(reader)
- if err != nil {
- return nil, err
- }
- if !bytes.Equal(data, fileContent) {
- return nil, fmt.Errorf("unexpected file content")
- }
- return &storage.FileInfo{
- ID: "test123",
- Name: name,
- Size: int64(len(fileContent)),
- ContentType: contentType,
- URL: "http://example.com/test.jpg",
- CreatedAt: time.Now(),
- UpdatedAt: time.Now(),
- }, nil
- }).Times(1)
-
- // Replace the Open method
- openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
- return &mockMultipartFile{bytes.NewReader(fileContent)}, nil
- }
-
- // Test upload
- media, err := s.svc.Upload(s.ctx, fileHeader, user.ID)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), media)
- assert.Equal(s.T(), "test123", media.StorageID)
- assert.Equal(s.T(), "test.jpg", media.OriginalName)
- assert.Equal(s.T(), int64(len(fileContent)), media.Size)
- assert.Equal(s.T(), "image/jpeg", media.MimeType)
- assert.Equal(s.T(), "http://example.com/test.jpg", media.URL)
- assert.Equal(s.T(), strconv.Itoa(user.ID), media.CreatedBy)
-
- // Now we can test other operations since we have a media record
- s.Run("Get Media", func() {
- result, err := s.svc.GetMedia(s.ctx, media.ID)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), result)
- assert.Equal(s.T(), media.ID, result.ID)
- assert.Equal(s.T(), media.StorageID, result.StorageID)
- assert.Equal(s.T(), media.URL, result.URL)
- })
-
- s.Run("Get File", func() {
- // Mock the storage behavior
- mockReader := io.NopCloser(strings.NewReader("test content"))
- mockFileInfo := &storage.FileInfo{
- ID: media.StorageID,
- Name: media.OriginalName,
- Size: media.Size,
- ContentType: media.MimeType,
- URL: media.URL,
- CreatedAt: media.CreatedAt,
- UpdatedAt: media.UpdatedAt,
- }
- s.storage.EXPECT().
- Get(gomock.Any(), media.StorageID).
- Return(mockReader, mockFileInfo, nil)
-
- // Test get file
- reader, fileInfo, err := s.svc.GetFile(s.ctx, media.ID)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), reader)
- require.NotNil(s.T(), fileInfo)
- assert.Equal(s.T(), media.OriginalName, fileInfo.Name)
- assert.Equal(s.T(), media.Size, fileInfo.Size)
- assert.Equal(s.T(), media.MimeType, fileInfo.ContentType)
- assert.Equal(s.T(), media.URL, fileInfo.URL)
-
- // Clean up
- reader.Close()
- })
-
- s.Run("List Media", func() {
- // Test list media
- list, err := s.svc.ListMedia(s.ctx, 10, 0)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), list)
- require.Len(s.T(), list, 1)
- assert.Equal(s.T(), "test.jpg", list[0].OriginalName)
- })
-
- s.Run("Delete Media", func() {
- // Mock the storage behavior
- s.storage.EXPECT().
- Delete(gomock.Any(), media.StorageID).
- Return(nil)
-
- // Test delete media
- err = s.svc.DeleteMedia(s.ctx, media.ID, user.ID)
- require.NoError(s.T(), err)
-
- // Verify media is deleted
- count, err := s.client.Media.Query().Count(s.ctx)
- require.NoError(s.T(), err)
- assert.Equal(s.T(), 0, count)
- })
- })
-
- s.Run("Delete Media - Unauthorized", func() {
- // Create a user
- user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
- require.NoError(s.T(), err)
-
- // Mock file content
- fileContent := []byte("test file content")
-
- // Mock the file header
- fileHeader := &multipart.FileHeader{
- Filename: "test2.jpg",
- Size: int64(len(fileContent)),
- Header: textproto.MIMEHeader{
- "Content-Type": []string{"image/jpeg"},
- },
- }
-
- // Mock the storage behavior
- s.storage.EXPECT().
- Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()).
- DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
- // Verify the reader content
- data, err := io.ReadAll(reader)
- if err != nil {
- return nil, err
- }
- if !bytes.Equal(data, fileContent) {
- return nil, fmt.Errorf("unexpected file content")
- }
- return &storage.FileInfo{
- ID: "test456",
- Name: name,
- Size: int64(len(fileContent)),
- ContentType: contentType,
- URL: "http://example.com/test2.jpg",
- CreatedAt: time.Now(),
- UpdatedAt: time.Now(),
- }, nil
- }).Times(1)
-
- // Replace the Open method
- openFile = func(fh *multipart.FileHeader) (multipart.File, error) {
- return &mockMultipartFile{bytes.NewReader(fileContent)}, nil
- }
-
- media, err := s.svc.Upload(s.ctx, fileHeader, user.ID)
- require.NoError(s.T(), err)
-
- // Try to delete with different user
- anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
- require.NoError(s.T(), err)
-
- err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)
- assert.Equal(s.T(), ErrUnauthorized, err)
-
- // Verify media is not deleted
- count, err := s.client.Media.Query().Count(s.ctx)
- require.NoError(s.T(), err)
- assert.Equal(s.T(), 1, count)
- })
-}
-
-func (s *ServiceImplTestSuite) TestContributor() {
- // 测试创建贡献者
- avatarURL := "https://example.com/avatar.jpg"
- bio := "Test bio"
- contributor, err := s.svc.CreateContributor(s.ctx, "Test Contributor", &avatarURL, &bio)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), contributor)
- assert.Equal(s.T(), "Test Contributor", contributor.Name)
- assert.Equal(s.T(), avatarURL, contributor.AvatarURL)
- assert.Equal(s.T(), bio, contributor.Bio)
-
- // 测试添加社交链接
- link, err := s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "github", "GitHub", "https://github.com/test")
- require.NoError(s.T(), err)
- require.NotNil(s.T(), link)
- assert.Equal(s.T(), "github", link.Type.String())
- assert.Equal(s.T(), "GitHub", link.Name)
- assert.Equal(s.T(), "https://github.com/test", link.Value)
-
- // 测试获取贡献者
- fetchedContributor, err := s.svc.GetContributorByID(s.ctx, contributor.ID)
- require.NoError(s.T(), err)
- require.NotNil(s.T(), fetchedContributor)
- assert.Equal(s.T(), contributor.ID, fetchedContributor.ID)
- assert.Equal(s.T(), contributor.Name, fetchedContributor.Name)
- assert.Equal(s.T(), contributor.AvatarURL, fetchedContributor.AvatarURL)
- assert.Equal(s.T(), contributor.Bio, fetchedContributor.Bio)
- require.Len(s.T(), fetchedContributor.Edges.SocialLinks, 1)
- assert.Equal(s.T(), link.ID, fetchedContributor.Edges.SocialLinks[0].ID)
-
- // 测试列出贡献者
- contributors, err := s.svc.ListContributors(s.ctx)
- require.NoError(s.T(), err)
- require.NotEmpty(s.T(), contributors)
- assert.Equal(s.T(), contributor.ID, contributors[0].ID)
- require.Len(s.T(), contributors[0].Edges.SocialLinks, 1)
-
- // 测试错误情况
- _, err = s.svc.GetContributorByID(s.ctx, -1)
- assert.Error(s.T(), err)
-
- _, err = s.svc.AddContributorSocialLink(s.ctx, -1, "github", "GitHub", "https://github.com/test")
- assert.Error(s.T(), err)
-
- // 测试无效的社交链接类型
- _, err = s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "invalid_type", "Invalid", "https://example.com")
- assert.Error(s.T(), err)
-}
-
-func TestServiceSuite(t *testing.T) {
- suite.Run(t, new(ServiceSuite))
-}
-
-type ServiceSuite struct {
- suite.Suite
-}
-
-func TestServiceInterface(t *testing.T) {
- var _ Service = (*serviceImpl)(nil)
-}
-
-// 创建测试分类的辅助函数
-func (s *ServiceImplTestSuite) createTestCategory(ctx context.Context, slug string) *ent.Category {
- category, err := s.svc.CreateCategory(ctx)
- s.Require().NoError(err)
- return category
-}
diff --git a/backend/internal/service/media_test.go b/backend/internal/service/media_test.go
deleted file mode 100644
index 412689d..0000000
--- a/backend/internal/service/media_test.go
+++ /dev/null
@@ -1,332 +0,0 @@
-package service
-
-import (
- "bytes"
- "context"
- "fmt"
- "io"
- "mime/multipart"
- "net/textproto"
- "reflect"
- "testing"
-
- "github.com/stretchr/testify/require"
- "github.com/stretchr/testify/suite"
- "go.uber.org/mock/gomock"
-
- "tss-rocks-be/ent"
- "tss-rocks-be/internal/storage"
- "tss-rocks-be/internal/storage/mock"
- "tss-rocks-be/internal/testutil"
-
- "bou.ke/monkey"
-)
-
-type MediaServiceTestSuite struct {
- suite.Suite
- ctx context.Context
- client *ent.Client
- storage *mock.MockStorage
- ctrl *gomock.Controller
- svc MediaService
-}
-
-func (s *MediaServiceTestSuite) SetupTest() {
- s.ctx = context.Background()
- s.client = testutil.NewTestClient()
- require.NotNil(s.T(), s.client)
-
- s.ctrl = gomock.NewController(s.T())
- s.storage = mock.NewMockStorage(s.ctrl)
- s.svc = NewMediaService(s.client, s.storage)
-
- // 清理数据库
- _, err := s.client.Media.Delete().Exec(s.ctx)
- require.NoError(s.T(), err)
-}
-
-func (s *MediaServiceTestSuite) TearDownTest() {
- s.ctrl.Finish()
- s.client.Close()
-}
-
-func TestMediaServiceSuite(t *testing.T) {
- suite.Run(t, new(MediaServiceTestSuite))
-}
-
-type mockFileHeader struct {
- filename string
- contentType string
- size int64
- content []byte
-}
-
-func (h *mockFileHeader) Open() (multipart.File, error) {
- return newMockMultipartFile(h.content), nil
-}
-
-func (h *mockFileHeader) Filename() string {
- return h.filename
-}
-
-func (h *mockFileHeader) Size() int64 {
- return h.size
-}
-
-func (h *mockFileHeader) Header() textproto.MIMEHeader {
- header := make(textproto.MIMEHeader)
- header.Set("Content-Type", h.contentType)
- return header
-}
-
-func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
- header := &multipart.FileHeader{
- Filename: filename,
- Header: make(textproto.MIMEHeader),
- Size: int64(len(content)),
- }
- header.Header.Set("Content-Type", contentType)
-
- monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
- return newMockMultipartFile(content), nil
- })
-
- return header
-}
-
-func (s *MediaServiceTestSuite) TestUpload() {
- testCases := []struct {
- name string
- filename string
- contentType string
- content []byte
- setupMock func()
- wantErr bool
- errMsg string
- }{
- {
- name: "Upload text file",
- filename: "test.txt",
- contentType: "text/plain",
- content: []byte("test content"),
- setupMock: func() {
- s.storage.EXPECT().
- Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
- DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
- content, err := io.ReadAll(reader)
- s.Require().NoError(err)
- s.Equal([]byte("test content"), content)
- return &storage.FileInfo{
- ID: "test-id",
- Name: "test.txt",
- ContentType: "text/plain",
- Size: int64(len(content)),
- }, nil
- })
- },
- wantErr: false,
- },
- {
- name: "Invalid filename",
- filename: "../test.txt",
- contentType: "text/plain",
- content: []byte("test content"),
- setupMock: func() {},
- wantErr: true,
- errMsg: "invalid filename",
- },
- {
- name: "Storage error",
- filename: "test.txt",
- contentType: "text/plain",
- content: []byte("test content"),
- setupMock: func() {
- s.storage.EXPECT().
- Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
- Return(nil, fmt.Errorf("storage error"))
- },
- wantErr: true,
- errMsg: "storage error",
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- // Setup mock
- tc.setupMock()
-
- // Create test file
- fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
-
- // Add debug output
- s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
-
- // Test upload
- media, err := s.svc.Upload(s.ctx, fileHeader, 1)
-
- // Add debug output
- if err != nil {
- s.T().Logf("Upload error: %v", err)
- }
-
- if tc.wantErr {
- s.Require().Error(err)
- s.Contains(err.Error(), tc.errMsg)
- return
- }
-
- s.Require().NoError(err)
- s.NotNil(media)
- s.Equal(tc.filename, media.OriginalName)
- s.Equal(tc.contentType, media.MimeType)
- s.Equal(int64(len(tc.content)), media.Size)
- s.Equal("1", media.CreatedBy)
- })
- }
-}
-
-func (s *MediaServiceTestSuite) TestGet() {
- // Create test media
- media, err := s.client.Media.Create().
- SetStorageID("test-id").
- SetOriginalName("test.txt").
- SetMimeType("text/plain").
- SetSize(12).
- SetURL("/api/media/test-id").
- SetCreatedBy("1").
- Save(s.ctx)
- s.Require().NoError(err)
-
- // Test get existing media
- result, err := s.svc.Get(s.ctx, media.ID)
- s.Require().NoError(err)
- s.Equal(media.ID, result.ID)
- s.Equal(media.OriginalName, result.OriginalName)
-
- // Test get non-existing media
- _, err = s.svc.Get(s.ctx, -1)
- s.Require().Error(err)
- s.Contains(err.Error(), "media not found")
-}
-
-func (s *MediaServiceTestSuite) TestDelete() {
- // Create test media
- media, err := s.client.Media.Create().
- SetStorageID("test-id").
- SetOriginalName("test.txt").
- SetMimeType("text/plain").
- SetSize(12).
- SetURL("/api/media/test-id").
- SetCreatedBy("1").
- Save(s.ctx)
- s.Require().NoError(err)
-
- // Test delete by unauthorized user
- err = s.svc.Delete(s.ctx, media.ID, 2)
- s.Require().Error(err)
- s.Contains(err.Error(), "unauthorized")
-
- // Test delete by owner
- s.storage.EXPECT().
- Delete(gomock.Any(), "test-id").
- Return(nil)
- err = s.svc.Delete(s.ctx, media.ID, 1)
- s.Require().NoError(err)
-
- // Verify media is deleted
- _, err = s.svc.Get(s.ctx, media.ID)
- s.Require().Error(err)
- s.Contains(err.Error(), "not found")
-}
-
-func (s *MediaServiceTestSuite) TestList() {
- // Create test media
- for i := 0; i < 5; i++ {
- _, err := s.client.Media.Create().
- SetStorageID(fmt.Sprintf("test-id-%d", i)).
- SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
- SetMimeType("text/plain").
- SetSize(12).
- SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
- SetCreatedBy("1").
- Save(s.ctx)
- s.Require().NoError(err)
- }
-
- // Test list with limit and offset
- media, err := s.svc.List(s.ctx, 3, 1)
- s.Require().NoError(err)
- s.Len(media, 3)
-}
-
-func (s *MediaServiceTestSuite) TestGetFile() {
- // Create test media
- media, err := s.client.Media.Create().
- SetStorageID("test-id").
- SetOriginalName("test.txt").
- SetMimeType("text/plain").
- SetSize(12).
- SetURL("/api/media/test-id").
- SetCreatedBy("1").
- Save(s.ctx)
- s.Require().NoError(err)
-
- // Mock storage.Get
- mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
- mockFileInfo := &storage.FileInfo{
- ID: "test-id",
- Name: "test.txt",
- ContentType: "text/plain",
- Size: 12,
- }
- s.storage.EXPECT().
- Get(gomock.Any(), "test-id").
- Return(mockReader, mockFileInfo, nil)
-
- // Test get file
- reader, info, err := s.svc.GetFile(s.ctx, media.ID)
- s.Require().NoError(err)
- s.NotNil(reader)
- s.Equal(mockFileInfo, info)
-
- // Test get non-existing file
- _, _, err = s.svc.GetFile(s.ctx, -1)
- s.Require().Error(err)
- s.Contains(err.Error(), "not found")
-}
-
-func (s *MediaServiceTestSuite) TestIsValidFilename() {
- testCases := []struct {
- name string
- filename string
- want bool
- }{
- {
- name: "Valid filename",
- filename: "test.txt",
- want: true,
- },
- {
- name: "Invalid filename with ../",
- filename: "../test.txt",
- want: false,
- },
- {
- name: "Invalid filename with ./",
- filename: "./test.txt",
- want: false,
- },
- {
- name: "Invalid filename with backslash",
- filename: "test\\file.txt",
- want: false,
- },
- }
-
- for _, tc := range testCases {
- s.Run(tc.name, func() {
- got := isValidFilename(tc.filename)
- s.Equal(tc.want, got)
- })
- }
-}
diff --git a/backend/internal/service/mock/mock.go b/backend/internal/service/mock/mock.go
deleted file mode 100644
index b9ee689..0000000
--- a/backend/internal/service/mock/mock.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package mock
-
-//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock
diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go
index 5b88d4f..ee6aedd 100644
--- a/backend/internal/service/service.go
+++ b/backend/internal/service/service.go
@@ -1,7 +1,5 @@
package service
-//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
-
import (
"context"
"io"
@@ -34,16 +32,16 @@ type Service interface {
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
// Post operations
- CreatePost(ctx context.Context, status string) (*ent.Post, error)
+ CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error)
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
- ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
+ ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
- GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
+ GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// Contributor operations
diff --git a/backend/internal/storage/local.go b/backend/internal/storage/local.go
index 2b1d8bf..d8ccff7 100644
--- a/backend/internal/storage/local.go
+++ b/backend/internal/storage/local.go
@@ -11,6 +11,8 @@ import (
"path/filepath"
"strings"
"time"
+
+ "github.com/rs/zerolog/log"
)
type LocalStorage struct {
@@ -44,6 +46,24 @@ func (s *LocalStorage) generateID() (string, error) {
return hex.EncodeToString(bytes), nil
}
+func (s *LocalStorage) generateFilePath(id string, ext string, createTime time.Time) string {
+ // Create year/month directory structure
+ year := createTime.Format("2006")
+ month := createTime.Format("01")
+
+ // If id already has an extension, don't add ext
+ if filepath.Ext(id) != "" {
+ return filepath.Join(year, month, id)
+ }
+
+ // Otherwise, add the extension if provided
+ filename := id
+ if ext != "" {
+ filename = id + ext
+ }
+ return filepath.Join(year, month, filename)
+}
+
func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error {
metaPath := filepath.Join(s.metaDir, id+".meta")
file, err := os.Create(metaPath)
@@ -89,11 +109,64 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
- // Create the file path
- filePath := filepath.Join(s.rootDir, id)
+ // Get file extension from original name or content type
+ ext := filepath.Ext(name)
+ if ext == "" {
+ // If no extension in name, try to get it from content type
+ switch contentType {
+ case "image/jpeg":
+ ext = ".jpg"
+ case "image/png":
+ ext = ".png"
+ case "image/gif":
+ ext = ".gif"
+ case "image/webp":
+ ext = ".webp"
+ case "image/svg+xml":
+ ext = ".svg"
+ case "video/mp4":
+ ext = ".mp4"
+ case "video/webm":
+ ext = ".webm"
+ case "audio/mpeg":
+ ext = ".mp3"
+ case "audio/ogg":
+ ext = ".ogg"
+ case "audio/wav":
+ ext = ".wav"
+ case "application/pdf":
+ ext = ".pdf"
+ case "application/msword":
+ ext = ".doc"
+ case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
+ ext = ".docx"
+ case "application/vnd.ms-excel":
+ ext = ".xls"
+ case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
+ ext = ".xlsx"
+ case "application/zip":
+ ext = ".zip"
+ case "application/x-rar-compressed":
+ ext = ".rar"
+ case "text/plain":
+ ext = ".txt"
+ case "text/csv":
+ ext = ".csv"
+ }
+ }
+
+ // Create the file path with year/month structure
+ now := time.Now()
+ relPath := s.generateFilePath(id, ext, now)
+ fullPath := filepath.Join(s.rootDir, relPath)
+
+ // Create directory if it doesn't exist
+ if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil {
+ return nil, fmt.Errorf("failed to create directory: %w", err)
+ }
// Create the file
- file, err := os.Create(filePath)
+ file, err := os.Create(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
}
@@ -102,52 +175,73 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
// Copy the content
size, err := io.Copy(file, reader)
if err != nil {
- // Clean up the file if there's an error
- os.Remove(filePath)
+ os.Remove(fullPath) // Clean up on error
return nil, fmt.Errorf("failed to write file content: %w", err)
}
- now := time.Now()
+ // Save metadata
info := &FileInfo{
ID: id,
Name: name,
- Size: size,
ContentType: contentType,
+ Size: size,
CreatedAt: now,
- UpdatedAt: now,
- URL: fmt.Sprintf("/api/media/file/%s", id),
+ URL: fmt.Sprintf("/media/%s/%s/%s", now.Format("2006"), now.Format("01"), filepath.Base(relPath)),
}
-
- // Save metadata
if err := s.saveMetadata(id, info); err != nil {
- os.Remove(filePath)
- return nil, err
+ os.Remove(fullPath) // Clean up on error
+ return nil, fmt.Errorf("failed to save metadata: %w", err)
}
return info, nil
}
func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
- filePath := filepath.Join(s.rootDir, id)
+ // 从 id 中提取文件扩展名和基础 ID
+ ext := filepath.Ext(id)
+ baseID := strings.TrimSuffix(id, ext)
- // Open the file
+ // 获取文件的创建时间(从元数据或当前时间)
+ metaPath := filepath.Join(s.metaDir, baseID+".meta")
+ var createTime time.Time
+ if stat, err := os.Stat(metaPath); err == nil {
+ createTime = stat.ModTime()
+ } else {
+ createTime = time.Now() // 如果找不到元数据,使用当前时间
+ }
+
+ // 生成完整的文件路径
+ year := createTime.Format("2006")
+ month := createTime.Format("01")
+ filePath := filepath.Join(s.rootDir, year, month, id) // 直接使用完整的 id(包含扩展名)
+
+ // 调试日志
+ log.Debug().
+ Str("id", id).
+ Str("baseID", baseID).
+ Str("ext", ext).
+ Str("filePath", filePath).
+ Time("createTime", createTime).
+ Msg("Attempting to get file")
+
+ // 打开文件
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
- return nil, nil, fmt.Errorf("file not found: %s", id)
+ return nil, nil, fmt.Errorf("file not found: %s (path: %s)", id, filePath)
}
return nil, nil, fmt.Errorf("failed to open file: %w", err)
}
- // Get file info
+ // 获取文件信息
stat, err := file.Stat()
if err != nil {
file.Close()
return nil, nil, fmt.Errorf("failed to get file info: %w", err)
}
- // Load metadata
- name, contentType, err := s.loadMetadata(id)
+ // 加载元数据
+ name, contentType, err := s.loadMetadata(baseID)
if err != nil {
file.Close()
return nil, nil, err
@@ -158,27 +252,44 @@ func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *File
Name: name,
Size: stat.Size(),
ContentType: contentType,
- CreatedAt: stat.ModTime(),
+ CreatedAt: createTime,
UpdatedAt: stat.ModTime(),
- URL: fmt.Sprintf("/api/media/file/%s", id),
+ URL: fmt.Sprintf("/media/%s/%s/%s", year, month, id),
}
return file, info, nil
}
func (s *LocalStorage) Delete(ctx context.Context, id string) error {
- filePath := filepath.Join(s.rootDir, id)
- if err := os.Remove(filePath); err != nil {
- if os.IsNotExist(err) {
- return fmt.Errorf("file not found: %s", id)
- }
- return fmt.Errorf("failed to delete file: %w", err)
+ // 从 id 中提取文件扩展名
+ ext := filepath.Ext(id)
+ baseID := strings.TrimSuffix(id, ext)
+
+ // 获取文件的创建时间(从元数据或当前时间)
+ metaPath := filepath.Join(s.metaDir, baseID+".meta")
+ var createTime time.Time
+ if stat, err := os.Stat(metaPath); err == nil {
+ createTime = stat.ModTime()
+ } else {
+ createTime = time.Now() // 如果找不到元数据,使用当前时间
}
- // Remove metadata
- metaPath := filepath.Join(s.metaDir, id+".meta")
- if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("failed to remove metadata: %w", err)
+ // 生成完整的文件路径
+ relPath := s.generateFilePath(baseID, ext, createTime)
+ filePath := filepath.Join(s.rootDir, relPath)
+
+ // 删除文件
+ if err := os.Remove(filePath); err != nil {
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("failed to delete file: %w", err)
+ }
+ }
+
+ // 删除元数据
+ if err := os.Remove(metaPath); err != nil {
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("failed to delete metadata: %w", err)
+ }
}
return nil
diff --git a/backend/internal/storage/local_test.go b/backend/internal/storage/local_test.go
deleted file mode 100644
index f27a16f..0000000
--- a/backend/internal/storage/local_test.go
+++ /dev/null
@@ -1,154 +0,0 @@
-package storage
-
-import (
- "bytes"
- "context"
- "io"
- "os"
- "path/filepath"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestLocalStorage(t *testing.T) {
- // Create a temporary directory for testing
- tempDir, err := os.MkdirTemp("", "storage_test_*")
- require.NoError(t, err)
- defer os.RemoveAll(tempDir)
-
- // Create a new LocalStorage instance
- storage, err := NewLocalStorage(tempDir)
- require.NoError(t, err)
-
- ctx := context.Background()
-
- t.Run("Save and Get", func(t *testing.T) {
- content := []byte("test content")
- reader := bytes.NewReader(content)
-
- // Save the file
- fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
- require.NoError(t, err)
- assert.NotEmpty(t, fileInfo.ID)
- assert.Equal(t, "test.txt", fileInfo.Name)
- assert.Equal(t, int64(len(content)), fileInfo.Size)
- assert.Equal(t, "text/plain", fileInfo.ContentType)
- assert.False(t, fileInfo.CreatedAt.IsZero())
-
- // Get the file
- readCloser, info, err := storage.Get(ctx, fileInfo.ID)
- require.NoError(t, err)
- defer readCloser.Close()
-
- data, err := io.ReadAll(readCloser)
- require.NoError(t, err)
- assert.Equal(t, content, data)
- assert.Equal(t, fileInfo.ID, info.ID)
- assert.Equal(t, fileInfo.Name, info.Name)
- assert.Equal(t, fileInfo.Size, info.Size)
- })
-
- t.Run("List", func(t *testing.T) {
- // Clear the directory first
- dirEntries, err := os.ReadDir(tempDir)
- require.NoError(t, err)
- for _, entry := range dirEntries {
- if entry.Name() != ".meta" {
- os.Remove(filepath.Join(tempDir, entry.Name()))
- }
- }
-
- // Save multiple files
- testFiles := []struct {
- name string
- content string
- }{
- {"test1.txt", "content1"},
- {"test2.txt", "content2"},
- {"other.txt", "content3"},
- }
-
- for _, f := range testFiles {
- reader := bytes.NewReader([]byte(f.content))
- _, err := storage.Save(ctx, f.name, "text/plain", reader)
- require.NoError(t, err)
- }
-
- // List all files
- allFiles, err := storage.List(ctx, "", 10, 0)
- require.NoError(t, err)
- assert.Len(t, allFiles, 3)
-
- // List files with prefix
- filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
- require.NoError(t, err)
- assert.Len(t, filesWithPrefix, 2)
- for _, f := range filesWithPrefix {
- assert.True(t, strings.HasPrefix(f.Name, "test"))
- }
-
- // Test pagination
- pagedFiles, err := storage.List(ctx, "", 2, 1)
- require.NoError(t, err)
- assert.Len(t, pagedFiles, 2)
- })
-
- t.Run("Exists", func(t *testing.T) {
- // Save a file
- content := []byte("test content")
- reader := bytes.NewReader(content)
- fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
- require.NoError(t, err)
-
- // Check if file exists
- exists, err := storage.Exists(ctx, fileInfo.ID)
- require.NoError(t, err)
- assert.True(t, exists)
-
- // Check non-existent file
- exists, err = storage.Exists(ctx, "non-existent")
- require.NoError(t, err)
- assert.False(t, exists)
- })
-
- t.Run("Delete", func(t *testing.T) {
- // Save a file
- content := []byte("test content")
- reader := bytes.NewReader(content)
- fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
- require.NoError(t, err)
-
- // Delete the file
- err = storage.Delete(ctx, fileInfo.ID)
- require.NoError(t, err)
-
- // Verify file is deleted
- exists, err := storage.Exists(ctx, fileInfo.ID)
- require.NoError(t, err)
- assert.False(t, exists)
-
- // Try to delete non-existent file
- err = storage.Delete(ctx, "non-existent")
- assert.Error(t, err)
- })
-
- t.Run("Invalid operations", func(t *testing.T) {
- // Try to get non-existent file
- _, _, err := storage.Get(ctx, "non-existent")
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "file not found")
-
- // Try to save file with nil reader
- _, err = storage.Save(ctx, "test.txt", "text/plain", nil)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "reader cannot be nil")
-
- // Try to delete non-existent file
- err = storage.Delete(ctx, "non-existent")
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "file not found")
- })
-}
diff --git a/backend/internal/storage/s3.go b/backend/internal/storage/s3.go
index bea236d..88f8825 100644
--- a/backend/internal/storage/s3.go
+++ b/backend/internal/storage/s3.go
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
+ "path/filepath"
"strings"
"time"
@@ -48,14 +49,25 @@ func (s *S3Storage) generateID() (string, error) {
return hex.EncodeToString(bytes), nil
}
-func (s *S3Storage) getObjectURL(id string) string {
+func (s *S3Storage) generateObjectKey(id string, ext string, createTime time.Time) string {
+ // Create year/month structure
+ year := createTime.Format("2006")
+ month := createTime.Format("01")
+ filename := id
+ if ext != "" {
+ filename = id + ext
+ }
+ return fmt.Sprintf("%s/%s/%s", year, month, filename)
+}
+
+func (s *S3Storage) getObjectURL(key string) string {
if s.customURL != "" {
- return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id)
+ return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), key)
}
if s.proxyS3 {
- return fmt.Sprintf("/api/media/file/%s", id)
+ return fmt.Sprintf("/media/%s", key)
}
- return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id)
+ return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, key)
}
func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
@@ -65,10 +77,60 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
+ // Get file extension from original name or content type
+ ext := filepath.Ext(name)
+ if ext == "" {
+ // If no extension in name, try to get it from content type
+ switch contentType {
+ case "image/jpeg":
+ ext = ".jpg"
+ case "image/png":
+ ext = ".png"
+ case "image/gif":
+ ext = ".gif"
+ case "image/webp":
+ ext = ".webp"
+ case "image/svg+xml":
+ ext = ".svg"
+ case "video/mp4":
+ ext = ".mp4"
+ case "video/webm":
+ ext = ".webm"
+ case "audio/mpeg":
+ ext = ".mp3"
+ case "audio/ogg":
+ ext = ".ogg"
+ case "audio/wav":
+ ext = ".wav"
+ case "application/pdf":
+ ext = ".pdf"
+ case "application/msword":
+ ext = ".doc"
+ case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
+ ext = ".docx"
+ case "application/vnd.ms-excel":
+ ext = ".xls"
+ case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
+ ext = ".xlsx"
+ case "application/zip":
+ ext = ".zip"
+ case "application/x-rar-compressed":
+ ext = ".rar"
+ case "text/plain":
+ ext = ".txt"
+ case "text/csv":
+ ext = ".csv"
+ }
+ }
+
+ // Create the object key with year/month structure
+ now := time.Now()
+ key := s.generateObjectKey(id, ext, now)
+
// Check if the file exists
_, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
- Key: aws.String(id),
+ Key: aws.String(key),
})
if err == nil {
return nil, fmt.Errorf("file already exists with ID: %s", id)
@@ -82,37 +144,50 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
// Upload the file
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
- Key: aws.String(id),
+ Key: aws.String(key),
Body: reader,
ContentType: aws.String(contentType),
Metadata: map[string]string{
"x-amz-meta-original-name": name,
+ "x-amz-meta-created-at": now.Format(time.RFC3339),
},
})
if err != nil {
return nil, fmt.Errorf("failed to upload file: %w", err)
}
- now := time.Now()
info := &FileInfo{
ID: id,
Name: name,
Size: 0, // Size is not available until after upload
ContentType: contentType,
CreatedAt: now,
- UpdatedAt: now,
- URL: s.getObjectURL(id),
+ URL: s.getObjectURL(key),
}
return info, nil
}
func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
- // Get the object from S3
- result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
- Bucket: aws.String(s.bucket),
- Key: aws.String(id),
- })
+ // Try to find the file with different extensions
+ exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
+ ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
+
+ var result *s3.GetObjectOutput
+ var err error
+ var key string
+
+ for _, ext := range exts {
+ key = s.generateObjectKey(id, ext, time.Now())
+ result, err = s.client.GetObject(ctx, &s3.GetObjectInput{
+ Bucket: aws.String(s.bucket),
+ Key: aws.String(key),
+ })
+ if err == nil {
+ break
+ }
+ }
+
if err != nil {
return nil, nil, fmt.Errorf("failed to get file from S3: %w", err)
}
@@ -122,111 +197,117 @@ func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInf
Name: result.Metadata["x-amz-meta-original-name"],
Size: aws.ToInt64(result.ContentLength),
ContentType: aws.ToString(result.ContentType),
- CreatedAt: aws.ToTime(result.LastModified),
- UpdatedAt: aws.ToTime(result.LastModified),
- URL: s.getObjectURL(id),
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ URL: s.getObjectURL(key),
}
return result.Body, info, nil
}
func (s *S3Storage) Delete(ctx context.Context, id string) error {
- _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
- Bucket: aws.String(s.bucket),
- Key: aws.String(id),
- })
- if err != nil {
- return fmt.Errorf("failed to delete file from S3: %w", err)
+ // Try to find and delete the file with different extensions
+ exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
+ ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
+
+ var lastErr error
+ for _, ext := range exts {
+ key := s.generateObjectKey(id, ext, time.Now())
+ _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
+ Bucket: aws.String(s.bucket),
+ Key: aws.String(key),
+ })
+ if err == nil {
+ return nil
+ }
+ lastErr = err
}
- return nil
+ return fmt.Errorf("failed to delete file from S3: %w", lastErr)
+}
+
+func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
+ // Try to find the file with different extensions
+ exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
+ ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
+
+ for _, ext := range exts {
+ key := s.generateObjectKey(id, ext, time.Now())
+ _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
+ Bucket: aws.String(s.bucket),
+ Key: aws.String(key),
+ })
+ if err == nil {
+ return true, nil
+ }
+ }
+
+ return false, nil
}
func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
var files []*FileInfo
var continuationToken *string
+ count := 0
+ skip := offset
- // Skip objects for offset
- for i := 0; i < offset/1000; i++ {
- output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
+ for {
+ input := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
- MaxKeys: aws.Int32(1000),
- })
+ MaxKeys: aws.Int32(100), // Fetch in batches of 100
+ }
+
+ result, err := s.client.ListObjectsV2(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to list files from S3: %w", err)
}
- if !aws.ToBool(output.IsTruncated) {
- return files, nil
- }
- continuationToken = output.NextContinuationToken
- }
- // Get the actual objects
- output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
- Bucket: aws.String(s.bucket),
- Prefix: aws.String(prefix),
- ContinuationToken: continuationToken,
- MaxKeys: aws.Int32(int32(limit)),
- })
- if err != nil {
- return nil, fmt.Errorf("failed to list files from S3: %w", err)
- }
-
- for _, obj := range output.Contents {
- // Get the object metadata
- head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
- Bucket: aws.String(s.bucket),
- Key: obj.Key,
- })
-
- var contentType string
- var originalName string
-
- if err != nil {
- var noSuchKey *types.NoSuchKey
- if errors.As(err, &noSuchKey) {
- // If the object doesn't exist (which shouldn't happen normally),
- // we'll still include it in the list but with empty metadata
- contentType = ""
- originalName = aws.ToString(obj.Key)
- } else {
+ // Process each object
+ for _, obj := range result.Contents {
+ if skip > 0 {
+ skip--
continue
}
- } else {
- contentType = aws.ToString(head.ContentType)
- originalName = head.Metadata["x-amz-meta-original-name"]
- if originalName == "" {
- originalName = aws.ToString(obj.Key)
+
+ // Get object metadata
+ head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
+ Bucket: aws.String(s.bucket),
+ Key: obj.Key,
+ })
+ if err != nil {
+ continue // Skip files we can't get metadata for
+ }
+
+ // Extract the ID from the key (remove extension if present)
+ id := aws.ToString(obj.Key)
+ if ext := filepath.Ext(id); ext != "" {
+ id = id[:len(id)-len(ext)]
+ }
+
+ info := &FileInfo{
+ ID: id,
+ Name: head.Metadata["x-amz-meta-original-name"],
+ Size: aws.ToInt64(obj.Size),
+ ContentType: aws.ToString(head.ContentType),
+ CreatedAt: aws.ToTime(obj.LastModified),
+ UpdatedAt: aws.ToTime(obj.LastModified),
+ URL: s.getObjectURL(aws.ToString(obj.Key)),
+ }
+ files = append(files, info)
+ count++
+
+ if count >= limit {
+ return files, nil
}
}
- files = append(files, &FileInfo{
- ID: aws.ToString(obj.Key),
- Name: originalName,
- Size: aws.ToInt64(obj.Size),
- ContentType: contentType,
- CreatedAt: aws.ToTime(obj.LastModified),
- UpdatedAt: aws.ToTime(obj.LastModified),
- URL: s.getObjectURL(aws.ToString(obj.Key)),
- })
+ if !aws.ToBool(result.IsTruncated) {
+ break
+ }
+ continuationToken = result.NextContinuationToken
}
return files, nil
}
-
-func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
- _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
- Bucket: aws.String(s.bucket),
- Key: aws.String(id),
- })
- if err != nil {
- var nsk *types.NoSuchKey
- if ok := errors.As(err, &nsk); ok {
- return false, nil
- }
- return false, fmt.Errorf("failed to check file existence in S3: %w", err)
- }
- return true, nil
-}
diff --git a/backend/internal/storage/s3_test.go b/backend/internal/storage/s3_test.go
deleted file mode 100644
index 215f04e..0000000
--- a/backend/internal/storage/s3_test.go
+++ /dev/null
@@ -1,211 +0,0 @@
-package storage
-
-import (
- "bytes"
- "context"
- "io"
- "testing"
- "time"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/s3"
- "github.com/aws/aws-sdk-go-v2/service/s3/types"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/require"
-)
-
-// MockS3Client is a mock implementation of the S3 client interface
-type MockS3Client struct {
- mock.Mock
-}
-
-func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
- args := m.Called(ctx, params)
- return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
-}
-
-func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
- args := m.Called(ctx, params)
- return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
-}
-
-func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
- args := m.Called(ctx, params)
- return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
-}
-
-func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
- args := m.Called(ctx, params)
- return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
-}
-
-func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
- args := m.Called(ctx, params)
- return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
-}
-
-func TestS3Storage(t *testing.T) {
- ctx := context.Background()
- mockClient := new(MockS3Client)
- storage := NewS3Storage(mockClient, "test-bucket", "", false)
-
- t.Run("Save", func(t *testing.T) {
- mockClient.ExpectedCalls = nil
- mockClient.Calls = nil
-
- content := []byte("test content")
- reader := bytes.NewReader(content)
-
- // Mock HeadObject to return NotFound error
- mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket"
- })).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
- Message: aws.String("The specified key does not exist."),
- })
-
- mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.ContentType) == "text/plain"
- })).Return(&s3.PutObjectOutput{}, nil)
-
- fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
- require.NoError(t, err)
- assert.NotEmpty(t, fileInfo.ID)
- assert.Equal(t, "test.txt", fileInfo.Name)
- assert.Equal(t, "text/plain", fileInfo.ContentType)
-
- mockClient.AssertExpectations(t)
- })
-
- t.Run("Get", func(t *testing.T) {
- content := []byte("test content")
- mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "test-id"
- })).Return(&s3.GetObjectOutput{
- Body: io.NopCloser(bytes.NewReader(content)),
- ContentType: aws.String("text/plain"),
- ContentLength: aws.Int64(int64(len(content))),
- LastModified: aws.Time(time.Now()),
- }, nil)
-
- readCloser, info, err := storage.Get(ctx, "test-id")
- require.NoError(t, err)
- defer readCloser.Close()
-
- data, err := io.ReadAll(readCloser)
- require.NoError(t, err)
- assert.Equal(t, content, data)
- assert.Equal(t, "test-id", info.ID)
- assert.Equal(t, int64(len(content)), info.Size)
-
- mockClient.AssertExpectations(t)
- })
-
- t.Run("List", func(t *testing.T) {
- mockClient.ExpectedCalls = nil
- mockClient.Calls = nil
-
- mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Prefix) == "test" &&
- aws.ToInt32(input.MaxKeys) == 10
- })).Return(&s3.ListObjectsV2Output{
- Contents: []types.Object{
- {
- Key: aws.String("test1"),
- Size: aws.Int64(100),
- LastModified: aws.Time(time.Now()),
- },
- {
- Key: aws.String("test2"),
- Size: aws.Int64(200),
- LastModified: aws.Time(time.Now()),
- },
- },
- }, nil)
-
- // Mock HeadObject for both files
- mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "test1"
- })).Return(&s3.HeadObjectOutput{
- ContentType: aws.String("text/plain"),
- Metadata: map[string]string{
- "x-amz-meta-original-name": "test1.txt",
- },
- }, nil).Once()
-
- mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "test2"
- })).Return(&s3.HeadObjectOutput{
- ContentType: aws.String("text/plain"),
- Metadata: map[string]string{
- "x-amz-meta-original-name": "test2.txt",
- },
- }, nil).Once()
-
- files, err := storage.List(ctx, "test", 10, 0)
- require.NoError(t, err)
- assert.Len(t, files, 2)
- assert.Equal(t, "test1", files[0].ID)
- assert.Equal(t, int64(100), files[0].Size)
- assert.Equal(t, "test1.txt", files[0].Name)
- assert.Equal(t, "text/plain", files[0].ContentType)
-
- mockClient.AssertExpectations(t)
- })
-
- t.Run("Delete", func(t *testing.T) {
- mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "test-id"
- })).Return(&s3.DeleteObjectOutput{}, nil)
-
- err := storage.Delete(ctx, "test-id")
- require.NoError(t, err)
-
- mockClient.AssertExpectations(t)
- })
-
- t.Run("Exists", func(t *testing.T) {
- mockClient.ExpectedCalls = nil
- mockClient.Calls = nil
-
- // Mock HeadObject for existing file
- mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "test-id"
- })).Return(&s3.HeadObjectOutput{}, nil).Once()
-
- exists, err := storage.Exists(ctx, "test-id")
- require.NoError(t, err)
- assert.True(t, exists)
-
- // Mock HeadObject for non-existing file
- mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
- return aws.ToString(input.Bucket) == "test-bucket" &&
- aws.ToString(input.Key) == "non-existent"
- })).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
- Message: aws.String("The specified key does not exist."),
- }).Once()
-
- exists, err = storage.Exists(ctx, "non-existent")
- require.NoError(t, err)
- assert.False(t, exists)
-
- mockClient.AssertExpectations(t)
- })
-
- t.Run("Custom URL", func(t *testing.T) {
- customStorage := &S3Storage{
- client: mockClient,
- bucket: "test-bucket",
- customURL: "https://custom.domain",
- proxyS3: true,
- }
- assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
- })
-}
diff --git a/backend/internal/storage/storage.go b/backend/internal/storage/storage.go
index ccd00d2..81170ad 100644
--- a/backend/internal/storage/storage.go
+++ b/backend/internal/storage/storage.go
@@ -1,7 +1,5 @@
package storage
-//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock
-
import (
"context"
"io"
diff --git a/backend/internal/types/config_test.go b/backend/internal/types/config_test.go
deleted file mode 100644
index 298e318..0000000
--- a/backend/internal/types/config_test.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package types
-
-import (
- "testing"
-)
-
-func TestRateLimitConfig(t *testing.T) {
- config := RateLimitConfig{
- IPRate: 100,
- IPBurst: 200,
- RouteRates: map[string]struct {
- Rate int `yaml:"rate"`
- Burst int `yaml:"burst"`
- }{
- "/api/test": {
- Rate: 50,
- Burst: 100,
- },
- },
- }
-
- if config.IPRate != 100 {
- t.Errorf("Expected IPRate 100, got %d", config.IPRate)
- }
- if config.IPBurst != 200 {
- t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
- }
-
- route := config.RouteRates["/api/test"]
- if route.Rate != 50 {
- t.Errorf("Expected route rate 50, got %d", route.Rate)
- }
- if route.Burst != 100 {
- t.Errorf("Expected route burst 100, got %d", route.Burst)
- }
-}
-
-func TestAccessLogConfig(t *testing.T) {
- config := AccessLogConfig{
- EnableConsole: true,
- EnableFile: true,
- FilePath: "/var/log/app.log",
- Format: "json",
- Level: "info",
- Rotation: struct {
- MaxSize int `yaml:"max_size"`
- MaxAge int `yaml:"max_age"`
- MaxBackups int `yaml:"max_backups"`
- Compress bool `yaml:"compress"`
- LocalTime bool `yaml:"local_time"`
- }{
- MaxSize: 100,
- MaxAge: 7,
- MaxBackups: 5,
- Compress: true,
- LocalTime: true,
- },
- }
-
- if !config.EnableConsole {
- t.Error("Expected EnableConsole to be true")
- }
- if !config.EnableFile {
- t.Error("Expected EnableFile to be true")
- }
- if config.FilePath != "/var/log/app.log" {
- t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
- }
- if config.Format != "json" {
- t.Errorf("Expected Format 'json', got '%s'", config.Format)
- }
- if config.Level != "info" {
- t.Errorf("Expected Level 'info', got '%s'", config.Level)
- }
-
- rotation := config.Rotation
- if rotation.MaxSize != 100 {
- t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
- }
- if rotation.MaxAge != 7 {
- t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
- }
- if rotation.MaxBackups != 5 {
- t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
- }
- if !rotation.Compress {
- t.Error("Expected Compress to be true")
- }
- if !rotation.LocalTime {
- t.Error("Expected LocalTime to be true")
- }
-}
-
-func TestUploadConfig(t *testing.T) {
- config := UploadConfig{
- MaxSize: 10,
- AllowedTypes: []string{"image/jpeg", "image/png"},
- AllowedExtensions: []string{".jpg", ".png"},
- }
-
- if config.MaxSize != 10 {
- t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
- }
- if len(config.AllowedTypes) != 2 {
- t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
- }
- if config.AllowedTypes[0] != "image/jpeg" {
- t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
- }
- if len(config.AllowedExtensions) != 2 {
- t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
- }
- if config.AllowedExtensions[0] != ".jpg" {
- t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
- }
-}
diff --git a/backend/internal/types/file_test.go b/backend/internal/types/file_test.go
deleted file mode 100644
index 1335a21..0000000
--- a/backend/internal/types/file_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package types
-
-import "testing"
-
-func TestFileInfo(t *testing.T) {
- fileInfo := FileInfo{
- Size: 1024,
- Name: "test.jpg",
- ContentType: "image/jpeg",
- }
-
- if fileInfo.Size != 1024 {
- t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
- }
- if fileInfo.Name != "test.jpg" {
- t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
- }
- if fileInfo.ContentType != "image/jpeg" {
- t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
- }
-}
diff --git a/backend/internal/types/types_test.go b/backend/internal/types/types_test.go
deleted file mode 100644
index ff3461d..0000000
--- a/backend/internal/types/types_test.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package types
-
-import (
- "testing"
-)
-
-func TestCategory(t *testing.T) {
- description := "Test Description"
- category := Category{
- ID: 1,
- Name: "Test Category",
- Slug: "test-category",
- Description: &description,
- }
-
- if category.ID != 1 {
- t.Errorf("Expected ID 1, got %d", category.ID)
- }
- if category.Name != "Test Category" {
- t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
- }
- if category.Slug != "test-category" {
- t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
- }
- if *category.Description != description {
- t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
- }
-}
-
-func TestPost(t *testing.T) {
- metaKeywords := "test,blog"
- metaDesc := "Test Description"
- post := Post{
- ID: 1,
- Title: "Test Post",
- Slug: "test-post",
- ContentMarkdown: "# Test Content",
- Summary: "Test Summary",
- MetaKeywords: &metaKeywords,
- MetaDescription: &metaDesc,
- }
-
- if post.ID != 1 {
- t.Errorf("Expected ID 1, got %d", post.ID)
- }
- if post.Title != "Test Post" {
- t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
- }
- if post.Slug != "test-post" {
- t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
- }
- if *post.MetaKeywords != metaKeywords {
- t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
- }
-}
-
-func TestDaily(t *testing.T) {
- daily := Daily{
- ID: "2025-02-12",
- CategoryID: 1,
- ImageURL: "https://example.com/image.jpg",
- Quote: "Test Quote",
- }
-
- if daily.ID != "2025-02-12" {
- t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
- }
- if daily.CategoryID != 1 {
- t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
- }
- if daily.ImageURL != "https://example.com/image.jpg" {
- t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
- }
- if daily.Quote != "Test Quote" {
- t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
- }
-}
diff --git a/backend/pkg/config/config_test.go b/backend/pkg/config/config_test.go
deleted file mode 100644
index a785dff..0000000
--- a/backend/pkg/config/config_test.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package config
-
-import (
- "os"
- "path/filepath"
- "testing"
-)
-
-func TestLoad(t *testing.T) {
- // Create a temporary test config file
- testConfig := `
-database:
- driver: postgres
- dsn: postgres://user:pass@localhost:5432/db
-server:
- port: 8080
- host: localhost
-jwt:
- secret: test-secret
- expiration: 24h
-logging:
- level: debug
- format: console
-`
- tmpDir := t.TempDir()
- configPath := filepath.Join(tmpDir, "config.yaml")
- if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil {
- t.Fatalf("Failed to create test config file: %v", err)
- }
-
- // Test successful config loading
- cfg, err := Load(configPath)
- if err != nil {
- t.Fatalf("Failed to load config: %v", err)
- }
-
- // Verify loaded values
- tests := []struct {
- name string
- got interface{}
- expected interface{}
- }{
- {"database.driver", cfg.Database.Driver, "postgres"},
- {"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"},
- {"server.port", cfg.Server.Port, 8080},
- {"server.host", cfg.Server.Host, "localhost"},
- {"jwt.secret", cfg.JWT.Secret, "test-secret"},
- {"jwt.expiration", cfg.JWT.Expiration, "24h"},
- {"logging.level", cfg.Logging.Level, "debug"},
- {"logging.format", cfg.Logging.Format, "console"},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if tt.got != tt.expected {
- t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected)
- }
- })
- }
-
- // Test loading non-existent file
- _, err = Load("non-existent.yaml")
- if err == nil {
- t.Error("Expected error when loading non-existent file, got nil")
- }
-
- // Test loading invalid YAML
- invalidPath := filepath.Join(tmpDir, "invalid.yaml")
- if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil {
- t.Fatalf("Failed to create invalid config file: %v", err)
- }
-
- _, err = Load(invalidPath)
- if err == nil {
- t.Error("Expected error when loading invalid YAML, got nil")
- }
-}
diff --git a/backend/pkg/imageutil/processor_test.go b/backend/pkg/imageutil/processor_test.go
deleted file mode 100644
index c5cb6d5..0000000
--- a/backend/pkg/imageutil/processor_test.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package imageutil
-
-import (
- "bytes"
- "image"
- "image/color"
- "image/png"
- "testing"
-)
-
-func TestIsImageFormat(t *testing.T) {
- tests := []struct {
- name string
- contentType string
- want bool
- }{
- {"JPEG", "image/jpeg", true},
- {"PNG", "image/png", true},
- {"GIF", "image/gif", true},
- {"WebP", "image/webp", true},
- {"Invalid", "image/invalid", false},
- {"Empty", "", false},
- {"Text", "text/plain", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := IsImageFormat(tt.contentType); got != tt.want {
- t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want)
- }
- })
- }
-}
-
-func TestDefaultOptions(t *testing.T) {
- opts := DefaultOptions()
-
- if !opts.Lossless {
- t.Error("DefaultOptions().Lossless = false, want true")
- }
- if opts.Quality != 90 {
- t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality)
- }
- if opts.Compression != 4 {
- t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression)
- }
-}
-
-func TestProcessImage(t *testing.T) {
- // Create a test image
- img := image.NewRGBA(image.Rect(0, 0, 100, 100))
- for y := 0; y < 100; y++ {
- for x := 0; x < 100; x++ {
- img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
- }
- }
-
- var buf bytes.Buffer
- if err := png.Encode(&buf, img); err != nil {
- t.Fatalf("Failed to create test PNG: %v", err)
- }
-
- tests := []struct {
- name string
- opts ProcessOptions
- wantErr bool
- }{
- {
- name: "Default options",
- opts: DefaultOptions(),
- },
- {
- name: "Custom quality",
- opts: ProcessOptions{
- Lossless: false,
- Quality: 75,
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- reader := bytes.NewReader(buf.Bytes())
- result, err := ProcessImage(reader, tt.opts)
- if (err != nil) != tt.wantErr {
- t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !tt.wantErr && len(result) == 0 {
- t.Error("ProcessImage() returned empty result")
- }
- })
- }
-
- // Test with invalid input
- _, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions())
- if err == nil {
- t.Error("ProcessImage() with invalid input should return error")
- }
-}
diff --git a/backend/pkg/logger/logger_test.go b/backend/pkg/logger/logger_test.go
deleted file mode 100644
index 319d851..0000000
--- a/backend/pkg/logger/logger_test.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package logger
-
-import (
- "testing"
- "tss-rocks-be/internal/config"
-
- "github.com/rs/zerolog"
-)
-
-func TestSetup(t *testing.T) {
- tests := []struct {
- name string
- config *config.Config
- expectedLevel zerolog.Level
- }{
- {
- name: "Debug level",
- config: &config.Config{
- Logging: struct {
- Level string `yaml:"level"`
- Format string `yaml:"format"`
- }{
- Level: "debug",
- Format: "json",
- },
- },
- expectedLevel: zerolog.DebugLevel,
- },
- {
- name: "Info level",
- config: &config.Config{
- Logging: struct {
- Level string `yaml:"level"`
- Format string `yaml:"format"`
- }{
- Level: "info",
- Format: "json",
- },
- },
- expectedLevel: zerolog.InfoLevel,
- },
- {
- name: "Error level",
- config: &config.Config{
- Logging: struct {
- Level string `yaml:"level"`
- Format string `yaml:"format"`
- }{
- Level: "error",
- Format: "json",
- },
- },
- expectedLevel: zerolog.ErrorLevel,
- },
- {
- name: "Invalid level defaults to Info",
- config: &config.Config{
- Logging: struct {
- Level string `yaml:"level"`
- Format string `yaml:"format"`
- }{
- Level: "invalid",
- Format: "json",
- },
- },
- expectedLevel: zerolog.InfoLevel,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- Setup(tt.config)
- if zerolog.GlobalLevel() != tt.expectedLevel {
- t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel)
- }
- })
- }
-}
-
-func TestGetLogger(t *testing.T) {
- logger := GetLogger()
- if logger == nil {
- t.Error("GetLogger() returned nil")
- }
-}
diff --git a/frontend/data/i18n/en.json b/frontend/data/i18n/en.json
index 06015e4..46bf27f 100644
--- a/frontend/data/i18n/en.json
+++ b/frontend/data/i18n/en.json
@@ -1,6 +1,6 @@
{
"categories": {
- "man": "Man",
+ "man": "Human",
"machine": "Machine",
"earth": "Earth",
"space": "Space",
@@ -29,8 +29,6 @@
"edit": "Edit",
"delete": "Delete",
"upload": "Upload",
- "save": "Save",
- "saving": "Saving...",
"status": "Status",
"actions": "Actions",
"published": "Published",
@@ -48,12 +46,10 @@
"joinDate": "Join Date",
"username": "Username",
"logout": "Logout",
- "language": "Language",
- "theme": {
- "light": "Light Mode",
- "dark": "Dark Mode",
- "system": "System"
- }
+ "save": "Save",
+ "saving": "Saving...",
+ "noData": "No data available",
+ "unsavedChanges": "You have unsaved changes. Are you sure you want to leave?"
},
"dashboard": {
"totalPosts": "Total Posts",
@@ -61,13 +57,33 @@
"totalUsers": "Total Users",
"totalContributors": "Total Contributors"
},
+ "posts": {
+ "title": "Title",
+ "categories": "Categories",
+ "createdAt": "Created At",
+ "status": "Status",
+ "noTitle": "No Title",
+ "deleteConfirm": "Are you sure you want to delete this post?",
+ "create": "Create Post",
+ "edit": "Edit Post",
+ "slug": "Slug",
+ "content": "Content",
+ "summary": "Summary",
+ "metaKeywords": "Meta Keywords",
+ "metaDescription": "Meta Description",
+ "selectCategories": "Select Categories",
+ "saving": "Saving...",
+ "publishing": "Publishing...",
+ "saveDraft": "Save Draft",
+ "publish": "Publish"
+ },
"login": {
"title": "Admin Login",
"username": "Username",
"password": "Password",
"remember": "Remember me",
- "submit": "Sign in",
- "loading": "Signing in...",
+ "submit": "Login",
+ "loading": "Logging in...",
"error": {
"failed": "Login failed",
"retry": "Login failed, please try again later"
@@ -76,7 +92,7 @@
"nav": {
"dashboard": "Dashboard",
"posts": "Posts",
- "daily": "Daily Quotes",
+ "daily": "Daily",
"medias": "Media",
"categories": "Categories",
"users": "Users",
@@ -110,5 +126,45 @@
"passwordMismatch": "Passwords do not match",
"passwordTooShort": "Password must be at least 8 characters long"
}
+ },
+ "editor": {
+ "heading1": "Heading 1",
+ "heading2": "Heading 2",
+ "heading3": "Heading 3",
+ "bold": "Bold",
+ "italic": "Italic",
+ "orderedList": "Ordered List",
+ "unorderedList": "Unordered List",
+ "quote": "Quote",
+ "link": "Link",
+ "image": "Image",
+ "inlineCode": "Inline Code",
+ "codeBlock": "Code Block",
+ "table": "Table",
+ "togglePreview": "Toggle Preview",
+ "toggleFullscreen": "Toggle Fullscreen",
+ "selectLanguage": "Select Language",
+ "plainText": "Plain Text",
+ "insertTable": "Insert Table",
+ "addRowAbove": "Add Row Above",
+ "addRowBelow": "Add Row Below",
+ "addColumnLeft": "Add Column Left",
+ "addColumnRight": "Add Column Right",
+ "deleteRow": "Delete Row",
+ "deleteColumn": "Delete Column",
+ "deleteTable": "Delete Table",
+ "dragAndDrop": "Drop images here",
+ "dropToUpload": "Drop files here to upload",
+ "uploading": "Uploading...",
+ "uploadProgress": "Upload progress: {{progress}}%",
+ "uploadError": "Failed to upload image: {{error}}",
+ "uploadSuccess": "Image uploaded successfully",
+ "codeBlockShortcut": "Ctrl+Shift+K",
+ "boldShortcut": "Ctrl+B",
+ "italicShortcut": "Ctrl+I",
+ "linkShortcut": "Ctrl+K",
+ "heading1Shortcut": "Shift+1",
+ "heading2Shortcut": "Shift+2",
+ "heading3Shortcut": "Shift+3"
}
}
diff --git a/frontend/data/i18n/zh-Hans.json b/frontend/data/i18n/zh-Hans.json
index 296d279..50c6432 100644
--- a/frontend/data/i18n/zh-Hans.json
+++ b/frontend/data/i18n/zh-Hans.json
@@ -47,7 +47,9 @@
"username": "用户名",
"logout": "退出登录",
"save": "保存",
- "saving": "保存中..."
+ "saving": "保存中...",
+ "noData": "暂无数据",
+ "unsavedChanges": "你有未保存的更改,确定要离开吗?"
},
"dashboard": {
"totalPosts": "文章总数",
@@ -55,6 +57,26 @@
"totalUsers": "用户总数",
"totalContributors": "贡献者总数"
},
+ "posts": {
+ "title": "标题",
+ "categories": "分类",
+ "createdAt": "创建时间",
+ "status": "状态",
+ "noTitle": "无标题",
+ "deleteConfirm": "确定要删除这篇文章吗?",
+ "create": "创建文章",
+ "edit": "编辑文章",
+ "slug": "文章链接",
+ "content": "内容",
+ "summary": "摘要",
+ "metaKeywords": "关键词",
+ "metaDescription": "描述",
+ "selectCategories": "选择分类",
+ "saving": "保存中...",
+ "publishing": "发布中...",
+ "saveDraft": "保存草稿",
+ "publish": "发布文章"
+ },
"login": {
"title": "管理员登录",
"username": "用户名",
@@ -104,5 +126,45 @@
"passwordMismatch": "两次输入的密码不一致",
"passwordTooShort": "密码长度不能少于8个字符"
}
+ },
+ "editor": {
+ "heading1": "一级标题",
+ "heading2": "二级标题",
+ "heading3": "三级标题",
+ "bold": "粗体",
+ "italic": "斜体",
+ "orderedList": "有序列表",
+ "unorderedList": "无序列表",
+ "quote": "引用",
+ "link": "链接",
+ "image": "图片",
+ "inlineCode": "行内代码",
+ "codeBlock": "代码块",
+ "table": "表格",
+ "togglePreview": "切换预览",
+ "toggleFullscreen": "切换全屏",
+ "selectLanguage": "选择语言",
+ "plainText": "纯文本",
+ "insertTable": "插入表格",
+ "addRowAbove": "在上方插入行",
+ "addRowBelow": "在下方插入行",
+ "addColumnLeft": "在左侧插入列",
+ "addColumnRight": "在右侧插入列",
+ "deleteRow": "删除行",
+ "deleteColumn": "删除列",
+ "deleteTable": "删除表格",
+ "dragAndDrop": "拖放图片到此处",
+ "dropToUpload": "拖放文件到此处上传",
+ "uploading": "上传中...",
+ "uploadProgress": "上传进度:{{progress}}%",
+ "uploadError": "图片上传失败:{{error}}",
+ "uploadSuccess": "图片上传成功",
+ "codeBlockShortcut": "Ctrl+Shift+K",
+ "boldShortcut": "Ctrl+B",
+ "italicShortcut": "Ctrl+I",
+ "linkShortcut": "Ctrl+K",
+ "heading1Shortcut": "Shift+1",
+ "heading2Shortcut": "Shift+2",
+ "heading3Shortcut": "Shift+3"
}
}
diff --git a/frontend/data/i18n/zh-Hant.json b/frontend/data/i18n/zh-Hant.json
index 7a89534..34b35a4 100644
--- a/frontend/data/i18n/zh-Hant.json
+++ b/frontend/data/i18n/zh-Hant.json
@@ -45,7 +45,7 @@
"lastLogin": "最後登入",
"joinDate": "加入時間",
"username": "用戶名",
- "logout": "退出登錄",
+ "logout": "退出登入",
"language": "語言",
"theme": {
"light": "淺色模式",
@@ -53,7 +53,9 @@
"system": "跟隨系統"
},
"save": "保存",
- "saving": "保存中..."
+ "saving": "保存中...",
+ "noData": "暫無數據",
+ "unsavedChanges": "你有未保存的更改,確定要離開嗎?"
},
"nav": {
"dashboard": "儀表板",
@@ -83,15 +85,15 @@
"uploadDate": "上傳日期"
},
"login": {
- "title": "管理員登錄",
+ "title": "管理員登入",
"username": "用戶名",
"password": "密碼",
"remember": "記住我",
- "submit": "登錄",
- "loading": "登錄中...",
+ "submit": "登入",
+ "loading": "登入中...",
"error": {
- "failed": "登錄失敗",
- "retry": "登錄失敗,請稍後重試"
+ "failed": "登入失敗",
+ "retry": "登入失敗,請稍後重試"
}
},
"roles": {
@@ -110,5 +112,45 @@
"passwordMismatch": "兩次輸入的密碼不一致",
"passwordTooShort": "密碼長度不能少於8個字符"
}
+ },
+ "editor": {
+ "heading1": "一級標題",
+ "heading2": "二級標題",
+ "heading3": "三級標題",
+ "bold": "粗體",
+ "italic": "斜體",
+ "orderedList": "有序列表",
+ "unorderedList": "無序列表",
+ "quote": "引用",
+ "link": "連結",
+ "image": "圖片",
+ "inlineCode": "行內程式碼",
+ "codeBlock": "程式碼區塊",
+ "table": "表格",
+ "togglePreview": "切換預覽",
+ "toggleFullscreen": "切換全螢幕",
+ "selectLanguage": "選擇語言",
+ "plainText": "純文字",
+ "insertTable": "插入表格",
+ "addRowAbove": "在上方插入列",
+ "addRowBelow": "在下方插入列",
+ "addColumnLeft": "在左側插入欄",
+ "addColumnRight": "在右側插入欄",
+ "deleteRow": "刪除列",
+ "deleteColumn": "刪除欄",
+ "deleteTable": "刪除表格",
+ "dragAndDrop": "拖放圖片到此處",
+ "dropToUpload": "拖放檔案到此處上傳",
+ "uploading": "上傳中...",
+ "uploadProgress": "上傳進度:{{progress}}%",
+ "uploadError": "圖片上傳失敗:{{error}}",
+ "uploadSuccess": "圖片上傳成功",
+ "codeBlockShortcut": "Ctrl+Shift+K",
+ "boldShortcut": "Ctrl+B",
+ "italicShortcut": "Ctrl+I",
+ "linkShortcut": "Ctrl+K",
+ "heading1Shortcut": "Shift+1",
+ "heading2Shortcut": "Shift+2",
+ "heading3Shortcut": "Shift+3"
}
}
diff --git a/frontend/package.json b/frontend/package.json
index ff1ee5e..195deb2 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -12,7 +12,14 @@
"dependencies": {
"@headlessui/react": "^2.2.0",
"@tss-rocks/api": "workspace:*",
+ "@types/axios": "^0.14.4",
+ "@types/classnames": "^2.3.4",
+ "@types/highlight.js": "^10.1.0",
"@types/markdown-it": "^14.1.2",
+ "axios": "^1.7.9",
+ "classnames": "^2.5.1",
+ "dayjs": "^1.11.13",
+ "highlight.js": "^11.11.1",
"i18next": "^24.2.2",
"i18next-browser-languagedetector": "^8.0.4",
"lucide-react": "^0.474.0",
@@ -21,7 +28,13 @@
"react-dom": "^19.0.0",
"react-i18next": "^15.4.0",
"react-icons": "^5.4.0",
- "react-router-dom": "^7.1.5"
+ "react-markdown": "^10.0.0",
+ "react-router-dom": "^7.1.5",
+ "react-toastify": "^11.0.3",
+ "rehype-highlight": "^7.0.2",
+ "rehype-raw": "^7.0.0",
+ "rehype-sanitize": "^6.0.0",
+ "remark-gfm": "^4.0.1"
},
"devDependencies": {
"@eslint/js": "^9.9.1",
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 9fa788f..db0ad14 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -4,16 +4,21 @@ import { AuthProvider } from './contexts/AuthContext';
import { UserProvider } from './contexts/UserContext';
import router from './router';
import LoadingSpinner from './components/LoadingSpinner';
+import { ToastContainer } from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
function App() {
return (
-
+ {children}
+
+ );
+ }
+ return (
+
+
+ {children}
+
+
+ {children}
+ ), + h1: ({ children, ...props }) => ( +{children}+ ), + table: ({ children, ...props }) => ( +
{t('admin.common.title')} | -{t('admin.common.category')} | -{t('admin.common.publishDate')} | -{t('admin.common.status')} | -{t('admin.common.actions')} | -
---|---|---|---|---|
{post.title} | -{post.category} | -{post.publishDate} | -- - {t(`admin.common.${post.status}`)} - - | -- - - | -