diff --git a/api/schemas/components/schemas.yaml b/api/schemas/components/schemas.yaml index 056ef91..3812e98 100644 --- a/api/schemas/components/schemas.yaml +++ b/api/schemas/components/schemas.yaml @@ -103,6 +103,7 @@ Post: - slug - status - contents + - categories properties: id: type: integer @@ -117,6 +118,11 @@ Post: type: array items: $ref: '#/PostContent' + categories: + type: array + items: + $ref: '#/Category' + description: 文章所属分类 created_at: type: string format: date-time diff --git a/backend/Dockerfile b/backend/Dockerfile index c89062d..3a1eed0 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -14,9 +14,7 @@ RUN adduser -u 1000 -D tss-rocks USER tss-rocks WORKDIR /app -# 复制二进制文件和配置 COPY --from=builder /app/tss-rocks-be . -COPY --from=builder /app/config/config.yaml ./config/ EXPOSE 8080 ENV GIN_MODE=release diff --git a/backend/config/config.yaml.example b/backend/config/config.yaml.example index 220ace8..894b281 100644 --- a/backend/config/config.yaml.example +++ b/backend/config/config.yaml.example @@ -21,10 +21,50 @@ auth: message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息 storage: - driver: local + type: local # local or s3 local: - root: storage - base_url: http://localhost:8080/storage + root_dir: "./storage/media" + s3: + region: "us-east-1" + bucket: "your-bucket-name" + access_key_id: "your-access-key-id" + secret_access_key: "your-secret-access-key" + endpoint: "" # Optional, for MinIO or other S3-compatible services + custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media) + proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting + upload: + limits: + image: + max_size: 10 # MB + allowed_types: + - image/jpeg + - image/png + - image/gif + - image/webp + - image/svg+xml + video: + max_size: 500 # MB + allowed_types: + - video/mp4 + - video/webm + audio: + max_size: 50 # MB + allowed_types: + - audio/mpeg + - audio/ogg + - audio/wav + document: + max_size: 20 # MB + allowed_types: + - application/pdf + - application/msword + - application/vnd.openxmlformats-officedocument.wordprocessingml.document + - application/vnd.ms-excel + - application/vnd.openxmlformats-officedocument.spreadsheetml.sheet + - application/zip + - application/x-rar-compressed + - text/plain + - text/csv logging: level: debug diff --git a/backend/ent/category/category.go b/backend/ent/category/category.go index 80073b1..529079d 100644 --- a/backend/ent/category/category.go +++ b/backend/ent/category/category.go @@ -33,13 +33,11 @@ const ( ContentsInverseTable = "category_contents" // ContentsColumn is the table column denoting the contents relation/edge. ContentsColumn = "category_contents" - // PostsTable is the table that holds the posts relation/edge. - PostsTable = "posts" + // PostsTable is the table that holds the posts relation/edge. The primary key declared below. + PostsTable = "category_posts" // PostsInverseTable is the table name for the Post entity. // It exists in this package in order to avoid circular dependency with the "post" package. PostsInverseTable = "posts" - // PostsColumn is the table column denoting the posts relation/edge. - PostsColumn = "category_posts" // DailyItemsTable is the table that holds the daily_items relation/edge. DailyItemsTable = "dailies" // DailyItemsInverseTable is the table name for the Daily entity. @@ -56,6 +54,12 @@ var Columns = []string{ FieldUpdatedAt, } +var ( + // PostsPrimaryKey and PostsColumn2 are the table columns denoting the + // primary key for the posts relation (M2M). + PostsPrimaryKey = []string{"category_id", "post_id"} +) + // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { for i := range Columns { @@ -145,7 +149,7 @@ func newPostsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), sqlgraph.To(PostsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...), ) } func newDailyItemsStep() *sqlgraph.Step { diff --git a/backend/ent/category/where.go b/backend/ent/category/where.go index ea42ba5..223d752 100644 --- a/backend/ent/category/where.go +++ b/backend/ent/category/where.go @@ -173,7 +173,7 @@ func HasPosts() predicate.Category { return predicate.Category(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...), ) sqlgraph.HasNeighbors(s, step) }) diff --git a/backend/ent/category_create.go b/backend/ent/category_create.go index f479883..b78ea04 100644 --- a/backend/ent/category_create.go +++ b/backend/ent/category_create.go @@ -201,10 +201,10 @@ func (cc *CategoryCreate) createSpec() (*Category, *sqlgraph.CreateSpec) { } if nodes := cc.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), diff --git a/backend/ent/category_query.go b/backend/ent/category_query.go index 26c7b61..83dadcc 100644 --- a/backend/ent/category_query.go +++ b/backend/ent/category_query.go @@ -101,7 +101,7 @@ func (cq *CategoryQuery) QueryPosts() *PostQuery { step := sqlgraph.NewStep( sqlgraph.From(category.Table, category.FieldID, selector), sqlgraph.To(post.Table, post.FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...), ) fromU = sqlgraph.SetNeighbors(cq.driver.Dialect(), step) return fromU, nil @@ -523,33 +523,63 @@ func (cq *CategoryQuery) loadContents(ctx context.Context, query *CategoryConten return nil } func (cq *CategoryQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*Category, init func(*Category), assign func(*Category, *Post)) error { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Category) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Category) + nids := make(map[int]map[*Category]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node if init != nil { - init(nodes[i]) + init(node) } } - query.withFKs = true - query.Where(predicate.Post(func(s *sql.Selector) { - s.Where(sql.InValues(s.C(category.PostsColumn), fks...)) - })) - neighbors, err := query.All(ctx) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(category.PostsTable) + s.Join(joinT).On(s.C(post.FieldID), joinT.C(category.PostsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(category.PostsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(category.PostsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Category]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Post](ctx, query, qr, query.inters) if err != nil { return err } for _, n := range neighbors { - fk := n.category_posts - if fk == nil { - return fmt.Errorf(`foreign-key "category_posts" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] + nodes, ok := nids[n.ID] if !ok { - return fmt.Errorf(`unexpected referenced foreign-key "category_posts" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected "posts" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } - assign(node, n) } return nil } diff --git a/backend/ent/category_update.go b/backend/ent/category_update.go index 628a7f2..57c3012 100644 --- a/backend/ent/category_update.go +++ b/backend/ent/category_update.go @@ -262,10 +262,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if cu.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -275,10 +275,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if nodes := cu.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cu.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -291,10 +291,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if nodes := cu.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -631,10 +631,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if cuo.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -644,10 +644,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if nodes := cuo.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cuo.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -660,10 +660,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if nodes := cuo.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), diff --git a/backend/ent/client.go b/backend/ent/client.go index bbcf6f3..2e7b243 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -464,7 +464,7 @@ func (c *CategoryClient) QueryPosts(ca *Category) *PostQuery { step := sqlgraph.NewStep( sqlgraph.From(category.Table, category.FieldID, id), sqlgraph.To(post.Table, post.FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...), ) fromV = sqlgraph.Neighbors(ca.driver.Dialect(), step) return fromV, nil @@ -2207,15 +2207,15 @@ func (c *PostClient) QueryContributors(po *Post) *PostContributorQuery { return query } -// QueryCategory queries the category edge of a Post. -func (c *PostClient) QueryCategory(po *Post) *CategoryQuery { +// QueryCategories queries the categories edge of a Post. +func (c *PostClient) QueryCategories(po *Post) *CategoryQuery { query := (&CategoryClient{config: c.config}).Query() query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := po.ID step := sqlgraph.NewStep( sqlgraph.From(post.Table, post.FieldID, id), sqlgraph.To(category.Table, category.FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...), ) fromV = sqlgraph.Neighbors(po.driver.Dialect(), step) return fromV, nil diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 562eed3..614e3af 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -265,21 +265,12 @@ var ( {Name: "slug", Type: field.TypeString, Unique: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, - {Name: "category_posts", Type: field.TypeInt, Nullable: true}, } // PostsTable holds the schema information for the "posts" table. PostsTable = &schema.Table{ Name: "posts", Columns: PostsColumns, PrimaryKey: []*schema.Column{PostsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{ - { - Symbol: "posts_categories_posts", - Columns: []*schema.Column{PostsColumns[5]}, - RefColumns: []*schema.Column{CategoriesColumns[0]}, - OnDelete: schema.SetNull, - }, - }, } // PostContentsColumns holds the columns for the "post_contents" table. PostContentsColumns = []*schema.Column{ @@ -387,6 +378,31 @@ var ( Columns: UsersColumns, PrimaryKey: []*schema.Column{UsersColumns[0]}, } + // CategoryPostsColumns holds the columns for the "category_posts" table. + CategoryPostsColumns = []*schema.Column{ + {Name: "category_id", Type: field.TypeInt}, + {Name: "post_id", Type: field.TypeInt}, + } + // CategoryPostsTable holds the schema information for the "category_posts" table. + CategoryPostsTable = &schema.Table{ + Name: "category_posts", + Columns: CategoryPostsColumns, + PrimaryKey: []*schema.Column{CategoryPostsColumns[0], CategoryPostsColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "category_posts_category_id", + Columns: []*schema.Column{CategoryPostsColumns[0]}, + RefColumns: []*schema.Column{CategoriesColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "category_posts_post_id", + Columns: []*schema.Column{CategoryPostsColumns[1]}, + RefColumns: []*schema.Column{PostsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } // RolePermissionsColumns holds the columns for the "role_permissions" table. RolePermissionsColumns = []*schema.Column{ {Name: "role_id", Type: field.TypeInt}, @@ -455,6 +471,7 @@ var ( PostContributorsTable, RolesTable, UsersTable, + CategoryPostsTable, RolePermissionsTable, UserRolesTable, } @@ -469,11 +486,12 @@ func init() { DailyCategoryContentsTable.ForeignKeys[0].RefTable = DailyCategoriesTable DailyContentsTable.ForeignKeys[0].RefTable = DailiesTable MediaTable.ForeignKeys[0].RefTable = UsersTable - PostsTable.ForeignKeys[0].RefTable = CategoriesTable PostContentsTable.ForeignKeys[0].RefTable = PostsTable PostContributorsTable.ForeignKeys[0].RefTable = ContributorsTable PostContributorsTable.ForeignKeys[1].RefTable = ContributorRolesTable PostContributorsTable.ForeignKeys[2].RefTable = PostsTable + CategoryPostsTable.ForeignKeys[0].RefTable = CategoriesTable + CategoryPostsTable.ForeignKeys[1].RefTable = PostsTable RolePermissionsTable.ForeignKeys[0].RefTable = RolesTable RolePermissionsTable.ForeignKeys[1].RefTable = PermissionsTable UserRolesTable.ForeignKeys[0].RefTable = UsersTable diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index b8ccb5b..f15bc73 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -6578,8 +6578,9 @@ type PostMutation struct { contributors map[int]struct{} removedcontributors map[int]struct{} clearedcontributors bool - category *int - clearedcategory bool + categories map[int]struct{} + removedcategories map[int]struct{} + clearedcategories bool done bool oldValue func(context.Context) (*Post, error) predicates []predicate.Post @@ -6935,43 +6936,58 @@ func (m *PostMutation) ResetContributors() { m.removedcontributors = nil } -// SetCategoryID sets the "category" edge to the Category entity by id. -func (m *PostMutation) SetCategoryID(id int) { - m.category = &id +// AddCategoryIDs adds the "categories" edge to the Category entity by ids. +func (m *PostMutation) AddCategoryIDs(ids ...int) { + if m.categories == nil { + m.categories = make(map[int]struct{}) + } + for i := range ids { + m.categories[ids[i]] = struct{}{} + } } -// ClearCategory clears the "category" edge to the Category entity. -func (m *PostMutation) ClearCategory() { - m.clearedcategory = true +// ClearCategories clears the "categories" edge to the Category entity. +func (m *PostMutation) ClearCategories() { + m.clearedcategories = true } -// CategoryCleared reports if the "category" edge to the Category entity was cleared. -func (m *PostMutation) CategoryCleared() bool { - return m.clearedcategory +// CategoriesCleared reports if the "categories" edge to the Category entity was cleared. +func (m *PostMutation) CategoriesCleared() bool { + return m.clearedcategories } -// CategoryID returns the "category" edge ID in the mutation. -func (m *PostMutation) CategoryID() (id int, exists bool) { - if m.category != nil { - return *m.category, true +// RemoveCategoryIDs removes the "categories" edge to the Category entity by IDs. +func (m *PostMutation) RemoveCategoryIDs(ids ...int) { + if m.removedcategories == nil { + m.removedcategories = make(map[int]struct{}) + } + for i := range ids { + delete(m.categories, ids[i]) + m.removedcategories[ids[i]] = struct{}{} + } +} + +// RemovedCategories returns the removed IDs of the "categories" edge to the Category entity. +func (m *PostMutation) RemovedCategoriesIDs() (ids []int) { + for id := range m.removedcategories { + ids = append(ids, id) } return } -// CategoryIDs returns the "category" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// CategoryID instead. It exists only for internal usage by the builders. -func (m *PostMutation) CategoryIDs() (ids []int) { - if id := m.category; id != nil { - ids = append(ids, *id) +// CategoriesIDs returns the "categories" edge IDs in the mutation. +func (m *PostMutation) CategoriesIDs() (ids []int) { + for id := range m.categories { + ids = append(ids, id) } return } -// ResetCategory resets all changes to the "category" edge. -func (m *PostMutation) ResetCategory() { - m.category = nil - m.clearedcategory = false +// ResetCategories resets all changes to the "categories" edge. +func (m *PostMutation) ResetCategories() { + m.categories = nil + m.clearedcategories = false + m.removedcategories = nil } // Where appends a list predicates to the PostMutation builder. @@ -7165,8 +7181,8 @@ func (m *PostMutation) AddedEdges() []string { if m.contributors != nil { edges = append(edges, post.EdgeContributors) } - if m.category != nil { - edges = append(edges, post.EdgeCategory) + if m.categories != nil { + edges = append(edges, post.EdgeCategories) } return edges } @@ -7187,10 +7203,12 @@ func (m *PostMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids - case post.EdgeCategory: - if id := m.category; id != nil { - return []ent.Value{*id} + case post.EdgeCategories: + ids := make([]ent.Value, 0, len(m.categories)) + for id := range m.categories { + ids = append(ids, id) } + return ids } return nil } @@ -7204,6 +7222,9 @@ func (m *PostMutation) RemovedEdges() []string { if m.removedcontributors != nil { edges = append(edges, post.EdgeContributors) } + if m.removedcategories != nil { + edges = append(edges, post.EdgeCategories) + } return edges } @@ -7223,6 +7244,12 @@ func (m *PostMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case post.EdgeCategories: + ids := make([]ent.Value, 0, len(m.removedcategories)) + for id := range m.removedcategories { + ids = append(ids, id) + } + return ids } return nil } @@ -7236,8 +7263,8 @@ func (m *PostMutation) ClearedEdges() []string { if m.clearedcontributors { edges = append(edges, post.EdgeContributors) } - if m.clearedcategory { - edges = append(edges, post.EdgeCategory) + if m.clearedcategories { + edges = append(edges, post.EdgeCategories) } return edges } @@ -7250,8 +7277,8 @@ func (m *PostMutation) EdgeCleared(name string) bool { return m.clearedcontents case post.EdgeContributors: return m.clearedcontributors - case post.EdgeCategory: - return m.clearedcategory + case post.EdgeCategories: + return m.clearedcategories } return false } @@ -7260,9 +7287,6 @@ func (m *PostMutation) EdgeCleared(name string) bool { // if that edge is not defined in the schema. func (m *PostMutation) ClearEdge(name string) error { switch name { - case post.EdgeCategory: - m.ClearCategory() - return nil } return fmt.Errorf("unknown Post unique edge %s", name) } @@ -7277,8 +7301,8 @@ func (m *PostMutation) ResetEdge(name string) error { case post.EdgeContributors: m.ResetContributors() return nil - case post.EdgeCategory: - m.ResetCategory() + case post.EdgeCategories: + m.ResetCategories() return nil } return fmt.Errorf("unknown Post edge %s", name) diff --git a/backend/ent/post.go b/backend/ent/post.go index 0be8691..3e4529c 100644 --- a/backend/ent/post.go +++ b/backend/ent/post.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" "time" - "tss-rocks-be/ent/category" "tss-rocks-be/ent/post" "entgo.io/ent" @@ -28,9 +27,8 @@ type Post struct { UpdatedAt time.Time `json:"updated_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the PostQuery when eager-loading is set. - Edges PostEdges `json:"edges"` - category_posts *int - selectValues sql.SelectValues + Edges PostEdges `json:"edges"` + selectValues sql.SelectValues } // PostEdges holds the relations/edges for other nodes in the graph. @@ -39,8 +37,8 @@ type PostEdges struct { Contents []*PostContent `json:"contents,omitempty"` // Contributors holds the value of the contributors edge. Contributors []*PostContributor `json:"contributors,omitempty"` - // Category holds the value of the category edge. - Category *Category `json:"category,omitempty"` + // Categories holds the value of the categories edge. + Categories []*Category `json:"categories,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. loadedTypes [3]bool @@ -64,15 +62,13 @@ func (e PostEdges) ContributorsOrErr() ([]*PostContributor, error) { return nil, &NotLoadedError{edge: "contributors"} } -// CategoryOrErr returns the Category value or an error if the edge -// was not loaded in eager-loading, or loaded but was not found. -func (e PostEdges) CategoryOrErr() (*Category, error) { - if e.Category != nil { - return e.Category, nil - } else if e.loadedTypes[2] { - return nil, &NotFoundError{label: category.Label} +// CategoriesOrErr returns the Categories value or an error if the edge +// was not loaded in eager-loading. +func (e PostEdges) CategoriesOrErr() ([]*Category, error) { + if e.loadedTypes[2] { + return e.Categories, nil } - return nil, &NotLoadedError{edge: "category"} + return nil, &NotLoadedError{edge: "categories"} } // scanValues returns the types for scanning values from sql.Rows. @@ -86,8 +82,6 @@ func (*Post) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullString) case post.FieldCreatedAt, post.FieldUpdatedAt: values[i] = new(sql.NullTime) - case post.ForeignKeys[0]: // category_posts - values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) } @@ -133,13 +127,6 @@ func (po *Post) assignValues(columns []string, values []any) error { } else if value.Valid { po.UpdatedAt = value.Time } - case post.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for edge-field category_posts", value) - } else if value.Valid { - po.category_posts = new(int) - *po.category_posts = int(value.Int64) - } default: po.selectValues.Set(columns[i], values[i]) } @@ -163,9 +150,9 @@ func (po *Post) QueryContributors() *PostContributorQuery { return NewPostClient(po.config).QueryContributors(po) } -// QueryCategory queries the "category" edge of the Post entity. -func (po *Post) QueryCategory() *CategoryQuery { - return NewPostClient(po.config).QueryCategory(po) +// QueryCategories queries the "categories" edge of the Post entity. +func (po *Post) QueryCategories() *CategoryQuery { + return NewPostClient(po.config).QueryCategories(po) } // Update returns a builder for updating this Post. diff --git a/backend/ent/post/post.go b/backend/ent/post/post.go index 6dc1e83..b72d3e0 100644 --- a/backend/ent/post/post.go +++ b/backend/ent/post/post.go @@ -27,8 +27,8 @@ const ( EdgeContents = "contents" // EdgeContributors holds the string denoting the contributors edge name in mutations. EdgeContributors = "contributors" - // EdgeCategory holds the string denoting the category edge name in mutations. - EdgeCategory = "category" + // EdgeCategories holds the string denoting the categories edge name in mutations. + EdgeCategories = "categories" // Table holds the table name of the post in the database. Table = "posts" // ContentsTable is the table that holds the contents relation/edge. @@ -45,13 +45,11 @@ const ( ContributorsInverseTable = "post_contributors" // ContributorsColumn is the table column denoting the contributors relation/edge. ContributorsColumn = "post_contributors" - // CategoryTable is the table that holds the category relation/edge. - CategoryTable = "posts" - // CategoryInverseTable is the table name for the Category entity. + // CategoriesTable is the table that holds the categories relation/edge. The primary key declared below. + CategoriesTable = "category_posts" + // CategoriesInverseTable is the table name for the Category entity. // It exists in this package in order to avoid circular dependency with the "category" package. - CategoryInverseTable = "categories" - // CategoryColumn is the table column denoting the category relation/edge. - CategoryColumn = "category_posts" + CategoriesInverseTable = "categories" ) // Columns holds all SQL columns for post fields. @@ -63,11 +61,11 @@ var Columns = []string{ FieldUpdatedAt, } -// ForeignKeys holds the SQL foreign-keys that are owned by the "posts" -// table and are not defined as standalone fields in the schema. -var ForeignKeys = []string{ - "category_posts", -} +var ( + // CategoriesPrimaryKey and CategoriesColumn2 are the table columns denoting the + // primary key for the categories relation (M2M). + CategoriesPrimaryKey = []string{"category_id", "post_id"} +) // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { @@ -76,11 +74,6 @@ func ValidColumn(column string) bool { return true } } - for i := range ForeignKeys { - if column == ForeignKeys[i] { - return true - } - } return false } @@ -178,10 +171,17 @@ func ByContributors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } -// ByCategoryField orders the results by category field. -func ByCategoryField(field string, opts ...sql.OrderTermOption) OrderOption { +// ByCategoriesCount orders the results by categories count. +func ByCategoriesCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { - sqlgraph.OrderByNeighborTerms(s, newCategoryStep(), sql.OrderByField(field, opts...)) + sqlgraph.OrderByNeighborsCount(s, newCategoriesStep(), opts...) + } +} + +// ByCategories orders the results by categories terms. +func ByCategories(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newCategoriesStep(), append([]sql.OrderTerm{term}, terms...)...) } } func newContentsStep() *sqlgraph.Step { @@ -198,10 +198,10 @@ func newContributorsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, ContributorsTable, ContributorsColumn), ) } -func newCategoryStep() *sqlgraph.Step { +func newCategoriesStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(CategoryInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn), + sqlgraph.To(CategoriesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...), ) } diff --git a/backend/ent/post/where.go b/backend/ent/post/where.go index 7b09ff9..4d9dfda 100644 --- a/backend/ent/post/where.go +++ b/backend/ent/post/where.go @@ -281,21 +281,21 @@ func HasContributorsWith(preds ...predicate.PostContributor) predicate.Post { }) } -// HasCategory applies the HasEdge predicate on the "category" edge. -func HasCategory() predicate.Post { +// HasCategories applies the HasEdge predicate on the "categories" edge. +func HasCategories() predicate.Post { return predicate.Post(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...), ) sqlgraph.HasNeighbors(s, step) }) } -// HasCategoryWith applies the HasEdge predicate on the "category" edge with a given conditions (other predicates). -func HasCategoryWith(preds ...predicate.Category) predicate.Post { +// HasCategoriesWith applies the HasEdge predicate on the "categories" edge with a given conditions (other predicates). +func HasCategoriesWith(preds ...predicate.Category) predicate.Post { return predicate.Post(func(s *sql.Selector) { - step := newCategoryStep() + step := newCategoriesStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) diff --git a/backend/ent/post_create.go b/backend/ent/post_create.go index c001fe2..b9fb3c7 100644 --- a/backend/ent/post_create.go +++ b/backend/ent/post_create.go @@ -101,23 +101,19 @@ func (pc *PostCreate) AddContributors(p ...*PostContributor) *PostCreate { return pc.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (pc *PostCreate) SetCategoryID(id int) *PostCreate { - pc.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (pc *PostCreate) AddCategoryIDs(ids ...int) *PostCreate { + pc.mutation.AddCategoryIDs(ids...) return pc } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (pc *PostCreate) SetNillableCategoryID(id *int) *PostCreate { - if id != nil { - pc = pc.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (pc *PostCreate) AddCategories(c ...*Category) *PostCreate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return pc -} - -// SetCategory sets the "category" edge to the Category entity. -func (pc *PostCreate) SetCategory(c *Category) *PostCreate { - return pc.SetCategoryID(c.ID) + return pc.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -267,12 +263,12 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } - if nodes := pc.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := pc.mutation.CategoriesIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -281,7 +277,6 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) { for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.category_posts = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } return _node, _spec diff --git a/backend/ent/post_query.go b/backend/ent/post_query.go index f90b5ff..c3fbb7a 100644 --- a/backend/ent/post_query.go +++ b/backend/ent/post_query.go @@ -28,8 +28,7 @@ type PostQuery struct { predicates []predicate.Post withContents *PostContentQuery withContributors *PostContributorQuery - withCategory *CategoryQuery - withFKs bool + withCategories *CategoryQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -110,8 +109,8 @@ func (pq *PostQuery) QueryContributors() *PostContributorQuery { return query } -// QueryCategory chains the current query on the "category" edge. -func (pq *PostQuery) QueryCategory() *CategoryQuery { +// QueryCategories chains the current query on the "categories" edge. +func (pq *PostQuery) QueryCategories() *CategoryQuery { query := (&CategoryClient{config: pq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := pq.prepareQuery(ctx); err != nil { @@ -124,7 +123,7 @@ func (pq *PostQuery) QueryCategory() *CategoryQuery { step := sqlgraph.NewStep( sqlgraph.From(post.Table, post.FieldID, selector), sqlgraph.To(category.Table, category.FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...), ) fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step) return fromU, nil @@ -326,7 +325,7 @@ func (pq *PostQuery) Clone() *PostQuery { predicates: append([]predicate.Post{}, pq.predicates...), withContents: pq.withContents.Clone(), withContributors: pq.withContributors.Clone(), - withCategory: pq.withCategory.Clone(), + withCategories: pq.withCategories.Clone(), // clone intermediate query. sql: pq.sql.Clone(), path: pq.path, @@ -355,14 +354,14 @@ func (pq *PostQuery) WithContributors(opts ...func(*PostContributorQuery)) *Post return pq } -// WithCategory tells the query-builder to eager-load the nodes that are connected to -// the "category" edge. The optional arguments are used to configure the query builder of the edge. -func (pq *PostQuery) WithCategory(opts ...func(*CategoryQuery)) *PostQuery { +// WithCategories tells the query-builder to eager-load the nodes that are connected to +// the "categories" edge. The optional arguments are used to configure the query builder of the edge. +func (pq *PostQuery) WithCategories(opts ...func(*CategoryQuery)) *PostQuery { query := (&CategoryClient{config: pq.config}).Query() for _, opt := range opts { opt(query) } - pq.withCategory = query + pq.withCategories = query return pq } @@ -443,20 +442,13 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) { var ( nodes = []*Post{} - withFKs = pq.withFKs _spec = pq.querySpec() loadedTypes = [3]bool{ pq.withContents != nil, pq.withContributors != nil, - pq.withCategory != nil, + pq.withCategories != nil, } ) - if pq.withCategory != nil { - withFKs = true - } - if withFKs { - _spec.Node.Columns = append(_spec.Node.Columns, post.ForeignKeys...) - } _spec.ScanValues = func(columns []string) ([]any, error) { return (*Post).scanValues(nil, columns) } @@ -489,9 +481,10 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e return nil, err } } - if query := pq.withCategory; query != nil { - if err := pq.loadCategory(ctx, query, nodes, nil, - func(n *Post, e *Category) { n.Edges.Category = e }); err != nil { + if query := pq.withCategories; query != nil { + if err := pq.loadCategories(ctx, query, nodes, + func(n *Post) { n.Edges.Categories = []*Category{} }, + func(n *Post, e *Category) { n.Edges.Categories = append(n.Edges.Categories, e) }); err != nil { return nil, err } } @@ -560,34 +553,63 @@ func (pq *PostQuery) loadContributors(ctx context.Context, query *PostContributo } return nil } -func (pq *PostQuery) loadCategory(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Post) - for i := range nodes { - if nodes[i].category_posts == nil { - continue +func (pq *PostQuery) loadCategories(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Post) + nids := make(map[int]map[*Post]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - fk := *nodes[i].category_posts - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) } - if len(ids) == 0 { - return nil + query.Where(func(s *sql.Selector) { + joinT := sql.Table(post.CategoriesTable) + s.Join(joinT).On(s.C(category.FieldID), joinT.C(post.CategoriesPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(post.CategoriesPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(post.CategoriesPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err } - query.Where(category.IDIn(ids...)) - neighbors, err := query.All(ctx) + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Post]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Category](ctx, query, qr, query.inters) if err != nil { return err } for _, n := range neighbors { - nodes, ok := nodeids[n.ID] + nodes, ok := nids[n.ID] if !ok { - return fmt.Errorf(`unexpected foreign-key "category_posts" returned %v`, n.ID) + return fmt.Errorf(`unexpected "categories" node returned %v`, n.ID) } - for i := range nodes { - assign(nodes[i], n) + for kn := range nodes { + assign(kn, n) } } return nil diff --git a/backend/ent/post_update.go b/backend/ent/post_update.go index d936a6a..7eb44b3 100644 --- a/backend/ent/post_update.go +++ b/backend/ent/post_update.go @@ -109,23 +109,19 @@ func (pu *PostUpdate) AddContributors(p ...*PostContributor) *PostUpdate { return pu.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (pu *PostUpdate) SetCategoryID(id int) *PostUpdate { - pu.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (pu *PostUpdate) AddCategoryIDs(ids ...int) *PostUpdate { + pu.mutation.AddCategoryIDs(ids...) return pu } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (pu *PostUpdate) SetNillableCategoryID(id *int) *PostUpdate { - if id != nil { - pu = pu.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (pu *PostUpdate) AddCategories(c ...*Category) *PostUpdate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return pu -} - -// SetCategory sets the "category" edge to the Category entity. -func (pu *PostUpdate) SetCategory(c *Category) *PostUpdate { - return pu.SetCategoryID(c.ID) + return pu.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -175,12 +171,27 @@ func (pu *PostUpdate) RemoveContributors(p ...*PostContributor) *PostUpdate { return pu.RemoveContributorIDs(ids...) } -// ClearCategory clears the "category" edge to the Category entity. -func (pu *PostUpdate) ClearCategory() *PostUpdate { - pu.mutation.ClearCategory() +// ClearCategories clears all "categories" edges to the Category entity. +func (pu *PostUpdate) ClearCategories() *PostUpdate { + pu.mutation.ClearCategories() return pu } +// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs. +func (pu *PostUpdate) RemoveCategoryIDs(ids ...int) *PostUpdate { + pu.mutation.RemoveCategoryIDs(ids...) + return pu +} + +// RemoveCategories removes "categories" edges to Category entities. +func (pu *PostUpdate) RemoveCategories(c ...*Category) *PostUpdate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID + } + return pu.RemoveCategoryIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (pu *PostUpdate) Save(ctx context.Context) (int, error) { pu.defaults() @@ -346,12 +357,12 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } - if pu.mutation.CategoryCleared() { + if pu.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -359,12 +370,28 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } - if nodes := pu.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := pu.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !pu.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := pu.mutation.CategoriesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -473,23 +500,19 @@ func (puo *PostUpdateOne) AddContributors(p ...*PostContributor) *PostUpdateOne return puo.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (puo *PostUpdateOne) SetCategoryID(id int) *PostUpdateOne { - puo.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (puo *PostUpdateOne) AddCategoryIDs(ids ...int) *PostUpdateOne { + puo.mutation.AddCategoryIDs(ids...) return puo } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (puo *PostUpdateOne) SetNillableCategoryID(id *int) *PostUpdateOne { - if id != nil { - puo = puo.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (puo *PostUpdateOne) AddCategories(c ...*Category) *PostUpdateOne { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return puo -} - -// SetCategory sets the "category" edge to the Category entity. -func (puo *PostUpdateOne) SetCategory(c *Category) *PostUpdateOne { - return puo.SetCategoryID(c.ID) + return puo.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -539,12 +562,27 @@ func (puo *PostUpdateOne) RemoveContributors(p ...*PostContributor) *PostUpdateO return puo.RemoveContributorIDs(ids...) } -// ClearCategory clears the "category" edge to the Category entity. -func (puo *PostUpdateOne) ClearCategory() *PostUpdateOne { - puo.mutation.ClearCategory() +// ClearCategories clears all "categories" edges to the Category entity. +func (puo *PostUpdateOne) ClearCategories() *PostUpdateOne { + puo.mutation.ClearCategories() return puo } +// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs. +func (puo *PostUpdateOne) RemoveCategoryIDs(ids ...int) *PostUpdateOne { + puo.mutation.RemoveCategoryIDs(ids...) + return puo +} + +// RemoveCategories removes "categories" edges to Category entities. +func (puo *PostUpdateOne) RemoveCategories(c ...*Category) *PostUpdateOne { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID + } + return puo.RemoveCategoryIDs(ids...) +} + // Where appends a list predicates to the PostUpdate builder. func (puo *PostUpdateOne) Where(ps ...predicate.Post) *PostUpdateOne { puo.mutation.Where(ps...) @@ -740,12 +778,12 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } - if puo.mutation.CategoryCleared() { + if puo.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -753,12 +791,28 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error) } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } - if nodes := puo.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := puo.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !puo.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := puo.mutation.CategoriesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), diff --git a/backend/ent/schema/media.go b/backend/ent/schema/media.go index 5c41d70..3ad19d5 100644 --- a/backend/ent/schema/media.go +++ b/backend/ent/schema/media.go @@ -16,11 +16,14 @@ type Media struct { func (Media) Fields() []ent.Field { return []ent.Field{ field.String("storage_id"). + StorageKey("storage_id"). NotEmpty(). Unique(), field.String("original_name"). + StorageKey("original_name"). NotEmpty(), field.String("mime_type"). + StorageKey("mime_type"). NotEmpty(), field.Int64("size"). Positive(), diff --git a/backend/ent/schema/post.go b/backend/ent/schema/post.go index 05362e4..ff97e93 100644 --- a/backend/ent/schema/post.go +++ b/backend/ent/schema/post.go @@ -34,8 +34,7 @@ func (Post) Edges() []ent.Edge { return []ent.Edge{ edge.To("contents", PostContent.Type), edge.To("contributors", PostContributor.Type), - edge.From("category", Category.Type). - Ref("posts"). - Unique(), + edge.From("categories", Category.Type). + Ref("posts"), } } diff --git a/backend/go.mod b/backend/go.mod index cc83baa..e2141f5 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,13 +3,13 @@ module tss-rocks-be go 1.23.6 require ( - bou.ke/monkey v1.0.2 entgo.io/ent v0.14.1 github.com/aws/aws-sdk-go-v2 v1.36.1 github.com/aws/aws-sdk-go-v2/config v1.29.6 github.com/aws/aws-sdk-go-v2/credentials v1.17.59 github.com/aws/aws-sdk-go-v2/service/s3 v1.76.1 github.com/chai2010/webp v1.1.1 + github.com/disintegration/imaging v1.6.2 github.com/gin-gonic/gin v1.10.0 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 @@ -17,7 +17,6 @@ require ( github.com/rs/zerolog v1.33.0 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.10.0 - go.uber.org/mock v0.5.0 golang.org/x/crypto v0.33.0 golang.org/x/time v0.10.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -68,12 +67,12 @@ require ( github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/zclconf/go-cty v1.16.2 // indirect github.com/zclconf/go-cty-yaml v1.1.0 // indirect golang.org/x/arch v0.14.0 // indirect + golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect golang.org/x/mod v0.23.0 // indirect golang.org/x/net v0.35.0 // indirect golang.org/x/sync v0.11.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index f3dba0a..37a92d6 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,7 +1,5 @@ ariga.io/atlas v0.31.0 h1:Nw6/Jdc7OpZfiy6oh/dJAYPp5XxGYvMTWLOUutwWjeY= ariga.io/atlas v0.31.0/go.mod h1:J3chwsQAgjDF6Ostz7JmJJRTCbtqIupUbVR/gqZrMiA= -bou.ke/monkey v1.0.2 h1:kWcnsrCNUatbxncxR/ThdYqbytgOIArtYWqcQLQzKLI= -bou.ke/monkey v1.0.2/go.mod h1:OqickVX3tNx6t33n1xvtTtu85YN5s6cKwVug+oHMaIA= entgo.io/ent v0.14.1 h1:fUERL506Pqr92EPHJqr8EYxbPioflJo6PudkrEA8a/s= entgo.io/ent v0.14.1/go.mod h1:MH6XLG0KXpkcDQhKiHfANZSzR55TJyPL5IGNpI8wpco= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= @@ -63,6 +61,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= +github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= @@ -110,8 +110,6 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= @@ -121,8 +119,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -139,7 +135,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -159,12 +154,13 @@ github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6 github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM= github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0= github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4= golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= @@ -176,10 +172,13 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= diff --git a/backend/internal/auth/auth_test.go b/backend/internal/auth/auth_test.go deleted file mode 100644 index 109a7fe..0000000 --- a/backend/internal/auth/auth_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package auth - -import ( - "context" - "testing" -) - -func TestUserIDKey(t *testing.T) { - // Test that the UserIDKey constant is defined correctly - if UserIDKey != "user_id" { - t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id") - } - - // Test context with user ID - ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123") - value := ctx.Value(UserIDKey) - if value != "test-user-123" { - t.Errorf("Context value = %v, want %v", value, "test-user-123") - } - - // Test context without user ID - emptyCtx := context.Background() - emptyValue := emptyCtx.Value(UserIDKey) - if emptyValue != nil { - t.Errorf("Empty context value = %v, want nil", emptyValue) - } -} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index faa6bdd..f7eedd8 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -49,7 +49,7 @@ type StorageConfig struct { Type string `yaml:"type"` Local LocalStorage `yaml:"local"` S3 S3Storage `yaml:"s3"` - Upload types.UploadConfig `yaml:"upload"` + Upload UploadConfig `yaml:"upload"` } type LocalStorage struct { @@ -66,6 +66,27 @@ type S3Storage struct { ProxyS3 bool `yaml:"proxy_s3"` } +type UploadConfig struct { + Limits struct { + Image struct { + MaxSize int `yaml:"max_size"` + AllowedTypes []string `yaml:"allowed_types"` + } `yaml:"image"` + Video struct { + MaxSize int `yaml:"max_size"` + AllowedTypes []string `yaml:"allowed_types"` + } `yaml:"video"` + Audio struct { + MaxSize int `yaml:"max_size"` + AllowedTypes []string `yaml:"allowed_types"` + } `yaml:"audio"` + Document struct { + MaxSize int `yaml:"max_size"` + AllowedTypes []string `yaml:"allowed_types"` + } `yaml:"document"` + } `yaml:"limits"` +} + // Load loads configuration from a YAML file func Load(path string) (*Config, error) { data, err := os.ReadFile(path) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go deleted file mode 100644 index 800f5b0..0000000 --- a/backend/internal/config/config_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package config - -import ( - "os" - "path/filepath" - "testing" -) - -func TestLoad(t *testing.T) { - // Create a temporary test config file - content := []byte(` -database: - driver: postgres - dsn: postgres://user:pass@localhost:5432/dbname -server: - port: 8080 - host: localhost -jwt: - secret: test-secret - expiration: 24h -storage: - type: local - local: - root_dir: /tmp/storage - upload: - max_size: 10485760 -logging: - level: info - format: json -`) - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - if err := os.WriteFile(configPath, content, 0644); err != nil { - t.Fatalf("Failed to write test config: %v", err) - } - - // Test loading config - cfg, err := Load(configPath) - if err != nil { - t.Fatalf("Load() error = %v", err) - } - - // Verify loaded values - tests := []struct { - name string - got interface{} - want interface{} - errorMsg string - }{ - {"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"}, - {"Server Port", cfg.Server.Port, 8080, "incorrect server port"}, - {"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"}, - {"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"}, - {"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.got != tt.want { - t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want) - } - }) - } -} - -func TestLoadError(t *testing.T) { - // Test loading non-existent file - _, err := Load("non-existent-file.yaml") - if err == nil { - t.Error("Load() error = nil, want error for non-existent file") - } - - // Test loading invalid YAML - tmpDir := t.TempDir() - invalidPath := filepath.Join(tmpDir, "invalid.yaml") - if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil { - t.Fatalf("Failed to write invalid config: %v", err) - } - - _, err = Load(invalidPath) - if err == nil { - t.Error("Load() error = nil, want error for invalid YAML") - } -} diff --git a/backend/internal/handler/auth_handler_test.go b/backend/internal/handler/auth_handler_test.go deleted file mode 100644 index cae7360..0000000 --- a/backend/internal/handler/auth_handler_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service/mock" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - "golang.org/x/crypto/bcrypt" -) - -type AuthHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *AuthHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - s.handler = NewHandler(&config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - }, - Auth: config.AuthConfig{ - Registration: struct { - Enabled bool `yaml:"enabled"` - Message string `yaml:"message"` - }{ - Enabled: true, - Message: "Registration is disabled", - }, - }, - }, s.service) - s.router = gin.New() -} - -func (s *AuthHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestAuthHandlerSuite(t *testing.T) { - suite.Run(t, new(AuthHandlerTestSuite)) -} - -type ErrorResponse struct { - Error struct { - Code string `json:"code"` - Message string `json:"message"` - } `json:"error"` -} - -func (s *AuthHandlerTestSuite) TestRegister() { - testCases := []struct { - name string - request RegisterRequest - setupMock func() - expectedStatus int - expectedError string - registration bool - }{ - { - name: "成功注册", - request: RegisterRequest{ - Username: "testuser", - Email: "test@example.com", - Password: "password123", - Role: "contributor", - }, - setupMock: func() { - s.service.EXPECT(). - CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor"). - Return(&ent.User{ - ID: 1, - Username: "testuser", - Email: "test@example.com", - }, nil) - s.service.EXPECT(). - GetUserRoles(gomock.Any(), 1). - Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) - }, - expectedStatus: http.StatusCreated, - registration: true, - }, - { - name: "注册功能已禁用", - request: RegisterRequest{ - Username: "testuser", - Email: "test@example.com", - Password: "password123", - Role: "contributor", - }, - setupMock: func() {}, - expectedStatus: http.StatusForbidden, - expectedError: "Registration is disabled", - registration: false, - }, - { - name: "无效的邮箱格式", - request: RegisterRequest{ - Username: "testuser", - Email: "invalid-email", - Password: "password123", - Role: "contributor", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", - registration: true, - }, - { - name: "密码太短", - request: RegisterRequest{ - Username: "testuser", - Email: "test@example.com", - Password: "short", - Role: "contributor", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag", - registration: true, - }, - { - name: "无效的角色", - request: RegisterRequest{ - Username: "testuser", - Email: "test@example.com", - Password: "password123", - Role: "invalid-role", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag", - registration: true, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置注册功能状态 - s.handler.cfg.Auth.Registration.Enabled = tc.registration - - // 设置 mock - tc.setupMock() - - // 创建请求 - reqBody, _ := json.Marshal(tc.request) - req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // 执行请求 - s.handler.Register(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response ErrorResponse - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Contains(response.Error.Message, tc.expectedError) - } else { - var response AuthResponse - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.NotEmpty(response.Token) - } - }) - } -} - -func (s *AuthHandlerTestSuite) TestLogin() { - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) - - testCases := []struct { - name string - request LoginRequest - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功登录", - request: LoginRequest{ - Username: "testuser", - Password: "password123", - }, - setupMock: func() { - s.service.EXPECT(). - GetUserByUsername(gomock.Any(), "testuser"). - Return(&ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - }, nil) - s.service.EXPECT(). - GetUserRoles(gomock.Any(), 1). - Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) - }, - expectedStatus: http.StatusOK, - }, - { - name: "无效的用户名", - request: LoginRequest{ - Username: "te", - Password: "password123", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag", - }, - { - name: "用户不存在", - request: LoginRequest{ - Username: "nonexistent", - Password: "password123", - }, - setupMock: func() { - s.service.EXPECT(). - GetUserByUsername(gomock.Any(), "nonexistent"). - Return(nil, fmt.Errorf("user not found")) - }, - expectedStatus: http.StatusUnauthorized, - expectedError: "Invalid username or password", - }, - { - name: "密码错误", - request: LoginRequest{ - Username: "testuser", - Password: "wrongpassword", - }, - setupMock: func() { - s.service.EXPECT(). - GetUserByUsername(gomock.Any(), "testuser"). - Return(&ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - }, nil) - }, - expectedStatus: http.StatusUnauthorized, - expectedError: "Invalid username or password", - }, - { - name: "获取用户角色失败", - request: LoginRequest{ - Username: "testuser", - Password: "password123", - }, - setupMock: func() { - s.service.EXPECT(). - GetUserByUsername(gomock.Any(), "testuser"). - Return(&ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - }, nil) - s.service.EXPECT(). - GetUserRoles(gomock.Any(), 1). - Return(nil, fmt.Errorf("failed to get roles")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to get user roles", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - reqBody, _ := json.Marshal(tc.request) - req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // 执行请求 - s.handler.Login(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response ErrorResponse - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Contains(response.Error.Message, tc.expectedError) - } else { - var response AuthResponse - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.NotEmpty(response.Token) - } - }) - } -} diff --git a/backend/internal/handler/category_handler_test.go b/backend/internal/handler/category_handler_test.go deleted file mode 100644 index 95c4bd9..0000000 --- a/backend/internal/handler/category_handler_test.go +++ /dev/null @@ -1,481 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "tss-rocks-be/ent" - "tss-rocks-be/ent/categorycontent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service" - "tss-rocks-be/internal/service/mock" - "tss-rocks-be/internal/types" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "errors" -) - -// Custom assertion function for comparing categories -func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool { - if expected == nil && actual == nil { - return true - } - if expected == nil || actual == nil { - return assert.Fail(t, "One category is nil while the other is not") - } - - // Compare only relevant fields, ignoring time fields - return assert.Equal(t, expected.ID, actual.ID) && - assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents) -} - -// Custom assertion function for comparing category slices -func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool { - if len(expected) != len(actual) { - return assert.Fail(t, "Category slice lengths do not match") - } - - for i := range expected { - if !assertCategoryEqual(t, expected[i], actual[i]) { - return false - } - } - return true -} - -type CategoryHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *CategoryHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - }, - } - s.handler = NewHandler(cfg, s.service) - - // Setup Gin router - gin.SetMode(gin.TestMode) - s.router = gin.New() - - // Setup mock for GetTokenBlacklist - tokenBlacklist := &service.TokenBlacklist{} - s.service.EXPECT(). - GetTokenBlacklist(). - Return(tokenBlacklist). - AnyTimes() - - s.handler.RegisterRoutes(s.router) -} - -func (s *CategoryHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestCategoryHandlerSuite(t *testing.T) { - suite.Run(t, new(CategoryHandlerTestSuite)) -} - -// Test cases for ListCategories -func (s *CategoryHandlerTestSuite) TestListCategories() { - testCases := []struct { - name string - langCode string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success with default language", - langCode: "", - setupMock: func() { - s.service.EXPECT(). - ListCategories(gomock.Any(), gomock.Eq("en")). - Return([]*ent.Category{ - { - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: "Test Description", - Slug: "test-category", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Category{ - { - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: "Test Description", - Slug: "test-category", - }, - }, - }, - }, - }, - }, - { - name: "Success with specific language", - langCode: "zh", - setupMock: func() { - s.service.EXPECT(). - ListCategories(gomock.Any(), gomock.Eq("zh")). - Return([]*ent.Category{ - { - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("zh"), - Name: "测试分类", - Description: "测试描述", - Slug: "test-category", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Category{ - { - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("zh"), - Name: "测试分类", - Description: "测试描述", - Slug: "test-category", - }, - }, - }, - }, - }, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup mock - tc.setupMock() - - // Create request - url := "/api/v1/categories" - if tc.langCode != "" { - url += "?lang=" + tc.langCode - } - req := httptest.NewRequest(http.MethodGet, url, nil) - w := httptest.NewRecorder() - - // Perform request - s.router.ServeHTTP(w, req) - - // Assert response - assert.Equal(s.T(), tc.expectedStatus, w.Code) - if tc.expectedBody != nil { - var response []*ent.Category - err := json.Unmarshal(w.Body.Bytes(), &response) - assert.NoError(s.T(), err) - assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response) - } - }) - } -} - -// Test cases for GetCategory -func (s *CategoryHandlerTestSuite) TestGetCategory() { - testCases := []struct { - name string - langCode string - slug string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - langCode: "en", - slug: "test-category", - setupMock: func() { - s.service.EXPECT(). - GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")). - Return(&ent.Category{ - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: "Test Description", - Slug: "test-category", - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: &ent.Category{ - ID: 1, - Edges: ent.CategoryEdges{ - Contents: []*ent.CategoryContent{ - { - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: "Test Description", - Slug: "test-category", - }, - }, - }, - }, - }, - { - name: "Not Found", - langCode: "en", - slug: "non-existent", - setupMock: func() { - s.service.EXPECT(). - GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")). - Return(nil, types.ErrNotFound) - }, - expectedStatus: http.StatusNotFound, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup mock - tc.setupMock() - - // Create request - url := "/api/v1/categories/" + tc.slug - if tc.langCode != "" { - url += "?lang=" + tc.langCode - } - req := httptest.NewRequest(http.MethodGet, url, nil) - w := httptest.NewRecorder() - - // Perform request - s.router.ServeHTTP(w, req) - - // Assert response - assert.Equal(s.T(), tc.expectedStatus, w.Code) - if tc.expectedBody != nil { - var response ent.Category - err := json.Unmarshal(w.Body.Bytes(), &response) - assert.NoError(s.T(), err) - assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response) - } - }) - } -} - -// Test cases for AddCategoryContent -func (s *CategoryHandlerTestSuite) TestAddCategoryContent() { - var description = "Test Description" - testCases := []struct { - name string - categoryID string - requestBody interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - categoryID: "1", - requestBody: AddCategoryContentRequest{ - LanguageCode: "en", - Name: "Test Category", - Description: &description, - Slug: "test-category", - }, - setupMock: func() { - s.service.EXPECT(). - AddCategoryContent( - gomock.Any(), - 1, - "en", - "Test Category", - description, - "test-category", - ). - Return(&ent.CategoryContent{ - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: description, - Slug: "test-category", - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: &ent.CategoryContent{ - LanguageCode: categorycontent.LanguageCode("en"), - Name: "Test Category", - Description: description, - Slug: "test-category", - }, - }, - { - name: "Invalid JSON", - categoryID: "1", - requestBody: "invalid json", - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "Invalid Category ID", - categoryID: "invalid", - requestBody: AddCategoryContentRequest{ - LanguageCode: "en", - Name: "Test Category", - Description: &description, - Slug: "test-category", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "Service Error", - categoryID: "1", - requestBody: AddCategoryContentRequest{ - LanguageCode: "en", - Name: "Test Category", - Description: &description, - Slug: "test-category", - }, - setupMock: func() { - s.service.EXPECT(). - AddCategoryContent( - gomock.Any(), - 1, - "en", - "Test Category", - description, - "test-category", - ). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup mock - tc.setupMock() - - // Create request - var body []byte - var err error - if str, ok := tc.requestBody.(string); ok { - body = []byte(str) - } else { - body, err = json.Marshal(tc.requestBody) - s.NoError(err) - } - - req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - // Perform request - s.router.ServeHTTP(w, req) - - // Assert response - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedBody != nil { - var response ent.CategoryContent - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedBody, &response) - } - }) - } -} - -// Test cases for CreateCategory -func (s *CategoryHandlerTestSuite) TestCreateCategory() { - testCases := []struct { - name string - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功创建分类", - setupMock: func() { - category := &ent.Category{ - ID: 1, - } - s.service.EXPECT(). - CreateCategory(gomock.Any()). - Return(category, nil) - }, - expectedStatus: http.StatusCreated, - }, - { - name: "创建分类失败", - setupMock: func() { - s.service.EXPECT(). - CreateCategory(gomock.Any()). - Return(nil, errors.New("failed to create category")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to create category", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - req, _ := http.NewRequest(http.MethodPost, "/categories", nil) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // 执行请求 - s.handler.CreateCategory(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedError, response["error"]) - } else { - var response *ent.Category - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.NotNil(response) - s.Equal(1, response.ID) - } - }) - } -} diff --git a/backend/internal/handler/contributor_handler_test.go b/backend/internal/handler/contributor_handler_test.go deleted file mode 100644 index 50f36da..0000000 --- a/backend/internal/handler/contributor_handler_test.go +++ /dev/null @@ -1,456 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service" - "tss-rocks-be/internal/service/mock" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "errors" -) - -type ContributorHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *ContributorHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - }, - } - s.handler = NewHandler(cfg, s.service) - - // Setup Gin router - gin.SetMode(gin.TestMode) - s.router = gin.New() - - // Setup mock for GetTokenBlacklist - tokenBlacklist := &service.TokenBlacklist{} - s.service.EXPECT(). - GetTokenBlacklist(). - Return(tokenBlacklist). - AnyTimes() - - s.handler.RegisterRoutes(s.router) -} - -func (s *ContributorHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestContributorHandlerSuite(t *testing.T) { - suite.Run(t, new(ContributorHandlerTestSuite)) -} - -func (s *ContributorHandlerTestSuite) TestListContributors() { - testCases := []struct { - name string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - setupMock: func() { - s.service.EXPECT(). - ListContributors(gomock.Any()). - Return([]*ent.Contributor{ - { - ID: 1, - Name: "John Doe", - Edges: ent.ContributorEdges{ - SocialLinks: []*ent.ContributorSocialLink{ - { - Type: "github", - Value: "https://github.com/johndoe", - Edges: ent.ContributorSocialLinkEdges{}, - }, - }, - }, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - }, - { - ID: 2, - Name: "Jane Smith", - Edges: ent.ContributorEdges{ - SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present - }, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []gin.H{ - { - "id": 1, - "name": "John Doe", - "created_at": time.Time{}, - "updated_at": time.Time{}, - "edges": gin.H{ - "social_links": []gin.H{ - { - "type": "github", - "value": "https://github.com/johndoe", - "edges": gin.H{}, - }, - }, - }, - }, - { - "id": 2, - "name": "Jane Smith", - "created_at": time.Time{}, - "updated_at": time.Time{}, - "edges": gin.H{ - "social_links": []gin.H{}, // Ensure empty SocialLinks array is present - }, - }, - }, - }, - { - name: "Service error", - setupMock: func() { - s.service.EXPECT(). - ListContributors(gomock.Any()). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to list contributors"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *ContributorHandlerTestSuite) TestGetContributor() { - testCases := []struct { - name string - id string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - id: "1", - setupMock: func() { - s.service.EXPECT(). - GetContributorByID(gomock.Any(), 1). - Return(&ent.Contributor{ - ID: 1, - Name: "John Doe", - Edges: ent.ContributorEdges{ - SocialLinks: []*ent.ContributorSocialLink{ - { - Type: "github", - Value: "https://github.com/johndoe", - Edges: ent.ContributorSocialLinkEdges{}, - }, - }, - }, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: gin.H{ - "id": 1, - "name": "John Doe", - "created_at": time.Time{}, - "updated_at": time.Time{}, - "edges": gin.H{ - "social_links": []gin.H{ - { - "type": "github", - "value": "https://github.com/johndoe", - "edges": gin.H{}, - }, - }, - }, - }, - }, - { - name: "Invalid ID", - id: "invalid", - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Invalid contributor ID"}, - }, - { - name: "Service error", - id: "1", - setupMock: func() { - s.service.EXPECT(). - GetContributorByID(gomock.Any(), 1). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to get contributor"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *ContributorHandlerTestSuite) TestCreateContributor() { - testCases := []struct { - name string - body interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - body: CreateContributorRequest{ - Name: "John Doe", - }, - setupMock: func() { - name := "John Doe" - s.service.EXPECT(). - CreateContributor( - gomock.Any(), - name, - nil, - nil, - ). - Return(&ent.Contributor{ - ID: 1, - Name: name, - CreatedAt: time.Time{}, - UpdatedAt: time.Time{}, - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: gin.H{ - "id": 1, - "name": "John Doe", - "created_at": time.Time{}, - "updated_at": time.Time{}, - "edges": gin.H{}, - }, - }, - { - name: "Invalid request body", - body: map[string]interface{}{ - "name": "", // Empty name is not allowed - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"}, - }, - { - name: "Service error", - body: CreateContributorRequest{ - Name: "John Doe", - }, - setupMock: func() { - name := "John Doe" - s.service.EXPECT(). - CreateContributor( - gomock.Any(), - name, - nil, - nil, - ). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to create contributor"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - body, err := json.Marshal(tc.body) - s.NoError(err, "Failed to marshal request body") - - req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() { - testCases := []struct { - name string - id string - body interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - id: "1", - body: func() AddContributorSocialLinkRequest { - name := "johndoe" - return AddContributorSocialLinkRequest{ - Type: "github", - Name: &name, - Value: "https://github.com/johndoe", - } - }(), - setupMock: func() { - name := "johndoe" - s.service.EXPECT(). - AddContributorSocialLink( - gomock.Any(), - 1, - "github", - name, - "https://github.com/johndoe", - ). - Return(&ent.ContributorSocialLink{ - Type: "github", - Name: name, - Value: "https://github.com/johndoe", - Edges: ent.ContributorSocialLinkEdges{}, - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: gin.H{ - "type": "github", - "name": "johndoe", - "value": "https://github.com/johndoe", - "edges": gin.H{}, - }, - }, - { - name: "Invalid contributor ID", - id: "invalid", - body: func() AddContributorSocialLinkRequest { - name := "johndoe" - return AddContributorSocialLinkRequest{ - Type: "github", - Name: &name, - Value: "https://github.com/johndoe", - } - }(), - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Invalid contributor ID"}, - }, - { - name: "Invalid request body", - id: "1", - body: map[string]interface{}{ - "type": "", // Empty type is not allowed - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"}, - }, - { - name: "Service error", - id: "1", - body: func() AddContributorSocialLinkRequest { - name := "johndoe" - return AddContributorSocialLinkRequest{ - Type: "github", - Name: &name, - Value: "https://github.com/johndoe", - } - }(), - setupMock: func() { - name := "johndoe" - s.service.EXPECT(). - AddContributorSocialLink( - gomock.Any(), - 1, - "github", - name, - "https://github.com/johndoe", - ). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to add contributor social link"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - body, err := json.Marshal(tc.body) - s.NoError(err, "Failed to marshal request body") - - req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} diff --git a/backend/internal/handler/daily_handler_test.go b/backend/internal/handler/daily_handler_test.go deleted file mode 100644 index 24c94cc..0000000 --- a/backend/internal/handler/daily_handler_test.go +++ /dev/null @@ -1,532 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service" - "tss-rocks-be/internal/service/mock" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "errors" - "strings" -) - -type DailyHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *DailyHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - }, - } - s.handler = NewHandler(cfg, s.service) - - // Setup Gin router - gin.SetMode(gin.TestMode) - s.router = gin.New() - - // Setup mock for GetTokenBlacklist - tokenBlacklist := &service.TokenBlacklist{} - s.service.EXPECT(). - GetTokenBlacklist(). - Return(tokenBlacklist). - AnyTimes() - - s.handler.RegisterRoutes(s.router) -} - -func (s *DailyHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestDailyHandlerSuite(t *testing.T) { - suite.Run(t, new(DailyHandlerTestSuite)) -} - -func (s *DailyHandlerTestSuite) TestListDailies() { - testCases := []struct { - name string - langCode string - categoryID string - limit string - offset string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success with default language", - langCode: "", - setupMock: func() { - s.service.EXPECT(). - ListDailies(gomock.Any(), "en", nil, 10, 0). - Return([]*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, - }, - }, - { - name: "Success with specific language", - langCode: "zh", - setupMock: func() { - s.service.EXPECT(). - ListDailies(gomock.Any(), "zh", nil, 10, 0). - Return([]*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "zh", - Quote: "测试语录1", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "zh", - Quote: "测试语录1", - }, - }, - }, - }, - }, - }, - { - name: "Success with category filter", - categoryID: "1", - setupMock: func() { - categoryID := 1 - s.service.EXPECT(). - ListDailies(gomock.Any(), "en", &categoryID, 10, 0). - Return([]*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Daily{ - { - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, - }, - }, - { - name: "Success with pagination", - limit: "2", - offset: "1", - setupMock: func() { - s.service.EXPECT(). - ListDailies(gomock.Any(), "en", nil, 2, 1). - Return([]*ent.Daily{ - { - ID: "daily2", - ImageURL: "https://example.com/image2.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 2", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Daily{ - { - ID: "daily2", - ImageURL: "https://example.com/image2.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 2", - }, - }, - }, - }, - }, - }, - { - name: "Service Error", - setupMock: func() { - s.service.EXPECT(). - ListDailies(gomock.Any(), "en", nil, 10, 0). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to list dailies"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - url := "/api/v1/dailies" - if tc.langCode != "" { - url += "?lang=" + tc.langCode - } - if tc.categoryID != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "category_id=" + tc.categoryID - } - if tc.limit != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "limit=" + tc.limit - } - if tc.offset != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "offset=" + tc.offset - } - - req := httptest.NewRequest(http.MethodGet, url, nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *DailyHandlerTestSuite) TestGetDaily() { - testCases := []struct { - name string - id string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - id: "daily1", - setupMock: func() { - s.service.EXPECT(). - GetDailyByID(gomock.Any(), "daily1"). - Return(&ent.Daily{ - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: &ent.Daily{ - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{ - { - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - }, - }, - }, - { - name: "Service error", - id: "daily1", - setupMock: func() { - s.service.EXPECT(). - GetDailyByID(gomock.Any(), "daily1"). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to get daily"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *DailyHandlerTestSuite) TestCreateDaily() { - testCases := []struct { - name string - body interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - body: CreateDailyRequest{ - ID: "daily1", - CategoryID: 1, - ImageURL: "https://example.com/image1.jpg", - }, - setupMock: func() { - s.service.EXPECT(). - CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg"). - Return(&ent.Daily{ - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{}, - }, - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: &ent.Daily{ - ID: "daily1", - ImageURL: "https://example.com/image1.jpg", - Edges: ent.DailyEdges{ - Category: &ent.Category{ID: 1}, - Contents: []*ent.DailyContent{}, - }, - }, - }, - { - name: "Invalid request body", - body: map[string]interface{}{ - "id": "daily1", - // Missing required fields - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"}, - }, - { - name: "Service error", - body: CreateDailyRequest{ - ID: "daily1", - CategoryID: 1, - ImageURL: "https://example.com/image1.jpg", - }, - setupMock: func() { - s.service.EXPECT(). - CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg"). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to create daily"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - body, err := json.Marshal(tc.body) - s.NoError(err, "Failed to marshal request body") - - req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -func (s *DailyHandlerTestSuite) TestAddDailyContent() { - testCases := []struct { - name string - dailyID string - body interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - dailyID: "daily1", - body: AddDailyContentRequest{ - LanguageCode: "en", - Quote: "Test Quote 1", - }, - setupMock: func() { - s.service.EXPECT(). - AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1"). - Return(&ent.DailyContent{ - LanguageCode: "en", - Quote: "Test Quote 1", - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: &ent.DailyContent{ - LanguageCode: "en", - Quote: "Test Quote 1", - }, - }, - { - name: "Invalid request body", - dailyID: "daily1", - body: map[string]interface{}{ - "language_code": "en", - // Missing required fields - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"}, - }, - { - name: "Service error", - dailyID: "daily1", - body: AddDailyContentRequest{ - LanguageCode: "en", - Quote: "Test Quote 1", - }, - setupMock: func() { - s.service.EXPECT(). - AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1"). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to add daily content"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - body, err := json.Marshal(tc.body) - s.NoError(err, "Failed to marshal request body") - - req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5afcaf0..11c5b3e 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -25,8 +25,9 @@ func NewHandler(cfg *config.Config, service service.Service) *Handler { } // RegisterRoutes registers all the routes -func (h *Handler) RegisterRoutes(r *gin.Engine) { - api := r.Group("/api/v1") +func (h *Handler) RegisterRoutes(router *gin.Engine) { + // API routes + api := router.Group("/api/v1") { // Auth routes auth := api.Group("/auth") @@ -93,6 +94,9 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) { media.DELETE("/:id", h.DeleteMedia) } } + + // Public media files + router.GET("/media/:year/:month/:filename", h.GetMediaFile) } // Category handlers @@ -186,10 +190,12 @@ func (h *Handler) ListPosts(c *gin.Context) { langCode = "en" // Default to English } - var categoryID *int - if catIDStr := c.Query("category_id"); catIDStr != "" { - if id, err := strconv.Atoi(catIDStr); err == nil { - categoryID = &id + var categoryIDs []int + if catIDsStr := c.QueryArray("category_ids"); len(catIDsStr) > 0 { + for _, idStr := range catIDsStr { + if id, err := strconv.Atoi(idStr); err == nil { + categoryIDs = append(categoryIDs, id) + } } } @@ -207,7 +213,7 @@ func (h *Handler) ListPosts(c *gin.Context) { } } - posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset) + posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryIDs, limit, offset) if err != nil { log.Error().Err(err).Msg("Failed to list posts") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"}) @@ -244,10 +250,10 @@ func (h *Handler) GetPost(c *gin.Context) { contents := make([]gin.H, 0, len(post.Edges.Contents)) for _, content := range post.Edges.Contents { contents = append(contents, gin.H{ - "language_code": content.LanguageCode, - "title": content.Title, + "language_code": content.LanguageCode, + "title": content.Title, "content_markdown": content.ContentMarkdown, - "summary": content.Summary, + "summary": content.Summary, }) } response["edges"].(gin.H)["contents"] = contents @@ -256,7 +262,22 @@ func (h *Handler) GetPost(c *gin.Context) { } func (h *Handler) CreatePost(c *gin.Context) { - post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status + var req struct { + Status string `json:"status" binding:"omitempty,oneof=draft published archived"` + CategoryIDs []int `json:"category_ids" binding:"required,min=1"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + status := req.Status + if status == "" { + status = "draft" // Default to draft status + } + + post, err := h.service.CreatePost(c.Request.Context(), status, req.CategoryIDs) if err != nil { log.Error().Err(err).Msg("Failed to create post") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"}) @@ -268,7 +289,8 @@ func (h *Handler) CreatePost(c *gin.Context) { "id": post.ID, "status": post.Status, "edges": gin.H{ - "contents": []interface{}{}, + "contents": []interface{}{}, + "categories": []interface{}{}, }, } @@ -276,7 +298,7 @@ func (h *Handler) CreatePost(c *gin.Context) { } type AddPostContentRequest struct { - LanguageCode string `json:"language_code" binding:"required"` + LanguageCode string `json:"language_code" binding:"required"` Title string `json:"title" binding:"required"` ContentMarkdown string `json:"content_markdown" binding:"required"` Summary string `json:"summary" binding:"required"` @@ -308,10 +330,10 @@ func (h *Handler) AddPostContent(c *gin.Context) { "title": content.Title, "content_markdown": content.ContentMarkdown, "language_code": content.LanguageCode, - "summary": content.Summary, + "summary": content.Summary, "meta_keywords": content.MetaKeywords, "meta_description": content.MetaDescription, - "edges": gin.H{}, + "edges": gin.H{}, }) } @@ -517,11 +539,3 @@ func (h *Handler) AddDailyContent(c *gin.Context) { c.JSON(http.StatusCreated, content) } - -// Helper functions -func stringPtr(s *string) string { - if s == nil { - return "" - } - return *s -} diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go deleted file mode 100644 index a9d3b9f..0000000 --- a/backend/internal/handler/handler_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package handler - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestStringPtr(t *testing.T) { - testCases := []struct { - name string - input *string - expected string - }{ - { - name: "nil pointer", - input: nil, - expected: "", - }, - { - name: "empty string", - input: strPtr(""), - expected: "", - }, - { - name: "non-empty string", - input: strPtr("test"), - expected: "test", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := stringPtr(tc.input) - assert.Equal(t, tc.expected, result) - }) - } -} - -// Helper function to create string pointer -func strPtr(s string) *string { - return &s -} diff --git a/backend/internal/handler/media.go b/backend/internal/handler/media.go index 68d4a4b..1ad429b 100644 --- a/backend/internal/handler/media.go +++ b/backend/internal/handler/media.go @@ -5,9 +5,11 @@ import ( "io" "net/http" "strconv" + "strings" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "path/filepath" ) // Media handlers @@ -33,17 +35,29 @@ func (h *Handler) ListMedia(c *gin.Context) { return } - c.JSON(http.StatusOK, media) + c.JSON(http.StatusOK, gin.H{ + "data": media, + }) } func (h *Handler) UploadMedia(c *gin.Context) { // Get user ID from context (set by auth middleware) - userID, exists := c.Get("user_id") + userIDStr, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return } + // Convert user ID to int + userID, err := strconv.Atoi(userIDStr.(string)) + if err != nil { + log.Error().Err(err). + Str("user_id", fmt.Sprintf("%v", userIDStr)). + Msg("Failed to convert user ID to int") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) + return + } + // Get file from form file, err := c.FormFile("file") if err != nil { @@ -51,38 +65,107 @@ func (h *Handler) UploadMedia(c *gin.Context) { return } - // 文件大小限制 - if file.Size > 10*1024*1024 { // 10MB - c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"}) + // 获取文件类型和扩展名 + contentType := file.Header.Get("Content-Type") + ext := strings.ToLower(filepath.Ext(file.Filename)) + if contentType == "" { + // 如果 Content-Type 为空,尝试从文件扩展名判断 + switch ext { + case ".jpg", ".jpeg": + contentType = "image/jpeg" + case ".png": + contentType = "image/png" + case ".gif": + contentType = "image/gif" + case ".webp": + contentType = "image/webp" + case ".mp4": + contentType = "video/mp4" + case ".webm": + contentType = "video/webm" + case ".mp3": + contentType = "audio/mpeg" + case ".ogg": + contentType = "audio/ogg" + case ".wav": + contentType = "audio/wav" + case ".pdf": + contentType = "application/pdf" + case ".doc": + contentType = "application/msword" + case ".docx": + contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + } + } + + // 根据 Content-Type 确定文件类型和限制 + var maxSize int64 + var allowedTypes []string + var fileType string + + limits := h.cfg.Storage.Upload.Limits + switch { + case strings.HasPrefix(contentType, "image/"): + maxSize = int64(limits.Image.MaxSize) * 1024 * 1024 + allowedTypes = limits.Image.AllowedTypes + fileType = "image" + case strings.HasPrefix(contentType, "video/"): + maxSize = int64(limits.Video.MaxSize) * 1024 * 1024 + allowedTypes = limits.Video.AllowedTypes + fileType = "video" + case strings.HasPrefix(contentType, "audio/"): + maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024 + allowedTypes = limits.Audio.AllowedTypes + fileType = "audio" + case strings.HasPrefix(contentType, "application/"): + maxSize = int64(limits.Document.MaxSize) * 1024 * 1024 + allowedTypes = limits.Document.AllowedTypes + fileType = "document" + default: + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Unsupported file type", + }) return } - // 文件类型限制 - allowedTypes := map[string]bool{ - "image/jpeg": true, - "image/png": true, - "image/gif": true, - "video/mp4": true, - "video/webm": true, - "audio/mpeg": true, - "audio/ogg": true, - "application/pdf": true, + // 检查文件类型是否允许 + typeAllowed := false + for _, allowed := range allowedTypes { + if contentType == allowed { + typeAllowed = true + break + } } - contentType := file.Header.Get("Content-Type") - if _, ok := allowedTypes[contentType]; !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"}) + if !typeAllowed { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType), + }) + return + } + + // 检查文件大小 + if file.Size > maxSize { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType), + }) return } // Upload file - media, err := h.service.Upload(c.Request.Context(), file, userID.(int)) + media, err := h.service.Upload(c.Request.Context(), file, userID) if err != nil { - log.Error().Err(err).Msg("Failed to upload media") + log.Error().Err(err). + Str("filename", file.Filename). + Str("content_type", contentType). + Int("user_id", userID). + Msg("Failed to upload media") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"}) return } - c.JSON(http.StatusCreated, media) + c.JSON(http.StatusCreated, gin.H{ + "data": media, + }) } func (h *Handler) GetMedia(c *gin.Context) { @@ -101,7 +184,7 @@ func (h *Handler) GetMedia(c *gin.Context) { } // Get file content - reader, info, err := h.service.GetFile(c.Request.Context(), id) + reader, info, err := h.service.GetFile(c.Request.Context(), media.StorageID) if err != nil { log.Error().Err(err).Msg("Failed to get media file") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"}) @@ -122,16 +205,18 @@ func (h *Handler) GetMedia(c *gin.Context) { } func (h *Handler) GetMediaFile(c *gin.Context) { - id, err := strconv.Atoi(c.Param("id")) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"}) - return - } + year := c.Param("year") + month := c.Param("month") + filename := c.Param("filename") // Get file content - reader, info, err := h.service.GetFile(c.Request.Context(), id) + reader, info, err := h.service.GetFile(c.Request.Context(), filename) // 直接使用完整的文件名 if err != nil { - log.Error().Err(err).Msg("Failed to get media file") + log.Error().Err(err). + Str("year", year). + Str("month", month). + Str("filename", filename). + Msg("Failed to get media file") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"}) return } @@ -151,20 +236,33 @@ func (h *Handler) GetMediaFile(c *gin.Context) { func (h *Handler) DeleteMedia(c *gin.Context) { // Get user ID from context (set by auth middleware) - userID, exists := c.Get("user_id") + userIDStr, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return } + // Convert user ID to int + userID, err := strconv.Atoi(userIDStr.(string)) + if err != nil { + log.Error().Err(err). + Str("user_id", fmt.Sprintf("%v", userIDStr)). + Msg("Failed to convert user ID to int") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) + return + } + id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"}) return } - if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil { - log.Error().Err(err).Msg("Failed to delete media") + if err := h.service.DeleteMedia(c.Request.Context(), id, userID); err != nil { + log.Error().Err(err). + Int("media_id", id). + Int("user_id", userID). + Msg("Failed to delete media") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"}) return } diff --git a/backend/internal/handler/media_handler_test.go b/backend/internal/handler/media_handler_test.go deleted file mode 100644 index 5452129..0000000 --- a/backend/internal/handler/media_handler_test.go +++ /dev/null @@ -1,524 +0,0 @@ -package handler - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "mime/multipart" - "net/http" - "net/http/httptest" - "strings" - "testing" - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service/mock" - "tss-rocks-be/internal/storage" - - "net/textproto" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" -) - -type MediaHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *MediaHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - s.handler = NewHandler(&config.Config{}, s.service) - s.router = gin.New() -} - -func (s *MediaHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestMediaHandlerSuite(t *testing.T) { - suite.Run(t, new(MediaHandlerTestSuite)) -} - -func (s *MediaHandlerTestSuite) TestListMedia() { - testCases := []struct { - name string - query string - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功列出媒体", - query: "?limit=10&offset=0", - setupMock: func() { - s.service.EXPECT(). - ListMedia(gomock.Any(), 10, 0). - Return([]*ent.Media{{ID: 1}}, nil) - }, - expectedStatus: http.StatusOK, - }, - { - name: "使用默认限制和偏移", - query: "", - setupMock: func() { - s.service.EXPECT(). - ListMedia(gomock.Any(), 10, 0). - Return([]*ent.Media{{ID: 1}}, nil) - }, - expectedStatus: http.StatusOK, - }, - { - name: "列出媒体失败", - query: "", - setupMock: func() { - s.service.EXPECT(). - ListMedia(gomock.Any(), 10, 0). - Return(nil, errors.New("failed to list media")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to list media", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // 执行请求 - s.handler.ListMedia(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedError, response["error"]) - } else { - var response []*ent.Media - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.NotEmpty(response) - } - }) - } -} - -func (s *MediaHandlerTestSuite) TestUploadMedia() { - testCases := []struct { - name string - setupRequest func() (*http.Request, error) - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功上传媒体", - setupRequest: func() (*http.Request, error) { - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - // 创建文件部分 - fileHeader := make(textproto.MIMEHeader) - fileHeader.Set("Content-Type", "image/jpeg") - fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`) - part, err := writer.CreatePart(fileHeader) - if err != nil { - return nil, err - } - testContent := "test content" - _, err = io.Copy(part, strings.NewReader(testContent)) - if err != nil { - return nil, err - } - writer.Close() - - req := httptest.NewRequest(http.MethodPost, "/media", body) - req.Header.Set("Content-Type", writer.FormDataContentType()) - return req, nil - }, - setupMock: func() { - expectedFile := &multipart.FileHeader{ - Filename: "test.jpg", - Size: int64(len("test content")), - Header: textproto.MIMEHeader{ - "Content-Type": []string{"image/jpeg"}, - }, - } - s.service.EXPECT(). - Upload(gomock.Any(), gomock.Any(), 1). - DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) { - s.Equal(expectedFile.Filename, f.Filename) - s.Equal(expectedFile.Size, f.Size) - s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type")) - return &ent.Media{ID: 1}, nil - }) - }, - expectedStatus: http.StatusCreated, - }, - { - name: "未授权", - setupRequest: func() (*http.Request, error) { - req := httptest.NewRequest(http.MethodPost, "/media", nil) - return req, nil - }, - setupMock: func() {}, - expectedStatus: http.StatusUnauthorized, - expectedError: "Unauthorized", - }, - { - name: "上传失败", - setupRequest: func() (*http.Request, error) { - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - // 创建文件部分 - fileHeader := make(textproto.MIMEHeader) - fileHeader.Set("Content-Type", "image/jpeg") - fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`) - part, err := writer.CreatePart(fileHeader) - if err != nil { - return nil, err - } - testContent := "test content" - _, err = io.Copy(part, strings.NewReader(testContent)) - if err != nil { - return nil, err - } - writer.Close() - - req := httptest.NewRequest(http.MethodPost, "/media", body) - req.Header.Set("Content-Type", writer.FormDataContentType()) - return req, nil - }, - setupMock: func() { - s.service.EXPECT(). - Upload(gomock.Any(), gomock.Any(), 1). - Return(nil, errors.New("failed to upload")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to upload media", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - req, err := tc.setupRequest() - s.NoError(err) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // 设置用户ID(除了未授权的测试用例) - if tc.expectedError != "Unauthorized" { - c.Set("user_id", 1) - } - - // 执行请求 - s.handler.UploadMedia(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedError, response["error"]) - } else { - var response *ent.Media - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.NotNil(response) - } - }) - } -} - -func (s *MediaHandlerTestSuite) TestGetMedia() { - testCases := []struct { - name string - mediaID string - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功获取媒体", - mediaID: "1", - setupMock: func() { - media := &ent.Media{ - ID: 1, - MimeType: "image/jpeg", - OriginalName: "test.jpg", - } - s.service.EXPECT(). - GetMedia(gomock.Any(), 1). - Return(media, nil) - s.service.EXPECT(). - GetFile(gomock.Any(), 1). - Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{ - Size: 11, - Name: "test.jpg", - ContentType: "image/jpeg", - }, nil) - }, - expectedStatus: http.StatusOK, - }, - { - name: "无效的媒体ID", - mediaID: "invalid", - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Invalid media ID", - }, - { - name: "获取媒体元数据失败", - mediaID: "1", - setupMock: func() { - s.service.EXPECT(). - GetMedia(gomock.Any(), 1). - Return(nil, errors.New("failed to get media")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to get media", - }, - { - name: "获取媒体文件失败", - mediaID: "1", - setupMock: func() { - media := &ent.Media{ - ID: 1, - MimeType: "image/jpeg", - OriginalName: "test.jpg", - } - s.service.EXPECT(). - GetMedia(gomock.Any(), 1). - Return(media, nil) - s.service.EXPECT(). - GetFile(gomock.Any(), 1). - Return(nil, nil, errors.New("failed to get file")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to get media file", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // Extract ID from URL path - parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") - if len(parts) >= 2 { - c.Params = []gin.Param{{Key: "id", Value: parts[1]}} - } - - // 执行请求 - s.handler.GetMedia(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedError, response["error"]) - } else { - s.Equal("image/jpeg", w.Header().Get("Content-Type")) - s.Equal("11", w.Header().Get("Content-Length")) - s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition")) - s.Equal("test content", w.Body.String()) - } - }) - } -} - -func (s *MediaHandlerTestSuite) TestGetMediaFile() { - testCases := []struct { - name string - setupRequest func() (*http.Request, error) - setupMock func() - expectedStatus int - expectedBody []byte - }{ - { - name: "成功获取媒体文件", - setupRequest: func() (*http.Request, error) { - return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil - }, - setupMock: func() { - fileContent := "test file content" - s.service.EXPECT(). - GetFile(gomock.Any(), 1). - Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{ - Name: "test.jpg", - Size: int64(len(fileContent)), - ContentType: "image/jpeg", - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []byte("test file content"), - }, - { - name: "无效的媒体ID", - setupRequest: func() (*http.Request, error) { - return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "获取媒体文件失败", - setupRequest: func() (*http.Request, error) { - return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil - }, - setupMock: func() { - s.service.EXPECT(). - GetFile(gomock.Any(), 1). - Return(nil, nil, errors.New("failed to get file")) - }, - expectedStatus: http.StatusInternalServerError, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup - req, err := tc.setupRequest() - s.Require().NoError(err) - - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // Extract ID from URL path - parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") - if len(parts) >= 2 { - c.Params = []gin.Param{{Key: "id", Value: parts[1]}} - } - - // Setup mock - tc.setupMock() - - // Test - s.handler.GetMediaFile(c) - - // Verify - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedBody != nil { - s.Equal(tc.expectedBody, w.Body.Bytes()) - s.Equal("image/jpeg", w.Header().Get("Content-Type")) - s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length")) - s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition")) - } - }) - } -} - -func (s *MediaHandlerTestSuite) TestDeleteMedia() { - testCases := []struct { - name string - mediaID string - setupMock func() - expectedStatus int - expectedError string - }{ - { - name: "成功删除媒体", - mediaID: "1", - setupMock: func() { - s.service.EXPECT(). - DeleteMedia(gomock.Any(), 1, 1). - Return(nil) - }, - expectedStatus: http.StatusNoContent, - }, - { - name: "未授权", - mediaID: "1", - setupMock: func() {}, - expectedStatus: http.StatusUnauthorized, - expectedError: "Unauthorized", - }, - { - name: "无效的媒体ID", - mediaID: "invalid", - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Invalid media ID", - }, - { - name: "删除媒体失败", - mediaID: "1", - setupMock: func() { - s.service.EXPECT(). - DeleteMedia(gomock.Any(), 1, 1). - Return(errors.New("failed to delete")) - }, - expectedStatus: http.StatusInternalServerError, - expectedError: "Failed to delete media", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // 设置 mock - tc.setupMock() - - // 创建请求 - req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = req - - // Extract ID from URL path - parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") - if len(parts) >= 2 { - c.Params = []gin.Param{{Key: "id", Value: parts[1]}} - } - - // 设置用户ID(除了未授权的测试用例) - if tc.expectedError != "Unauthorized" { - c.Set("user_id", 1) - } - - // 执行请求 - s.handler.DeleteMedia(c) - - // 验证响应 - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedError != "" { - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedError, response["error"]) - } - }) - } -} diff --git a/backend/internal/handler/post_handler_test.go b/backend/internal/handler/post_handler_test.go deleted file mode 100644 index 76d58ed..0000000 --- a/backend/internal/handler/post_handler_test.go +++ /dev/null @@ -1,624 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/service" - "tss-rocks-be/internal/service/mock" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "errors" - "strings" -) - -type PostHandlerTestSuite struct { - suite.Suite - ctrl *gomock.Controller - service *mock.MockService - handler *Handler - router *gin.Engine -} - -func (s *PostHandlerTestSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.service = mock.NewMockService(s.ctrl) - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - }, - } - s.handler = NewHandler(cfg, s.service) - - // Setup Gin router - gin.SetMode(gin.TestMode) - s.router = gin.New() - - // Setup mock for GetTokenBlacklist - tokenBlacklist := &service.TokenBlacklist{} - s.service.EXPECT(). - GetTokenBlacklist(). - Return(tokenBlacklist). - AnyTimes() - - s.handler.RegisterRoutes(s.router) -} - -func (s *PostHandlerTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func TestPostHandlerSuite(t *testing.T) { - suite.Run(t, new(PostHandlerTestSuite)) -} - -// Test cases for ListPosts -func (s *PostHandlerTestSuite) TestListPosts() { - categoryID := 1 - testCases := []struct { - name string - langCode string - categoryID string - limit string - offset string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success with default language", - langCode: "", - setupMock: func() { - s.service.EXPECT(). - ListPosts(gomock.Any(), "en", nil, 10, 0). - Return([]*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - }, - }, - }, - }, - }, - { - name: "Success with specific language", - langCode: "zh", - setupMock: func() { - s.service.EXPECT(). - ListPosts(gomock.Any(), "zh", nil, 10, 0). - Return([]*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "zh", - Title: "测试帖子", - ContentMarkdown: "测试内容", - Summary: "测试摘要", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "zh", - Title: "测试帖子", - ContentMarkdown: "测试内容", - Summary: "测试摘要", - }, - }, - }, - }, - }, - }, - { - name: "Success with category filter", - langCode: "en", - categoryID: "1", - setupMock: func() { - s.service.EXPECT(). - ListPosts(gomock.Any(), "en", &categoryID, 10, 0). - Return([]*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Post{ - { - ID: 1, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - }, - }, - }, - }, - }, - { - name: "Success with pagination", - langCode: "en", - limit: "2", - offset: "1", - setupMock: func() { - s.service.EXPECT(). - ListPosts(gomock.Any(), "en", nil, 2, 1). - Return([]*ent.Post{ - { - ID: 2, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post 2", - ContentMarkdown: "Test Content 2", - Summary: "Test Summary 2", - }, - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: []*ent.Post{ - { - ID: 2, - Status: "published", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post 2", - ContentMarkdown: "Test Content 2", - Summary: "Test Summary 2", - }, - }, - }, - }, - }, - }, - { - name: "Service Error", - langCode: "en", - setupMock: func() { - s.service.EXPECT(). - ListPosts(gomock.Any(), "en", nil, 10, 0). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup mock - tc.setupMock() - - // Create request - url := "/api/v1/posts" - if tc.langCode != "" { - url += "?lang=" + tc.langCode - } - if tc.categoryID != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "category_id=" + tc.categoryID - } - if tc.limit != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "limit=" + tc.limit - } - if tc.offset != "" { - if strings.Contains(url, "?") { - url += "&" - } else { - url += "?" - } - url += "offset=" + tc.offset - } - - req := httptest.NewRequest(http.MethodGet, url, nil) - w := httptest.NewRecorder() - - // Perform request - s.router.ServeHTTP(w, req) - - // Assert response - s.Equal(tc.expectedStatus, w.Code) - if tc.expectedBody != nil { - var response []*ent.Post - err := json.Unmarshal(w.Body.Bytes(), &response) - s.NoError(err) - s.Equal(tc.expectedBody, response) - } - }) - } -} - -// Test cases for GetPost -func (s *PostHandlerTestSuite) TestGetPost() { - testCases := []struct { - name string - langCode string - slug string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success with default language", - langCode: "", - slug: "test-post", - setupMock: func() { - s.service.EXPECT(). - GetPostBySlug(gomock.Any(), "en", "test-post"). - Return(&ent.Post{ - ID: 1, - Status: "published", - Slug: "test-post", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: gin.H{ - "id": 1, - "status": "published", - "slug": "test-post", - "edges": gin.H{ - "contents": []gin.H{ - { - "language_code": "en", - "title": "Test Post", - "content_markdown": "Test Content", - "summary": "Test Summary", - }, - }, - }, - }, - }, - { - name: "Success with specific language", - langCode: "zh", - slug: "test-post", - setupMock: func() { - s.service.EXPECT(). - GetPostBySlug(gomock.Any(), "zh", "test-post"). - Return(&ent.Post{ - ID: 1, - Status: "published", - Slug: "test-post", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{ - { - LanguageCode: "zh", - Title: "测试帖子", - ContentMarkdown: "测试内容", - Summary: "测试摘要", - }, - }, - }, - }, nil) - }, - expectedStatus: http.StatusOK, - expectedBody: gin.H{ - "id": 1, - "status": "published", - "slug": "test-post", - "edges": gin.H{ - "contents": []gin.H{ - { - "language_code": "zh", - "title": "测试帖子", - "content_markdown": "测试内容", - "summary": "测试摘要", - }, - }, - }, - }, - }, - { - name: "Service error", - slug: "test-post", - setupMock: func() { - s.service.EXPECT(). - GetPostBySlug(gomock.Any(), "en", "test-post"). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to get post"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - url := "/api/v1/posts/" + tc.slug - if tc.langCode != "" { - url += "?lang=" + tc.langCode - } - - req := httptest.NewRequest(http.MethodGet, url, nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -// Test cases for CreatePost -func (s *PostHandlerTestSuite) TestCreatePost() { - testCases := []struct { - name string - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - setupMock: func() { - s.service.EXPECT(). - CreatePost(gomock.Any(), "draft"). - Return(&ent.Post{ - ID: 1, - Status: "draft", - Edges: ent.PostEdges{ - Contents: []*ent.PostContent{}, - }, - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: gin.H{ - "id": 1, - "status": "draft", - "edges": gin.H{ - "contents": []gin.H{}, - }, - }, - }, - { - name: "Service error", - setupMock: func() { - s.service.EXPECT(). - CreatePost(gomock.Any(), "draft"). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to create post"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil) - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} - -// Test cases for AddPostContent -func (s *PostHandlerTestSuite) TestAddPostContent() { - testCases := []struct { - name string - postID string - body interface{} - setupMock func() - expectedStatus int - expectedBody interface{} - }{ - { - name: "Success", - postID: "1", - body: AddPostContentRequest{ - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - MetaKeywords: "test,keywords", - MetaDescription: "Test meta description", - }, - setupMock: func() { - s.service.EXPECT(). - AddPostContent( - gomock.Any(), - 1, - "en", - "Test Post", - "Test Content", - "Test Summary", - "test,keywords", - "Test meta description", - ). - Return(&ent.PostContent{ - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - MetaKeywords: "test,keywords", - MetaDescription: "Test meta description", - Edges: ent.PostContentEdges{}, - }, nil) - }, - expectedStatus: http.StatusCreated, - expectedBody: gin.H{ - "language_code": "en", - "title": "Test Post", - "content_markdown": "Test Content", - "summary": "Test Summary", - "meta_keywords": "test,keywords", - "meta_description": "Test meta description", - "edges": gin.H{}, - }, - }, - { - name: "Invalid post ID", - postID: "invalid", - body: AddPostContentRequest{ - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Invalid post ID"}, - }, - { - name: "Invalid request body", - postID: "1", - body: map[string]interface{}{ - "language_code": "en", - // Missing required fields - }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"}, - }, - { - name: "Service error", - postID: "1", - body: AddPostContentRequest{ - LanguageCode: "en", - Title: "Test Post", - ContentMarkdown: "Test Content", - Summary: "Test Summary", - MetaKeywords: "test,keywords", - MetaDescription: "Test meta description", - }, - setupMock: func() { - s.service.EXPECT(). - AddPostContent( - gomock.Any(), - 1, - "en", - "Test Post", - "Test Content", - "Test Summary", - "test,keywords", - "Test meta description", - ). - Return(nil, errors.New("service error")) - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: gin.H{"error": "Failed to add post content"}, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - tc.setupMock() - - body, err := json.Marshal(tc.body) - s.NoError(err, "Failed to marshal request body") - - req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - s.router.ServeHTTP(w, req) - - s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch") - - if tc.expectedBody != nil { - expectedJSON, err := json.Marshal(tc.expectedBody) - s.NoError(err, "Failed to marshal expected body") - s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch") - } - }) - } -} diff --git a/backend/internal/middleware/accesslog_test.go b/backend/internal/middleware/accesslog_test.go deleted file mode 100644 index 449513a..0000000 --- a/backend/internal/middleware/accesslog_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package middleware - -import ( - "bytes" - "io" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - "time" - "tss-rocks-be/internal/types" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestAccessLog(t *testing.T) { - // 设置测试临时目录 - tmpDir := t.TempDir() - logPath := filepath.Join(tmpDir, "test.log") - - testCases := []struct { - name string - config *types.AccessLogConfig - expectedError bool - setupRequest func(*http.Request) - validateOutput func(*testing.T, *httptest.ResponseRecorder, string) - }{ - { - name: "Console logging only", - config: &types.AccessLogConfig{ - EnableConsole: true, - EnableFile: false, - Format: "json", - Level: "info", - }, - expectedError: false, - setupRequest: func(req *http.Request) { - req.Header.Set("User-Agent", "test-agent") - }, - validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, logOutput, "GET /test") - assert.Contains(t, logOutput, "test-agent") - }, - }, - { - name: "File logging only", - config: &types.AccessLogConfig{ - EnableConsole: false, - EnableFile: true, - FilePath: logPath, - Format: "json", - Level: "info", - Rotation: struct { - MaxSize int `yaml:"max_size"` - MaxAge int `yaml:"max_age"` - MaxBackups int `yaml:"max_backups"` - Compress bool `yaml:"compress"` - LocalTime bool `yaml:"local_time"` - }{ - MaxSize: 1, - MaxAge: 1, - MaxBackups: 1, - Compress: false, - LocalTime: true, - }, - }, - expectedError: false, - setupRequest: func(req *http.Request) { - req.Header.Set("User-Agent", "test-agent") - }, - validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { - assert.Equal(t, http.StatusOK, w.Code) - - // 读取日志文件内容 - content, err := os.ReadFile(logPath) - assert.NoError(t, err) - assert.Contains(t, string(content), "GET /test") - assert.Contains(t, string(content), "test-agent") - }, - }, - { - name: "Both console and file logging", - config: &types.AccessLogConfig{ - EnableConsole: true, - EnableFile: true, - FilePath: logPath, - Format: "json", - Level: "info", - Rotation: struct { - MaxSize int `yaml:"max_size"` - MaxAge int `yaml:"max_age"` - MaxBackups int `yaml:"max_backups"` - Compress bool `yaml:"compress"` - LocalTime bool `yaml:"local_time"` - }{ - MaxSize: 1, - MaxAge: 1, - MaxBackups: 1, - Compress: false, - LocalTime: true, - }, - }, - expectedError: false, - setupRequest: func(req *http.Request) { - req.Header.Set("User-Agent", "test-agent") - }, - validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, logOutput, "GET /test") - assert.Contains(t, logOutput, "test-agent") - - // 读取日志文件内容 - content, err := os.ReadFile(logPath) - assert.NoError(t, err) - assert.Contains(t, string(content), "GET /test") - assert.Contains(t, string(content), "test-agent") - }, - }, - { - name: "With authenticated user", - config: &types.AccessLogConfig{ - EnableConsole: true, - EnableFile: false, - Format: "json", - Level: "info", - }, - expectedError: false, - setupRequest: func(req *http.Request) { - req.Header.Set("User-Agent", "test-agent") - }, - validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) { - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, logOutput, "GET /test") - assert.Contains(t, logOutput, "test-agent") - assert.Contains(t, logOutput, "test-user") - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // 捕获标准输出 - oldStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // 创建一个新的 gin 引擎 - gin.SetMode(gin.TestMode) - router := gin.New() - - // 创建访问日志中间件 - middleware, err := AccessLog(tc.config) - if tc.expectedError { - assert.Error(t, err) - return - } - assert.NoError(t, err) - - // 添加测试路由 - router.Use(middleware) - router.GET("/test", func(c *gin.Context) { - // 如果是测试认证用户的情况,设置用户ID - if tc.name == "With authenticated user" { - c.Set("user_id", "test-user") - } - c.Status(http.StatusOK) - }) - - // 创建测试请求 - req := httptest.NewRequest("GET", "/test", nil) - if tc.setupRequest != nil { - tc.setupRequest(req) - } - rec := httptest.NewRecorder() - - // 执行请求 - router.ServeHTTP(rec, req) - - // 恢复标准输出并获取输出内容 - w.Close() - var buf bytes.Buffer - io.Copy(&buf, r) - os.Stdout = oldStdout - - // 验证输出 - if tc.validateOutput != nil { - tc.validateOutput(t, rec, buf.String()) - } - - // 关闭日志文件 - if tc.config.EnableFile { - // 调用中间件函数来关闭日志文件 - middleware(nil) - // 等待一小段时间确保文件完全关闭 - time.Sleep(100 * time.Millisecond) - } - }) - } -} - -func TestAccessLogInvalidConfig(t *testing.T) { - testCases := []struct { - name string - config *types.AccessLogConfig - expectedError bool - }{ - { - name: "Invalid log level", - config: &types.AccessLogConfig{ - EnableConsole: true, - Level: "invalid_level", - }, - expectedError: false, // 应该使用默认的 info 级别 - }, - { - name: "Invalid file path", - config: &types.AccessLogConfig{ - EnableFile: true, - FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的 - }, - expectedError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := AccessLog(tc.config) - if tc.expectedError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/backend/internal/middleware/auth_test.go b/backend/internal/middleware/auth_test.go deleted file mode 100644 index 78e2f03..0000000 --- a/backend/internal/middleware/auth_test.go +++ /dev/null @@ -1,227 +0,0 @@ -package middleware - -import ( - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "testing" - "time" - - "tss-rocks-be/internal/service" -) - -func createTestToken(secret string, claims jwt.MapClaims) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - signedToken, err := token.SignedString([]byte(secret)) - if err != nil { - panic(fmt.Sprintf("Failed to sign token: %v", err)) - } - return signedToken -} - -func TestAuthMiddleware(t *testing.T) { - jwtSecret := "test-secret" - tokenBlacklist := service.NewTokenBlacklist() - - testCases := []struct { - name string - setupAuth func(*http.Request) - expectedStatus int - expectedBody map[string]string - checkUserData bool - expectedUserID string - expectedRoles []string - }{ - { - name: "No Authorization header", - setupAuth: func(req *http.Request) {}, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "Authorization header is required"}, - }, - { - name: "Invalid Authorization format", - setupAuth: func(req *http.Request) { - req.Header.Set("Authorization", "InvalidFormat") - }, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"}, - }, - { - name: "Invalid token", - setupAuth: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer invalid.token.here") - }, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "Invalid token"}, - }, - { - name: "Valid token", - setupAuth: func(req *http.Request) { - claims := jwt.MapClaims{ - "sub": "123", - "roles": []string{"admin", "editor"}, - "exp": time.Now().Add(time.Hour).Unix(), - } - token := createTestToken(jwtSecret, claims) - req.Header.Set("Authorization", "Bearer "+token) - }, - expectedStatus: http.StatusOK, - checkUserData: true, - expectedUserID: "123", - expectedRoles: []string{"admin", "editor"}, - }, - { - name: "Expired token", - setupAuth: func(req *http.Request) { - claims := jwt.MapClaims{ - "sub": "123", - "roles": []string{"user"}, - "exp": time.Now().Add(-time.Hour).Unix(), - } - token := createTestToken(jwtSecret, claims) - req.Header.Set("Authorization", "Bearer "+token) - }, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "Invalid token"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gin.SetMode(gin.TestMode) - router := gin.New() - - // 添加认证中间件 - router.Use(func(c *gin.Context) { - // 设置日志级别为 debug - gin.SetMode(gin.DebugMode) - c.Next() - }, AuthMiddleware(jwtSecret, tokenBlacklist)) - - // 测试路由 - router.GET("/test", func(c *gin.Context) { - if tc.checkUserData { - userID, exists := c.Get("user_id") - assert.True(t, exists, "user_id should exist in context") - assert.Equal(t, tc.expectedUserID, userID, "user_id should match") - - roles, exists := c.Get("user_roles") - assert.True(t, exists, "user_roles should exist in context") - assert.Equal(t, tc.expectedRoles, roles, "user_roles should match") - } - c.Status(http.StatusOK) - }) - - // 创建请求 - req := httptest.NewRequest("GET", "/test", nil) - tc.setupAuth(req) - rec := httptest.NewRecorder() - - // 执行请求 - router.ServeHTTP(rec, req) - - // 验证响应 - assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match") - - if tc.expectedBody != nil { - var response map[string]string - err := json.NewDecoder(rec.Body).Decode(&response) - assert.NoError(t, err, "Response body should be valid JSON") - assert.Equal(t, tc.expectedBody, response, "Response body should match") - } - }) - } -} - -func TestRoleMiddleware(t *testing.T) { - testCases := []struct { - name string - setupContext func(*gin.Context) - allowedRoles []string - expectedStatus int - expectedBody map[string]string - }{ - { - name: "No user roles", - setupContext: func(c *gin.Context) { - // 不设置用户角色 - }, - allowedRoles: []string{"admin"}, - expectedStatus: http.StatusUnauthorized, - expectedBody: map[string]string{"error": "User roles not found"}, - }, - { - name: "Invalid roles type", - setupContext: func(c *gin.Context) { - c.Set("user_roles", 123) // 设置错误类型的角色 - }, - allowedRoles: []string{"admin"}, - expectedStatus: http.StatusInternalServerError, - expectedBody: map[string]string{"error": "Invalid user roles type"}, - }, - { - name: "Insufficient permissions", - setupContext: func(c *gin.Context) { - c.Set("user_roles", []string{"user"}) - }, - allowedRoles: []string{"admin"}, - expectedStatus: http.StatusForbidden, - expectedBody: map[string]string{"error": "Insufficient permissions"}, - }, - { - name: "Allowed role", - setupContext: func(c *gin.Context) { - c.Set("user_roles", []string{"admin"}) - }, - allowedRoles: []string{"admin"}, - expectedStatus: http.StatusOK, - }, - { - name: "One of multiple allowed roles", - setupContext: func(c *gin.Context) { - c.Set("user_roles", []string{"user", "editor"}) - }, - allowedRoles: []string{"admin", "editor", "moderator"}, - expectedStatus: http.StatusOK, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gin.SetMode(gin.TestMode) - router := gin.New() - - // 添加角色中间件 - router.Use(func(c *gin.Context) { - tc.setupContext(c) - c.Next() - }, RoleMiddleware(tc.allowedRoles...)) - - // 测试路由 - router.GET("/test", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - - // 创建请求 - req := httptest.NewRequest("GET", "/test", nil) - rec := httptest.NewRecorder() - - // 执行请求 - router.ServeHTTP(rec, req) - - // 验证响应 - assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match") - - if tc.expectedBody != nil { - var response map[string]string - err := json.NewDecoder(rec.Body).Decode(&response) - assert.NoError(t, err, "Response body should be valid JSON") - assert.Equal(t, tc.expectedBody, response, "Response body should match") - } - }) - } -} diff --git a/backend/internal/middleware/cors_test.go b/backend/internal/middleware/cors_test.go deleted file mode 100644 index bf187f9..0000000 --- a/backend/internal/middleware/cors_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package middleware - -import ( - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "testing" -) - -func TestCORS(t *testing.T) { - testCases := []struct { - name string - method string - expectedStatus int - checkHeaders bool - }{ - { - name: "Normal GET request", - method: "GET", - expectedStatus: http.StatusOK, - checkHeaders: true, - }, - { - name: "OPTIONS request", - method: "OPTIONS", - expectedStatus: http.StatusNoContent, - checkHeaders: true, - }, - { - name: "POST request", - method: "POST", - expectedStatus: http.StatusOK, - checkHeaders: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // 创建一个新的 gin 引擎 - gin.SetMode(gin.TestMode) - router := gin.New() - - // 添加 CORS 中间件 - router.Use(CORS()) - - // 添加测试路由 - router.Any("/test", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - - // 创建测试请求 - req := httptest.NewRequest(tc.method, "/test", nil) - rec := httptest.NewRecorder() - - // 执行请求 - router.ServeHTTP(rec, req) - - // 验证状态码 - assert.Equal(t, tc.expectedStatus, rec.Code) - - if tc.checkHeaders { - // 验证 CORS 头部 - headers := rec.Header() - assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin")) - assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials")) - assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type") - assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization") - assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST") - assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET") - assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT") - assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE") - } - }) - } -} diff --git a/backend/internal/middleware/ratelimit_test.go b/backend/internal/middleware/ratelimit_test.go deleted file mode 100644 index 381f0af..0000000 --- a/backend/internal/middleware/ratelimit_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package middleware - -import ( - "encoding/json" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "testing" - "time" - "tss-rocks-be/internal/types" -) - -func TestRateLimit(t *testing.T) { - testCases := []struct { - name string - config *types.RateLimitConfig - setupTest func(*gin.Engine) - runTest func(*testing.T, *gin.Engine) - expectedStatus int - expectedBody map[string]string - }{ - { - name: "IP rate limit", - config: &types.RateLimitConfig{ - IPRate: 1, // 每秒1个请求 - IPBurst: 1, - }, - setupTest: func(router *gin.Engine) { - router.GET("/test", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - }, - runTest: func(t *testing.T, router *gin.Engine) { - // 第一个请求应该成功 - req := httptest.NewRequest("GET", "/test", nil) - req.RemoteAddr = "192.168.1.1:1234" - rec := httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - - // 第二个请求应该被限制 - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusTooManyRequests, rec.Code) - var response map[string]string - err := json.NewDecoder(rec.Body).Decode(&response) - assert.NoError(t, err) - assert.Equal(t, "too many requests from this IP", response["error"]) - - // 等待限流器重置 - time.Sleep(time.Second) - - // 第三个请求应该成功 - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - }, - }, - { - name: "Route rate limit", - config: &types.RateLimitConfig{ - IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流 - IPBurst: 10, - RouteRates: map[string]struct { - Rate int `yaml:"rate"` - Burst int `yaml:"burst"` - }{ - "/limited": { - Rate: 1, - Burst: 1, - }, - }, - }, - setupTest: func(router *gin.Engine) { - router.GET("/limited", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - router.GET("/unlimited", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - }, - runTest: func(t *testing.T, router *gin.Engine) { - // 测试限流路由 - req := httptest.NewRequest("GET", "/limited", nil) - req.RemoteAddr = "192.168.1.2:1234" - rec := httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - - // 等待一小段时间确保限流器生效 - time.Sleep(10 * time.Millisecond) - - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusTooManyRequests, rec.Code) - var response map[string]string - err := json.NewDecoder(rec.Body).Decode(&response) - assert.NoError(t, err) - assert.Equal(t, "too many requests for this route", response["error"]) - - // 测试未限流路由 - req = httptest.NewRequest("GET", "/unlimited", nil) - req.RemoteAddr = "192.168.1.2:1234" - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - - // 等待一小段时间确保限流器生效 - time.Sleep(10 * time.Millisecond) - - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - }, - }, - { - name: "Multiple IPs", - config: &types.RateLimitConfig{ - IPRate: 1, - IPBurst: 1, - }, - setupTest: func(router *gin.Engine) { - router.GET("/test", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - }, - runTest: func(t *testing.T, router *gin.Engine) { - // IP1 的请求 - req1 := httptest.NewRequest("GET", "/test", nil) - req1.RemoteAddr = "192.168.1.3:1234" - rec := httptest.NewRecorder() - router.ServeHTTP(rec, req1) - assert.Equal(t, http.StatusOK, rec.Code) - - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req1) - assert.Equal(t, http.StatusTooManyRequests, rec.Code) - - // IP2 的请求应该不受 IP1 的限制影响 - req2 := httptest.NewRequest("GET", "/test", nil) - req2.RemoteAddr = "192.168.1.4:1234" - rec = httptest.NewRecorder() - router.ServeHTTP(rec, req2) - assert.Equal(t, http.StatusOK, rec.Code) - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gin.SetMode(gin.TestMode) - router := gin.New() - - // 添加限流中间件 - router.Use(RateLimit(tc.config)) - - // 设置测试路由 - tc.setupTest(router) - - // 运行测试 - tc.runTest(t, router) - }) - } -} - -func TestRateLimiterCleanup(t *testing.T) { - config := &types.RateLimitConfig{ - IPRate: 1, - IPBurst: 1, - } - - rl := newRateLimiter(config) - - // 添加一些IP限流器 - ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"} - for _, ip := range ips { - rl.getLimiter(ip) - } - - // 验证IP限流器已创建 - rl.mu.RLock() - assert.Equal(t, len(ips), len(rl.ips)) - rl.mu.RUnlock() - - // 修改一些IP的最后访问时间为1小时前 - rl.mu.Lock() - rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour) - rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour) - rl.mu.Unlock() - - // 手动触发清理 - rl.mu.Lock() - for ip, limiter := range rl.ips { - if time.Since(limiter.lastSeen) > time.Hour { - delete(rl.ips, ip) - } - } - rl.mu.Unlock() - - // 验证过期的IP限流器已被删除 - rl.mu.RLock() - assert.Equal(t, 1, len(rl.ips)) - _, exists := rl.ips["192.168.1.3"] - assert.True(t, exists) - rl.mu.RUnlock() -} diff --git a/backend/internal/middleware/upload.go b/backend/internal/middleware/upload.go index 91f9e53..c5a3756 100644 --- a/backend/internal/middleware/upload.go +++ b/backend/internal/middleware/upload.go @@ -3,146 +3,120 @@ package middleware import ( "bytes" "fmt" - "io" "net/http" "path/filepath" "strings" + "tss-rocks-be/internal/config" + "github.com/gin-gonic/gin" - "tss-rocks-be/internal/types" ) -const ( - defaultMaxMemory = 32 << 20 // 32 MB - maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数 -) - -// ValidateUpload 创建文件上传验证中间件 -func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc { +// ValidateUpload 验证上传的文件 +func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc { return func(c *gin.Context) { - // 检查是否是multipart/form-data请求 - if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Content-Type must be multipart/form-data", - }) + // Get file from form + file, err := c.FormFile("file") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"}) c.Abort() return } - // 解析multipart表单 - if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("Failed to parse form: %v", err), - }) - c.Abort() - return - } + // 获取文件类型和扩展名 + contentType := file.Header.Get("Content-Type") + ext := strings.ToLower(filepath.Ext(file.Filename)) - form := c.Request.MultipartForm - if form == nil || form.File == nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "No file uploaded", - }) - c.Abort() - return - } - - // 遍历所有上传的文件 - for _, files := range form.File { - for _, file := range files { - // 检查文件大小 - if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节 - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize), - }) - c.Abort() - return - } - - // 检查文件扩展名 - ext := strings.ToLower(filepath.Ext(file.Filename)) - if !contains(cfg.AllowedExtensions, ext) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("File extension %s is not allowed", ext), - }) - c.Abort() - return - } - - // 打开文件 - src, err := file.Open() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("Failed to open file: %v", err), - }) - c.Abort() - return - } - defer src.Close() - - // 读取文件头部用于MIME类型检测 - header := make([]byte, maxHeaderBytes) - n, err := src.Read(header) - if err != nil && err != io.EOF { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("Failed to read file: %v", err), - }) - c.Abort() - return - } - header = header[:n] - - // 检测MIME类型 - contentType := http.DetectContentType(header) - if !contains(cfg.AllowedTypes, contentType) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("File type %s is not allowed", contentType), - }) - c.Abort() - return - } - - // 将文件指针重置到开始位置 - _, err = src.Seek(0, 0) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("Failed to read file: %v", err), - }) - c.Abort() - return - } - - // 将文件内容读入缓冲区 - buf := &bytes.Buffer{} - _, err = io.Copy(buf, src) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("Failed to read file: %v", err), - }) - c.Abort() - return - } - - // 将验证过的文件内容和类型保存到上下文中 - c.Set("validated_file_"+file.Filename, buf) - c.Set("validated_content_type_"+file.Filename, contentType) + // 如果 Content-Type 为空,尝试从文件扩展名判断 + if contentType == "" { + switch ext { + case ".jpg", ".jpeg": + contentType = "image/jpeg" + case ".png": + contentType = "image/png" + case ".gif": + contentType = "image/gif" + case ".webp": + contentType = "image/webp" + case ".mp4": + contentType = "video/mp4" + case ".webm": + contentType = "video/webm" + case ".mp3": + contentType = "audio/mpeg" + case ".ogg": + contentType = "audio/ogg" + case ".wav": + contentType = "audio/wav" + case ".pdf": + contentType = "application/pdf" + case ".doc": + contentType = "application/msword" + case ".docx": + contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" } } + // 根据 Content-Type 确定文件类型和限制 + var maxSize int64 + var allowedTypes []string + var fileType string + + limits := cfg.Limits + switch { + case strings.HasPrefix(contentType, "image/"): + maxSize = int64(limits.Image.MaxSize) * 1024 * 1024 + allowedTypes = limits.Image.AllowedTypes + fileType = "image" + case strings.HasPrefix(contentType, "video/"): + maxSize = int64(limits.Video.MaxSize) * 1024 * 1024 + allowedTypes = limits.Video.AllowedTypes + fileType = "video" + case strings.HasPrefix(contentType, "audio/"): + maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024 + allowedTypes = limits.Audio.AllowedTypes + fileType = "audio" + case strings.HasPrefix(contentType, "application/"): + maxSize = int64(limits.Document.MaxSize) * 1024 * 1024 + allowedTypes = limits.Document.AllowedTypes + fileType = "document" + default: + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Unsupported file type: %s", contentType), + }) + c.Abort() + return + } + + // 检查文件类型是否允许 + typeAllowed := false + for _, allowed := range allowedTypes { + if contentType == allowed { + typeAllowed = true + break + } + } + if !typeAllowed { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType), + }) + c.Abort() + return + } + + // 检查文件大小 + if file.Size > maxSize { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType), + }) + c.Abort() + return + } + c.Next() } } -// contains 检查切片中是否包含指定的字符串 -func contains(slice []string, str string) bool { - for _, s := range slice { - if s == str { - return true - } - } - return false -} - // GetValidatedFile 从上下文中获取验证过的文件内容 func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) { file, exists := c.Get("validated_file_" + filename) diff --git a/backend/internal/middleware/upload_test.go b/backend/internal/middleware/upload_test.go deleted file mode 100644 index 7434484..0000000 --- a/backend/internal/middleware/upload_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package middleware - -import ( - "bytes" - "encoding/json" - "io" - "mime/multipart" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "tss-rocks-be/internal/types" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) { - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - part, err := writer.CreateFormFile("file", filename) - if err != nil { - return nil, err - } - - _, err = io.Copy(part, bytes.NewReader(content)) - if err != nil { - return nil, err - } - - err = writer.Close() - if err != nil { - return nil, err - } - - req := httptest.NewRequest("POST", "/upload", body) - req.Header.Set("Content-Type", writer.FormDataContentType()) - return req, nil -} - -func TestValidateUpload(t *testing.T) { - tests := []struct { - name string - config *types.UploadConfig - filename string - content []byte - setupRequest func(*testing.T) *http.Request - expectedStatus int - expectedError string - }{ - { - name: "Valid image upload", - config: &types.UploadConfig{ - MaxSize: 5, // 5MB - AllowedExtensions: []string{".jpg", ".jpeg", ".png"}, - AllowedTypes: []string{"image/jpeg", "image/png"}, - }, - filename: "test.jpg", - content: []byte{ - 0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers - 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, - }, - expectedStatus: http.StatusOK, - }, - { - name: "Invalid file extension", - config: &types.UploadConfig{ - MaxSize: 5, - AllowedExtensions: []string{".jpg", ".jpeg", ".png"}, - AllowedTypes: []string{"image/jpeg", "image/png"}, - }, - filename: "test.txt", - content: []byte("test content"), - expectedStatus: http.StatusBadRequest, - expectedError: "File extension .txt is not allowed", - }, - { - name: "File too large", - config: &types.UploadConfig{ - MaxSize: 1, // 1MB - AllowedExtensions: []string{".jpg"}, - AllowedTypes: []string{"image/jpeg"}, - }, - filename: "large.jpg", - content: make([]byte, 2<<20), // 2MB - expectedStatus: http.StatusBadRequest, - expectedError: "File large.jpg exceeds maximum size of 1 MB", - }, - { - name: "Invalid content type", - config: &types.UploadConfig{ - MaxSize: 5, - AllowedExtensions: []string{".jpg"}, - AllowedTypes: []string{"image/jpeg"}, - }, - filename: "fake.jpg", - content: []byte("not a real image"), - expectedStatus: http.StatusBadRequest, - expectedError: "File type text/plain; charset=utf-8 is not allowed", - }, - { - name: "Missing file", - config: &types.UploadConfig{ - MaxSize: 5, - AllowedExtensions: []string{".jpg"}, - AllowedTypes: []string{"image/jpeg"}, - }, - setupRequest: func(t *testing.T) *http.Request { - req := httptest.NewRequest("POST", "/upload", strings.NewReader("")) - req.Header.Set("Content-Type", "multipart/form-data") - return req - }, - expectedStatus: http.StatusBadRequest, - expectedError: "Failed to parse form", - }, - { - name: "Invalid content type header", - config: &types.UploadConfig{ - MaxSize: 5, - AllowedExtensions: []string{".jpg"}, - AllowedTypes: []string{"image/jpeg"}, - }, - setupRequest: func(t *testing.T) *http.Request { - return httptest.NewRequest("POST", "/upload", nil) - }, - expectedStatus: http.StatusBadRequest, - expectedError: "Content-Type must be multipart/form-data", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - var req *http.Request - var err error - - if tt.setupRequest != nil { - req = tt.setupRequest(t) - } else { - req, err = createMultipartRequest(t, tt.filename, tt.content, "") - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - } - - c.Request = req - - middleware := ValidateUpload(tt.config) - middleware(c) - - assert.Equal(t, tt.expectedStatus, w.Code) - if tt.expectedError != "" { - var response map[string]string - err := json.NewDecoder(w.Body).Decode(&response) - assert.NoError(t, err) - assert.Contains(t, response["error"], tt.expectedError) - } - }) - } -} - -func TestGetValidatedFile(t *testing.T) { - tests := []struct { - name string - setupContext func(*gin.Context) - filename string - expectedFound bool - expectedError string - }{ - { - name: "Get existing file", - setupContext: func(c *gin.Context) { - // 创建测试文件内容 - content := []byte("test content") - buf := bytes.NewBuffer(content) - - // 设置验证过的文件和内容类型 - c.Set("validated_file_test.txt", buf) - c.Set("validated_content_type_test.txt", "text/plain") - }, - filename: "test.txt", - expectedFound: true, - }, - { - name: "File not found", - setupContext: func(c *gin.Context) { - // 不设置任何文件 - }, - filename: "nonexistent.txt", - expectedFound: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - if tt.setupContext != nil { - tt.setupContext(c) - } - - buffer, contentType, found := GetValidatedFile(c, tt.filename) - - assert.Equal(t, tt.expectedFound, found) - if tt.expectedFound { - assert.NotNil(t, buffer) - assert.NotEmpty(t, contentType) - } else { - assert.Nil(t, buffer) - assert.Empty(t, contentType) - } - }) - } -} - -func TestContains(t *testing.T) { - tests := []struct { - name string - slice []string - str string - expected bool - }{ - { - name: "String found in slice", - slice: []string{"a", "b", "c"}, - str: "b", - expected: true, - }, - { - name: "String not found in slice", - slice: []string{"a", "b", "c"}, - str: "d", - expected: false, - }, - { - name: "Empty slice", - slice: []string{}, - str: "a", - expected: false, - }, - { - name: "Empty string", - slice: []string{"a", "b", "c"}, - str: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := contains(tt.slice, tt.str) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/backend/internal/rbac/init_test.go b/backend/internal/rbac/init_test.go deleted file mode 100644 index 2da0158..0000000 --- a/backend/internal/rbac/init_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package rbac - -import ( - "context" - "testing" - - "tss-rocks-be/ent/enttest" - "tss-rocks-be/ent/role" - - _ "github.com/mattn/go-sqlite3" -) - -func TestInitializeRBAC(t *testing.T) { - // Create an in-memory SQLite client for testing - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - ctx := context.Background() - - // Test initialization - err := InitializeRBAC(ctx, client) - if err != nil { - t.Fatalf("Failed to initialize RBAC: %v", err) - } - - // Verify roles were created - for roleName := range DefaultRoles { - r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx) - if err != nil { - t.Errorf("Role %s was not created: %v", roleName, err) - } - - // Verify permissions for each role - perms, err := r.QueryPermissions().All(ctx) - if err != nil { - t.Errorf("Failed to query permissions for role %s: %v", roleName, err) - } - - expectedPerms := DefaultRoles[roleName] - permCount := 0 - for _, actions := range expectedPerms { - permCount += len(actions) - } - - if len(perms) != permCount { - t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount) - } - } -} - -func TestAssignRoleToUser(t *testing.T) { - // Create an in-memory SQLite client for testing - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - ctx := context.Background() - - // Initialize RBAC - err := InitializeRBAC(ctx, client) - if err != nil { - t.Fatalf("Failed to initialize RBAC: %v", err) - } - - // Create a test user - user, err := client.User.Create(). - SetEmail("test@example.com"). - SetUsername("testuser"). - SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy"). - Save(ctx) - if err != nil { - t.Fatalf("Failed to create test user: %v", err) - } - - // Test assigning role to user - err = AssignRoleToUser(ctx, client, user.ID, "editor") - if err != nil { - t.Fatalf("Failed to assign role to user: %v", err) - } - - // Verify role assignment - assignedRoles, err := user.QueryRoles().All(ctx) - if err != nil { - t.Fatalf("Failed to query user roles: %v", err) - } - - if len(assignedRoles) != 1 { - t.Errorf("Expected 1 role, got %d", len(assignedRoles)) - } - - if assignedRoles[0].Name != "editor" { - t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name) - } - - // Test assigning non-existent role - err = AssignRoleToUser(ctx, client, user.ID, "nonexistent") - if err == nil { - t.Error("Expected error when assigning non-existent role, got nil") - } -} diff --git a/backend/internal/server/database_test.go b/backend/internal/server/database_test.go deleted file mode 100644 index c8e35da..0000000 --- a/backend/internal/server/database_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package server - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestInitDatabase(t *testing.T) { - tests := []struct { - name string - driver string - dsn string - wantErr bool - errContains string - }{ - { - name: "success with sqlite3", - driver: "sqlite3", - dsn: "file:ent?mode=memory&cache=shared&_fk=1", - }, - { - name: "invalid driver", - driver: "invalid_driver", - dsn: "file:ent?mode=memory", - wantErr: true, - errContains: "unsupported driver", - }, - { - name: "invalid dsn", - driver: "sqlite3", - dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项 - wantErr: true, - errContains: "foreign_keys pragma is off", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - client, err := InitDatabase(ctx, tt.driver, tt.dsn) - - if tt.wantErr { - assert.Error(t, err) - if tt.errContains != "" { - assert.Contains(t, err.Error(), tt.errContains) - } - assert.Nil(t, client) - } else { - require.NoError(t, err) - assert.NotNil(t, client) - - // 测试数据库连接是否正常工作 - err = client.Schema.Create(ctx) - assert.NoError(t, err) - - // 清理 - client.Close() - } - }) - } -} diff --git a/backend/internal/server/ent_test.go b/backend/internal/server/ent_test.go deleted file mode 100644 index 1f71881..0000000 --- a/backend/internal/server/ent_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package server - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "tss-rocks-be/internal/config" -) - -func TestNewEntClient(t *testing.T) { - tests := []struct { - name string - cfg *config.Config - }{ - { - name: "default sqlite3 config", - cfg: &config.Config{ - Database: config.DatabaseConfig{ - Driver: "sqlite3", - DSN: "file:ent?mode=memory&cache=shared&_fk=1", - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := NewEntClient(tt.cfg) - assert.NotNil(t, client) - - // 验证客户端是否可以正常工作 - err := client.Schema.Create(context.Background()) - assert.NoError(t, err) - - // 清理 - client.Close() - }) - } -} diff --git a/backend/internal/server/server_test.go b/backend/internal/server/server_test.go deleted file mode 100644 index 41d552e..0000000 --- a/backend/internal/server/server_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package server - -import ( - "context" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/types" - "tss-rocks-be/ent/enttest" -) - -func TestNew(t *testing.T) { - // 创建测试配置 - cfg := &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - Port: 8080, - }, - Storage: config.StorageConfig{ - Type: "local", - Local: config.LocalStorage{ - RootDir: "testdata", - }, - Upload: types.UploadConfig{ - MaxSize: 10, - AllowedTypes: []string{"image/jpeg", "image/png"}, - AllowedExtensions: []string{".jpg", ".png"}, - }, - }, - RateLimit: types.RateLimitConfig{ - IPRate: 100, - IPBurst: 200, - RouteRates: map[string]struct { - Rate int `yaml:"rate"` - Burst int `yaml:"burst"` - }{ - "/api/v1/upload": {Rate: 10, Burst: 20}, - }, - }, - AccessLog: types.AccessLogConfig{ - EnableConsole: true, - EnableFile: true, - FilePath: "testdata/access.log", - Format: "json", - Level: "info", - Rotation: struct { - MaxSize int `yaml:"max_size"` - MaxAge int `yaml:"max_age"` - MaxBackups int `yaml:"max_backups"` - Compress bool `yaml:"compress"` - LocalTime bool `yaml:"local_time"` - }{ - MaxSize: 100, - MaxAge: 7, - MaxBackups: 3, - Compress: true, - LocalTime: true, - }, - }, - } - - // 创建测试数据库客户端 - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - // 测试服务器初始化 - s, err := New(cfg, client) - require.NoError(t, err) - assert.NotNil(t, s) - assert.NotNil(t, s.router) - assert.NotNil(t, s.handler) - assert.Equal(t, cfg, s.config) -} - -func TestNew_StorageError(t *testing.T) { - // 创建一个无效的存储配置 - cfg := &config.Config{ - Storage: config.StorageConfig{ - Type: "invalid_type", // 使用无效的存储类型 - }, - } - - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - s, err := New(cfg, client) - assert.Error(t, err) - assert.Nil(t, s) - assert.Contains(t, err.Error(), "failed to initialize storage") -} - -func TestServer_StartAndShutdown(t *testing.T) { - // 创建测试配置 - cfg := &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - Port: 0, // 使用随机端口 - }, - Storage: config.StorageConfig{ - Type: "local", - Local: config.LocalStorage{ - RootDir: "testdata", - }, - }, - RateLimit: types.RateLimitConfig{ - IPRate: 100, - IPBurst: 200, - }, - AccessLog: types.AccessLogConfig{ - EnableConsole: true, - Format: "json", - Level: "info", - }, - } - - // 创建测试数据库客户端 - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - // 初始化服务器 - s, err := New(cfg, client) - require.NoError(t, err) - - // 创建一个通道来接收服务器错误 - errChan := make(chan error, 1) - - // 在 goroutine 中启动服务器 - go func() { - err := s.Start() - if err != nil && err != http.ErrServerClosed { - errChan <- err - } - close(errChan) - }() - - // 给服务器一些时间启动 - time.Sleep(100 * time.Millisecond) - - // 测试关闭服务器 - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - err = s.Shutdown(ctx) - assert.NoError(t, err) - - // 检查服务器是否有错误发生 - err = <-errChan - assert.NoError(t, err) -} - -func TestServer_StartError(t *testing.T) { - // 创建一个配置,使用已经被占用的端口来触发错误 - cfg := &config.Config{ - Server: config.ServerConfig{ - Host: "localhost", - Port: 8899, // 使用固定端口以便测试 - }, - Storage: config.StorageConfig{ - Type: "local", - Local: config.LocalStorage{ - RootDir: "testdata", - }, - }, - } - - client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - defer client.Close() - - // 创建第一个服务器实例 - s1, err := New(cfg, client) - require.NoError(t, err) - - // 创建一个通道来接收服务器错误 - errChan := make(chan error, 1) - - // 启动第一个服务器 - go func() { - err := s1.Start() - if err != nil && err != http.ErrServerClosed { - errChan <- err - } - close(errChan) - }() - - // 给服务器一些时间启动 - time.Sleep(100 * time.Millisecond) - - // 尝试在同一端口启动第二个服务器,应该会失败 - s2, err := New(cfg, client) - require.NoError(t, err) - - err = s2.Start() - assert.Error(t, err) - - // 清理 - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - // 关闭第一个服务器 - err = s1.Shutdown(ctx) - assert.NoError(t, err) - - // 检查第一个服务器是否有错误发生 - err = <-errChan - assert.NoError(t, err) - - // 关闭第二个服务器 - err = s2.Shutdown(ctx) - assert.NoError(t, err) -} - -func TestServer_ShutdownWithNilServer(t *testing.T) { - s := &Server{} - err := s.Shutdown(context.Background()) - assert.NoError(t, err) -} diff --git a/backend/internal/service/impl.go b/backend/internal/service/impl.go index 216e348..98203ee 100644 --- a/backend/internal/service/impl.go +++ b/backend/internal/service/impl.go @@ -1,11 +1,14 @@ package service import ( + "bytes" "context" "errors" "fmt" "io" "mime/multipart" + "os" + "path/filepath" "sort" "strconv" "strings" @@ -26,6 +29,8 @@ import ( "tss-rocks-be/ent/user" "tss-rocks-be/internal/storage" + "github.com/chai2010/webp" + "github.com/disintegration/imaging" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) @@ -419,21 +424,71 @@ func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, us // Open the uploaded file src, err := openFile(file) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open file: %v", err) } defer src.Close() + // 获取文件类型和扩展名 + contentType := file.Header.Get("Content-Type") + ext := strings.ToLower(filepath.Ext(file.Filename)) + if contentType == "" { + // 如果 Content-Type 为空,尝试从文件扩展名判断 + switch ext { + case ".jpg", ".jpeg": + contentType = "image/jpeg" + case ".png": + contentType = "image/png" + case ".gif": + contentType = "image/gif" + case ".webp": + contentType = "image/webp" + case ".mp4": + contentType = "video/mp4" + case ".webm": + contentType = "video/webm" + case ".mp3": + contentType = "audio/mpeg" + case ".ogg": + contentType = "audio/ogg" + case ".wav": + contentType = "audio/wav" + case ".pdf": + contentType = "application/pdf" + case ".doc": + contentType = "application/msword" + case ".docx": + contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + } + } + + // 如果是图片,检查是否需要转换为 WebP + var fileToSave multipart.File = src + var finalContentType = contentType + if strings.HasPrefix(contentType, "image/") && contentType != "image/webp" { + // 转换为 WebP + webpFile, err := convertToWebP(src) + if err != nil { + return nil, fmt.Errorf("failed to convert image to WebP: %v", err) + } + fileToSave = webpFile + finalContentType = "image/webp" + ext = ".webp" + } + + // 生成带扩展名的存储文件名 + storageFilename := uuid.New().String() + ext + // Save the file to storage - fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src) + fileInfo, err := s.storage.Save(ctx, storageFilename, finalContentType, fileToSave) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to save file: %v", err) } // Create media record return s.client.Media.Create(). SetStorageID(fileInfo.ID). SetOriginalName(file.Filename). - SetMimeType(fileInfo.ContentType). + SetMimeType(finalContentType). SetSize(fileInfo.Size). SetURL(fileInfo.URL). SetCreatedBy(strconv.Itoa(userID)). @@ -444,13 +499,8 @@ func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) return s.client.Media.Get(ctx, id) } -func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) { - media, err := s.GetMedia(ctx, id) - if err != nil { - return nil, nil, err - } - - return s.storage.Get(ctx, media.StorageID) +func (s *serviceImpl) GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) { + return s.storage.Get(ctx, storageID) } func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error { @@ -476,7 +526,7 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error } // Post operations -func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) { +func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) { var postStatus post.Status switch status { case "draft": @@ -492,10 +542,25 @@ func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, // Generate a random slug slug := fmt.Sprintf("post-%s", uuid.New().String()[:8]) - return s.client.Post.Create(). + // Create post with categories + postCreate := s.client.Post.Create(). SetStatus(postStatus). - SetSlug(slug). - Save(ctx) + SetSlug(slug) + + // Add categories if provided + if len(categoryIDs) > 0 { + categories := make([]*ent.Category, 0, len(categoryIDs)) + for _, id := range categoryIDs { + category, err := s.client.Category.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get category %d: %w", id, err) + } + categories = append(categories, category) + } + postCreate.AddCategories(categories...) + } + + return postCreate.Save(ctx) } func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) { @@ -574,7 +639,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). - WithCategory(). + WithCategories(). All(ctx) if err != nil { return nil, fmt.Errorf("failed to get posts: %w", err) @@ -590,7 +655,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) return posts[0], nil } -func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) { +func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) { var languageCode postcontent.LanguageCode switch langCode { case "en": @@ -610,8 +675,8 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID Where(post.StatusEQ(post.StatusPublished)) // Add category filter if provided - if categoryID != nil { - query = query.Where(post.HasCategoryWith(category.ID(*categoryID))) + if len(categoryIDs) > 0 { + query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...))) } // Get unique post IDs @@ -638,7 +703,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID } // If no category filter is applied, only take the latest 5 posts - if categoryID == nil && len(postIDs) > 5 { + if len(categoryIDs) == 0 && len(postIDs) > 5 { postIDs = postIDs[:5] } @@ -666,7 +731,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). - WithCategory(). + WithCategories(). Order(ent.Desc(post.FieldCreatedAt)). All(ctx) if err != nil { @@ -1050,3 +1115,47 @@ func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID return s.client.Daily.DeleteOneID(id).Exec(ctx) } + +// convertToWebP 将图片转换为 WebP 格式 +func convertToWebP(src multipart.File) (multipart.File, error) { + // 读取原始图片 + img, err := imaging.Decode(src) + if err != nil { + return nil, fmt.Errorf("failed to decode image: %v", err) + } + + // 创建一个新的缓冲区来存储 WebP 图片 + buf := new(bytes.Buffer) + + // 将图片编码为 WebP 格式 + // 设置较高的质量以保持图片质量 + err = webp.Encode(buf, img, &webp.Options{ + Lossless: false, + Quality: 90, + }) + if err != nil { + return nil, fmt.Errorf("failed to encode image to WebP: %v", err) + } + + // 创建一个新的临时文件来存储转换后的图片 + tmpFile, err := os.CreateTemp("", "webp-*.webp") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %v", err) + } + + // 写入转换后的数据 + if _, err := io.Copy(tmpFile, buf); err != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) + return nil, fmt.Errorf("failed to write WebP data: %v", err) + } + + // 将文件指针移回开始位置 + if _, err := tmpFile.Seek(0, 0); err != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) + return nil, fmt.Errorf("failed to seek file: %v", err) + } + + return tmpFile, nil +} diff --git a/backend/internal/service/impl_test.go b/backend/internal/service/impl_test.go deleted file mode 100644 index abba3cb..0000000 --- a/backend/internal/service/impl_test.go +++ /dev/null @@ -1,1090 +0,0 @@ -package service - -import ( - "bytes" - "context" - "fmt" - "io" - "mime/multipart" - "net/textproto" - "strconv" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "tss-rocks-be/ent" - "tss-rocks-be/ent/categorycontent" - "tss-rocks-be/ent/dailycontent" - "tss-rocks-be/internal/storage" - "tss-rocks-be/internal/storage/mock" - "tss-rocks-be/internal/testutil" -) - -type ServiceImplTestSuite struct { - suite.Suite - ctx context.Context - client *ent.Client - storage *mock.MockStorage - ctrl *gomock.Controller - svc Service -} - -func (s *ServiceImplTestSuite) SetupTest() { - s.ctx = context.Background() - s.client = testutil.NewTestClient() - require.NotNil(s.T(), s.client) - - s.ctrl = gomock.NewController(s.T()) - s.storage = mock.NewMockStorage(s.ctrl) - s.svc = NewService(s.client, s.storage) - - // 清理数据库 - _, err := s.client.Category.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.CategoryContent.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.User.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.Role.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.Permission.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.Daily.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - _, err = s.client.DailyContent.Delete().Exec(s.ctx) - require.NoError(s.T(), err) - - // 初始化 RBAC 系统 - err = s.svc.InitializeRBAC(s.ctx) - require.NoError(s.T(), err) - - // Set default openFile function - openFile = func(fh *multipart.FileHeader) (multipart.File, error) { - return fh.Open() - } -} - -func (s *ServiceImplTestSuite) TearDownTest() { - s.ctrl.Finish() - s.client.Close() -} - -func TestServiceImplSuite(t *testing.T) { - suite.Run(t, new(ServiceImplTestSuite)) -} - -// mockMultipartFile implements multipart.File interface -type mockMultipartFile struct { - *bytes.Reader -} - -func (m *mockMultipartFile) Close() error { - return nil -} - -func (m *mockMultipartFile) ReadAt(p []byte, off int64) (n int, err error) { - return m.Reader.ReadAt(p, off) -} - -func (m *mockMultipartFile) Seek(offset int64, whence int) (int64, error) { - return m.Reader.Seek(offset, whence) -} - -func newMockMultipartFile(data []byte) *mockMultipartFile { - return &mockMultipartFile{ - Reader: bytes.NewReader(data), - } -} - -func (s *ServiceImplTestSuite) TestCreateUser() { - testCases := []struct { - name string - username string - email string - password string - role string - wantErr bool - }{ - { - name: "有效的用户", - username: "testuser", - email: "test@example.com", - password: "password123", - role: "user", - wantErr: false, - }, - { - name: "无效的邮箱", - username: "testuser2", - email: "invalid-email", - password: "password123", - role: "user", - wantErr: true, - }, - { - name: "空密码", - username: "testuser3", - email: "test3@example.com", - password: "", - role: "user", - wantErr: true, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role) - if tc.wantErr { - s.Error(err) - s.Nil(user) - } else { - s.NoError(err) - s.NotNil(user) - s.Equal(tc.email, user.Email) - s.Equal(tc.username, user.Username) - } - }) - } -} - -func (s *ServiceImplTestSuite) TestGetUserByEmail() { - // Create a test user first - email := "test@example.com" - password := "password123" - role := "user" - - user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role) - require.NoError(s.T(), err) - require.NotNil(s.T(), user) - - s.Run("Existing user", func() { - found, err := s.svc.GetUserByEmail(s.ctx, email) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), found) - assert.Equal(s.T(), email, found.Email) - }) - - s.Run("Non-existing user", func() { - found, err := s.svc.GetUserByEmail(s.ctx, "nonexistent@example.com") - assert.Error(s.T(), err) - assert.Nil(s.T(), found) - }) -} - -func (s *ServiceImplTestSuite) TestValidatePassword() { - // Create a test user first - email := "test@example.com" - password := "password123" - role := "user" - - user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role) - require.NoError(s.T(), err) - require.NotNil(s.T(), user) - - s.Run("Valid password", func() { - valid := s.svc.ValidatePassword(s.ctx, user, password) - assert.True(s.T(), valid) - }) - - s.Run("Invalid password", func() { - valid := s.svc.ValidatePassword(s.ctx, user, "wrongpassword") - assert.False(s.T(), valid) - }) -} - -func (s *ServiceImplTestSuite) TestRBAC() { - s.Run("AssignRole", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin") - require.NoError(s.T(), err) - - err = s.svc.AssignRole(s.ctx, user.ID, "user") - assert.NoError(s.T(), err) - }) - - s.Run("RemoveRole", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin") - require.NoError(s.T(), err) - - err = s.svc.RemoveRole(s.ctx, user.ID, "admin") - assert.NoError(s.T(), err) - }) - - s.Run("HasPermission", func() { - s.Run("Admin can create users", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") - require.NoError(s.T(), err) - assert.True(s.T(), hasPermission) - }) - - s.Run("Editor cannot create users", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") - require.NoError(s.T(), err) - assert.False(s.T(), hasPermission) - }) - - s.Run("User cannot create users", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") - require.NoError(s.T(), err) - assert.False(s.T(), hasPermission) - }) - - s.Run("Editor can create posts", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") - require.NoError(s.T(), err) - assert.True(s.T(), hasPermission) - }) - - s.Run("User can read posts", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read") - require.NoError(s.T(), err) - assert.True(s.T(), hasPermission) - }) - - s.Run("User cannot create posts", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user") - require.NoError(s.T(), err) - - hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") - require.NoError(s.T(), err) - assert.False(s.T(), hasPermission) - }) - - s.Run("Invalid permission format", func() { - user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user") - require.NoError(s.T(), err) - - _, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission") - require.Error(s.T(), err) - assert.Contains(s.T(), err.Error(), "invalid permission format") - }) - }) -} - -func (s *ServiceImplTestSuite) TestCategory() { - // Create a test user with admin role for testing - adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin") - require.NoError(s.T(), err) - require.NotNil(s.T(), adminUser) - - s.Run("CreateCategory", func() { - // Test category creation - category, err := s.svc.CreateCategory(s.ctx) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), category) - assert.NotZero(s.T(), category.ID) - }) - - s.Run("AddCategoryContent", func() { - // Create a category first - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - require.NotNil(s.T(), category) - - testCases := []struct { - name string - langCode string - catName string - desc string - slug string - wantError bool - }{ - { - name: "Valid category content", - langCode: "en", - catName: "Test Category", - desc: "Test Description", - slug: "test-category", - wantError: false, - }, - { - name: "Empty language code", - langCode: "", - catName: "Test Category", - desc: "Test Description", - slug: "test-category-2", - wantError: true, - }, - { - name: "Empty name", - langCode: "en", - catName: "", - desc: "Test Description", - slug: "test-category-3", - wantError: true, - }, - { - name: "Empty slug", - langCode: "en", - catName: "Test Category", - desc: "Test Description", - slug: "", - wantError: true, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - content, err := s.svc.AddCategoryContent(s.ctx, category.ID, tc.langCode, tc.catName, tc.desc, tc.slug) - if tc.wantError { - assert.Error(s.T(), err) - assert.Nil(s.T(), content) - } else { - assert.NoError(s.T(), err) - assert.NotNil(s.T(), content) - assert.Equal(s.T(), categorycontent.LanguageCode(tc.langCode), content.LanguageCode) - assert.Equal(s.T(), tc.catName, content.Name) - assert.Equal(s.T(), tc.desc, content.Description) - assert.Equal(s.T(), tc.slug, content.Slug) - } - }) - } - }) - - s.Run("GetCategoryBySlug", func() { - // Create a category with content first - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - require.NotNil(s.T(), category) - - content, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category-get") - require.NoError(s.T(), err) - require.NotNil(s.T(), content) - - s.Run("Existing category", func() { - found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "test-category-get") - assert.NoError(s.T(), err) - assert.NotNil(s.T(), found) - assert.Equal(s.T(), category.ID, found.ID) - - // Check if content is loaded - require.NotEmpty(s.T(), found.Edges.Contents) - assert.Equal(s.T(), "Test Category", found.Edges.Contents[0].Name) - }) - - s.Run("Non-existing category", func() { - found, err := s.svc.GetCategoryBySlug(s.ctx, "en", "non-existent") - assert.Error(s.T(), err) - assert.Nil(s.T(), found) - }) - - s.Run("Wrong language code", func() { - found, err := s.svc.GetCategoryBySlug(s.ctx, "fr", "test-category-get") - assert.Error(s.T(), err) - assert.Nil(s.T(), found) - }) - }) - - s.Run("ListCategories", func() { - s.Run("List English categories", func() { - // 创建多个分类,但只有 3 个有英文内容 - var createdCategories []*ent.Category - for i := 0; i < 5; i++ { - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - require.NotNil(s.T(), category) - createdCategories = append(createdCategories, category) - - // 只给前 3 个分类添加英文内容 - if i < 3 { - _, err = s.svc.AddCategoryContent(s.ctx, category.ID, "en", - fmt.Sprintf("Category %d", i), - fmt.Sprintf("Description %d", i), - fmt.Sprintf("category-list-%d", i)) - require.NoError(s.T(), err) - } - } - - categories, err := s.svc.ListCategories(s.ctx, "en") - assert.NoError(s.T(), err) - assert.NotNil(s.T(), categories) - assert.Len(s.T(), categories, 3) - - // 检查所有返回的分类都有英文内容 - for _, cat := range categories { - assert.NotEmpty(s.T(), cat.Edges.Contents) - for _, content := range cat.Edges.Contents { - assert.Equal(s.T(), categorycontent.LanguageCodeEN, content.LanguageCode) - } - } - }) - - s.Run("List Chinese categories", func() { - // 创建多个分类,但只有 2 个有中文内容 - for i := 0; i < 4; i++ { - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - require.NotNil(s.T(), category) - - // 只给前 2 个分类添加中文内容 - if i < 2 { - _, err = s.svc.AddCategoryContent(s.ctx, category.ID, "zh-Hans", - fmt.Sprintf("分类 %d", i), - fmt.Sprintf("描述 %d", i), - fmt.Sprintf("category-list-%d", i)) - require.NoError(s.T(), err) - } - } - - categories, err := s.svc.ListCategories(s.ctx, "zh-Hans") - assert.NoError(s.T(), err) - assert.NotNil(s.T(), categories) - assert.Len(s.T(), categories, 2) - - // 检查所有返回的分类都有中文内容 - for _, cat := range categories { - assert.NotEmpty(s.T(), cat.Edges.Contents) - for _, content := range cat.Edges.Contents { - assert.Equal(s.T(), categorycontent.LanguageCodeZH_HANS, content.LanguageCode) - } - } - }) - - s.Run("List non-existing language", func() { - categories, err := s.svc.ListCategories(s.ctx, "fr") - assert.NoError(s.T(), err) - assert.Empty(s.T(), categories) - }) - }) -} - -func (s *ServiceImplTestSuite) TestGetCategories() { - ctx := context.Background() - - // 测试不支持的语言代码 - categories, err := s.svc.GetCategories(ctx, "invalid") - s.Require().NoError(err) - s.Empty(categories) - - // 创建测试数据 - cat1 := s.createTestCategory(ctx, "test-cat-1") - cat2 := s.createTestCategory(ctx, "test-cat-2") - - // 为分类添加不同语言的内容 - _, err = s.svc.AddCategoryContent(ctx, cat1.ID, "en", "Test Category 1", "Test Description 1", "category-list-test-1") - s.Require().NoError(err) - - _, err = s.svc.AddCategoryContent(ctx, cat2.ID, "zh-Hans", "测试分类2", "测试描述2", "category-list-test-2") - s.Require().NoError(err) - - // 测试获取英文分类 - enCategories, err := s.svc.GetCategories(ctx, "en") - s.Require().NoError(err) - s.Len(enCategories, 1) - s.Equal(cat1.ID, enCategories[0].ID) - - // 测试获取简体中文分类 - zhCategories, err := s.svc.GetCategories(ctx, "zh-Hans") - s.Require().NoError(err) - s.Len(zhCategories, 1) - s.Equal(cat2.ID, zhCategories[0].ID) - - // 测试获取繁体中文分类(应该为空) - zhHantCategories, err := s.svc.GetCategories(ctx, "zh-Hant") - s.Require().NoError(err) - s.Empty(zhHantCategories) -} - -func (s *ServiceImplTestSuite) TestGetUserRoles() { - ctx := context.Background() - - // 创建测试用户,默认会有 "user" 角色 - user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user") - s.Require().NoError(err) - - // 测试新用户有默认的 "user" 角色 - roles, err := s.svc.GetUserRoles(ctx, user.ID) - s.Require().NoError(err) - s.Len(roles, 1) - s.Equal("user", roles[0].Name) - - // 分配角色给用户 - err = s.svc.AssignRole(ctx, user.ID, "admin") - s.Require().NoError(err) - - // 测试用户现在有两个角色 - roles, err = s.svc.GetUserRoles(ctx, user.ID) - s.Require().NoError(err) - s.Len(roles, 2) - roleNames := []string{roles[0].Name, roles[1].Name} - s.Contains(roleNames, "user") - s.Contains(roleNames, "admin") - - // 测试不存在的用户 - _, err = s.svc.GetUserRoles(ctx, -1) - s.Require().Error(err) -} - -func (s *ServiceImplTestSuite) TestDaily() { - // 创建一个测试分类 - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - require.NotNil(s.T(), category) - - // 添加分类内容 - categoryContent, err := s.svc.AddCategoryContent(s.ctx, category.ID, "en", "Test Category", "Test Description", "test-category") - require.NoError(s.T(), err) - require.NotNil(s.T(), categoryContent) - - dailyID := "250212" // 使用符合验证规则的 ID 格式:YYMMDD - - // 测试创建 Daily - s.Run("Create Daily", func() { - daily, err := s.svc.CreateDaily(s.ctx, dailyID, category.ID, "http://example.com/image.jpg") - require.NoError(s.T(), err) - require.NotNil(s.T(), daily) - assert.Equal(s.T(), dailyID, daily.ID) - assert.Equal(s.T(), category.ID, daily.Edges.Category.ID) - assert.Equal(s.T(), "http://example.com/image.jpg", daily.ImageURL) - }) - - // 测试添加 Daily 内容 - s.Run("Add Daily Content", func() { - content, err := s.svc.AddDailyContent(s.ctx, dailyID, "en", "Test quote for the day") - require.NoError(s.T(), err) - require.NotNil(s.T(), content) - assert.Equal(s.T(), dailycontent.LanguageCodeEN, content.LanguageCode) - assert.Equal(s.T(), "Test quote for the day", content.Quote) - }) - - // 测试获取 Daily - s.Run("Get Daily By ID", func() { - daily, err := s.svc.GetDailyByID(s.ctx, dailyID) - require.NoError(s.T(), err) - require.NotNil(s.T(), daily) - assert.Equal(s.T(), dailyID, daily.ID) - assert.Equal(s.T(), category.ID, daily.Edges.Category.ID) - }) - - // 测试列出 Daily - s.Run("List Dailies", func() { - // 创建另一个 Daily 用于测试列表 - anotherDailyID := "250213" - _, err := s.svc.CreateDaily(s.ctx, anotherDailyID, category.ID, "http://example.com/image2.jpg") - assert.NoError(s.T(), err) - _, err = s.svc.AddDailyContent(s.ctx, anotherDailyID, "en", "Another test quote") - assert.NoError(s.T(), err) - - // 测试列表功能 - dailies, err := s.svc.ListDailies(s.ctx, "en", &category.ID, 10, 0) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), dailies) - assert.Len(s.T(), dailies, 2) - - // 测试分页 - dailies, err = s.svc.ListDailies(s.ctx, "en", &category.ID, 1, 0) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), dailies) - assert.Len(s.T(), dailies, 1) - - // 测试无分类过滤 - dailies, err = s.svc.ListDailies(s.ctx, "en", nil, 10, 0) - assert.NoError(s.T(), err) - assert.NotNil(s.T(), dailies) - assert.Len(s.T(), dailies, 2) - }) -} - -func (s *ServiceImplTestSuite) TestPost() { - s.Run("Create Post", func() { - s.Run("Draft", func() { - post, err := s.svc.CreatePost(s.ctx, "draft") - require.NoError(s.T(), err) - require.NotNil(s.T(), post) - assert.Equal(s.T(), "draft", post.Status.String()) - }) - - s.Run("Published", func() { - post, err := s.svc.CreatePost(s.ctx, "published") - require.NoError(s.T(), err) - require.NotNil(s.T(), post) - assert.Equal(s.T(), "published", post.Status.String()) - }) - - s.Run("Archived", func() { - post, err := s.svc.CreatePost(s.ctx, "archived") - require.NoError(s.T(), err) - require.NotNil(s.T(), post) - assert.Equal(s.T(), "archived", post.Status.String()) - }) - - s.Run("Invalid Status", func() { - post, err := s.svc.CreatePost(s.ctx, "invalid") - assert.Error(s.T(), err) - assert.Nil(s.T(), post) - }) - }) - - s.Run("Add Post Content", func() { - // Create a post first - post, err := s.svc.CreatePost(s.ctx, "draft") - require.NoError(s.T(), err) - - s.Run("English Content", func() { - content, err := s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description") - require.NoError(s.T(), err) - require.NotNil(s.T(), content) - assert.Equal(s.T(), "en", content.LanguageCode.String()) - assert.Equal(s.T(), "Test Post", content.Title) - assert.Equal(s.T(), "# Test Content", content.ContentMarkdown) - assert.Equal(s.T(), "Test Summary", content.Summary) - assert.Equal(s.T(), "test,post", content.MetaKeywords) - assert.Equal(s.T(), "Test Description", content.MetaDescription) - assert.Equal(s.T(), "test-post", content.Slug) - }) - - s.Run("Simplified Chinese Content", func() { - content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述") - require.NoError(s.T(), err) - require.NotNil(s.T(), content) - assert.Equal(s.T(), "zh-Hans", content.LanguageCode.String()) - assert.Equal(s.T(), "测试帖子", content.Title) - assert.Equal(s.T(), "# 测试内容", content.ContentMarkdown) - assert.Equal(s.T(), "测试摘要", content.Summary) - assert.Equal(s.T(), "测试,帖子", content.MetaKeywords) - assert.Equal(s.T(), "测试描述", content.MetaDescription) - assert.Equal(s.T(), "测试帖子", content.Slug) - }) - - s.Run("Traditional Chinese Content", func() { - content, err := s.svc.AddPostContent(s.ctx, post.ID, "zh-Hant", "測試貼文", "# 測試內容", "測試摘要", "測試,貼文", "測試描述") - require.NoError(s.T(), err) - require.NotNil(s.T(), content) - assert.Equal(s.T(), "zh-Hant", content.LanguageCode.String()) - assert.Equal(s.T(), "測試貼文", content.Title) - assert.Equal(s.T(), "# 測試內容", content.ContentMarkdown) - assert.Equal(s.T(), "測試摘要", content.Summary) - assert.Equal(s.T(), "測試,貼文", content.MetaKeywords) - assert.Equal(s.T(), "測試描述", content.MetaDescription) - assert.Equal(s.T(), "測試貼文", content.Slug) - }) - - s.Run("Invalid Language Code", func() { - content, err := s.svc.AddPostContent(s.ctx, post.ID, "fr", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description") - assert.Error(s.T(), err) - assert.Nil(s.T(), content) - }) - - s.Run("Non-existent Post", func() { - content, err := s.svc.AddPostContent(s.ctx, 999999, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description") - assert.Error(s.T(), err) - assert.Nil(s.T(), content) - }) - }) - - s.Run("Get Post By Slug", func() { - // Create a post first - post, err := s.svc.CreatePost(s.ctx, "published") - require.NoError(s.T(), err) - - // Add content in different languages - _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", "Test Post", "# Test Content", "Test Summary", "test,post", "Test Description") - require.NoError(s.T(), err) - _, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", "测试帖子", "# 测试内容", "测试摘要", "测试,帖子", "测试描述") - require.NoError(s.T(), err) - - s.Run("Get Post By Slug - English", func() { - result, err := s.svc.GetPostBySlug(s.ctx, "en", "test-post") - require.NoError(s.T(), err) - require.NotNil(s.T(), result) - assert.Equal(s.T(), post.ID, result.ID) - assert.Equal(s.T(), "published", result.Status.String()) - - contents := result.Edges.Contents - require.Len(s.T(), contents, 1) - assert.Equal(s.T(), "en", contents[0].LanguageCode.String()) - assert.Equal(s.T(), "Test Post", contents[0].Title) - }) - - s.Run("Get Post By Slug - Chinese", func() { - result, err := s.svc.GetPostBySlug(s.ctx, "zh-Hans", "测试帖子") - require.NoError(s.T(), err) - require.NotNil(s.T(), result) - assert.Equal(s.T(), post.ID, result.ID) - assert.Equal(s.T(), "published", result.Status.String()) - - contents := result.Edges.Contents - require.Len(s.T(), contents, 1) - assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String()) - assert.Equal(s.T(), "测试帖子", contents[0].Title) - }) - - s.Run("Non-existent Post", func() { - result, err := s.svc.GetPostBySlug(s.ctx, "en", "non-existent") - assert.Error(s.T(), err) - assert.Nil(s.T(), result) - }) - - s.Run("Invalid Language Code", func() { - result, err := s.svc.GetPostBySlug(s.ctx, "fr", "test-post") - assert.Error(s.T(), err) - assert.Nil(s.T(), result) - }) - }) - - s.Run("List Posts", func() { - // Create some posts with content - for i := 0; i < 5; i++ { - post, err := s.svc.CreatePost(s.ctx, "published") - require.NoError(s.T(), err) - - // Add content in different languages - _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Post %d", i), "# Content", "Summary", "test", "Description") - require.NoError(s.T(), err) - _, err = s.svc.AddPostContent(s.ctx, post.ID, "zh-Hans", fmt.Sprintf("帖子 %d", i), "# 内容", "摘要", "测试", "描述") - require.NoError(s.T(), err) - } - - s.Run("List All Posts - English", func() { - posts, err := s.svc.ListPosts(s.ctx, "en", nil, 10, 0) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 5) - - // Check that all posts have English content - for _, post := range posts { - contents := post.Edges.Contents - require.Len(s.T(), contents, 1) - assert.Equal(s.T(), "en", contents[0].LanguageCode.String()) - } - }) - - s.Run("List All Posts - Chinese", func() { - posts, err := s.svc.ListPosts(s.ctx, "zh-Hans", nil, 10, 0) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 5) - - // Check that all posts have Chinese content - for _, post := range posts { - contents := post.Edges.Contents - require.Len(s.T(), contents, 1) - assert.Equal(s.T(), "zh-Hans", contents[0].LanguageCode.String()) - } - }) - - s.Run("List Posts with Pagination", func() { - // Get first page - posts, err := s.svc.ListPosts(s.ctx, "en", nil, 2, 0) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 2) - - // Get second page - posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 2) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 2) - - // Get last page - posts, err = s.svc.ListPosts(s.ctx, "en", nil, 2, 4) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 1) - }) - - s.Run("List Posts by Category", func() { - // Create a category - category, err := s.svc.CreateCategory(s.ctx) - require.NoError(s.T(), err) - - // Create posts in this category - for i := 0; i < 3; i++ { - post, err := s.svc.CreatePost(s.ctx, "published") - require.NoError(s.T(), err) - - // Set category - _, err = s.client.Post.UpdateOne(post).SetCategoryID(category.ID).Save(s.ctx) - require.NoError(s.T(), err) - - // Add content - _, err = s.svc.AddPostContent(s.ctx, post.ID, "en", fmt.Sprintf("Category Post %d", i), "# Content", "Summary", "test", "Description") - require.NoError(s.T(), err) - } - - // List posts in this category - posts, err := s.svc.ListPosts(s.ctx, "en", &category.ID, 10, 0) - require.NoError(s.T(), err) - require.Len(s.T(), posts, 3) - - // Check that all posts belong to the category - for _, post := range posts { - assert.Equal(s.T(), category.ID, post.Edges.Category.ID) - } - }) - - s.Run("Invalid Language Code", func() { - posts, err := s.svc.ListPosts(s.ctx, "fr", nil, 10, 0) - assert.Error(s.T(), err) - assert.Nil(s.T(), posts) - }) - }) -} - -func (s *ServiceImplTestSuite) TestMedia() { - s.Run("Upload Media", func() { - // Create a user first - user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user") - require.NoError(s.T(), err) - require.NotNil(s.T(), user) - - // Mock file content - fileContent := []byte("test file content") - - // Mock the file header - fileHeader := &multipart.FileHeader{ - Filename: "test.jpg", - Size: int64(len(fileContent)), - Header: textproto.MIMEHeader{ - "Content-Type": []string{"image/jpeg"}, - }, - } - - // Mock the storage behavior - s.storage.EXPECT(). - Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()). - DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) { - // Verify the reader content - data, err := io.ReadAll(reader) - if err != nil { - return nil, err - } - if !bytes.Equal(data, fileContent) { - return nil, fmt.Errorf("unexpected file content") - } - return &storage.FileInfo{ - ID: "test123", - Name: name, - Size: int64(len(fileContent)), - ContentType: contentType, - URL: "http://example.com/test.jpg", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }, nil - }).Times(1) - - // Replace the Open method - openFile = func(fh *multipart.FileHeader) (multipart.File, error) { - return &mockMultipartFile{bytes.NewReader(fileContent)}, nil - } - - // Test upload - media, err := s.svc.Upload(s.ctx, fileHeader, user.ID) - require.NoError(s.T(), err) - require.NotNil(s.T(), media) - assert.Equal(s.T(), "test123", media.StorageID) - assert.Equal(s.T(), "test.jpg", media.OriginalName) - assert.Equal(s.T(), int64(len(fileContent)), media.Size) - assert.Equal(s.T(), "image/jpeg", media.MimeType) - assert.Equal(s.T(), "http://example.com/test.jpg", media.URL) - assert.Equal(s.T(), strconv.Itoa(user.ID), media.CreatedBy) - - // Now we can test other operations since we have a media record - s.Run("Get Media", func() { - result, err := s.svc.GetMedia(s.ctx, media.ID) - require.NoError(s.T(), err) - require.NotNil(s.T(), result) - assert.Equal(s.T(), media.ID, result.ID) - assert.Equal(s.T(), media.StorageID, result.StorageID) - assert.Equal(s.T(), media.URL, result.URL) - }) - - s.Run("Get File", func() { - // Mock the storage behavior - mockReader := io.NopCloser(strings.NewReader("test content")) - mockFileInfo := &storage.FileInfo{ - ID: media.StorageID, - Name: media.OriginalName, - Size: media.Size, - ContentType: media.MimeType, - URL: media.URL, - CreatedAt: media.CreatedAt, - UpdatedAt: media.UpdatedAt, - } - s.storage.EXPECT(). - Get(gomock.Any(), media.StorageID). - Return(mockReader, mockFileInfo, nil) - - // Test get file - reader, fileInfo, err := s.svc.GetFile(s.ctx, media.ID) - require.NoError(s.T(), err) - require.NotNil(s.T(), reader) - require.NotNil(s.T(), fileInfo) - assert.Equal(s.T(), media.OriginalName, fileInfo.Name) - assert.Equal(s.T(), media.Size, fileInfo.Size) - assert.Equal(s.T(), media.MimeType, fileInfo.ContentType) - assert.Equal(s.T(), media.URL, fileInfo.URL) - - // Clean up - reader.Close() - }) - - s.Run("List Media", func() { - // Test list media - list, err := s.svc.ListMedia(s.ctx, 10, 0) - require.NoError(s.T(), err) - require.NotNil(s.T(), list) - require.Len(s.T(), list, 1) - assert.Equal(s.T(), "test.jpg", list[0].OriginalName) - }) - - s.Run("Delete Media", func() { - // Mock the storage behavior - s.storage.EXPECT(). - Delete(gomock.Any(), media.StorageID). - Return(nil) - - // Test delete media - err = s.svc.DeleteMedia(s.ctx, media.ID, user.ID) - require.NoError(s.T(), err) - - // Verify media is deleted - count, err := s.client.Media.Query().Count(s.ctx) - require.NoError(s.T(), err) - assert.Equal(s.T(), 0, count) - }) - }) - - s.Run("Delete Media - Unauthorized", func() { - // Create a user - user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user") - require.NoError(s.T(), err) - - // Mock file content - fileContent := []byte("test file content") - - // Mock the file header - fileHeader := &multipart.FileHeader{ - Filename: "test2.jpg", - Size: int64(len(fileContent)), - Header: textproto.MIMEHeader{ - "Content-Type": []string{"image/jpeg"}, - }, - } - - // Mock the storage behavior - s.storage.EXPECT(). - Save(gomock.Any(), fileHeader.Filename, "image/jpeg", gomock.Any()). - DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) { - // Verify the reader content - data, err := io.ReadAll(reader) - if err != nil { - return nil, err - } - if !bytes.Equal(data, fileContent) { - return nil, fmt.Errorf("unexpected file content") - } - return &storage.FileInfo{ - ID: "test456", - Name: name, - Size: int64(len(fileContent)), - ContentType: contentType, - URL: "http://example.com/test2.jpg", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }, nil - }).Times(1) - - // Replace the Open method - openFile = func(fh *multipart.FileHeader) (multipart.File, error) { - return &mockMultipartFile{bytes.NewReader(fileContent)}, nil - } - - media, err := s.svc.Upload(s.ctx, fileHeader, user.ID) - require.NoError(s.T(), err) - - // Try to delete with different user - anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user") - require.NoError(s.T(), err) - - err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID) - assert.Equal(s.T(), ErrUnauthorized, err) - - // Verify media is not deleted - count, err := s.client.Media.Query().Count(s.ctx) - require.NoError(s.T(), err) - assert.Equal(s.T(), 1, count) - }) -} - -func (s *ServiceImplTestSuite) TestContributor() { - // 测试创建贡献者 - avatarURL := "https://example.com/avatar.jpg" - bio := "Test bio" - contributor, err := s.svc.CreateContributor(s.ctx, "Test Contributor", &avatarURL, &bio) - require.NoError(s.T(), err) - require.NotNil(s.T(), contributor) - assert.Equal(s.T(), "Test Contributor", contributor.Name) - assert.Equal(s.T(), avatarURL, contributor.AvatarURL) - assert.Equal(s.T(), bio, contributor.Bio) - - // 测试添加社交链接 - link, err := s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "github", "GitHub", "https://github.com/test") - require.NoError(s.T(), err) - require.NotNil(s.T(), link) - assert.Equal(s.T(), "github", link.Type.String()) - assert.Equal(s.T(), "GitHub", link.Name) - assert.Equal(s.T(), "https://github.com/test", link.Value) - - // 测试获取贡献者 - fetchedContributor, err := s.svc.GetContributorByID(s.ctx, contributor.ID) - require.NoError(s.T(), err) - require.NotNil(s.T(), fetchedContributor) - assert.Equal(s.T(), contributor.ID, fetchedContributor.ID) - assert.Equal(s.T(), contributor.Name, fetchedContributor.Name) - assert.Equal(s.T(), contributor.AvatarURL, fetchedContributor.AvatarURL) - assert.Equal(s.T(), contributor.Bio, fetchedContributor.Bio) - require.Len(s.T(), fetchedContributor.Edges.SocialLinks, 1) - assert.Equal(s.T(), link.ID, fetchedContributor.Edges.SocialLinks[0].ID) - - // 测试列出贡献者 - contributors, err := s.svc.ListContributors(s.ctx) - require.NoError(s.T(), err) - require.NotEmpty(s.T(), contributors) - assert.Equal(s.T(), contributor.ID, contributors[0].ID) - require.Len(s.T(), contributors[0].Edges.SocialLinks, 1) - - // 测试错误情况 - _, err = s.svc.GetContributorByID(s.ctx, -1) - assert.Error(s.T(), err) - - _, err = s.svc.AddContributorSocialLink(s.ctx, -1, "github", "GitHub", "https://github.com/test") - assert.Error(s.T(), err) - - // 测试无效的社交链接类型 - _, err = s.svc.AddContributorSocialLink(s.ctx, contributor.ID, "invalid_type", "Invalid", "https://example.com") - assert.Error(s.T(), err) -} - -func TestServiceSuite(t *testing.T) { - suite.Run(t, new(ServiceSuite)) -} - -type ServiceSuite struct { - suite.Suite -} - -func TestServiceInterface(t *testing.T) { - var _ Service = (*serviceImpl)(nil) -} - -// 创建测试分类的辅助函数 -func (s *ServiceImplTestSuite) createTestCategory(ctx context.Context, slug string) *ent.Category { - category, err := s.svc.CreateCategory(ctx) - s.Require().NoError(err) - return category -} diff --git a/backend/internal/service/media_test.go b/backend/internal/service/media_test.go deleted file mode 100644 index 412689d..0000000 --- a/backend/internal/service/media_test.go +++ /dev/null @@ -1,332 +0,0 @@ -package service - -import ( - "bytes" - "context" - "fmt" - "io" - "mime/multipart" - "net/textproto" - "reflect" - "testing" - - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "tss-rocks-be/ent" - "tss-rocks-be/internal/storage" - "tss-rocks-be/internal/storage/mock" - "tss-rocks-be/internal/testutil" - - "bou.ke/monkey" -) - -type MediaServiceTestSuite struct { - suite.Suite - ctx context.Context - client *ent.Client - storage *mock.MockStorage - ctrl *gomock.Controller - svc MediaService -} - -func (s *MediaServiceTestSuite) SetupTest() { - s.ctx = context.Background() - s.client = testutil.NewTestClient() - require.NotNil(s.T(), s.client) - - s.ctrl = gomock.NewController(s.T()) - s.storage = mock.NewMockStorage(s.ctrl) - s.svc = NewMediaService(s.client, s.storage) - - // 清理数据库 - _, err := s.client.Media.Delete().Exec(s.ctx) - require.NoError(s.T(), err) -} - -func (s *MediaServiceTestSuite) TearDownTest() { - s.ctrl.Finish() - s.client.Close() -} - -func TestMediaServiceSuite(t *testing.T) { - suite.Run(t, new(MediaServiceTestSuite)) -} - -type mockFileHeader struct { - filename string - contentType string - size int64 - content []byte -} - -func (h *mockFileHeader) Open() (multipart.File, error) { - return newMockMultipartFile(h.content), nil -} - -func (h *mockFileHeader) Filename() string { - return h.filename -} - -func (h *mockFileHeader) Size() int64 { - return h.size -} - -func (h *mockFileHeader) Header() textproto.MIMEHeader { - header := make(textproto.MIMEHeader) - header.Set("Content-Type", h.contentType) - return header -} - -func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader { - header := &multipart.FileHeader{ - Filename: filename, - Header: make(textproto.MIMEHeader), - Size: int64(len(content)), - } - header.Header.Set("Content-Type", contentType) - - monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) { - return newMockMultipartFile(content), nil - }) - - return header -} - -func (s *MediaServiceTestSuite) TestUpload() { - testCases := []struct { - name string - filename string - contentType string - content []byte - setupMock func() - wantErr bool - errMsg string - }{ - { - name: "Upload text file", - filename: "test.txt", - contentType: "text/plain", - content: []byte("test content"), - setupMock: func() { - s.storage.EXPECT(). - Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()). - DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) { - content, err := io.ReadAll(reader) - s.Require().NoError(err) - s.Equal([]byte("test content"), content) - return &storage.FileInfo{ - ID: "test-id", - Name: "test.txt", - ContentType: "text/plain", - Size: int64(len(content)), - }, nil - }) - }, - wantErr: false, - }, - { - name: "Invalid filename", - filename: "../test.txt", - contentType: "text/plain", - content: []byte("test content"), - setupMock: func() {}, - wantErr: true, - errMsg: "invalid filename", - }, - { - name: "Storage error", - filename: "test.txt", - contentType: "text/plain", - content: []byte("test content"), - setupMock: func() { - s.storage.EXPECT(). - Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()). - Return(nil, fmt.Errorf("storage error")) - }, - wantErr: true, - errMsg: "storage error", - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - // Setup mock - tc.setupMock() - - // Create test file - fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content) - - // Add debug output - s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size) - - // Test upload - media, err := s.svc.Upload(s.ctx, fileHeader, 1) - - // Add debug output - if err != nil { - s.T().Logf("Upload error: %v", err) - } - - if tc.wantErr { - s.Require().Error(err) - s.Contains(err.Error(), tc.errMsg) - return - } - - s.Require().NoError(err) - s.NotNil(media) - s.Equal(tc.filename, media.OriginalName) - s.Equal(tc.contentType, media.MimeType) - s.Equal(int64(len(tc.content)), media.Size) - s.Equal("1", media.CreatedBy) - }) - } -} - -func (s *MediaServiceTestSuite) TestGet() { - // Create test media - media, err := s.client.Media.Create(). - SetStorageID("test-id"). - SetOriginalName("test.txt"). - SetMimeType("text/plain"). - SetSize(12). - SetURL("/api/media/test-id"). - SetCreatedBy("1"). - Save(s.ctx) - s.Require().NoError(err) - - // Test get existing media - result, err := s.svc.Get(s.ctx, media.ID) - s.Require().NoError(err) - s.Equal(media.ID, result.ID) - s.Equal(media.OriginalName, result.OriginalName) - - // Test get non-existing media - _, err = s.svc.Get(s.ctx, -1) - s.Require().Error(err) - s.Contains(err.Error(), "media not found") -} - -func (s *MediaServiceTestSuite) TestDelete() { - // Create test media - media, err := s.client.Media.Create(). - SetStorageID("test-id"). - SetOriginalName("test.txt"). - SetMimeType("text/plain"). - SetSize(12). - SetURL("/api/media/test-id"). - SetCreatedBy("1"). - Save(s.ctx) - s.Require().NoError(err) - - // Test delete by unauthorized user - err = s.svc.Delete(s.ctx, media.ID, 2) - s.Require().Error(err) - s.Contains(err.Error(), "unauthorized") - - // Test delete by owner - s.storage.EXPECT(). - Delete(gomock.Any(), "test-id"). - Return(nil) - err = s.svc.Delete(s.ctx, media.ID, 1) - s.Require().NoError(err) - - // Verify media is deleted - _, err = s.svc.Get(s.ctx, media.ID) - s.Require().Error(err) - s.Contains(err.Error(), "not found") -} - -func (s *MediaServiceTestSuite) TestList() { - // Create test media - for i := 0; i < 5; i++ { - _, err := s.client.Media.Create(). - SetStorageID(fmt.Sprintf("test-id-%d", i)). - SetOriginalName(fmt.Sprintf("test-%d.txt", i)). - SetMimeType("text/plain"). - SetSize(12). - SetURL(fmt.Sprintf("/api/media/test-id-%d", i)). - SetCreatedBy("1"). - Save(s.ctx) - s.Require().NoError(err) - } - - // Test list with limit and offset - media, err := s.svc.List(s.ctx, 3, 1) - s.Require().NoError(err) - s.Len(media, 3) -} - -func (s *MediaServiceTestSuite) TestGetFile() { - // Create test media - media, err := s.client.Media.Create(). - SetStorageID("test-id"). - SetOriginalName("test.txt"). - SetMimeType("text/plain"). - SetSize(12). - SetURL("/api/media/test-id"). - SetCreatedBy("1"). - Save(s.ctx) - s.Require().NoError(err) - - // Mock storage.Get - mockReader := io.NopCloser(bytes.NewReader([]byte("test content"))) - mockFileInfo := &storage.FileInfo{ - ID: "test-id", - Name: "test.txt", - ContentType: "text/plain", - Size: 12, - } - s.storage.EXPECT(). - Get(gomock.Any(), "test-id"). - Return(mockReader, mockFileInfo, nil) - - // Test get file - reader, info, err := s.svc.GetFile(s.ctx, media.ID) - s.Require().NoError(err) - s.NotNil(reader) - s.Equal(mockFileInfo, info) - - // Test get non-existing file - _, _, err = s.svc.GetFile(s.ctx, -1) - s.Require().Error(err) - s.Contains(err.Error(), "not found") -} - -func (s *MediaServiceTestSuite) TestIsValidFilename() { - testCases := []struct { - name string - filename string - want bool - }{ - { - name: "Valid filename", - filename: "test.txt", - want: true, - }, - { - name: "Invalid filename with ../", - filename: "../test.txt", - want: false, - }, - { - name: "Invalid filename with ./", - filename: "./test.txt", - want: false, - }, - { - name: "Invalid filename with backslash", - filename: "test\\file.txt", - want: false, - }, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - got := isValidFilename(tc.filename) - s.Equal(tc.want, got) - }) - } -} diff --git a/backend/internal/service/mock/mock.go b/backend/internal/service/mock/mock.go deleted file mode 100644 index b9ee689..0000000 --- a/backend/internal/service/mock/mock.go +++ /dev/null @@ -1,3 +0,0 @@ -package mock - -//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index 5b88d4f..ee6aedd 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -1,7 +1,5 @@ package service -//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock - import ( "context" "io" @@ -34,16 +32,16 @@ type Service interface { GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) // Post operations - CreatePost(ctx context.Context, status string) (*ent.Post, error) + CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) - ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) + ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) // Media operations ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) GetMedia(ctx context.Context, id int) (*ent.Media, error) - GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) + GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) DeleteMedia(ctx context.Context, id int, userID int) error // Contributor operations diff --git a/backend/internal/storage/local.go b/backend/internal/storage/local.go index 2b1d8bf..d8ccff7 100644 --- a/backend/internal/storage/local.go +++ b/backend/internal/storage/local.go @@ -11,6 +11,8 @@ import ( "path/filepath" "strings" "time" + + "github.com/rs/zerolog/log" ) type LocalStorage struct { @@ -44,6 +46,24 @@ func (s *LocalStorage) generateID() (string, error) { return hex.EncodeToString(bytes), nil } +func (s *LocalStorage) generateFilePath(id string, ext string, createTime time.Time) string { + // Create year/month directory structure + year := createTime.Format("2006") + month := createTime.Format("01") + + // If id already has an extension, don't add ext + if filepath.Ext(id) != "" { + return filepath.Join(year, month, id) + } + + // Otherwise, add the extension if provided + filename := id + if ext != "" { + filename = id + ext + } + return filepath.Join(year, month, filename) +} + func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error { metaPath := filepath.Join(s.metaDir, id+".meta") file, err := os.Create(metaPath) @@ -89,11 +109,64 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string return nil, fmt.Errorf("failed to generate file ID: %w", err) } - // Create the file path - filePath := filepath.Join(s.rootDir, id) + // Get file extension from original name or content type + ext := filepath.Ext(name) + if ext == "" { + // If no extension in name, try to get it from content type + switch contentType { + case "image/jpeg": + ext = ".jpg" + case "image/png": + ext = ".png" + case "image/gif": + ext = ".gif" + case "image/webp": + ext = ".webp" + case "image/svg+xml": + ext = ".svg" + case "video/mp4": + ext = ".mp4" + case "video/webm": + ext = ".webm" + case "audio/mpeg": + ext = ".mp3" + case "audio/ogg": + ext = ".ogg" + case "audio/wav": + ext = ".wav" + case "application/pdf": + ext = ".pdf" + case "application/msword": + ext = ".doc" + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + ext = ".docx" + case "application/vnd.ms-excel": + ext = ".xls" + case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + ext = ".xlsx" + case "application/zip": + ext = ".zip" + case "application/x-rar-compressed": + ext = ".rar" + case "text/plain": + ext = ".txt" + case "text/csv": + ext = ".csv" + } + } + + // Create the file path with year/month structure + now := time.Now() + relPath := s.generateFilePath(id, ext, now) + fullPath := filepath.Join(s.rootDir, relPath) + + // Create directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + return nil, fmt.Errorf("failed to create directory: %w", err) + } // Create the file - file, err := os.Create(filePath) + file, err := os.Create(fullPath) if err != nil { return nil, fmt.Errorf("failed to create file: %w", err) } @@ -102,52 +175,73 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string // Copy the content size, err := io.Copy(file, reader) if err != nil { - // Clean up the file if there's an error - os.Remove(filePath) + os.Remove(fullPath) // Clean up on error return nil, fmt.Errorf("failed to write file content: %w", err) } - now := time.Now() + // Save metadata info := &FileInfo{ ID: id, Name: name, - Size: size, ContentType: contentType, + Size: size, CreatedAt: now, - UpdatedAt: now, - URL: fmt.Sprintf("/api/media/file/%s", id), + URL: fmt.Sprintf("/media/%s/%s/%s", now.Format("2006"), now.Format("01"), filepath.Base(relPath)), } - - // Save metadata if err := s.saveMetadata(id, info); err != nil { - os.Remove(filePath) - return nil, err + os.Remove(fullPath) // Clean up on error + return nil, fmt.Errorf("failed to save metadata: %w", err) } return info, nil } func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) { - filePath := filepath.Join(s.rootDir, id) + // 从 id 中提取文件扩展名和基础 ID + ext := filepath.Ext(id) + baseID := strings.TrimSuffix(id, ext) - // Open the file + // 获取文件的创建时间(从元数据或当前时间) + metaPath := filepath.Join(s.metaDir, baseID+".meta") + var createTime time.Time + if stat, err := os.Stat(metaPath); err == nil { + createTime = stat.ModTime() + } else { + createTime = time.Now() // 如果找不到元数据,使用当前时间 + } + + // 生成完整的文件路径 + year := createTime.Format("2006") + month := createTime.Format("01") + filePath := filepath.Join(s.rootDir, year, month, id) // 直接使用完整的 id(包含扩展名) + + // 调试日志 + log.Debug(). + Str("id", id). + Str("baseID", baseID). + Str("ext", ext). + Str("filePath", filePath). + Time("createTime", createTime). + Msg("Attempting to get file") + + // 打开文件 file, err := os.Open(filePath) if err != nil { if os.IsNotExist(err) { - return nil, nil, fmt.Errorf("file not found: %s", id) + return nil, nil, fmt.Errorf("file not found: %s (path: %s)", id, filePath) } return nil, nil, fmt.Errorf("failed to open file: %w", err) } - // Get file info + // 获取文件信息 stat, err := file.Stat() if err != nil { file.Close() return nil, nil, fmt.Errorf("failed to get file info: %w", err) } - // Load metadata - name, contentType, err := s.loadMetadata(id) + // 加载元数据 + name, contentType, err := s.loadMetadata(baseID) if err != nil { file.Close() return nil, nil, err @@ -158,27 +252,44 @@ func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *File Name: name, Size: stat.Size(), ContentType: contentType, - CreatedAt: stat.ModTime(), + CreatedAt: createTime, UpdatedAt: stat.ModTime(), - URL: fmt.Sprintf("/api/media/file/%s", id), + URL: fmt.Sprintf("/media/%s/%s/%s", year, month, id), } return file, info, nil } func (s *LocalStorage) Delete(ctx context.Context, id string) error { - filePath := filepath.Join(s.rootDir, id) - if err := os.Remove(filePath); err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("file not found: %s", id) - } - return fmt.Errorf("failed to delete file: %w", err) + // 从 id 中提取文件扩展名 + ext := filepath.Ext(id) + baseID := strings.TrimSuffix(id, ext) + + // 获取文件的创建时间(从元数据或当前时间) + metaPath := filepath.Join(s.metaDir, baseID+".meta") + var createTime time.Time + if stat, err := os.Stat(metaPath); err == nil { + createTime = stat.ModTime() + } else { + createTime = time.Now() // 如果找不到元数据,使用当前时间 } - // Remove metadata - metaPath := filepath.Join(s.metaDir, id+".meta") - if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to remove metadata: %w", err) + // 生成完整的文件路径 + relPath := s.generateFilePath(baseID, ext, createTime) + filePath := filepath.Join(s.rootDir, relPath) + + // 删除文件 + if err := os.Remove(filePath); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to delete file: %w", err) + } + } + + // 删除元数据 + if err := os.Remove(metaPath); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to delete metadata: %w", err) + } } return nil diff --git a/backend/internal/storage/local_test.go b/backend/internal/storage/local_test.go deleted file mode 100644 index f27a16f..0000000 --- a/backend/internal/storage/local_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package storage - -import ( - "bytes" - "context" - "io" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLocalStorage(t *testing.T) { - // Create a temporary directory for testing - tempDir, err := os.MkdirTemp("", "storage_test_*") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - // Create a new LocalStorage instance - storage, err := NewLocalStorage(tempDir) - require.NoError(t, err) - - ctx := context.Background() - - t.Run("Save and Get", func(t *testing.T) { - content := []byte("test content") - reader := bytes.NewReader(content) - - // Save the file - fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader) - require.NoError(t, err) - assert.NotEmpty(t, fileInfo.ID) - assert.Equal(t, "test.txt", fileInfo.Name) - assert.Equal(t, int64(len(content)), fileInfo.Size) - assert.Equal(t, "text/plain", fileInfo.ContentType) - assert.False(t, fileInfo.CreatedAt.IsZero()) - - // Get the file - readCloser, info, err := storage.Get(ctx, fileInfo.ID) - require.NoError(t, err) - defer readCloser.Close() - - data, err := io.ReadAll(readCloser) - require.NoError(t, err) - assert.Equal(t, content, data) - assert.Equal(t, fileInfo.ID, info.ID) - assert.Equal(t, fileInfo.Name, info.Name) - assert.Equal(t, fileInfo.Size, info.Size) - }) - - t.Run("List", func(t *testing.T) { - // Clear the directory first - dirEntries, err := os.ReadDir(tempDir) - require.NoError(t, err) - for _, entry := range dirEntries { - if entry.Name() != ".meta" { - os.Remove(filepath.Join(tempDir, entry.Name())) - } - } - - // Save multiple files - testFiles := []struct { - name string - content string - }{ - {"test1.txt", "content1"}, - {"test2.txt", "content2"}, - {"other.txt", "content3"}, - } - - for _, f := range testFiles { - reader := bytes.NewReader([]byte(f.content)) - _, err := storage.Save(ctx, f.name, "text/plain", reader) - require.NoError(t, err) - } - - // List all files - allFiles, err := storage.List(ctx, "", 10, 0) - require.NoError(t, err) - assert.Len(t, allFiles, 3) - - // List files with prefix - filesWithPrefix, err := storage.List(ctx, "test", 10, 0) - require.NoError(t, err) - assert.Len(t, filesWithPrefix, 2) - for _, f := range filesWithPrefix { - assert.True(t, strings.HasPrefix(f.Name, "test")) - } - - // Test pagination - pagedFiles, err := storage.List(ctx, "", 2, 1) - require.NoError(t, err) - assert.Len(t, pagedFiles, 2) - }) - - t.Run("Exists", func(t *testing.T) { - // Save a file - content := []byte("test content") - reader := bytes.NewReader(content) - fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader) - require.NoError(t, err) - - // Check if file exists - exists, err := storage.Exists(ctx, fileInfo.ID) - require.NoError(t, err) - assert.True(t, exists) - - // Check non-existent file - exists, err = storage.Exists(ctx, "non-existent") - require.NoError(t, err) - assert.False(t, exists) - }) - - t.Run("Delete", func(t *testing.T) { - // Save a file - content := []byte("test content") - reader := bytes.NewReader(content) - fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader) - require.NoError(t, err) - - // Delete the file - err = storage.Delete(ctx, fileInfo.ID) - require.NoError(t, err) - - // Verify file is deleted - exists, err := storage.Exists(ctx, fileInfo.ID) - require.NoError(t, err) - assert.False(t, exists) - - // Try to delete non-existent file - err = storage.Delete(ctx, "non-existent") - assert.Error(t, err) - }) - - t.Run("Invalid operations", func(t *testing.T) { - // Try to get non-existent file - _, _, err := storage.Get(ctx, "non-existent") - assert.Error(t, err) - assert.Contains(t, err.Error(), "file not found") - - // Try to save file with nil reader - _, err = storage.Save(ctx, "test.txt", "text/plain", nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "reader cannot be nil") - - // Try to delete non-existent file - err = storage.Delete(ctx, "non-existent") - assert.Error(t, err) - assert.Contains(t, err.Error(), "file not found") - }) -} diff --git a/backend/internal/storage/s3.go b/backend/internal/storage/s3.go index bea236d..88f8825 100644 --- a/backend/internal/storage/s3.go +++ b/backend/internal/storage/s3.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "path/filepath" "strings" "time" @@ -48,14 +49,25 @@ func (s *S3Storage) generateID() (string, error) { return hex.EncodeToString(bytes), nil } -func (s *S3Storage) getObjectURL(id string) string { +func (s *S3Storage) generateObjectKey(id string, ext string, createTime time.Time) string { + // Create year/month structure + year := createTime.Format("2006") + month := createTime.Format("01") + filename := id + if ext != "" { + filename = id + ext + } + return fmt.Sprintf("%s/%s/%s", year, month, filename) +} + +func (s *S3Storage) getObjectURL(key string) string { if s.customURL != "" { - return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id) + return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), key) } if s.proxyS3 { - return fmt.Sprintf("/api/media/file/%s", id) + return fmt.Sprintf("/media/%s", key) } - return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id) + return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, key) } func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) { @@ -65,10 +77,60 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r return nil, fmt.Errorf("failed to generate file ID: %w", err) } + // Get file extension from original name or content type + ext := filepath.Ext(name) + if ext == "" { + // If no extension in name, try to get it from content type + switch contentType { + case "image/jpeg": + ext = ".jpg" + case "image/png": + ext = ".png" + case "image/gif": + ext = ".gif" + case "image/webp": + ext = ".webp" + case "image/svg+xml": + ext = ".svg" + case "video/mp4": + ext = ".mp4" + case "video/webm": + ext = ".webm" + case "audio/mpeg": + ext = ".mp3" + case "audio/ogg": + ext = ".ogg" + case "audio/wav": + ext = ".wav" + case "application/pdf": + ext = ".pdf" + case "application/msword": + ext = ".doc" + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + ext = ".docx" + case "application/vnd.ms-excel": + ext = ".xls" + case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + ext = ".xlsx" + case "application/zip": + ext = ".zip" + case "application/x-rar-compressed": + ext = ".rar" + case "text/plain": + ext = ".txt" + case "text/csv": + ext = ".csv" + } + } + + // Create the object key with year/month structure + now := time.Now() + key := s.generateObjectKey(id, ext, now) + // Check if the file exists _, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(id), + Key: aws.String(key), }) if err == nil { return nil, fmt.Errorf("file already exists with ID: %s", id) @@ -82,37 +144,50 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r // Upload the file _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(id), + Key: aws.String(key), Body: reader, ContentType: aws.String(contentType), Metadata: map[string]string{ "x-amz-meta-original-name": name, + "x-amz-meta-created-at": now.Format(time.RFC3339), }, }) if err != nil { return nil, fmt.Errorf("failed to upload file: %w", err) } - now := time.Now() info := &FileInfo{ ID: id, Name: name, Size: 0, // Size is not available until after upload ContentType: contentType, CreatedAt: now, - UpdatedAt: now, - URL: s.getObjectURL(id), + URL: s.getObjectURL(key), } return info, nil } func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) { - // Get the object from S3 - result, err := s.client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(s.bucket), - Key: aws.String(id), - }) + // Try to find the file with different extensions + exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav", + ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"} + + var result *s3.GetObjectOutput + var err error + var key string + + for _, ext := range exts { + key = s.generateObjectKey(id, ext, time.Now()) + result, err = s.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + }) + if err == nil { + break + } + } + if err != nil { return nil, nil, fmt.Errorf("failed to get file from S3: %w", err) } @@ -122,111 +197,117 @@ func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInf Name: result.Metadata["x-amz-meta-original-name"], Size: aws.ToInt64(result.ContentLength), ContentType: aws.ToString(result.ContentType), - CreatedAt: aws.ToTime(result.LastModified), - UpdatedAt: aws.ToTime(result.LastModified), - URL: s.getObjectURL(id), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + URL: s.getObjectURL(key), } return result.Body, info, nil } func (s *S3Storage) Delete(ctx context.Context, id string) error { - _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(s.bucket), - Key: aws.String(id), - }) - if err != nil { - return fmt.Errorf("failed to delete file from S3: %w", err) + // Try to find and delete the file with different extensions + exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav", + ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"} + + var lastErr error + for _, ext := range exts { + key := s.generateObjectKey(id, ext, time.Now()) + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + }) + if err == nil { + return nil + } + lastErr = err } - return nil + return fmt.Errorf("failed to delete file from S3: %w", lastErr) +} + +func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) { + // Try to find the file with different extensions + exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav", + ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"} + + for _, ext := range exts { + key := s.generateObjectKey(id, ext, time.Now()) + _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + }) + if err == nil { + return true, nil + } + } + + return false, nil } func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) { var files []*FileInfo var continuationToken *string + count := 0 + skip := offset - // Skip objects for offset - for i := 0; i < offset/1000; i++ { - output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + for { + input := &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucket), Prefix: aws.String(prefix), ContinuationToken: continuationToken, - MaxKeys: aws.Int32(1000), - }) + MaxKeys: aws.Int32(100), // Fetch in batches of 100 + } + + result, err := s.client.ListObjectsV2(ctx, input) if err != nil { return nil, fmt.Errorf("failed to list files from S3: %w", err) } - if !aws.ToBool(output.IsTruncated) { - return files, nil - } - continuationToken = output.NextContinuationToken - } - // Get the actual objects - output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ - Bucket: aws.String(s.bucket), - Prefix: aws.String(prefix), - ContinuationToken: continuationToken, - MaxKeys: aws.Int32(int32(limit)), - }) - if err != nil { - return nil, fmt.Errorf("failed to list files from S3: %w", err) - } - - for _, obj := range output.Contents { - // Get the object metadata - head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ - Bucket: aws.String(s.bucket), - Key: obj.Key, - }) - - var contentType string - var originalName string - - if err != nil { - var noSuchKey *types.NoSuchKey - if errors.As(err, &noSuchKey) { - // If the object doesn't exist (which shouldn't happen normally), - // we'll still include it in the list but with empty metadata - contentType = "" - originalName = aws.ToString(obj.Key) - } else { + // Process each object + for _, obj := range result.Contents { + if skip > 0 { + skip-- continue } - } else { - contentType = aws.ToString(head.ContentType) - originalName = head.Metadata["x-amz-meta-original-name"] - if originalName == "" { - originalName = aws.ToString(obj.Key) + + // Get object metadata + head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(s.bucket), + Key: obj.Key, + }) + if err != nil { + continue // Skip files we can't get metadata for + } + + // Extract the ID from the key (remove extension if present) + id := aws.ToString(obj.Key) + if ext := filepath.Ext(id); ext != "" { + id = id[:len(id)-len(ext)] + } + + info := &FileInfo{ + ID: id, + Name: head.Metadata["x-amz-meta-original-name"], + Size: aws.ToInt64(obj.Size), + ContentType: aws.ToString(head.ContentType), + CreatedAt: aws.ToTime(obj.LastModified), + UpdatedAt: aws.ToTime(obj.LastModified), + URL: s.getObjectURL(aws.ToString(obj.Key)), + } + files = append(files, info) + count++ + + if count >= limit { + return files, nil } } - files = append(files, &FileInfo{ - ID: aws.ToString(obj.Key), - Name: originalName, - Size: aws.ToInt64(obj.Size), - ContentType: contentType, - CreatedAt: aws.ToTime(obj.LastModified), - UpdatedAt: aws.ToTime(obj.LastModified), - URL: s.getObjectURL(aws.ToString(obj.Key)), - }) + if !aws.ToBool(result.IsTruncated) { + break + } + continuationToken = result.NextContinuationToken } return files, nil } - -func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) { - _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ - Bucket: aws.String(s.bucket), - Key: aws.String(id), - }) - if err != nil { - var nsk *types.NoSuchKey - if ok := errors.As(err, &nsk); ok { - return false, nil - } - return false, fmt.Errorf("failed to check file existence in S3: %w", err) - } - return true, nil -} diff --git a/backend/internal/storage/s3_test.go b/backend/internal/storage/s3_test.go deleted file mode 100644 index 215f04e..0000000 --- a/backend/internal/storage/s3_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package storage - -import ( - "bytes" - "context" - "io" - "testing" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/aws/aws-sdk-go-v2/service/s3/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -// MockS3Client is a mock implementation of the S3 client interface -type MockS3Client struct { - mock.Mock -} - -func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { - args := m.Called(ctx, params) - return args.Get(0).(*s3.PutObjectOutput), args.Error(1) -} - -func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { - args := m.Called(ctx, params) - return args.Get(0).(*s3.GetObjectOutput), args.Error(1) -} - -func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { - args := m.Called(ctx, params) - return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1) -} - -func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { - args := m.Called(ctx, params) - return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1) -} - -func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { - args := m.Called(ctx, params) - return args.Get(0).(*s3.HeadObjectOutput), args.Error(1) -} - -func TestS3Storage(t *testing.T) { - ctx := context.Background() - mockClient := new(MockS3Client) - storage := NewS3Storage(mockClient, "test-bucket", "", false) - - t.Run("Save", func(t *testing.T) { - mockClient.ExpectedCalls = nil - mockClient.Calls = nil - - content := []byte("test content") - reader := bytes.NewReader(content) - - // Mock HeadObject to return NotFound error - mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" - })).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{ - Message: aws.String("The specified key does not exist."), - }) - - mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.ContentType) == "text/plain" - })).Return(&s3.PutObjectOutput{}, nil) - - fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader) - require.NoError(t, err) - assert.NotEmpty(t, fileInfo.ID) - assert.Equal(t, "test.txt", fileInfo.Name) - assert.Equal(t, "text/plain", fileInfo.ContentType) - - mockClient.AssertExpectations(t) - }) - - t.Run("Get", func(t *testing.T) { - content := []byte("test content") - mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "test-id" - })).Return(&s3.GetObjectOutput{ - Body: io.NopCloser(bytes.NewReader(content)), - ContentType: aws.String("text/plain"), - ContentLength: aws.Int64(int64(len(content))), - LastModified: aws.Time(time.Now()), - }, nil) - - readCloser, info, err := storage.Get(ctx, "test-id") - require.NoError(t, err) - defer readCloser.Close() - - data, err := io.ReadAll(readCloser) - require.NoError(t, err) - assert.Equal(t, content, data) - assert.Equal(t, "test-id", info.ID) - assert.Equal(t, int64(len(content)), info.Size) - - mockClient.AssertExpectations(t) - }) - - t.Run("List", func(t *testing.T) { - mockClient.ExpectedCalls = nil - mockClient.Calls = nil - - mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Prefix) == "test" && - aws.ToInt32(input.MaxKeys) == 10 - })).Return(&s3.ListObjectsV2Output{ - Contents: []types.Object{ - { - Key: aws.String("test1"), - Size: aws.Int64(100), - LastModified: aws.Time(time.Now()), - }, - { - Key: aws.String("test2"), - Size: aws.Int64(200), - LastModified: aws.Time(time.Now()), - }, - }, - }, nil) - - // Mock HeadObject for both files - mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "test1" - })).Return(&s3.HeadObjectOutput{ - ContentType: aws.String("text/plain"), - Metadata: map[string]string{ - "x-amz-meta-original-name": "test1.txt", - }, - }, nil).Once() - - mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "test2" - })).Return(&s3.HeadObjectOutput{ - ContentType: aws.String("text/plain"), - Metadata: map[string]string{ - "x-amz-meta-original-name": "test2.txt", - }, - }, nil).Once() - - files, err := storage.List(ctx, "test", 10, 0) - require.NoError(t, err) - assert.Len(t, files, 2) - assert.Equal(t, "test1", files[0].ID) - assert.Equal(t, int64(100), files[0].Size) - assert.Equal(t, "test1.txt", files[0].Name) - assert.Equal(t, "text/plain", files[0].ContentType) - - mockClient.AssertExpectations(t) - }) - - t.Run("Delete", func(t *testing.T) { - mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "test-id" - })).Return(&s3.DeleteObjectOutput{}, nil) - - err := storage.Delete(ctx, "test-id") - require.NoError(t, err) - - mockClient.AssertExpectations(t) - }) - - t.Run("Exists", func(t *testing.T) { - mockClient.ExpectedCalls = nil - mockClient.Calls = nil - - // Mock HeadObject for existing file - mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "test-id" - })).Return(&s3.HeadObjectOutput{}, nil).Once() - - exists, err := storage.Exists(ctx, "test-id") - require.NoError(t, err) - assert.True(t, exists) - - // Mock HeadObject for non-existing file - mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool { - return aws.ToString(input.Bucket) == "test-bucket" && - aws.ToString(input.Key) == "non-existent" - })).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{ - Message: aws.String("The specified key does not exist."), - }).Once() - - exists, err = storage.Exists(ctx, "non-existent") - require.NoError(t, err) - assert.False(t, exists) - - mockClient.AssertExpectations(t) - }) - - t.Run("Custom URL", func(t *testing.T) { - customStorage := &S3Storage{ - client: mockClient, - bucket: "test-bucket", - customURL: "https://custom.domain", - proxyS3: true, - } - assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain") - }) -} diff --git a/backend/internal/storage/storage.go b/backend/internal/storage/storage.go index ccd00d2..81170ad 100644 --- a/backend/internal/storage/storage.go +++ b/backend/internal/storage/storage.go @@ -1,7 +1,5 @@ package storage -//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock - import ( "context" "io" diff --git a/backend/internal/types/config_test.go b/backend/internal/types/config_test.go deleted file mode 100644 index 298e318..0000000 --- a/backend/internal/types/config_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package types - -import ( - "testing" -) - -func TestRateLimitConfig(t *testing.T) { - config := RateLimitConfig{ - IPRate: 100, - IPBurst: 200, - RouteRates: map[string]struct { - Rate int `yaml:"rate"` - Burst int `yaml:"burst"` - }{ - "/api/test": { - Rate: 50, - Burst: 100, - }, - }, - } - - if config.IPRate != 100 { - t.Errorf("Expected IPRate 100, got %d", config.IPRate) - } - if config.IPBurst != 200 { - t.Errorf("Expected IPBurst 200, got %d", config.IPBurst) - } - - route := config.RouteRates["/api/test"] - if route.Rate != 50 { - t.Errorf("Expected route rate 50, got %d", route.Rate) - } - if route.Burst != 100 { - t.Errorf("Expected route burst 100, got %d", route.Burst) - } -} - -func TestAccessLogConfig(t *testing.T) { - config := AccessLogConfig{ - EnableConsole: true, - EnableFile: true, - FilePath: "/var/log/app.log", - Format: "json", - Level: "info", - Rotation: struct { - MaxSize int `yaml:"max_size"` - MaxAge int `yaml:"max_age"` - MaxBackups int `yaml:"max_backups"` - Compress bool `yaml:"compress"` - LocalTime bool `yaml:"local_time"` - }{ - MaxSize: 100, - MaxAge: 7, - MaxBackups: 5, - Compress: true, - LocalTime: true, - }, - } - - if !config.EnableConsole { - t.Error("Expected EnableConsole to be true") - } - if !config.EnableFile { - t.Error("Expected EnableFile to be true") - } - if config.FilePath != "/var/log/app.log" { - t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath) - } - if config.Format != "json" { - t.Errorf("Expected Format 'json', got '%s'", config.Format) - } - if config.Level != "info" { - t.Errorf("Expected Level 'info', got '%s'", config.Level) - } - - rotation := config.Rotation - if rotation.MaxSize != 100 { - t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize) - } - if rotation.MaxAge != 7 { - t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge) - } - if rotation.MaxBackups != 5 { - t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups) - } - if !rotation.Compress { - t.Error("Expected Compress to be true") - } - if !rotation.LocalTime { - t.Error("Expected LocalTime to be true") - } -} - -func TestUploadConfig(t *testing.T) { - config := UploadConfig{ - MaxSize: 10, - AllowedTypes: []string{"image/jpeg", "image/png"}, - AllowedExtensions: []string{".jpg", ".png"}, - } - - if config.MaxSize != 10 { - t.Errorf("Expected MaxSize 10, got %d", config.MaxSize) - } - if len(config.AllowedTypes) != 2 { - t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes)) - } - if config.AllowedTypes[0] != "image/jpeg" { - t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0]) - } - if len(config.AllowedExtensions) != 2 { - t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions)) - } - if config.AllowedExtensions[0] != ".jpg" { - t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0]) - } -} diff --git a/backend/internal/types/file_test.go b/backend/internal/types/file_test.go deleted file mode 100644 index 1335a21..0000000 --- a/backend/internal/types/file_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package types - -import "testing" - -func TestFileInfo(t *testing.T) { - fileInfo := FileInfo{ - Size: 1024, - Name: "test.jpg", - ContentType: "image/jpeg", - } - - if fileInfo.Size != 1024 { - t.Errorf("Expected Size 1024, got %d", fileInfo.Size) - } - if fileInfo.Name != "test.jpg" { - t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name) - } - if fileInfo.ContentType != "image/jpeg" { - t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType) - } -} diff --git a/backend/internal/types/types_test.go b/backend/internal/types/types_test.go deleted file mode 100644 index ff3461d..0000000 --- a/backend/internal/types/types_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package types - -import ( - "testing" -) - -func TestCategory(t *testing.T) { - description := "Test Description" - category := Category{ - ID: 1, - Name: "Test Category", - Slug: "test-category", - Description: &description, - } - - if category.ID != 1 { - t.Errorf("Expected ID 1, got %d", category.ID) - } - if category.Name != "Test Category" { - t.Errorf("Expected name 'Test Category', got '%s'", category.Name) - } - if category.Slug != "test-category" { - t.Errorf("Expected slug 'test-category', got '%s'", category.Slug) - } - if *category.Description != description { - t.Errorf("Expected description '%s', got '%s'", description, *category.Description) - } -} - -func TestPost(t *testing.T) { - metaKeywords := "test,blog" - metaDesc := "Test Description" - post := Post{ - ID: 1, - Title: "Test Post", - Slug: "test-post", - ContentMarkdown: "# Test Content", - Summary: "Test Summary", - MetaKeywords: &metaKeywords, - MetaDescription: &metaDesc, - } - - if post.ID != 1 { - t.Errorf("Expected ID 1, got %d", post.ID) - } - if post.Title != "Test Post" { - t.Errorf("Expected title 'Test Post', got '%s'", post.Title) - } - if post.Slug != "test-post" { - t.Errorf("Expected slug 'test-post', got '%s'", post.Slug) - } - if *post.MetaKeywords != metaKeywords { - t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords) - } -} - -func TestDaily(t *testing.T) { - daily := Daily{ - ID: "2025-02-12", - CategoryID: 1, - ImageURL: "https://example.com/image.jpg", - Quote: "Test Quote", - } - - if daily.ID != "2025-02-12" { - t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID) - } - if daily.CategoryID != 1 { - t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID) - } - if daily.ImageURL != "https://example.com/image.jpg" { - t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL) - } - if daily.Quote != "Test Quote" { - t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote) - } -} diff --git a/backend/pkg/config/config_test.go b/backend/pkg/config/config_test.go deleted file mode 100644 index a785dff..0000000 --- a/backend/pkg/config/config_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package config - -import ( - "os" - "path/filepath" - "testing" -) - -func TestLoad(t *testing.T) { - // Create a temporary test config file - testConfig := ` -database: - driver: postgres - dsn: postgres://user:pass@localhost:5432/db -server: - port: 8080 - host: localhost -jwt: - secret: test-secret - expiration: 24h -logging: - level: debug - format: console -` - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - // Test successful config loading - cfg, err := Load(configPath) - if err != nil { - t.Fatalf("Failed to load config: %v", err) - } - - // Verify loaded values - tests := []struct { - name string - got interface{} - expected interface{} - }{ - {"database.driver", cfg.Database.Driver, "postgres"}, - {"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"}, - {"server.port", cfg.Server.Port, 8080}, - {"server.host", cfg.Server.Host, "localhost"}, - {"jwt.secret", cfg.JWT.Secret, "test-secret"}, - {"jwt.expiration", cfg.JWT.Expiration, "24h"}, - {"logging.level", cfg.Logging.Level, "debug"}, - {"logging.format", cfg.Logging.Format, "console"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.got != tt.expected { - t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected) - } - }) - } - - // Test loading non-existent file - _, err = Load("non-existent.yaml") - if err == nil { - t.Error("Expected error when loading non-existent file, got nil") - } - - // Test loading invalid YAML - invalidPath := filepath.Join(tmpDir, "invalid.yaml") - if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil { - t.Fatalf("Failed to create invalid config file: %v", err) - } - - _, err = Load(invalidPath) - if err == nil { - t.Error("Expected error when loading invalid YAML, got nil") - } -} diff --git a/backend/pkg/imageutil/processor_test.go b/backend/pkg/imageutil/processor_test.go deleted file mode 100644 index c5cb6d5..0000000 --- a/backend/pkg/imageutil/processor_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package imageutil - -import ( - "bytes" - "image" - "image/color" - "image/png" - "testing" -) - -func TestIsImageFormat(t *testing.T) { - tests := []struct { - name string - contentType string - want bool - }{ - {"JPEG", "image/jpeg", true}, - {"PNG", "image/png", true}, - {"GIF", "image/gif", true}, - {"WebP", "image/webp", true}, - {"Invalid", "image/invalid", false}, - {"Empty", "", false}, - {"Text", "text/plain", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := IsImageFormat(tt.contentType); got != tt.want { - t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want) - } - }) - } -} - -func TestDefaultOptions(t *testing.T) { - opts := DefaultOptions() - - if !opts.Lossless { - t.Error("DefaultOptions().Lossless = false, want true") - } - if opts.Quality != 90 { - t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality) - } - if opts.Compression != 4 { - t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression) - } -} - -func TestProcessImage(t *testing.T) { - // Create a test image - img := image.NewRGBA(image.Rect(0, 0, 100, 100)) - for y := 0; y < 100; y++ { - for x := 0; x < 100; x++ { - img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255}) - } - } - - var buf bytes.Buffer - if err := png.Encode(&buf, img); err != nil { - t.Fatalf("Failed to create test PNG: %v", err) - } - - tests := []struct { - name string - opts ProcessOptions - wantErr bool - }{ - { - name: "Default options", - opts: DefaultOptions(), - }, - { - name: "Custom quality", - opts: ProcessOptions{ - Lossless: false, - Quality: 75, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reader := bytes.NewReader(buf.Bytes()) - result, err := ProcessImage(reader, tt.opts) - if (err != nil) != tt.wantErr { - t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.wantErr && len(result) == 0 { - t.Error("ProcessImage() returned empty result") - } - }) - } - - // Test with invalid input - _, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions()) - if err == nil { - t.Error("ProcessImage() with invalid input should return error") - } -} diff --git a/backend/pkg/logger/logger_test.go b/backend/pkg/logger/logger_test.go deleted file mode 100644 index 319d851..0000000 --- a/backend/pkg/logger/logger_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package logger - -import ( - "testing" - "tss-rocks-be/internal/config" - - "github.com/rs/zerolog" -) - -func TestSetup(t *testing.T) { - tests := []struct { - name string - config *config.Config - expectedLevel zerolog.Level - }{ - { - name: "Debug level", - config: &config.Config{ - Logging: struct { - Level string `yaml:"level"` - Format string `yaml:"format"` - }{ - Level: "debug", - Format: "json", - }, - }, - expectedLevel: zerolog.DebugLevel, - }, - { - name: "Info level", - config: &config.Config{ - Logging: struct { - Level string `yaml:"level"` - Format string `yaml:"format"` - }{ - Level: "info", - Format: "json", - }, - }, - expectedLevel: zerolog.InfoLevel, - }, - { - name: "Error level", - config: &config.Config{ - Logging: struct { - Level string `yaml:"level"` - Format string `yaml:"format"` - }{ - Level: "error", - Format: "json", - }, - }, - expectedLevel: zerolog.ErrorLevel, - }, - { - name: "Invalid level defaults to Info", - config: &config.Config{ - Logging: struct { - Level string `yaml:"level"` - Format string `yaml:"format"` - }{ - Level: "invalid", - Format: "json", - }, - }, - expectedLevel: zerolog.InfoLevel, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - Setup(tt.config) - if zerolog.GlobalLevel() != tt.expectedLevel { - t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel) - } - }) - } -} - -func TestGetLogger(t *testing.T) { - logger := GetLogger() - if logger == nil { - t.Error("GetLogger() returned nil") - } -} diff --git a/frontend/data/i18n/en.json b/frontend/data/i18n/en.json index 06015e4..46bf27f 100644 --- a/frontend/data/i18n/en.json +++ b/frontend/data/i18n/en.json @@ -1,6 +1,6 @@ { "categories": { - "man": "Man", + "man": "Human", "machine": "Machine", "earth": "Earth", "space": "Space", @@ -29,8 +29,6 @@ "edit": "Edit", "delete": "Delete", "upload": "Upload", - "save": "Save", - "saving": "Saving...", "status": "Status", "actions": "Actions", "published": "Published", @@ -48,12 +46,10 @@ "joinDate": "Join Date", "username": "Username", "logout": "Logout", - "language": "Language", - "theme": { - "light": "Light Mode", - "dark": "Dark Mode", - "system": "System" - } + "save": "Save", + "saving": "Saving...", + "noData": "No data available", + "unsavedChanges": "You have unsaved changes. Are you sure you want to leave?" }, "dashboard": { "totalPosts": "Total Posts", @@ -61,13 +57,33 @@ "totalUsers": "Total Users", "totalContributors": "Total Contributors" }, + "posts": { + "title": "Title", + "categories": "Categories", + "createdAt": "Created At", + "status": "Status", + "noTitle": "No Title", + "deleteConfirm": "Are you sure you want to delete this post?", + "create": "Create Post", + "edit": "Edit Post", + "slug": "Slug", + "content": "Content", + "summary": "Summary", + "metaKeywords": "Meta Keywords", + "metaDescription": "Meta Description", + "selectCategories": "Select Categories", + "saving": "Saving...", + "publishing": "Publishing...", + "saveDraft": "Save Draft", + "publish": "Publish" + }, "login": { "title": "Admin Login", "username": "Username", "password": "Password", "remember": "Remember me", - "submit": "Sign in", - "loading": "Signing in...", + "submit": "Login", + "loading": "Logging in...", "error": { "failed": "Login failed", "retry": "Login failed, please try again later" @@ -76,7 +92,7 @@ "nav": { "dashboard": "Dashboard", "posts": "Posts", - "daily": "Daily Quotes", + "daily": "Daily", "medias": "Media", "categories": "Categories", "users": "Users", @@ -110,5 +126,45 @@ "passwordMismatch": "Passwords do not match", "passwordTooShort": "Password must be at least 8 characters long" } + }, + "editor": { + "heading1": "Heading 1", + "heading2": "Heading 2", + "heading3": "Heading 3", + "bold": "Bold", + "italic": "Italic", + "orderedList": "Ordered List", + "unorderedList": "Unordered List", + "quote": "Quote", + "link": "Link", + "image": "Image", + "inlineCode": "Inline Code", + "codeBlock": "Code Block", + "table": "Table", + "togglePreview": "Toggle Preview", + "toggleFullscreen": "Toggle Fullscreen", + "selectLanguage": "Select Language", + "plainText": "Plain Text", + "insertTable": "Insert Table", + "addRowAbove": "Add Row Above", + "addRowBelow": "Add Row Below", + "addColumnLeft": "Add Column Left", + "addColumnRight": "Add Column Right", + "deleteRow": "Delete Row", + "deleteColumn": "Delete Column", + "deleteTable": "Delete Table", + "dragAndDrop": "Drop images here", + "dropToUpload": "Drop files here to upload", + "uploading": "Uploading...", + "uploadProgress": "Upload progress: {{progress}}%", + "uploadError": "Failed to upload image: {{error}}", + "uploadSuccess": "Image uploaded successfully", + "codeBlockShortcut": "Ctrl+Shift+K", + "boldShortcut": "Ctrl+B", + "italicShortcut": "Ctrl+I", + "linkShortcut": "Ctrl+K", + "heading1Shortcut": "Shift+1", + "heading2Shortcut": "Shift+2", + "heading3Shortcut": "Shift+3" } } diff --git a/frontend/data/i18n/zh-Hans.json b/frontend/data/i18n/zh-Hans.json index 296d279..50c6432 100644 --- a/frontend/data/i18n/zh-Hans.json +++ b/frontend/data/i18n/zh-Hans.json @@ -47,7 +47,9 @@ "username": "用户名", "logout": "退出登录", "save": "保存", - "saving": "保存中..." + "saving": "保存中...", + "noData": "暂无数据", + "unsavedChanges": "你有未保存的更改,确定要离开吗?" }, "dashboard": { "totalPosts": "文章总数", @@ -55,6 +57,26 @@ "totalUsers": "用户总数", "totalContributors": "贡献者总数" }, + "posts": { + "title": "标题", + "categories": "分类", + "createdAt": "创建时间", + "status": "状态", + "noTitle": "无标题", + "deleteConfirm": "确定要删除这篇文章吗?", + "create": "创建文章", + "edit": "编辑文章", + "slug": "文章链接", + "content": "内容", + "summary": "摘要", + "metaKeywords": "关键词", + "metaDescription": "描述", + "selectCategories": "选择分类", + "saving": "保存中...", + "publishing": "发布中...", + "saveDraft": "保存草稿", + "publish": "发布文章" + }, "login": { "title": "管理员登录", "username": "用户名", @@ -104,5 +126,45 @@ "passwordMismatch": "两次输入的密码不一致", "passwordTooShort": "密码长度不能少于8个字符" } + }, + "editor": { + "heading1": "一级标题", + "heading2": "二级标题", + "heading3": "三级标题", + "bold": "粗体", + "italic": "斜体", + "orderedList": "有序列表", + "unorderedList": "无序列表", + "quote": "引用", + "link": "链接", + "image": "图片", + "inlineCode": "行内代码", + "codeBlock": "代码块", + "table": "表格", + "togglePreview": "切换预览", + "toggleFullscreen": "切换全屏", + "selectLanguage": "选择语言", + "plainText": "纯文本", + "insertTable": "插入表格", + "addRowAbove": "在上方插入行", + "addRowBelow": "在下方插入行", + "addColumnLeft": "在左侧插入列", + "addColumnRight": "在右侧插入列", + "deleteRow": "删除行", + "deleteColumn": "删除列", + "deleteTable": "删除表格", + "dragAndDrop": "拖放图片到此处", + "dropToUpload": "拖放文件到此处上传", + "uploading": "上传中...", + "uploadProgress": "上传进度:{{progress}}%", + "uploadError": "图片上传失败:{{error}}", + "uploadSuccess": "图片上传成功", + "codeBlockShortcut": "Ctrl+Shift+K", + "boldShortcut": "Ctrl+B", + "italicShortcut": "Ctrl+I", + "linkShortcut": "Ctrl+K", + "heading1Shortcut": "Shift+1", + "heading2Shortcut": "Shift+2", + "heading3Shortcut": "Shift+3" } } diff --git a/frontend/data/i18n/zh-Hant.json b/frontend/data/i18n/zh-Hant.json index 7a89534..34b35a4 100644 --- a/frontend/data/i18n/zh-Hant.json +++ b/frontend/data/i18n/zh-Hant.json @@ -45,7 +45,7 @@ "lastLogin": "最後登入", "joinDate": "加入時間", "username": "用戶名", - "logout": "退出登錄", + "logout": "退出登入", "language": "語言", "theme": { "light": "淺色模式", @@ -53,7 +53,9 @@ "system": "跟隨系統" }, "save": "保存", - "saving": "保存中..." + "saving": "保存中...", + "noData": "暫無數據", + "unsavedChanges": "你有未保存的更改,確定要離開嗎?" }, "nav": { "dashboard": "儀表板", @@ -83,15 +85,15 @@ "uploadDate": "上傳日期" }, "login": { - "title": "管理員登錄", + "title": "管理員登入", "username": "用戶名", "password": "密碼", "remember": "記住我", - "submit": "登錄", - "loading": "登錄中...", + "submit": "登入", + "loading": "登入中...", "error": { - "failed": "登錄失敗", - "retry": "登錄失敗,請稍後重試" + "failed": "登入失敗", + "retry": "登入失敗,請稍後重試" } }, "roles": { @@ -110,5 +112,45 @@ "passwordMismatch": "兩次輸入的密碼不一致", "passwordTooShort": "密碼長度不能少於8個字符" } + }, + "editor": { + "heading1": "一級標題", + "heading2": "二級標題", + "heading3": "三級標題", + "bold": "粗體", + "italic": "斜體", + "orderedList": "有序列表", + "unorderedList": "無序列表", + "quote": "引用", + "link": "連結", + "image": "圖片", + "inlineCode": "行內程式碼", + "codeBlock": "程式碼區塊", + "table": "表格", + "togglePreview": "切換預覽", + "toggleFullscreen": "切換全螢幕", + "selectLanguage": "選擇語言", + "plainText": "純文字", + "insertTable": "插入表格", + "addRowAbove": "在上方插入列", + "addRowBelow": "在下方插入列", + "addColumnLeft": "在左側插入欄", + "addColumnRight": "在右側插入欄", + "deleteRow": "刪除列", + "deleteColumn": "刪除欄", + "deleteTable": "刪除表格", + "dragAndDrop": "拖放圖片到此處", + "dropToUpload": "拖放檔案到此處上傳", + "uploading": "上傳中...", + "uploadProgress": "上傳進度:{{progress}}%", + "uploadError": "圖片上傳失敗:{{error}}", + "uploadSuccess": "圖片上傳成功", + "codeBlockShortcut": "Ctrl+Shift+K", + "boldShortcut": "Ctrl+B", + "italicShortcut": "Ctrl+I", + "linkShortcut": "Ctrl+K", + "heading1Shortcut": "Shift+1", + "heading2Shortcut": "Shift+2", + "heading3Shortcut": "Shift+3" } } diff --git a/frontend/package.json b/frontend/package.json index ff1ee5e..195deb2 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -12,7 +12,14 @@ "dependencies": { "@headlessui/react": "^2.2.0", "@tss-rocks/api": "workspace:*", + "@types/axios": "^0.14.4", + "@types/classnames": "^2.3.4", + "@types/highlight.js": "^10.1.0", "@types/markdown-it": "^14.1.2", + "axios": "^1.7.9", + "classnames": "^2.5.1", + "dayjs": "^1.11.13", + "highlight.js": "^11.11.1", "i18next": "^24.2.2", "i18next-browser-languagedetector": "^8.0.4", "lucide-react": "^0.474.0", @@ -21,7 +28,13 @@ "react-dom": "^19.0.0", "react-i18next": "^15.4.0", "react-icons": "^5.4.0", - "react-router-dom": "^7.1.5" + "react-markdown": "^10.0.0", + "react-router-dom": "^7.1.5", + "react-toastify": "^11.0.3", + "rehype-highlight": "^7.0.2", + "rehype-raw": "^7.0.0", + "rehype-sanitize": "^6.0.0", + "remark-gfm": "^4.0.1" }, "devDependencies": { "@eslint/js": "^9.9.1", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 9fa788f..db0ad14 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -4,16 +4,21 @@ import { AuthProvider } from './contexts/AuthContext'; import { UserProvider } from './contexts/UserContext'; import router from './router'; import LoadingSpinner from './components/LoadingSpinner'; +import { ToastContainer } from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; function App() { return ( - }> - - - - - - + <> + }> + + + + + + + + ); } diff --git a/frontend/src/components/LoadingSpinner.tsx b/frontend/src/components/LoadingSpinner.tsx index f6742fc..e1dc29a 100644 --- a/frontend/src/components/LoadingSpinner.tsx +++ b/frontend/src/components/LoadingSpinner.tsx @@ -9,8 +9,8 @@ export default function LoadingSpinner({ fullScreen = false }: LoadingSpinnerPro const content = (
-
-
+
+
{t('admin.common.loading')}
diff --git a/frontend/src/components/admin/MarkdownEditor.tsx b/frontend/src/components/admin/MarkdownEditor.tsx new file mode 100644 index 0000000..0980c14 --- /dev/null +++ b/frontend/src/components/admin/MarkdownEditor.tsx @@ -0,0 +1,958 @@ +import { FC, useState, useRef, useEffect } from 'react'; +import { useTranslation } from 'react-i18next'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import rehypeRaw from 'rehype-raw'; +import rehypeSanitize from 'rehype-sanitize'; +import rehypeHighlight from 'rehype-highlight'; +import 'highlight.js/styles/github-dark.css'; +import classNames from 'classnames'; +import axios from 'axios'; +import { useToast } from '@/hooks/useToast'; +import { + RiH1, + RiH2, + RiH3, + RiBold, + RiItalic, + RiListOrdered, + RiListUnordered, + RiLink, + RiImage2Line, + RiCodeLine, + RiCodeBoxLine, + RiTableLine, + RiDoubleQuotesL, + RiEyeLine, + RiEyeOffLine, + RiFullscreenLine, +} from 'react-icons/ri'; + +interface MarkdownEditorProps { + value: string; + onChange: (value: string) => void; + placeholder?: string; +} + +type CodeBlockProps = { + inline?: boolean; + className?: string; + children: React.ReactNode; +} & React.HTMLAttributes; + +const MarkdownEditor: FC = ({ + value, + onChange, + placeholder, +}) => { + const { t } = useTranslation(); + const { showToast } = useToast(); + const textareaRef = useRef(null); + const tableMenuRef = useRef(null); + const langSelectorRef = useRef(null); + const [showTableMenu, setShowTableMenu] = useState(false); + const [showLangSelector, setShowLangSelector] = useState(false); + const [isPreview, setIsPreview] = useState(false); + const [isFullscreen, setIsFullscreen] = useState(false); + const [isDraggingOver, setIsDraggingOver] = useState(false); + + // 检查当前行是否为空 + const isCurrentLineEmpty = (textarea: HTMLTextAreaElement): boolean => { + const text = textarea.value; + const lines = text.split('\n'); + const currentLine = getCurrentLine(textarea); + return !lines[currentLine].trim(); + }; + + // 获取当前行号 + const getCurrentLine = (textarea: HTMLTextAreaElement): number => { + const text = textarea.value.substring(0, textarea.selectionStart); + return text.split('\n').length - 1; + }; + + // 在当前位置插入文本 + const insertText = (textarea: HTMLTextAreaElement, text: string) => { + const start = textarea.selectionStart; + const end = textarea.selectionEnd; + const before = textarea.value.substring(0, start); + const after = textarea.value.substring(end); + + const needNewLine = !isCurrentLineEmpty(textarea) && + (text.startsWith('#') || text.startsWith('>')); + + const newText = needNewLine ? `\n${text}` : text; + const newValue = before + newText + after; + + onChange(newValue); + + textarea.value = newValue; + const newCursorPos = start + newText.length; + textarea.selectionStart = newCursorPos; + textarea.selectionEnd = newCursorPos; + textarea.focus(); + }; + + const insertTable = (rows: number, cols: number) => { + if (!textareaRef.current) return; + + const headers = Array(cols).fill('header').join(' | '); + const separators = Array(cols).fill('---').join(' | '); + const cells = Array(cols).fill('content').join(' | '); + const rows_content = Array(rows).fill(cells).join('\n| '); + + insertText( + textareaRef.current, + `\n| ${headers} |\n| ${separators} |\n| ${rows_content} |\n\n` + ); + setShowTableMenu(false); + }; + + const findTableAtCursor = (): { table: string; start: number; end: number; rows: string[][]; alignments: string[] } | null => { + if (!textareaRef.current) return null; + + const text = textareaRef.current.value; + const cursorPos = textareaRef.current.selectionStart; + + const tableRegex = /\|[^\n]+\|[\s]*\n\|[- |:]+\|[\s]*\n(\|[^\n]+\|[\s]*\n?)+/g; + let match; + while ((match = tableRegex.exec(text)) !== null) { + const start = match.index; + const end = start + match[0].length; + if (cursorPos >= start && cursorPos <= end) { + const rows = match[0] + .trim() + .split('\n') + .map(row => + row + .trim() + .replace(/^\||\|$/g, '') + .split('|') + .map(cell => cell.trim()) + ); + + const alignments = rows[1].map(cell => { + if (cell.startsWith(':') && cell.endsWith(':')) return 'center'; + if (cell.endsWith(':')) return 'right'; + return 'left'; + }); + + return { + table: match[0], + start, + end, + rows: [rows[0], ...rows.slice(2)], + alignments, + }; + } + } + return null; + }; + + const generateTableText = (rows: string[][], alignments: string[]): string => { + const colWidths = rows[0].map((_, colIndex) => + Math.max(...rows.map(row => row[colIndex]?.length || 0)) + ); + + const separator = alignments.map((align, i) => { + const width = Math.max(3, colWidths[i]); + switch (align) { + case 'center': + return ':' + '-'.repeat(width - 2) + ':'; + case 'right': + return '-'.repeat(width - 1) + ':'; + default: + return '-'.repeat(width); + } + }); + + const tableRows = [ + rows[0], + separator, + ...rows.slice(1) + ].map(row => + '| ' + row.map((cell, i) => cell.padEnd(colWidths[i])).join(' | ') + ' |' + ); + + return tableRows.join('\n') + '\n'; + }; + + const updateTable = (pos: { table: string; start: number; end: number; rows: string[][]; alignments: string[] }, newRows?: string[][]) => { + if (!textareaRef.current) return; + + const text = textareaRef.current.value; + const newTable = generateTableText(newRows || pos.rows, pos.alignments); + const newText = text.substring(0, pos.start) + newTable + text.substring(pos.end); + onChange(newText); + }; + + const tableActions = [ + { + label: t('editor.insertTable'), + action: () => insertTable(2, 2), + }, + { + label: t('editor.addRowAbove'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + let rowIndex = 0; + let currentPos = tablePos.start; + + for (let i = 0; i < tablePos.rows.length; i++) { + const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments); + if (cursorPos <= currentPos + rowText.length) { + rowIndex = i; + break; + } + currentPos += rowText.length; + } + + const newRows = [...tablePos.rows]; + newRows.splice(rowIndex, 0, Array(tablePos.rows[0].length).fill('')); + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.addRowBelow'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + let rowIndex = 0; + let currentPos = tablePos.start; + + for (let i = 0; i < tablePos.rows.length; i++) { + const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments); + if (cursorPos <= currentPos + rowText.length) { + rowIndex = i; + break; + } + currentPos += rowText.length; + } + + const newRows = [...tablePos.rows]; + newRows.splice(rowIndex + 1, 0, Array(tablePos.rows[0].length).fill('')); + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.addColumnLeft'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + const line = textareaRef.current?.value + .substring(tablePos.start, cursorPos) + .split('\n') + .pop() || ''; + const colIndex = (line.match(/\|/g) || []).length - 1; + + const newRows = tablePos.rows.map(row => { + const newRow = [...row]; + newRow.splice(colIndex, 0, ''); + return newRow; + }); + + const newAlignments = [...tablePos.alignments]; + newAlignments.splice(colIndex, 0, 'left'); + + tablePos.alignments = newAlignments; + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.addColumnRight'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + const line = textareaRef.current?.value + .substring(tablePos.start, cursorPos) + .split('\n') + .pop() || ''; + const colIndex = (line.match(/\|/g) || []).length - 1; + + const newRows = tablePos.rows.map(row => { + const newRow = [...row]; + newRow.splice(colIndex + 1, 0, ''); + return newRow; + }); + + const newAlignments = [...tablePos.alignments]; + newAlignments.splice(colIndex + 1, 0, 'left'); + + tablePos.alignments = newAlignments; + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.deleteRow'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + let rowIndex = 0; + let currentPos = tablePos.start; + + for (let i = 0; i < tablePos.rows.length; i++) { + const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments); + if (cursorPos <= currentPos + rowText.length) { + rowIndex = i; + break; + } + currentPos += rowText.length; + } + + const newRows = [...tablePos.rows]; + newRows.splice(rowIndex, 1); + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.deleteColumn'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const cursorPos = textareaRef.current?.selectionStart || 0; + const line = textareaRef.current?.value + .substring(tablePos.start, cursorPos) + .split('\n') + .pop() || ''; + const colIndex = (line.match(/\|/g) || []).length - 1; + + const newRows = tablePos.rows.map(row => { + const newRow = [...row]; + newRow.splice(colIndex, 1); + return newRow; + }); + + const newAlignments = [...tablePos.alignments]; + newAlignments.splice(colIndex, 1); + + tablePos.alignments = newAlignments; + updateTable(tablePos, newRows); + }, + }, + { + label: t('editor.deleteTable'), + action: () => { + const tablePos = findTableAtCursor(); + if (!tablePos) return; + + const text = textareaRef.current?.value || ''; + const newText = text.substring(0, tablePos.start) + text.substring(tablePos.end); + onChange(newText); + }, + }, + ]; + + const toolbarButtons = [ + { + icon: RiH1, + label: t('editor.heading1'), + shortcut: t('editor.heading1Shortcut'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '# '); + }, + }, + { + icon: RiH2, + label: t('editor.heading2'), + shortcut: t('editor.heading2Shortcut'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '## '); + }, + }, + { + icon: RiH3, + label: t('editor.heading3'), + shortcut: t('editor.heading3Shortcut'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '### '); + }, + }, + { + icon: RiBold, + label: t('editor.bold'), + shortcut: t('editor.boldShortcut'), + action: (textarea: HTMLTextAreaElement) => { + const text = textarea.value.substring( + textarea.selectionStart, + textarea.selectionEnd + ); + insertText(textarea, `**${text || t('editor.bold')}**`); + }, + }, + { + icon: RiItalic, + label: t('editor.italic'), + shortcut: t('editor.italicShortcut'), + action: (textarea: HTMLTextAreaElement) => { + const text = textarea.value.substring( + textarea.selectionStart, + textarea.selectionEnd + ); + insertText(textarea, `_${text || t('editor.italic')}_`); + }, + }, + { + icon: RiListUnordered, + label: t('editor.unorderedList'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '- '); + }, + }, + { + icon: RiListOrdered, + label: t('editor.orderedList'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '1. '); + }, + }, + { + icon: RiDoubleQuotesL, + label: t('editor.quote'), + action: (textarea: HTMLTextAreaElement) => { + insertText(textarea, '> '); + }, + }, + { + icon: RiLink, + label: t('editor.link'), + shortcut: t('editor.linkShortcut'), + action: (textarea: HTMLTextAreaElement) => { + const text = textarea.value.substring( + textarea.selectionStart, + textarea.selectionEnd + ); + insertText(textarea, `[${text || t('editor.link')}](url)`); + }, + }, + { + icon: RiImage2Line, + label: t('editor.image'), + action: () => { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = 'image/*'; + input.onchange = (e) => { + const file = (e.target as HTMLInputElement).files?.[0]; + if (file) { + handleFileUpload(file); + } + }; + input.click(); + }, + }, + { + icon: RiCodeLine, + label: t('editor.inlineCode'), + action: (textarea: HTMLTextAreaElement) => { + const text = textarea.value.substring( + textarea.selectionStart, + textarea.selectionEnd + ); + insertText(textarea, `\`${text || t('editor.inlineCode')}\``); + }, + }, + ]; + + const commonLanguages = [ + { value: 'javascript', label: 'JavaScript' }, + { value: 'typescript', label: 'TypeScript' }, + { value: 'jsx', label: 'JSX' }, + { value: 'tsx', label: 'TSX' }, + { value: 'css', label: 'CSS' }, + { value: 'html', label: 'HTML' }, + { value: 'json', label: 'JSON' }, + { value: 'markdown', label: 'Markdown' }, + { value: 'python', label: 'Python' }, + { value: 'java', label: 'Java' }, + { value: 'c', label: 'C' }, + { value: 'cpp', label: 'C++' }, + { value: 'csharp', label: 'C#' }, + { value: 'go', label: 'Go' }, + { value: 'rust', label: 'Rust' }, + { value: 'php', label: 'PHP' }, + { value: 'ruby', label: 'Ruby' }, + { value: 'swift', label: 'Swift' }, + { value: 'kotlin', label: 'Kotlin' }, + { value: 'sql', label: 'SQL' }, + { value: 'shell', label: 'Shell' }, + { value: 'yaml', label: 'YAML' }, + { value: 'xml', label: 'XML' }, + ]; + + const insertCodeBlock = (language?: string) => { + if (!textareaRef.current) return; + const lang = language || ''; + insertText(textareaRef.current, `\n\`\`\`${lang}\n\n\`\`\`\n`); + const cursorPos = textareaRef.current.selectionStart - 4; + textareaRef.current.setSelectionRange(cursorPos, cursorPos); + setShowLangSelector(false); + }; + + // 处理文件上传 + const handleFileUpload = async (file: File) => { + if (!file.type.startsWith('image/')) { + showToast('error', t('editor.uploadError', { error: 'Not an image file' })); + return; + } + + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await axios.post('/api/v1/media', formData, { + headers: { + 'Content-Type': 'multipart/form-data', + 'Authorization': `Bearer ${localStorage.getItem('token')}`, + }, + withCredentials: true, + }); + + const imageUrl = response.data.data.url; + const imageMarkdown = `![${file.name}](${imageUrl})`; + + if (textareaRef.current) { + const start = textareaRef.current.selectionStart; + const end = textareaRef.current.selectionEnd; + const before = value.substring(0, start); + const after = value.substring(end); + const newValue = before + imageMarkdown + after; + onChange(newValue); + showToast('success', t('editor.uploadSuccess')); + } + } catch (error: any) { + console.error('Upload error:', error); + const errorMessage = error.response?.data?.error?.message || error.message; + showToast('error', t('editor.uploadError', { error: errorMessage })); + } + }; + + // 处理拖放事件 + const handleDragOver = (e: React.DragEvent) => { + e.preventDefault(); + e.stopPropagation(); + setIsDraggingOver(true); + }; + + const handleDragLeave = (e: React.DragEvent) => { + e.preventDefault(); + e.stopPropagation(); + setIsDraggingOver(false); + }; + + const handleDrop = async (e: React.DragEvent) => { + e.preventDefault(); + e.stopPropagation(); + setIsDraggingOver(false); + + const files = Array.from(e.dataTransfer.files); + for (const file of files) { + await handleFileUpload(file); + } + }; + + // 处理粘贴事件 + const handlePaste = async (e: React.ClipboardEvent) => { + const items = Array.from(e.clipboardData.items); + for (const item of items) { + if (item.type.startsWith('image/')) { + e.preventDefault(); + const file = item.getAsFile(); + if (file) { + await handleFileUpload(file); + } + } + } + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.ctrlKey || e.metaKey) { + const key = e.key.toLowerCase(); + switch (key) { + case 'b': // 粗体 + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `**${text}**`); + } + break; + case 'i': // 斜体 + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `_${text}_`); + } + break; + case 'k': // 链接 + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `[${text}](url)`); + } + break; + case '1': // 标题 1 + if (e.shiftKey) { + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `# ${text}`); + } + } + break; + case '2': // 标题 2 + if (e.shiftKey) { + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `## ${text}`); + } + } + break; + case '3': // 标题 3 + if (e.shiftKey) { + e.preventDefault(); + if (textareaRef.current) { + const text = textareaRef.current.value.substring( + textareaRef.current.selectionStart, + textareaRef.current.selectionEnd + ); + insertText(textareaRef.current, `### ${text}`); + } + } + break; + case 'e': // 预览 + e.preventDefault(); + setIsPreview(!isPreview); + break; + } + } + }; + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if ( + tableMenuRef.current && + !tableMenuRef.current.contains(event.target as Node) + ) { + setShowTableMenu(false); + } + if ( + langSelectorRef.current && + !langSelectorRef.current.contains(event.target as Node) + ) { + setShowLangSelector(false); + } + }; + + document.addEventListener('mousedown', handleClickOutside); + return () => { + document.removeEventListener('mousedown', handleClickOutside); + }; + }, []); + + return ( +
+
+ {/* 左侧编辑工具栏 */} +
+ {toolbarButtons.map((button, index) => { + return ( + + ); + })} + + {/* 表格按钮 */} +
+ + {showTableMenu && ( +
+ {tableActions.map((action, index) => ( + + ))} +
+ )} +
+ + {/* 代码块按钮 */} +
+ + {showLangSelector && ( +
+
+ {t('editor.selectLanguage')} +
+
+ + {commonLanguages.map(({ value, label }) => ( + + ))} +
+
+ )} +
+
+ + {/* 右侧预览和全屏按钮 */} +
+ + +
+
+ +
+
+ {isDraggingOver && ( +
+
+ {t('editor.dragAndDrop')} +
+
+ )} +