From be8bf22017a0d8b10089a0ec2551bcdb86967e92 Mon Sep 17 00:00:00 2001 From: cdn0x12 Date: Sat, 22 Feb 2025 02:42:55 +0800 Subject: [PATCH] [feature/backend] add categories param in posts --- api/schemas/components/schemas.yaml | 6 ++ backend/ent/category/category.go | 14 ++- backend/ent/category/where.go | 2 +- backend/ent/category_create.go | 4 +- backend/ent/category_query.go | 68 +++++++++---- backend/ent/category_update.go | 24 ++--- backend/ent/client.go | 8 +- backend/ent/migrate/schema.go | 38 +++++-- backend/ent/mutation.go | 102 +++++++++++-------- backend/ent/post.go | 39 +++----- backend/ent/post/post.go | 48 ++++----- backend/ent/post/where.go | 12 +-- backend/ent/post_create.go | 31 +++--- backend/ent/post_query.go | 102 +++++++++++-------- backend/ent/post_update.go | 150 +++++++++++++++++++--------- backend/ent/schema/post.go | 5 +- backend/go.mod | 2 - backend/go.sum | 3 - backend/internal/handler/handler.go | 32 ++++-- backend/internal/service/impl.go | 35 +++++-- backend/internal/service/service.go | 4 +- 21 files changed, 448 insertions(+), 281 deletions(-) diff --git a/api/schemas/components/schemas.yaml b/api/schemas/components/schemas.yaml index 056ef91..3812e98 100644 --- a/api/schemas/components/schemas.yaml +++ b/api/schemas/components/schemas.yaml @@ -103,6 +103,7 @@ Post: - slug - status - contents + - categories properties: id: type: integer @@ -117,6 +118,11 @@ Post: type: array items: $ref: '#/PostContent' + categories: + type: array + items: + $ref: '#/Category' + description: 文章所属分类 created_at: type: string format: date-time diff --git a/backend/ent/category/category.go b/backend/ent/category/category.go index 80073b1..529079d 100644 --- a/backend/ent/category/category.go +++ b/backend/ent/category/category.go @@ -33,13 +33,11 @@ const ( ContentsInverseTable = "category_contents" // ContentsColumn is the table column denoting the contents relation/edge. ContentsColumn = "category_contents" - // PostsTable is the table that holds the posts relation/edge. - PostsTable = "posts" + // PostsTable is the table that holds the posts relation/edge. The primary key declared below. + PostsTable = "category_posts" // PostsInverseTable is the table name for the Post entity. // It exists in this package in order to avoid circular dependency with the "post" package. PostsInverseTable = "posts" - // PostsColumn is the table column denoting the posts relation/edge. - PostsColumn = "category_posts" // DailyItemsTable is the table that holds the daily_items relation/edge. DailyItemsTable = "dailies" // DailyItemsInverseTable is the table name for the Daily entity. @@ -56,6 +54,12 @@ var Columns = []string{ FieldUpdatedAt, } +var ( + // PostsPrimaryKey and PostsColumn2 are the table columns denoting the + // primary key for the posts relation (M2M). + PostsPrimaryKey = []string{"category_id", "post_id"} +) + // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { for i := range Columns { @@ -145,7 +149,7 @@ func newPostsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), sqlgraph.To(PostsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...), ) } func newDailyItemsStep() *sqlgraph.Step { diff --git a/backend/ent/category/where.go b/backend/ent/category/where.go index ea42ba5..223d752 100644 --- a/backend/ent/category/where.go +++ b/backend/ent/category/where.go @@ -173,7 +173,7 @@ func HasPosts() predicate.Category { return predicate.Category(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...), ) sqlgraph.HasNeighbors(s, step) }) diff --git a/backend/ent/category_create.go b/backend/ent/category_create.go index f479883..b78ea04 100644 --- a/backend/ent/category_create.go +++ b/backend/ent/category_create.go @@ -201,10 +201,10 @@ func (cc *CategoryCreate) createSpec() (*Category, *sqlgraph.CreateSpec) { } if nodes := cc.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), diff --git a/backend/ent/category_query.go b/backend/ent/category_query.go index 26c7b61..83dadcc 100644 --- a/backend/ent/category_query.go +++ b/backend/ent/category_query.go @@ -101,7 +101,7 @@ func (cq *CategoryQuery) QueryPosts() *PostQuery { step := sqlgraph.NewStep( sqlgraph.From(category.Table, category.FieldID, selector), sqlgraph.To(post.Table, post.FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...), ) fromU = sqlgraph.SetNeighbors(cq.driver.Dialect(), step) return fromU, nil @@ -523,33 +523,63 @@ func (cq *CategoryQuery) loadContents(ctx context.Context, query *CategoryConten return nil } func (cq *CategoryQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*Category, init func(*Category), assign func(*Category, *Post)) error { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Category) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Category) + nids := make(map[int]map[*Category]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node if init != nil { - init(nodes[i]) + init(node) } } - query.withFKs = true - query.Where(predicate.Post(func(s *sql.Selector) { - s.Where(sql.InValues(s.C(category.PostsColumn), fks...)) - })) - neighbors, err := query.All(ctx) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(category.PostsTable) + s.Join(joinT).On(s.C(post.FieldID), joinT.C(category.PostsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(category.PostsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(category.PostsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Category]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Post](ctx, query, qr, query.inters) if err != nil { return err } for _, n := range neighbors { - fk := n.category_posts - if fk == nil { - return fmt.Errorf(`foreign-key "category_posts" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] + nodes, ok := nids[n.ID] if !ok { - return fmt.Errorf(`unexpected referenced foreign-key "category_posts" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected "posts" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } - assign(node, n) } return nil } diff --git a/backend/ent/category_update.go b/backend/ent/category_update.go index 628a7f2..57c3012 100644 --- a/backend/ent/category_update.go +++ b/backend/ent/category_update.go @@ -262,10 +262,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if cu.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -275,10 +275,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if nodes := cu.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cu.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -291,10 +291,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) { } if nodes := cu.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -631,10 +631,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if cuo.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -644,10 +644,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if nodes := cuo.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cuo.mutation.PostsCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), @@ -660,10 +660,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err } if nodes := cuo.mutation.PostsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.O2M, + Rel: sqlgraph.M2M, Inverse: false, Table: category.PostsTable, - Columns: []string{category.PostsColumn}, + Columns: category.PostsPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt), diff --git a/backend/ent/client.go b/backend/ent/client.go index bbcf6f3..2e7b243 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -464,7 +464,7 @@ func (c *CategoryClient) QueryPosts(ca *Category) *PostQuery { step := sqlgraph.NewStep( sqlgraph.From(category.Table, category.FieldID, id), sqlgraph.To(post.Table, post.FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn), + sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...), ) fromV = sqlgraph.Neighbors(ca.driver.Dialect(), step) return fromV, nil @@ -2207,15 +2207,15 @@ func (c *PostClient) QueryContributors(po *Post) *PostContributorQuery { return query } -// QueryCategory queries the category edge of a Post. -func (c *PostClient) QueryCategory(po *Post) *CategoryQuery { +// QueryCategories queries the categories edge of a Post. +func (c *PostClient) QueryCategories(po *Post) *CategoryQuery { query := (&CategoryClient{config: c.config}).Query() query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := po.ID step := sqlgraph.NewStep( sqlgraph.From(post.Table, post.FieldID, id), sqlgraph.To(category.Table, category.FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...), ) fromV = sqlgraph.Neighbors(po.driver.Dialect(), step) return fromV, nil diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 562eed3..614e3af 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -265,21 +265,12 @@ var ( {Name: "slug", Type: field.TypeString, Unique: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, - {Name: "category_posts", Type: field.TypeInt, Nullable: true}, } // PostsTable holds the schema information for the "posts" table. PostsTable = &schema.Table{ Name: "posts", Columns: PostsColumns, PrimaryKey: []*schema.Column{PostsColumns[0]}, - ForeignKeys: []*schema.ForeignKey{ - { - Symbol: "posts_categories_posts", - Columns: []*schema.Column{PostsColumns[5]}, - RefColumns: []*schema.Column{CategoriesColumns[0]}, - OnDelete: schema.SetNull, - }, - }, } // PostContentsColumns holds the columns for the "post_contents" table. PostContentsColumns = []*schema.Column{ @@ -387,6 +378,31 @@ var ( Columns: UsersColumns, PrimaryKey: []*schema.Column{UsersColumns[0]}, } + // CategoryPostsColumns holds the columns for the "category_posts" table. + CategoryPostsColumns = []*schema.Column{ + {Name: "category_id", Type: field.TypeInt}, + {Name: "post_id", Type: field.TypeInt}, + } + // CategoryPostsTable holds the schema information for the "category_posts" table. + CategoryPostsTable = &schema.Table{ + Name: "category_posts", + Columns: CategoryPostsColumns, + PrimaryKey: []*schema.Column{CategoryPostsColumns[0], CategoryPostsColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "category_posts_category_id", + Columns: []*schema.Column{CategoryPostsColumns[0]}, + RefColumns: []*schema.Column{CategoriesColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "category_posts_post_id", + Columns: []*schema.Column{CategoryPostsColumns[1]}, + RefColumns: []*schema.Column{PostsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } // RolePermissionsColumns holds the columns for the "role_permissions" table. RolePermissionsColumns = []*schema.Column{ {Name: "role_id", Type: field.TypeInt}, @@ -455,6 +471,7 @@ var ( PostContributorsTable, RolesTable, UsersTable, + CategoryPostsTable, RolePermissionsTable, UserRolesTable, } @@ -469,11 +486,12 @@ func init() { DailyCategoryContentsTable.ForeignKeys[0].RefTable = DailyCategoriesTable DailyContentsTable.ForeignKeys[0].RefTable = DailiesTable MediaTable.ForeignKeys[0].RefTable = UsersTable - PostsTable.ForeignKeys[0].RefTable = CategoriesTable PostContentsTable.ForeignKeys[0].RefTable = PostsTable PostContributorsTable.ForeignKeys[0].RefTable = ContributorsTable PostContributorsTable.ForeignKeys[1].RefTable = ContributorRolesTable PostContributorsTable.ForeignKeys[2].RefTable = PostsTable + CategoryPostsTable.ForeignKeys[0].RefTable = CategoriesTable + CategoryPostsTable.ForeignKeys[1].RefTable = PostsTable RolePermissionsTable.ForeignKeys[0].RefTable = RolesTable RolePermissionsTable.ForeignKeys[1].RefTable = PermissionsTable UserRolesTable.ForeignKeys[0].RefTable = UsersTable diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index b8ccb5b..f15bc73 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -6578,8 +6578,9 @@ type PostMutation struct { contributors map[int]struct{} removedcontributors map[int]struct{} clearedcontributors bool - category *int - clearedcategory bool + categories map[int]struct{} + removedcategories map[int]struct{} + clearedcategories bool done bool oldValue func(context.Context) (*Post, error) predicates []predicate.Post @@ -6935,43 +6936,58 @@ func (m *PostMutation) ResetContributors() { m.removedcontributors = nil } -// SetCategoryID sets the "category" edge to the Category entity by id. -func (m *PostMutation) SetCategoryID(id int) { - m.category = &id +// AddCategoryIDs adds the "categories" edge to the Category entity by ids. +func (m *PostMutation) AddCategoryIDs(ids ...int) { + if m.categories == nil { + m.categories = make(map[int]struct{}) + } + for i := range ids { + m.categories[ids[i]] = struct{}{} + } } -// ClearCategory clears the "category" edge to the Category entity. -func (m *PostMutation) ClearCategory() { - m.clearedcategory = true +// ClearCategories clears the "categories" edge to the Category entity. +func (m *PostMutation) ClearCategories() { + m.clearedcategories = true } -// CategoryCleared reports if the "category" edge to the Category entity was cleared. -func (m *PostMutation) CategoryCleared() bool { - return m.clearedcategory +// CategoriesCleared reports if the "categories" edge to the Category entity was cleared. +func (m *PostMutation) CategoriesCleared() bool { + return m.clearedcategories } -// CategoryID returns the "category" edge ID in the mutation. -func (m *PostMutation) CategoryID() (id int, exists bool) { - if m.category != nil { - return *m.category, true +// RemoveCategoryIDs removes the "categories" edge to the Category entity by IDs. +func (m *PostMutation) RemoveCategoryIDs(ids ...int) { + if m.removedcategories == nil { + m.removedcategories = make(map[int]struct{}) + } + for i := range ids { + delete(m.categories, ids[i]) + m.removedcategories[ids[i]] = struct{}{} + } +} + +// RemovedCategories returns the removed IDs of the "categories" edge to the Category entity. +func (m *PostMutation) RemovedCategoriesIDs() (ids []int) { + for id := range m.removedcategories { + ids = append(ids, id) } return } -// CategoryIDs returns the "category" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// CategoryID instead. It exists only for internal usage by the builders. -func (m *PostMutation) CategoryIDs() (ids []int) { - if id := m.category; id != nil { - ids = append(ids, *id) +// CategoriesIDs returns the "categories" edge IDs in the mutation. +func (m *PostMutation) CategoriesIDs() (ids []int) { + for id := range m.categories { + ids = append(ids, id) } return } -// ResetCategory resets all changes to the "category" edge. -func (m *PostMutation) ResetCategory() { - m.category = nil - m.clearedcategory = false +// ResetCategories resets all changes to the "categories" edge. +func (m *PostMutation) ResetCategories() { + m.categories = nil + m.clearedcategories = false + m.removedcategories = nil } // Where appends a list predicates to the PostMutation builder. @@ -7165,8 +7181,8 @@ func (m *PostMutation) AddedEdges() []string { if m.contributors != nil { edges = append(edges, post.EdgeContributors) } - if m.category != nil { - edges = append(edges, post.EdgeCategory) + if m.categories != nil { + edges = append(edges, post.EdgeCategories) } return edges } @@ -7187,10 +7203,12 @@ func (m *PostMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids - case post.EdgeCategory: - if id := m.category; id != nil { - return []ent.Value{*id} + case post.EdgeCategories: + ids := make([]ent.Value, 0, len(m.categories)) + for id := range m.categories { + ids = append(ids, id) } + return ids } return nil } @@ -7204,6 +7222,9 @@ func (m *PostMutation) RemovedEdges() []string { if m.removedcontributors != nil { edges = append(edges, post.EdgeContributors) } + if m.removedcategories != nil { + edges = append(edges, post.EdgeCategories) + } return edges } @@ -7223,6 +7244,12 @@ func (m *PostMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case post.EdgeCategories: + ids := make([]ent.Value, 0, len(m.removedcategories)) + for id := range m.removedcategories { + ids = append(ids, id) + } + return ids } return nil } @@ -7236,8 +7263,8 @@ func (m *PostMutation) ClearedEdges() []string { if m.clearedcontributors { edges = append(edges, post.EdgeContributors) } - if m.clearedcategory { - edges = append(edges, post.EdgeCategory) + if m.clearedcategories { + edges = append(edges, post.EdgeCategories) } return edges } @@ -7250,8 +7277,8 @@ func (m *PostMutation) EdgeCleared(name string) bool { return m.clearedcontents case post.EdgeContributors: return m.clearedcontributors - case post.EdgeCategory: - return m.clearedcategory + case post.EdgeCategories: + return m.clearedcategories } return false } @@ -7260,9 +7287,6 @@ func (m *PostMutation) EdgeCleared(name string) bool { // if that edge is not defined in the schema. func (m *PostMutation) ClearEdge(name string) error { switch name { - case post.EdgeCategory: - m.ClearCategory() - return nil } return fmt.Errorf("unknown Post unique edge %s", name) } @@ -7277,8 +7301,8 @@ func (m *PostMutation) ResetEdge(name string) error { case post.EdgeContributors: m.ResetContributors() return nil - case post.EdgeCategory: - m.ResetCategory() + case post.EdgeCategories: + m.ResetCategories() return nil } return fmt.Errorf("unknown Post edge %s", name) diff --git a/backend/ent/post.go b/backend/ent/post.go index 0be8691..3e4529c 100644 --- a/backend/ent/post.go +++ b/backend/ent/post.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" "time" - "tss-rocks-be/ent/category" "tss-rocks-be/ent/post" "entgo.io/ent" @@ -28,9 +27,8 @@ type Post struct { UpdatedAt time.Time `json:"updated_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the PostQuery when eager-loading is set. - Edges PostEdges `json:"edges"` - category_posts *int - selectValues sql.SelectValues + Edges PostEdges `json:"edges"` + selectValues sql.SelectValues } // PostEdges holds the relations/edges for other nodes in the graph. @@ -39,8 +37,8 @@ type PostEdges struct { Contents []*PostContent `json:"contents,omitempty"` // Contributors holds the value of the contributors edge. Contributors []*PostContributor `json:"contributors,omitempty"` - // Category holds the value of the category edge. - Category *Category `json:"category,omitempty"` + // Categories holds the value of the categories edge. + Categories []*Category `json:"categories,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. loadedTypes [3]bool @@ -64,15 +62,13 @@ func (e PostEdges) ContributorsOrErr() ([]*PostContributor, error) { return nil, &NotLoadedError{edge: "contributors"} } -// CategoryOrErr returns the Category value or an error if the edge -// was not loaded in eager-loading, or loaded but was not found. -func (e PostEdges) CategoryOrErr() (*Category, error) { - if e.Category != nil { - return e.Category, nil - } else if e.loadedTypes[2] { - return nil, &NotFoundError{label: category.Label} +// CategoriesOrErr returns the Categories value or an error if the edge +// was not loaded in eager-loading. +func (e PostEdges) CategoriesOrErr() ([]*Category, error) { + if e.loadedTypes[2] { + return e.Categories, nil } - return nil, &NotLoadedError{edge: "category"} + return nil, &NotLoadedError{edge: "categories"} } // scanValues returns the types for scanning values from sql.Rows. @@ -86,8 +82,6 @@ func (*Post) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullString) case post.FieldCreatedAt, post.FieldUpdatedAt: values[i] = new(sql.NullTime) - case post.ForeignKeys[0]: // category_posts - values[i] = new(sql.NullInt64) default: values[i] = new(sql.UnknownType) } @@ -133,13 +127,6 @@ func (po *Post) assignValues(columns []string, values []any) error { } else if value.Valid { po.UpdatedAt = value.Time } - case post.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for edge-field category_posts", value) - } else if value.Valid { - po.category_posts = new(int) - *po.category_posts = int(value.Int64) - } default: po.selectValues.Set(columns[i], values[i]) } @@ -163,9 +150,9 @@ func (po *Post) QueryContributors() *PostContributorQuery { return NewPostClient(po.config).QueryContributors(po) } -// QueryCategory queries the "category" edge of the Post entity. -func (po *Post) QueryCategory() *CategoryQuery { - return NewPostClient(po.config).QueryCategory(po) +// QueryCategories queries the "categories" edge of the Post entity. +func (po *Post) QueryCategories() *CategoryQuery { + return NewPostClient(po.config).QueryCategories(po) } // Update returns a builder for updating this Post. diff --git a/backend/ent/post/post.go b/backend/ent/post/post.go index 6dc1e83..b72d3e0 100644 --- a/backend/ent/post/post.go +++ b/backend/ent/post/post.go @@ -27,8 +27,8 @@ const ( EdgeContents = "contents" // EdgeContributors holds the string denoting the contributors edge name in mutations. EdgeContributors = "contributors" - // EdgeCategory holds the string denoting the category edge name in mutations. - EdgeCategory = "category" + // EdgeCategories holds the string denoting the categories edge name in mutations. + EdgeCategories = "categories" // Table holds the table name of the post in the database. Table = "posts" // ContentsTable is the table that holds the contents relation/edge. @@ -45,13 +45,11 @@ const ( ContributorsInverseTable = "post_contributors" // ContributorsColumn is the table column denoting the contributors relation/edge. ContributorsColumn = "post_contributors" - // CategoryTable is the table that holds the category relation/edge. - CategoryTable = "posts" - // CategoryInverseTable is the table name for the Category entity. + // CategoriesTable is the table that holds the categories relation/edge. The primary key declared below. + CategoriesTable = "category_posts" + // CategoriesInverseTable is the table name for the Category entity. // It exists in this package in order to avoid circular dependency with the "category" package. - CategoryInverseTable = "categories" - // CategoryColumn is the table column denoting the category relation/edge. - CategoryColumn = "category_posts" + CategoriesInverseTable = "categories" ) // Columns holds all SQL columns for post fields. @@ -63,11 +61,11 @@ var Columns = []string{ FieldUpdatedAt, } -// ForeignKeys holds the SQL foreign-keys that are owned by the "posts" -// table and are not defined as standalone fields in the schema. -var ForeignKeys = []string{ - "category_posts", -} +var ( + // CategoriesPrimaryKey and CategoriesColumn2 are the table columns denoting the + // primary key for the categories relation (M2M). + CategoriesPrimaryKey = []string{"category_id", "post_id"} +) // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { @@ -76,11 +74,6 @@ func ValidColumn(column string) bool { return true } } - for i := range ForeignKeys { - if column == ForeignKeys[i] { - return true - } - } return false } @@ -178,10 +171,17 @@ func ByContributors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } -// ByCategoryField orders the results by category field. -func ByCategoryField(field string, opts ...sql.OrderTermOption) OrderOption { +// ByCategoriesCount orders the results by categories count. +func ByCategoriesCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { - sqlgraph.OrderByNeighborTerms(s, newCategoryStep(), sql.OrderByField(field, opts...)) + sqlgraph.OrderByNeighborsCount(s, newCategoriesStep(), opts...) + } +} + +// ByCategories orders the results by categories terms. +func ByCategories(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newCategoriesStep(), append([]sql.OrderTerm{term}, terms...)...) } } func newContentsStep() *sqlgraph.Step { @@ -198,10 +198,10 @@ func newContributorsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, ContributorsTable, ContributorsColumn), ) } -func newCategoryStep() *sqlgraph.Step { +func newCategoriesStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(CategoryInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn), + sqlgraph.To(CategoriesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...), ) } diff --git a/backend/ent/post/where.go b/backend/ent/post/where.go index 7b09ff9..4d9dfda 100644 --- a/backend/ent/post/where.go +++ b/backend/ent/post/where.go @@ -281,21 +281,21 @@ func HasContributorsWith(preds ...predicate.PostContributor) predicate.Post { }) } -// HasCategory applies the HasEdge predicate on the "category" edge. -func HasCategory() predicate.Post { +// HasCategories applies the HasEdge predicate on the "categories" edge. +func HasCategories() predicate.Post { return predicate.Post(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...), ) sqlgraph.HasNeighbors(s, step) }) } -// HasCategoryWith applies the HasEdge predicate on the "category" edge with a given conditions (other predicates). -func HasCategoryWith(preds ...predicate.Category) predicate.Post { +// HasCategoriesWith applies the HasEdge predicate on the "categories" edge with a given conditions (other predicates). +func HasCategoriesWith(preds ...predicate.Category) predicate.Post { return predicate.Post(func(s *sql.Selector) { - step := newCategoryStep() + step := newCategoriesStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) diff --git a/backend/ent/post_create.go b/backend/ent/post_create.go index c001fe2..b9fb3c7 100644 --- a/backend/ent/post_create.go +++ b/backend/ent/post_create.go @@ -101,23 +101,19 @@ func (pc *PostCreate) AddContributors(p ...*PostContributor) *PostCreate { return pc.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (pc *PostCreate) SetCategoryID(id int) *PostCreate { - pc.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (pc *PostCreate) AddCategoryIDs(ids ...int) *PostCreate { + pc.mutation.AddCategoryIDs(ids...) return pc } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (pc *PostCreate) SetNillableCategoryID(id *int) *PostCreate { - if id != nil { - pc = pc.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (pc *PostCreate) AddCategories(c ...*Category) *PostCreate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return pc -} - -// SetCategory sets the "category" edge to the Category entity. -func (pc *PostCreate) SetCategory(c *Category) *PostCreate { - return pc.SetCategoryID(c.ID) + return pc.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -267,12 +263,12 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } - if nodes := pc.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := pc.mutation.CategoriesIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -281,7 +277,6 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) { for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.category_posts = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } return _node, _spec diff --git a/backend/ent/post_query.go b/backend/ent/post_query.go index f90b5ff..c3fbb7a 100644 --- a/backend/ent/post_query.go +++ b/backend/ent/post_query.go @@ -28,8 +28,7 @@ type PostQuery struct { predicates []predicate.Post withContents *PostContentQuery withContributors *PostContributorQuery - withCategory *CategoryQuery - withFKs bool + withCategories *CategoryQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -110,8 +109,8 @@ func (pq *PostQuery) QueryContributors() *PostContributorQuery { return query } -// QueryCategory chains the current query on the "category" edge. -func (pq *PostQuery) QueryCategory() *CategoryQuery { +// QueryCategories chains the current query on the "categories" edge. +func (pq *PostQuery) QueryCategories() *CategoryQuery { query := (&CategoryClient{config: pq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := pq.prepareQuery(ctx); err != nil { @@ -124,7 +123,7 @@ func (pq *PostQuery) QueryCategory() *CategoryQuery { step := sqlgraph.NewStep( sqlgraph.From(post.Table, post.FieldID, selector), sqlgraph.To(category.Table, category.FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn), + sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...), ) fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step) return fromU, nil @@ -326,7 +325,7 @@ func (pq *PostQuery) Clone() *PostQuery { predicates: append([]predicate.Post{}, pq.predicates...), withContents: pq.withContents.Clone(), withContributors: pq.withContributors.Clone(), - withCategory: pq.withCategory.Clone(), + withCategories: pq.withCategories.Clone(), // clone intermediate query. sql: pq.sql.Clone(), path: pq.path, @@ -355,14 +354,14 @@ func (pq *PostQuery) WithContributors(opts ...func(*PostContributorQuery)) *Post return pq } -// WithCategory tells the query-builder to eager-load the nodes that are connected to -// the "category" edge. The optional arguments are used to configure the query builder of the edge. -func (pq *PostQuery) WithCategory(opts ...func(*CategoryQuery)) *PostQuery { +// WithCategories tells the query-builder to eager-load the nodes that are connected to +// the "categories" edge. The optional arguments are used to configure the query builder of the edge. +func (pq *PostQuery) WithCategories(opts ...func(*CategoryQuery)) *PostQuery { query := (&CategoryClient{config: pq.config}).Query() for _, opt := range opts { opt(query) } - pq.withCategory = query + pq.withCategories = query return pq } @@ -443,20 +442,13 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) { var ( nodes = []*Post{} - withFKs = pq.withFKs _spec = pq.querySpec() loadedTypes = [3]bool{ pq.withContents != nil, pq.withContributors != nil, - pq.withCategory != nil, + pq.withCategories != nil, } ) - if pq.withCategory != nil { - withFKs = true - } - if withFKs { - _spec.Node.Columns = append(_spec.Node.Columns, post.ForeignKeys...) - } _spec.ScanValues = func(columns []string) ([]any, error) { return (*Post).scanValues(nil, columns) } @@ -489,9 +481,10 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e return nil, err } } - if query := pq.withCategory; query != nil { - if err := pq.loadCategory(ctx, query, nodes, nil, - func(n *Post, e *Category) { n.Edges.Category = e }); err != nil { + if query := pq.withCategories; query != nil { + if err := pq.loadCategories(ctx, query, nodes, + func(n *Post) { n.Edges.Categories = []*Category{} }, + func(n *Post, e *Category) { n.Edges.Categories = append(n.Edges.Categories, e) }); err != nil { return nil, err } } @@ -560,34 +553,63 @@ func (pq *PostQuery) loadContributors(ctx context.Context, query *PostContributo } return nil } -func (pq *PostQuery) loadCategory(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Post) - for i := range nodes { - if nodes[i].category_posts == nil { - continue +func (pq *PostQuery) loadCategories(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Post) + nids := make(map[int]map[*Post]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - fk := *nodes[i].category_posts - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) } - if len(ids) == 0 { - return nil + query.Where(func(s *sql.Selector) { + joinT := sql.Table(post.CategoriesTable) + s.Join(joinT).On(s.C(category.FieldID), joinT.C(post.CategoriesPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(post.CategoriesPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(post.CategoriesPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err } - query.Where(category.IDIn(ids...)) - neighbors, err := query.All(ctx) + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Post]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Category](ctx, query, qr, query.inters) if err != nil { return err } for _, n := range neighbors { - nodes, ok := nodeids[n.ID] + nodes, ok := nids[n.ID] if !ok { - return fmt.Errorf(`unexpected foreign-key "category_posts" returned %v`, n.ID) + return fmt.Errorf(`unexpected "categories" node returned %v`, n.ID) } - for i := range nodes { - assign(nodes[i], n) + for kn := range nodes { + assign(kn, n) } } return nil diff --git a/backend/ent/post_update.go b/backend/ent/post_update.go index d936a6a..7eb44b3 100644 --- a/backend/ent/post_update.go +++ b/backend/ent/post_update.go @@ -109,23 +109,19 @@ func (pu *PostUpdate) AddContributors(p ...*PostContributor) *PostUpdate { return pu.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (pu *PostUpdate) SetCategoryID(id int) *PostUpdate { - pu.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (pu *PostUpdate) AddCategoryIDs(ids ...int) *PostUpdate { + pu.mutation.AddCategoryIDs(ids...) return pu } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (pu *PostUpdate) SetNillableCategoryID(id *int) *PostUpdate { - if id != nil { - pu = pu.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (pu *PostUpdate) AddCategories(c ...*Category) *PostUpdate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return pu -} - -// SetCategory sets the "category" edge to the Category entity. -func (pu *PostUpdate) SetCategory(c *Category) *PostUpdate { - return pu.SetCategoryID(c.ID) + return pu.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -175,12 +171,27 @@ func (pu *PostUpdate) RemoveContributors(p ...*PostContributor) *PostUpdate { return pu.RemoveContributorIDs(ids...) } -// ClearCategory clears the "category" edge to the Category entity. -func (pu *PostUpdate) ClearCategory() *PostUpdate { - pu.mutation.ClearCategory() +// ClearCategories clears all "categories" edges to the Category entity. +func (pu *PostUpdate) ClearCategories() *PostUpdate { + pu.mutation.ClearCategories() return pu } +// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs. +func (pu *PostUpdate) RemoveCategoryIDs(ids ...int) *PostUpdate { + pu.mutation.RemoveCategoryIDs(ids...) + return pu +} + +// RemoveCategories removes "categories" edges to Category entities. +func (pu *PostUpdate) RemoveCategories(c ...*Category) *PostUpdate { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID + } + return pu.RemoveCategoryIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (pu *PostUpdate) Save(ctx context.Context) (int, error) { pu.defaults() @@ -346,12 +357,12 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } - if pu.mutation.CategoryCleared() { + if pu.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -359,12 +370,28 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } - if nodes := pu.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := pu.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !pu.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := pu.mutation.CategoriesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -473,23 +500,19 @@ func (puo *PostUpdateOne) AddContributors(p ...*PostContributor) *PostUpdateOne return puo.AddContributorIDs(ids...) } -// SetCategoryID sets the "category" edge to the Category entity by ID. -func (puo *PostUpdateOne) SetCategoryID(id int) *PostUpdateOne { - puo.mutation.SetCategoryID(id) +// AddCategoryIDs adds the "categories" edge to the Category entity by IDs. +func (puo *PostUpdateOne) AddCategoryIDs(ids ...int) *PostUpdateOne { + puo.mutation.AddCategoryIDs(ids...) return puo } -// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil. -func (puo *PostUpdateOne) SetNillableCategoryID(id *int) *PostUpdateOne { - if id != nil { - puo = puo.SetCategoryID(*id) +// AddCategories adds the "categories" edges to the Category entity. +func (puo *PostUpdateOne) AddCategories(c ...*Category) *PostUpdateOne { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID } - return puo -} - -// SetCategory sets the "category" edge to the Category entity. -func (puo *PostUpdateOne) SetCategory(c *Category) *PostUpdateOne { - return puo.SetCategoryID(c.ID) + return puo.AddCategoryIDs(ids...) } // Mutation returns the PostMutation object of the builder. @@ -539,12 +562,27 @@ func (puo *PostUpdateOne) RemoveContributors(p ...*PostContributor) *PostUpdateO return puo.RemoveContributorIDs(ids...) } -// ClearCategory clears the "category" edge to the Category entity. -func (puo *PostUpdateOne) ClearCategory() *PostUpdateOne { - puo.mutation.ClearCategory() +// ClearCategories clears all "categories" edges to the Category entity. +func (puo *PostUpdateOne) ClearCategories() *PostUpdateOne { + puo.mutation.ClearCategories() return puo } +// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs. +func (puo *PostUpdateOne) RemoveCategoryIDs(ids ...int) *PostUpdateOne { + puo.mutation.RemoveCategoryIDs(ids...) + return puo +} + +// RemoveCategories removes "categories" edges to Category entities. +func (puo *PostUpdateOne) RemoveCategories(c ...*Category) *PostUpdateOne { + ids := make([]int, len(c)) + for i := range c { + ids[i] = c[i].ID + } + return puo.RemoveCategoryIDs(ids...) +} + // Where appends a list predicates to the PostUpdate builder. func (puo *PostUpdateOne) Where(ps ...predicate.Post) *PostUpdateOne { puo.mutation.Where(ps...) @@ -740,12 +778,12 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } - if puo.mutation.CategoryCleared() { + if puo.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), @@ -753,12 +791,28 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error) } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } - if nodes := puo.mutation.CategoryIDs(); len(nodes) > 0 { + if nodes := puo.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !puo.mutation.CategoriesCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.M2M, Inverse: true, - Table: post.CategoryTable, - Columns: []string{post.CategoryColumn}, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := puo.mutation.CategoriesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: post.CategoriesTable, + Columns: post.CategoriesPrimaryKey, Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt), diff --git a/backend/ent/schema/post.go b/backend/ent/schema/post.go index 05362e4..ff97e93 100644 --- a/backend/ent/schema/post.go +++ b/backend/ent/schema/post.go @@ -34,8 +34,7 @@ func (Post) Edges() []ent.Edge { return []ent.Edge{ edge.To("contents", PostContent.Type), edge.To("contributors", PostContributor.Type), - edge.From("category", Category.Type). - Ref("posts"). - Unique(), + edge.From("categories", Category.Type). + Ref("posts"), } } diff --git a/backend/go.mod b/backend/go.mod index cc83baa..d42c65c 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,7 +3,6 @@ module tss-rocks-be go 1.23.6 require ( - bou.ke/monkey v1.0.2 entgo.io/ent v0.14.1 github.com/aws/aws-sdk-go-v2 v1.36.1 github.com/aws/aws-sdk-go-v2/config v1.29.6 @@ -68,7 +67,6 @@ require ( github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/zclconf/go-cty v1.16.2 // indirect diff --git a/backend/go.sum b/backend/go.sum index f3dba0a..4274445 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,7 +1,5 @@ ariga.io/atlas v0.31.0 h1:Nw6/Jdc7OpZfiy6oh/dJAYPp5XxGYvMTWLOUutwWjeY= ariga.io/atlas v0.31.0/go.mod h1:J3chwsQAgjDF6Ostz7JmJJRTCbtqIupUbVR/gqZrMiA= -bou.ke/monkey v1.0.2 h1:kWcnsrCNUatbxncxR/ThdYqbytgOIArtYWqcQLQzKLI= -bou.ke/monkey v1.0.2/go.mod h1:OqickVX3tNx6t33n1xvtTtu85YN5s6cKwVug+oHMaIA= entgo.io/ent v0.14.1 h1:fUERL506Pqr92EPHJqr8EYxbPioflJo6PudkrEA8a/s= entgo.io/ent v0.14.1/go.mod h1:MH6XLG0KXpkcDQhKiHfANZSzR55TJyPL5IGNpI8wpco= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= @@ -139,7 +137,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5afcaf0..f459f47 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -186,10 +186,12 @@ func (h *Handler) ListPosts(c *gin.Context) { langCode = "en" // Default to English } - var categoryID *int - if catIDStr := c.Query("category_id"); catIDStr != "" { - if id, err := strconv.Atoi(catIDStr); err == nil { - categoryID = &id + var categoryIDs []int + if catIDsStr := c.QueryArray("category_ids"); len(catIDsStr) > 0 { + for _, idStr := range catIDsStr { + if id, err := strconv.Atoi(idStr); err == nil { + categoryIDs = append(categoryIDs, id) + } } } @@ -207,7 +209,7 @@ func (h *Handler) ListPosts(c *gin.Context) { } } - posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset) + posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryIDs, limit, offset) if err != nil { log.Error().Err(err).Msg("Failed to list posts") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"}) @@ -256,7 +258,22 @@ func (h *Handler) GetPost(c *gin.Context) { } func (h *Handler) CreatePost(c *gin.Context) { - post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status + var req struct { + Status string `json:"status" binding:"omitempty,oneof=draft published archived"` + CategoryIDs []int `json:"category_ids" binding:"required,min=1"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + status := req.Status + if status == "" { + status = "draft" // Default to draft status + } + + post, err := h.service.CreatePost(c.Request.Context(), status, req.CategoryIDs) if err != nil { log.Error().Err(err).Msg("Failed to create post") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"}) @@ -268,7 +285,8 @@ func (h *Handler) CreatePost(c *gin.Context) { "id": post.ID, "status": post.Status, "edges": gin.H{ - "contents": []interface{}{}, + "contents": []interface{}{}, + "categories": []interface{}{}, }, } diff --git a/backend/internal/service/impl.go b/backend/internal/service/impl.go index 216e348..fed9495 100644 --- a/backend/internal/service/impl.go +++ b/backend/internal/service/impl.go @@ -476,7 +476,7 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error } // Post operations -func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) { +func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) { var postStatus post.Status switch status { case "draft": @@ -492,10 +492,25 @@ func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, // Generate a random slug slug := fmt.Sprintf("post-%s", uuid.New().String()[:8]) - return s.client.Post.Create(). + // Create post with categories + postCreate := s.client.Post.Create(). SetStatus(postStatus). - SetSlug(slug). - Save(ctx) + SetSlug(slug) + + // Add categories if provided + if len(categoryIDs) > 0 { + categories := make([]*ent.Category, 0, len(categoryIDs)) + for _, id := range categoryIDs { + category, err := s.client.Category.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get category %d: %w", id, err) + } + categories = append(categories, category) + } + postCreate.AddCategories(categories...) + } + + return postCreate.Save(ctx) } func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) { @@ -574,7 +589,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). - WithCategory(). + WithCategories(). All(ctx) if err != nil { return nil, fmt.Errorf("failed to get posts: %w", err) @@ -590,7 +605,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) return posts[0], nil } -func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) { +func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) { var languageCode postcontent.LanguageCode switch langCode { case "en": @@ -610,8 +625,8 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID Where(post.StatusEQ(post.StatusPublished)) // Add category filter if provided - if categoryID != nil { - query = query.Where(post.HasCategoryWith(category.ID(*categoryID))) + if len(categoryIDs) > 0 { + query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...))) } // Get unique post IDs @@ -638,7 +653,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID } // If no category filter is applied, only take the latest 5 posts - if categoryID == nil && len(postIDs) > 5 { + if len(categoryIDs) == 0 && len(postIDs) > 5 { postIDs = postIDs[:5] } @@ -666,7 +681,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID WithContents(func(q *ent.PostContentQuery) { q.Where(postcontent.LanguageCodeEQ(languageCode)) }). - WithCategory(). + WithCategories(). Order(ent.Desc(post.FieldCreatedAt)). All(ctx) if err != nil { diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index 5b88d4f..e5b6213 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -34,10 +34,10 @@ type Service interface { GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) // Post operations - CreatePost(ctx context.Context, status string) (*ent.Post, error) + CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) - ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) + ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) // Media operations ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)