Compare commits
9 commits
7a33038af8
...
3e6181e578
Author | SHA1 | Date | |
---|---|---|---|
3e6181e578 | |||
6e1be3d513 | |||
086c9761a9 | |||
e86d8c1576 | |||
be8bf22017 | |||
958e3c2886 | |||
1c9628124f | |||
3d19ef05b3 | |||
2c3e238e9a |
76 changed files with 4111 additions and 7531 deletions
|
@ -103,6 +103,7 @@ Post:
|
|||
- slug
|
||||
- status
|
||||
- contents
|
||||
- categories
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
|
@ -117,6 +118,11 @@ Post:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/PostContent'
|
||||
categories:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/Category'
|
||||
description: 文章所属分类
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
|
|
@ -14,9 +14,7 @@ RUN adduser -u 1000 -D tss-rocks
|
|||
USER tss-rocks
|
||||
WORKDIR /app
|
||||
|
||||
# 复制二进制文件和配置
|
||||
COPY --from=builder /app/tss-rocks-be .
|
||||
COPY --from=builder /app/config/config.yaml ./config/
|
||||
|
||||
EXPOSE 8080
|
||||
ENV GIN_MODE=release
|
||||
|
|
|
@ -21,10 +21,50 @@ auth:
|
|||
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
|
||||
|
||||
storage:
|
||||
driver: local
|
||||
type: local # local or s3
|
||||
local:
|
||||
root: storage
|
||||
base_url: http://localhost:8080/storage
|
||||
root_dir: "./storage/media"
|
||||
s3:
|
||||
region: "us-east-1"
|
||||
bucket: "your-bucket-name"
|
||||
access_key_id: "your-access-key-id"
|
||||
secret_access_key: "your-secret-access-key"
|
||||
endpoint: "" # Optional, for MinIO or other S3-compatible services
|
||||
custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
|
||||
proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
|
||||
upload:
|
||||
limits:
|
||||
image:
|
||||
max_size: 10 # MB
|
||||
allowed_types:
|
||||
- image/jpeg
|
||||
- image/png
|
||||
- image/gif
|
||||
- image/webp
|
||||
- image/svg+xml
|
||||
video:
|
||||
max_size: 500 # MB
|
||||
allowed_types:
|
||||
- video/mp4
|
||||
- video/webm
|
||||
audio:
|
||||
max_size: 50 # MB
|
||||
allowed_types:
|
||||
- audio/mpeg
|
||||
- audio/ogg
|
||||
- audio/wav
|
||||
document:
|
||||
max_size: 20 # MB
|
||||
allowed_types:
|
||||
- application/pdf
|
||||
- application/msword
|
||||
- application/vnd.openxmlformats-officedocument.wordprocessingml.document
|
||||
- application/vnd.ms-excel
|
||||
- application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
|
||||
- application/zip
|
||||
- application/x-rar-compressed
|
||||
- text/plain
|
||||
- text/csv
|
||||
|
||||
logging:
|
||||
level: debug
|
||||
|
|
|
@ -33,13 +33,11 @@ const (
|
|||
ContentsInverseTable = "category_contents"
|
||||
// ContentsColumn is the table column denoting the contents relation/edge.
|
||||
ContentsColumn = "category_contents"
|
||||
// PostsTable is the table that holds the posts relation/edge.
|
||||
PostsTable = "posts"
|
||||
// PostsTable is the table that holds the posts relation/edge. The primary key declared below.
|
||||
PostsTable = "category_posts"
|
||||
// PostsInverseTable is the table name for the Post entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "post" package.
|
||||
PostsInverseTable = "posts"
|
||||
// PostsColumn is the table column denoting the posts relation/edge.
|
||||
PostsColumn = "category_posts"
|
||||
// DailyItemsTable is the table that holds the daily_items relation/edge.
|
||||
DailyItemsTable = "dailies"
|
||||
// DailyItemsInverseTable is the table name for the Daily entity.
|
||||
|
@ -56,6 +54,12 @@ var Columns = []string{
|
|||
FieldUpdatedAt,
|
||||
}
|
||||
|
||||
var (
|
||||
// PostsPrimaryKey and PostsColumn2 are the table columns denoting the
|
||||
// primary key for the posts relation (M2M).
|
||||
PostsPrimaryKey = []string{"category_id", "post_id"}
|
||||
)
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
|
@ -145,7 +149,7 @@ func newPostsStep() *sqlgraph.Step {
|
|||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(PostsInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
|
||||
)
|
||||
}
|
||||
func newDailyItemsStep() *sqlgraph.Step {
|
||||
|
|
|
@ -173,7 +173,7 @@ func HasPosts() predicate.Category {
|
|||
return predicate.Category(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
|
|
|
@ -201,10 +201,10 @@ func (cc *CategoryCreate) createSpec() (*Category, *sqlgraph.CreateSpec) {
|
|||
}
|
||||
if nodes := cc.mutation.PostsIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
|
|
@ -101,7 +101,7 @@ func (cq *CategoryQuery) QueryPosts() *PostQuery {
|
|||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(category.Table, category.FieldID, selector),
|
||||
sqlgraph.To(post.Table, post.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(cq.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
|
@ -523,33 +523,63 @@ func (cq *CategoryQuery) loadContents(ctx context.Context, query *CategoryConten
|
|||
return nil
|
||||
}
|
||||
func (cq *CategoryQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*Category, init func(*Category), assign func(*Category, *Post)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int]*Category)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[int]*Category)
|
||||
nids := make(map[int]map[*Category]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
init(node)
|
||||
}
|
||||
}
|
||||
query.withFKs = true
|
||||
query.Where(predicate.Post(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(s.C(category.PostsColumn), fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(category.PostsTable)
|
||||
s.Join(joinT).On(s.C(post.FieldID), joinT.C(category.PostsPrimaryKey[1]))
|
||||
s.Where(sql.InValues(joinT.C(category.PostsPrimaryKey[0]), edgeIDs...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(category.PostsPrimaryKey[0]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
if err := query.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
|
||||
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]any{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []any) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*Category]struct{}{byID[outValue]: {}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byID[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
})
|
||||
neighbors, err := withInterceptors[[]*Post](ctx, query, qr, query.inters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.category_posts
|
||||
if fk == nil {
|
||||
return fmt.Errorf(`foreign-key "category_posts" is nil for node %v`, n.ID)
|
||||
}
|
||||
node, ok := nodeids[*fk]
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected referenced foreign-key "category_posts" returned %v for node %v`, *fk, n.ID)
|
||||
return fmt.Errorf(`unexpected "posts" node returned %v`, n.ID)
|
||||
}
|
||||
for kn := range nodes {
|
||||
assign(kn, n)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -262,10 +262,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
|||
}
|
||||
if cu.mutation.PostsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
@ -275,10 +275,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
|||
}
|
||||
if nodes := cu.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cu.mutation.PostsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
@ -291,10 +291,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
|||
}
|
||||
if nodes := cu.mutation.PostsIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
@ -631,10 +631,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
|
|||
}
|
||||
if cuo.mutation.PostsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
@ -644,10 +644,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
|
|||
}
|
||||
if nodes := cuo.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cuo.mutation.PostsCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
@ -660,10 +660,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
|
|||
}
|
||||
if nodes := cuo.mutation.PostsIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: false,
|
||||
Table: category.PostsTable,
|
||||
Columns: []string{category.PostsColumn},
|
||||
Columns: category.PostsPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
|
||||
|
|
|
@ -464,7 +464,7 @@ func (c *CategoryClient) QueryPosts(ca *Category) *PostQuery {
|
|||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(category.Table, category.FieldID, id),
|
||||
sqlgraph.To(post.Table, post.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(ca.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
|
@ -2207,15 +2207,15 @@ func (c *PostClient) QueryContributors(po *Post) *PostContributorQuery {
|
|||
return query
|
||||
}
|
||||
|
||||
// QueryCategory queries the category edge of a Post.
|
||||
func (c *PostClient) QueryCategory(po *Post) *CategoryQuery {
|
||||
// QueryCategories queries the categories edge of a Post.
|
||||
func (c *PostClient) QueryCategories(po *Post) *CategoryQuery {
|
||||
query := (&CategoryClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := po.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(post.Table, post.FieldID, id),
|
||||
sqlgraph.To(category.Table, category.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(po.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
|
|
|
@ -265,21 +265,12 @@ var (
|
|||
{Name: "slug", Type: field.TypeString, Unique: true},
|
||||
{Name: "created_at", Type: field.TypeTime},
|
||||
{Name: "updated_at", Type: field.TypeTime},
|
||||
{Name: "category_posts", Type: field.TypeInt, Nullable: true},
|
||||
}
|
||||
// PostsTable holds the schema information for the "posts" table.
|
||||
PostsTable = &schema.Table{
|
||||
Name: "posts",
|
||||
Columns: PostsColumns,
|
||||
PrimaryKey: []*schema.Column{PostsColumns[0]},
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "posts_categories_posts",
|
||||
Columns: []*schema.Column{PostsColumns[5]},
|
||||
RefColumns: []*schema.Column{CategoriesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
},
|
||||
}
|
||||
// PostContentsColumns holds the columns for the "post_contents" table.
|
||||
PostContentsColumns = []*schema.Column{
|
||||
|
@ -387,6 +378,31 @@ var (
|
|||
Columns: UsersColumns,
|
||||
PrimaryKey: []*schema.Column{UsersColumns[0]},
|
||||
}
|
||||
// CategoryPostsColumns holds the columns for the "category_posts" table.
|
||||
CategoryPostsColumns = []*schema.Column{
|
||||
{Name: "category_id", Type: field.TypeInt},
|
||||
{Name: "post_id", Type: field.TypeInt},
|
||||
}
|
||||
// CategoryPostsTable holds the schema information for the "category_posts" table.
|
||||
CategoryPostsTable = &schema.Table{
|
||||
Name: "category_posts",
|
||||
Columns: CategoryPostsColumns,
|
||||
PrimaryKey: []*schema.Column{CategoryPostsColumns[0], CategoryPostsColumns[1]},
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "category_posts_category_id",
|
||||
Columns: []*schema.Column{CategoryPostsColumns[0]},
|
||||
RefColumns: []*schema.Column{CategoriesColumns[0]},
|
||||
OnDelete: schema.Cascade,
|
||||
},
|
||||
{
|
||||
Symbol: "category_posts_post_id",
|
||||
Columns: []*schema.Column{CategoryPostsColumns[1]},
|
||||
RefColumns: []*schema.Column{PostsColumns[0]},
|
||||
OnDelete: schema.Cascade,
|
||||
},
|
||||
},
|
||||
}
|
||||
// RolePermissionsColumns holds the columns for the "role_permissions" table.
|
||||
RolePermissionsColumns = []*schema.Column{
|
||||
{Name: "role_id", Type: field.TypeInt},
|
||||
|
@ -455,6 +471,7 @@ var (
|
|||
PostContributorsTable,
|
||||
RolesTable,
|
||||
UsersTable,
|
||||
CategoryPostsTable,
|
||||
RolePermissionsTable,
|
||||
UserRolesTable,
|
||||
}
|
||||
|
@ -469,11 +486,12 @@ func init() {
|
|||
DailyCategoryContentsTable.ForeignKeys[0].RefTable = DailyCategoriesTable
|
||||
DailyContentsTable.ForeignKeys[0].RefTable = DailiesTable
|
||||
MediaTable.ForeignKeys[0].RefTable = UsersTable
|
||||
PostsTable.ForeignKeys[0].RefTable = CategoriesTable
|
||||
PostContentsTable.ForeignKeys[0].RefTable = PostsTable
|
||||
PostContributorsTable.ForeignKeys[0].RefTable = ContributorsTable
|
||||
PostContributorsTable.ForeignKeys[1].RefTable = ContributorRolesTable
|
||||
PostContributorsTable.ForeignKeys[2].RefTable = PostsTable
|
||||
CategoryPostsTable.ForeignKeys[0].RefTable = CategoriesTable
|
||||
CategoryPostsTable.ForeignKeys[1].RefTable = PostsTable
|
||||
RolePermissionsTable.ForeignKeys[0].RefTable = RolesTable
|
||||
RolePermissionsTable.ForeignKeys[1].RefTable = PermissionsTable
|
||||
UserRolesTable.ForeignKeys[0].RefTable = UsersTable
|
||||
|
|
|
@ -6578,8 +6578,9 @@ type PostMutation struct {
|
|||
contributors map[int]struct{}
|
||||
removedcontributors map[int]struct{}
|
||||
clearedcontributors bool
|
||||
category *int
|
||||
clearedcategory bool
|
||||
categories map[int]struct{}
|
||||
removedcategories map[int]struct{}
|
||||
clearedcategories bool
|
||||
done bool
|
||||
oldValue func(context.Context) (*Post, error)
|
||||
predicates []predicate.Post
|
||||
|
@ -6935,43 +6936,58 @@ func (m *PostMutation) ResetContributors() {
|
|||
m.removedcontributors = nil
|
||||
}
|
||||
|
||||
// SetCategoryID sets the "category" edge to the Category entity by id.
|
||||
func (m *PostMutation) SetCategoryID(id int) {
|
||||
m.category = &id
|
||||
// AddCategoryIDs adds the "categories" edge to the Category entity by ids.
|
||||
func (m *PostMutation) AddCategoryIDs(ids ...int) {
|
||||
if m.categories == nil {
|
||||
m.categories = make(map[int]struct{})
|
||||
}
|
||||
for i := range ids {
|
||||
m.categories[ids[i]] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCategory clears the "category" edge to the Category entity.
|
||||
func (m *PostMutation) ClearCategory() {
|
||||
m.clearedcategory = true
|
||||
// ClearCategories clears the "categories" edge to the Category entity.
|
||||
func (m *PostMutation) ClearCategories() {
|
||||
m.clearedcategories = true
|
||||
}
|
||||
|
||||
// CategoryCleared reports if the "category" edge to the Category entity was cleared.
|
||||
func (m *PostMutation) CategoryCleared() bool {
|
||||
return m.clearedcategory
|
||||
// CategoriesCleared reports if the "categories" edge to the Category entity was cleared.
|
||||
func (m *PostMutation) CategoriesCleared() bool {
|
||||
return m.clearedcategories
|
||||
}
|
||||
|
||||
// CategoryID returns the "category" edge ID in the mutation.
|
||||
func (m *PostMutation) CategoryID() (id int, exists bool) {
|
||||
if m.category != nil {
|
||||
return *m.category, true
|
||||
// RemoveCategoryIDs removes the "categories" edge to the Category entity by IDs.
|
||||
func (m *PostMutation) RemoveCategoryIDs(ids ...int) {
|
||||
if m.removedcategories == nil {
|
||||
m.removedcategories = make(map[int]struct{})
|
||||
}
|
||||
for i := range ids {
|
||||
delete(m.categories, ids[i])
|
||||
m.removedcategories[ids[i]] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovedCategories returns the removed IDs of the "categories" edge to the Category entity.
|
||||
func (m *PostMutation) RemovedCategoriesIDs() (ids []int) {
|
||||
for id := range m.removedcategories {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CategoryIDs returns the "category" edge IDs in the mutation.
|
||||
// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
|
||||
// CategoryID instead. It exists only for internal usage by the builders.
|
||||
func (m *PostMutation) CategoryIDs() (ids []int) {
|
||||
if id := m.category; id != nil {
|
||||
ids = append(ids, *id)
|
||||
// CategoriesIDs returns the "categories" edge IDs in the mutation.
|
||||
func (m *PostMutation) CategoriesIDs() (ids []int) {
|
||||
for id := range m.categories {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ResetCategory resets all changes to the "category" edge.
|
||||
func (m *PostMutation) ResetCategory() {
|
||||
m.category = nil
|
||||
m.clearedcategory = false
|
||||
// ResetCategories resets all changes to the "categories" edge.
|
||||
func (m *PostMutation) ResetCategories() {
|
||||
m.categories = nil
|
||||
m.clearedcategories = false
|
||||
m.removedcategories = nil
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PostMutation builder.
|
||||
|
@ -7165,8 +7181,8 @@ func (m *PostMutation) AddedEdges() []string {
|
|||
if m.contributors != nil {
|
||||
edges = append(edges, post.EdgeContributors)
|
||||
}
|
||||
if m.category != nil {
|
||||
edges = append(edges, post.EdgeCategory)
|
||||
if m.categories != nil {
|
||||
edges = append(edges, post.EdgeCategories)
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
@ -7187,10 +7203,12 @@ func (m *PostMutation) AddedIDs(name string) []ent.Value {
|
|||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
case post.EdgeCategory:
|
||||
if id := m.category; id != nil {
|
||||
return []ent.Value{*id}
|
||||
case post.EdgeCategories:
|
||||
ids := make([]ent.Value, 0, len(m.categories))
|
||||
for id := range m.categories {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -7204,6 +7222,9 @@ func (m *PostMutation) RemovedEdges() []string {
|
|||
if m.removedcontributors != nil {
|
||||
edges = append(edges, post.EdgeContributors)
|
||||
}
|
||||
if m.removedcategories != nil {
|
||||
edges = append(edges, post.EdgeCategories)
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
||||
|
@ -7223,6 +7244,12 @@ func (m *PostMutation) RemovedIDs(name string) []ent.Value {
|
|||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
case post.EdgeCategories:
|
||||
ids := make([]ent.Value, 0, len(m.removedcategories))
|
||||
for id := range m.removedcategories {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -7236,8 +7263,8 @@ func (m *PostMutation) ClearedEdges() []string {
|
|||
if m.clearedcontributors {
|
||||
edges = append(edges, post.EdgeContributors)
|
||||
}
|
||||
if m.clearedcategory {
|
||||
edges = append(edges, post.EdgeCategory)
|
||||
if m.clearedcategories {
|
||||
edges = append(edges, post.EdgeCategories)
|
||||
}
|
||||
return edges
|
||||
}
|
||||
|
@ -7250,8 +7277,8 @@ func (m *PostMutation) EdgeCleared(name string) bool {
|
|||
return m.clearedcontents
|
||||
case post.EdgeContributors:
|
||||
return m.clearedcontributors
|
||||
case post.EdgeCategory:
|
||||
return m.clearedcategory
|
||||
case post.EdgeCategories:
|
||||
return m.clearedcategories
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -7260,9 +7287,6 @@ func (m *PostMutation) EdgeCleared(name string) bool {
|
|||
// if that edge is not defined in the schema.
|
||||
func (m *PostMutation) ClearEdge(name string) error {
|
||||
switch name {
|
||||
case post.EdgeCategory:
|
||||
m.ClearCategory()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Post unique edge %s", name)
|
||||
}
|
||||
|
@ -7277,8 +7301,8 @@ func (m *PostMutation) ResetEdge(name string) error {
|
|||
case post.EdgeContributors:
|
||||
m.ResetContributors()
|
||||
return nil
|
||||
case post.EdgeCategory:
|
||||
m.ResetCategory()
|
||||
case post.EdgeCategories:
|
||||
m.ResetCategories()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Post edge %s", name)
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"tss-rocks-be/ent/category"
|
||||
"tss-rocks-be/ent/post"
|
||||
|
||||
"entgo.io/ent"
|
||||
|
@ -28,9 +27,8 @@ type Post struct {
|
|||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the PostQuery when eager-loading is set.
|
||||
Edges PostEdges `json:"edges"`
|
||||
category_posts *int
|
||||
selectValues sql.SelectValues
|
||||
Edges PostEdges `json:"edges"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// PostEdges holds the relations/edges for other nodes in the graph.
|
||||
|
@ -39,8 +37,8 @@ type PostEdges struct {
|
|||
Contents []*PostContent `json:"contents,omitempty"`
|
||||
// Contributors holds the value of the contributors edge.
|
||||
Contributors []*PostContributor `json:"contributors,omitempty"`
|
||||
// Category holds the value of the category edge.
|
||||
Category *Category `json:"category,omitempty"`
|
||||
// Categories holds the value of the categories edge.
|
||||
Categories []*Category `json:"categories,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [3]bool
|
||||
|
@ -64,15 +62,13 @@ func (e PostEdges) ContributorsOrErr() ([]*PostContributor, error) {
|
|||
return nil, &NotLoadedError{edge: "contributors"}
|
||||
}
|
||||
|
||||
// CategoryOrErr returns the Category value or an error if the edge
|
||||
// was not loaded in eager-loading, or loaded but was not found.
|
||||
func (e PostEdges) CategoryOrErr() (*Category, error) {
|
||||
if e.Category != nil {
|
||||
return e.Category, nil
|
||||
} else if e.loadedTypes[2] {
|
||||
return nil, &NotFoundError{label: category.Label}
|
||||
// CategoriesOrErr returns the Categories value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e PostEdges) CategoriesOrErr() ([]*Category, error) {
|
||||
if e.loadedTypes[2] {
|
||||
return e.Categories, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "category"}
|
||||
return nil, &NotLoadedError{edge: "categories"}
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
|
@ -86,8 +82,6 @@ func (*Post) scanValues(columns []string) ([]any, error) {
|
|||
values[i] = new(sql.NullString)
|
||||
case post.FieldCreatedAt, post.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
case post.ForeignKeys[0]: // category_posts
|
||||
values[i] = new(sql.NullInt64)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
|
@ -133,13 +127,6 @@ func (po *Post) assignValues(columns []string, values []any) error {
|
|||
} else if value.Valid {
|
||||
po.UpdatedAt = value.Time
|
||||
}
|
||||
case post.ForeignKeys[0]:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for edge-field category_posts", value)
|
||||
} else if value.Valid {
|
||||
po.category_posts = new(int)
|
||||
*po.category_posts = int(value.Int64)
|
||||
}
|
||||
default:
|
||||
po.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
|
@ -163,9 +150,9 @@ func (po *Post) QueryContributors() *PostContributorQuery {
|
|||
return NewPostClient(po.config).QueryContributors(po)
|
||||
}
|
||||
|
||||
// QueryCategory queries the "category" edge of the Post entity.
|
||||
func (po *Post) QueryCategory() *CategoryQuery {
|
||||
return NewPostClient(po.config).QueryCategory(po)
|
||||
// QueryCategories queries the "categories" edge of the Post entity.
|
||||
func (po *Post) QueryCategories() *CategoryQuery {
|
||||
return NewPostClient(po.config).QueryCategories(po)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this Post.
|
||||
|
|
|
@ -27,8 +27,8 @@ const (
|
|||
EdgeContents = "contents"
|
||||
// EdgeContributors holds the string denoting the contributors edge name in mutations.
|
||||
EdgeContributors = "contributors"
|
||||
// EdgeCategory holds the string denoting the category edge name in mutations.
|
||||
EdgeCategory = "category"
|
||||
// EdgeCategories holds the string denoting the categories edge name in mutations.
|
||||
EdgeCategories = "categories"
|
||||
// Table holds the table name of the post in the database.
|
||||
Table = "posts"
|
||||
// ContentsTable is the table that holds the contents relation/edge.
|
||||
|
@ -45,13 +45,11 @@ const (
|
|||
ContributorsInverseTable = "post_contributors"
|
||||
// ContributorsColumn is the table column denoting the contributors relation/edge.
|
||||
ContributorsColumn = "post_contributors"
|
||||
// CategoryTable is the table that holds the category relation/edge.
|
||||
CategoryTable = "posts"
|
||||
// CategoryInverseTable is the table name for the Category entity.
|
||||
// CategoriesTable is the table that holds the categories relation/edge. The primary key declared below.
|
||||
CategoriesTable = "category_posts"
|
||||
// CategoriesInverseTable is the table name for the Category entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "category" package.
|
||||
CategoryInverseTable = "categories"
|
||||
// CategoryColumn is the table column denoting the category relation/edge.
|
||||
CategoryColumn = "category_posts"
|
||||
CategoriesInverseTable = "categories"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for post fields.
|
||||
|
@ -63,11 +61,11 @@ var Columns = []string{
|
|||
FieldUpdatedAt,
|
||||
}
|
||||
|
||||
// ForeignKeys holds the SQL foreign-keys that are owned by the "posts"
|
||||
// table and are not defined as standalone fields in the schema.
|
||||
var ForeignKeys = []string{
|
||||
"category_posts",
|
||||
}
|
||||
var (
|
||||
// CategoriesPrimaryKey and CategoriesColumn2 are the table columns denoting the
|
||||
// primary key for the categories relation (M2M).
|
||||
CategoriesPrimaryKey = []string{"category_id", "post_id"}
|
||||
)
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
|
@ -76,11 +74,6 @@ func ValidColumn(column string) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
for i := range ForeignKeys {
|
||||
if column == ForeignKeys[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -178,10 +171,17 @@ func ByContributors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
|||
}
|
||||
}
|
||||
|
||||
// ByCategoryField orders the results by category field.
|
||||
func ByCategoryField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
// ByCategoriesCount orders the results by categories count.
|
||||
func ByCategoriesCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newCategoryStep(), sql.OrderByField(field, opts...))
|
||||
sqlgraph.OrderByNeighborsCount(s, newCategoriesStep(), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByCategories orders the results by categories terms.
|
||||
func ByCategories(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newCategoriesStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||
}
|
||||
}
|
||||
func newContentsStep() *sqlgraph.Step {
|
||||
|
@ -198,10 +198,10 @@ func newContributorsStep() *sqlgraph.Step {
|
|||
sqlgraph.Edge(sqlgraph.O2M, false, ContributorsTable, ContributorsColumn),
|
||||
)
|
||||
}
|
||||
func newCategoryStep() *sqlgraph.Step {
|
||||
func newCategoriesStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(CategoryInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
|
||||
sqlgraph.To(CategoriesInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -281,21 +281,21 @@ func HasContributorsWith(preds ...predicate.PostContributor) predicate.Post {
|
|||
})
|
||||
}
|
||||
|
||||
// HasCategory applies the HasEdge predicate on the "category" edge.
|
||||
func HasCategory() predicate.Post {
|
||||
// HasCategories applies the HasEdge predicate on the "categories" edge.
|
||||
func HasCategories() predicate.Post {
|
||||
return predicate.Post(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasCategoryWith applies the HasEdge predicate on the "category" edge with a given conditions (other predicates).
|
||||
func HasCategoryWith(preds ...predicate.Category) predicate.Post {
|
||||
// HasCategoriesWith applies the HasEdge predicate on the "categories" edge with a given conditions (other predicates).
|
||||
func HasCategoriesWith(preds ...predicate.Category) predicate.Post {
|
||||
return predicate.Post(func(s *sql.Selector) {
|
||||
step := newCategoryStep()
|
||||
step := newCategoriesStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
|
|
|
@ -101,23 +101,19 @@ func (pc *PostCreate) AddContributors(p ...*PostContributor) *PostCreate {
|
|||
return pc.AddContributorIDs(ids...)
|
||||
}
|
||||
|
||||
// SetCategoryID sets the "category" edge to the Category entity by ID.
|
||||
func (pc *PostCreate) SetCategoryID(id int) *PostCreate {
|
||||
pc.mutation.SetCategoryID(id)
|
||||
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
|
||||
func (pc *PostCreate) AddCategoryIDs(ids ...int) *PostCreate {
|
||||
pc.mutation.AddCategoryIDs(ids...)
|
||||
return pc
|
||||
}
|
||||
|
||||
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
|
||||
func (pc *PostCreate) SetNillableCategoryID(id *int) *PostCreate {
|
||||
if id != nil {
|
||||
pc = pc.SetCategoryID(*id)
|
||||
// AddCategories adds the "categories" edges to the Category entity.
|
||||
func (pc *PostCreate) AddCategories(c ...*Category) *PostCreate {
|
||||
ids := make([]int, len(c))
|
||||
for i := range c {
|
||||
ids[i] = c[i].ID
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// SetCategory sets the "category" edge to the Category entity.
|
||||
func (pc *PostCreate) SetCategory(c *Category) *PostCreate {
|
||||
return pc.SetCategoryID(c.ID)
|
||||
return pc.AddCategoryIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the PostMutation object of the builder.
|
||||
|
@ -267,12 +263,12 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
|
|||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
if nodes := pc.mutation.CategoryIDs(); len(nodes) > 0 {
|
||||
if nodes := pc.mutation.CategoriesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoryTable,
|
||||
Columns: []string{post.CategoryColumn},
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
|
@ -281,7 +277,6 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
|
|||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_node.category_posts = &nodes[0]
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
return _node, _spec
|
||||
|
|
|
@ -28,8 +28,7 @@ type PostQuery struct {
|
|||
predicates []predicate.Post
|
||||
withContents *PostContentQuery
|
||||
withContributors *PostContributorQuery
|
||||
withCategory *CategoryQuery
|
||||
withFKs bool
|
||||
withCategories *CategoryQuery
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -110,8 +109,8 @@ func (pq *PostQuery) QueryContributors() *PostContributorQuery {
|
|||
return query
|
||||
}
|
||||
|
||||
// QueryCategory chains the current query on the "category" edge.
|
||||
func (pq *PostQuery) QueryCategory() *CategoryQuery {
|
||||
// QueryCategories chains the current query on the "categories" edge.
|
||||
func (pq *PostQuery) QueryCategories() *CategoryQuery {
|
||||
query := (&CategoryClient{config: pq.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := pq.prepareQuery(ctx); err != nil {
|
||||
|
@ -124,7 +123,7 @@ func (pq *PostQuery) QueryCategory() *CategoryQuery {
|
|||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(post.Table, post.FieldID, selector),
|
||||
sqlgraph.To(category.Table, category.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
|
||||
sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
|
@ -326,7 +325,7 @@ func (pq *PostQuery) Clone() *PostQuery {
|
|||
predicates: append([]predicate.Post{}, pq.predicates...),
|
||||
withContents: pq.withContents.Clone(),
|
||||
withContributors: pq.withContributors.Clone(),
|
||||
withCategory: pq.withCategory.Clone(),
|
||||
withCategories: pq.withCategories.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: pq.sql.Clone(),
|
||||
path: pq.path,
|
||||
|
@ -355,14 +354,14 @@ func (pq *PostQuery) WithContributors(opts ...func(*PostContributorQuery)) *Post
|
|||
return pq
|
||||
}
|
||||
|
||||
// WithCategory tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "category" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (pq *PostQuery) WithCategory(opts ...func(*CategoryQuery)) *PostQuery {
|
||||
// WithCategories tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "categories" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (pq *PostQuery) WithCategories(opts ...func(*CategoryQuery)) *PostQuery {
|
||||
query := (&CategoryClient{config: pq.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
pq.withCategory = query
|
||||
pq.withCategories = query
|
||||
return pq
|
||||
}
|
||||
|
||||
|
@ -443,20 +442,13 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error {
|
|||
func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) {
|
||||
var (
|
||||
nodes = []*Post{}
|
||||
withFKs = pq.withFKs
|
||||
_spec = pq.querySpec()
|
||||
loadedTypes = [3]bool{
|
||||
pq.withContents != nil,
|
||||
pq.withContributors != nil,
|
||||
pq.withCategory != nil,
|
||||
pq.withCategories != nil,
|
||||
}
|
||||
)
|
||||
if pq.withCategory != nil {
|
||||
withFKs = true
|
||||
}
|
||||
if withFKs {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, post.ForeignKeys...)
|
||||
}
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*Post).scanValues(nil, columns)
|
||||
}
|
||||
|
@ -489,9 +481,10 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := pq.withCategory; query != nil {
|
||||
if err := pq.loadCategory(ctx, query, nodes, nil,
|
||||
func(n *Post, e *Category) { n.Edges.Category = e }); err != nil {
|
||||
if query := pq.withCategories; query != nil {
|
||||
if err := pq.loadCategories(ctx, query, nodes,
|
||||
func(n *Post) { n.Edges.Categories = []*Category{} },
|
||||
func(n *Post, e *Category) { n.Edges.Categories = append(n.Edges.Categories, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -560,34 +553,63 @@ func (pq *PostQuery) loadContributors(ctx context.Context, query *PostContributo
|
|||
}
|
||||
return nil
|
||||
}
|
||||
func (pq *PostQuery) loadCategory(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
|
||||
ids := make([]int, 0, len(nodes))
|
||||
nodeids := make(map[int][]*Post)
|
||||
for i := range nodes {
|
||||
if nodes[i].category_posts == nil {
|
||||
continue
|
||||
func (pq *PostQuery) loadCategories(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
|
||||
edgeIDs := make([]driver.Value, len(nodes))
|
||||
byID := make(map[int]*Post)
|
||||
nids := make(map[int]map[*Post]struct{})
|
||||
for i, node := range nodes {
|
||||
edgeIDs[i] = node.ID
|
||||
byID[node.ID] = node
|
||||
if init != nil {
|
||||
init(node)
|
||||
}
|
||||
fk := *nodes[i].category_posts
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
query.Where(func(s *sql.Selector) {
|
||||
joinT := sql.Table(post.CategoriesTable)
|
||||
s.Join(joinT).On(s.C(category.FieldID), joinT.C(post.CategoriesPrimaryKey[0]))
|
||||
s.Where(sql.InValues(joinT.C(post.CategoriesPrimaryKey[1]), edgeIDs...))
|
||||
columns := s.SelectedColumns()
|
||||
s.Select(joinT.C(post.CategoriesPrimaryKey[1]))
|
||||
s.AppendSelect(columns...)
|
||||
s.SetDistinct(false)
|
||||
})
|
||||
if err := query.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
query.Where(category.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
|
||||
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
|
||||
assign := spec.Assign
|
||||
values := spec.ScanValues
|
||||
spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
values, err := values(columns[1:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append([]any{new(sql.NullInt64)}, values...), nil
|
||||
}
|
||||
spec.Assign = func(columns []string, values []any) error {
|
||||
outValue := int(values[0].(*sql.NullInt64).Int64)
|
||||
inValue := int(values[1].(*sql.NullInt64).Int64)
|
||||
if nids[inValue] == nil {
|
||||
nids[inValue] = map[*Post]struct{}{byID[outValue]: {}}
|
||||
return assign(columns[1:], values[1:])
|
||||
}
|
||||
nids[inValue][byID[outValue]] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
})
|
||||
})
|
||||
neighbors, err := withInterceptors[[]*Category](ctx, query, qr, query.inters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
nodes, ok := nids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "category_posts" returned %v`, n.ID)
|
||||
return fmt.Errorf(`unexpected "categories" node returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
for kn := range nodes {
|
||||
assign(kn, n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -109,23 +109,19 @@ func (pu *PostUpdate) AddContributors(p ...*PostContributor) *PostUpdate {
|
|||
return pu.AddContributorIDs(ids...)
|
||||
}
|
||||
|
||||
// SetCategoryID sets the "category" edge to the Category entity by ID.
|
||||
func (pu *PostUpdate) SetCategoryID(id int) *PostUpdate {
|
||||
pu.mutation.SetCategoryID(id)
|
||||
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
|
||||
func (pu *PostUpdate) AddCategoryIDs(ids ...int) *PostUpdate {
|
||||
pu.mutation.AddCategoryIDs(ids...)
|
||||
return pu
|
||||
}
|
||||
|
||||
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
|
||||
func (pu *PostUpdate) SetNillableCategoryID(id *int) *PostUpdate {
|
||||
if id != nil {
|
||||
pu = pu.SetCategoryID(*id)
|
||||
// AddCategories adds the "categories" edges to the Category entity.
|
||||
func (pu *PostUpdate) AddCategories(c ...*Category) *PostUpdate {
|
||||
ids := make([]int, len(c))
|
||||
for i := range c {
|
||||
ids[i] = c[i].ID
|
||||
}
|
||||
return pu
|
||||
}
|
||||
|
||||
// SetCategory sets the "category" edge to the Category entity.
|
||||
func (pu *PostUpdate) SetCategory(c *Category) *PostUpdate {
|
||||
return pu.SetCategoryID(c.ID)
|
||||
return pu.AddCategoryIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the PostMutation object of the builder.
|
||||
|
@ -175,12 +171,27 @@ func (pu *PostUpdate) RemoveContributors(p ...*PostContributor) *PostUpdate {
|
|||
return pu.RemoveContributorIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearCategory clears the "category" edge to the Category entity.
|
||||
func (pu *PostUpdate) ClearCategory() *PostUpdate {
|
||||
pu.mutation.ClearCategory()
|
||||
// ClearCategories clears all "categories" edges to the Category entity.
|
||||
func (pu *PostUpdate) ClearCategories() *PostUpdate {
|
||||
pu.mutation.ClearCategories()
|
||||
return pu
|
||||
}
|
||||
|
||||
// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
|
||||
func (pu *PostUpdate) RemoveCategoryIDs(ids ...int) *PostUpdate {
|
||||
pu.mutation.RemoveCategoryIDs(ids...)
|
||||
return pu
|
||||
}
|
||||
|
||||
// RemoveCategories removes "categories" edges to Category entities.
|
||||
func (pu *PostUpdate) RemoveCategories(c ...*Category) *PostUpdate {
|
||||
ids := make([]int, len(c))
|
||||
for i := range c {
|
||||
ids[i] = c[i].ID
|
||||
}
|
||||
return pu.RemoveCategoryIDs(ids...)
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (pu *PostUpdate) Save(ctx context.Context) (int, error) {
|
||||
pu.defaults()
|
||||
|
@ -346,12 +357,12 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
|||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if pu.mutation.CategoryCleared() {
|
||||
if pu.mutation.CategoriesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoryTable,
|
||||
Columns: []string{post.CategoryColumn},
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
|
@ -359,12 +370,28 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
|||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := pu.mutation.CategoryIDs(); len(nodes) > 0 {
|
||||
if nodes := pu.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !pu.mutation.CategoriesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoryTable,
|
||||
Columns: []string{post.CategoryColumn},
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := pu.mutation.CategoriesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
|
@ -473,23 +500,19 @@ func (puo *PostUpdateOne) AddContributors(p ...*PostContributor) *PostUpdateOne
|
|||
return puo.AddContributorIDs(ids...)
|
||||
}
|
||||
|
||||
// SetCategoryID sets the "category" edge to the Category entity by ID.
|
||||
func (puo *PostUpdateOne) SetCategoryID(id int) *PostUpdateOne {
|
||||
puo.mutation.SetCategoryID(id)
|
||||
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
|
||||
func (puo *PostUpdateOne) AddCategoryIDs(ids ...int) *PostUpdateOne {
|
||||
puo.mutation.AddCategoryIDs(ids...)
|
||||
return puo
|
||||
}
|
||||
|
||||
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
|
||||
func (puo *PostUpdateOne) SetNillableCategoryID(id *int) *PostUpdateOne {
|
||||
if id != nil {
|
||||
puo = puo.SetCategoryID(*id)
|
||||
// AddCategories adds the "categories" edges to the Category entity.
|
||||
func (puo *PostUpdateOne) AddCategories(c ...*Category) *PostUpdateOne {
|
||||
ids := make([]int, len(c))
|
||||
for i := range c {
|
||||
ids[i] = c[i].ID
|
||||
}
|
||||
return puo
|
||||
}
|
||||
|
||||
// SetCategory sets the "category" edge to the Category entity.
|
||||
func (puo *PostUpdateOne) SetCategory(c *Category) *PostUpdateOne {
|
||||
return puo.SetCategoryID(c.ID)
|
||||
return puo.AddCategoryIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the PostMutation object of the builder.
|
||||
|
@ -539,12 +562,27 @@ func (puo *PostUpdateOne) RemoveContributors(p ...*PostContributor) *PostUpdateO
|
|||
return puo.RemoveContributorIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearCategory clears the "category" edge to the Category entity.
|
||||
func (puo *PostUpdateOne) ClearCategory() *PostUpdateOne {
|
||||
puo.mutation.ClearCategory()
|
||||
// ClearCategories clears all "categories" edges to the Category entity.
|
||||
func (puo *PostUpdateOne) ClearCategories() *PostUpdateOne {
|
||||
puo.mutation.ClearCategories()
|
||||
return puo
|
||||
}
|
||||
|
||||
// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
|
||||
func (puo *PostUpdateOne) RemoveCategoryIDs(ids ...int) *PostUpdateOne {
|
||||
puo.mutation.RemoveCategoryIDs(ids...)
|
||||
return puo
|
||||
}
|
||||
|
||||
// RemoveCategories removes "categories" edges to Category entities.
|
||||
func (puo *PostUpdateOne) RemoveCategories(c ...*Category) *PostUpdateOne {
|
||||
ids := make([]int, len(c))
|
||||
for i := range c {
|
||||
ids[i] = c[i].ID
|
||||
}
|
||||
return puo.RemoveCategoryIDs(ids...)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the PostUpdate builder.
|
||||
func (puo *PostUpdateOne) Where(ps ...predicate.Post) *PostUpdateOne {
|
||||
puo.mutation.Where(ps...)
|
||||
|
@ -740,12 +778,12 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
|
|||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if puo.mutation.CategoryCleared() {
|
||||
if puo.mutation.CategoriesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoryTable,
|
||||
Columns: []string{post.CategoryColumn},
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
|
@ -753,12 +791,28 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
|
|||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := puo.mutation.CategoryIDs(); len(nodes) > 0 {
|
||||
if nodes := puo.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !puo.mutation.CategoriesCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoryTable,
|
||||
Columns: []string{post.CategoryColumn},
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := puo.mutation.CategoriesIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2M,
|
||||
Inverse: true,
|
||||
Table: post.CategoriesTable,
|
||||
Columns: post.CategoriesPrimaryKey,
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
|
||||
|
|
|
@ -16,11 +16,14 @@ type Media struct {
|
|||
func (Media) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.String("storage_id").
|
||||
StorageKey("storage_id").
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.String("original_name").
|
||||
StorageKey("original_name").
|
||||
NotEmpty(),
|
||||
field.String("mime_type").
|
||||
StorageKey("mime_type").
|
||||
NotEmpty(),
|
||||
field.Int64("size").
|
||||
Positive(),
|
||||
|
|
|
@ -34,8 +34,7 @@ func (Post) Edges() []ent.Edge {
|
|||
return []ent.Edge{
|
||||
edge.To("contents", PostContent.Type),
|
||||
edge.To("contributors", PostContributor.Type),
|
||||
edge.From("category", Category.Type).
|
||||
Ref("posts").
|
||||
Unique(),
|
||||
edge.From("categories", Category.Type).
|
||||
Ref("posts"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,13 +3,13 @@ module tss-rocks-be
|
|||
go 1.23.6
|
||||
|
||||
require (
|
||||
bou.ke/monkey v1.0.2
|
||||
entgo.io/ent v0.14.1
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.6
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.59
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.76.1
|
||||
github.com/chai2010/webp v1.1.1
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
|
@ -17,7 +17,6 @@ require (
|
|||
github.com/rs/zerolog v1.33.0
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.uber.org/mock v0.5.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/time v0.10.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
|
@ -68,12 +67,12 @@ require (
|
|||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/zclconf/go-cty v1.16.2 // indirect
|
||||
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
|
||||
golang.org/x/arch v0.14.0 // indirect
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
ariga.io/atlas v0.31.0 h1:Nw6/Jdc7OpZfiy6oh/dJAYPp5XxGYvMTWLOUutwWjeY=
|
||||
ariga.io/atlas v0.31.0/go.mod h1:J3chwsQAgjDF6Ostz7JmJJRTCbtqIupUbVR/gqZrMiA=
|
||||
bou.ke/monkey v1.0.2 h1:kWcnsrCNUatbxncxR/ThdYqbytgOIArtYWqcQLQzKLI=
|
||||
bou.ke/monkey v1.0.2/go.mod h1:OqickVX3tNx6t33n1xvtTtu85YN5s6cKwVug+oHMaIA=
|
||||
entgo.io/ent v0.14.1 h1:fUERL506Pqr92EPHJqr8EYxbPioflJo6PudkrEA8a/s=
|
||||
entgo.io/ent v0.14.1/go.mod h1:MH6XLG0KXpkcDQhKiHfANZSzR55TJyPL5IGNpI8wpco=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
|
@ -63,6 +61,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
|
||||
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
|
||||
|
@ -110,8 +110,6 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
|||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
|
||||
|
@ -121,8 +119,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
|||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
@ -139,7 +135,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
|||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
|
@ -159,12 +154,13 @@ github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6
|
|||
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
|
||||
github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0=
|
||||
github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4=
|
||||
golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
|
@ -176,10 +172,13 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserIDKey(t *testing.T) {
|
||||
// Test that the UserIDKey constant is defined correctly
|
||||
if UserIDKey != "user_id" {
|
||||
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
|
||||
}
|
||||
|
||||
// Test context with user ID
|
||||
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
|
||||
value := ctx.Value(UserIDKey)
|
||||
if value != "test-user-123" {
|
||||
t.Errorf("Context value = %v, want %v", value, "test-user-123")
|
||||
}
|
||||
|
||||
// Test context without user ID
|
||||
emptyCtx := context.Background()
|
||||
emptyValue := emptyCtx.Value(UserIDKey)
|
||||
if emptyValue != nil {
|
||||
t.Errorf("Empty context value = %v, want nil", emptyValue)
|
||||
}
|
||||
}
|
|
@ -49,7 +49,7 @@ type StorageConfig struct {
|
|||
Type string `yaml:"type"`
|
||||
Local LocalStorage `yaml:"local"`
|
||||
S3 S3Storage `yaml:"s3"`
|
||||
Upload types.UploadConfig `yaml:"upload"`
|
||||
Upload UploadConfig `yaml:"upload"`
|
||||
}
|
||||
|
||||
type LocalStorage struct {
|
||||
|
@ -66,6 +66,27 @@ type S3Storage struct {
|
|||
ProxyS3 bool `yaml:"proxy_s3"`
|
||||
}
|
||||
|
||||
type UploadConfig struct {
|
||||
Limits struct {
|
||||
Image struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
AllowedTypes []string `yaml:"allowed_types"`
|
||||
} `yaml:"image"`
|
||||
Video struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
AllowedTypes []string `yaml:"allowed_types"`
|
||||
} `yaml:"video"`
|
||||
Audio struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
AllowedTypes []string `yaml:"allowed_types"`
|
||||
} `yaml:"audio"`
|
||||
Document struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
AllowedTypes []string `yaml:"allowed_types"`
|
||||
} `yaml:"document"`
|
||||
} `yaml:"limits"`
|
||||
}
|
||||
|
||||
// Load loads configuration from a YAML file
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary test config file
|
||||
content := []byte(`
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost:5432/dbname
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
jwt:
|
||||
secret: test-secret
|
||||
expiration: 24h
|
||||
storage:
|
||||
type: local
|
||||
local:
|
||||
root_dir: /tmp/storage
|
||||
upload:
|
||||
max_size: 10485760
|
||||
logging:
|
||||
level: info
|
||||
format: json
|
||||
`)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Test loading config
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
errorMsg string
|
||||
}{
|
||||
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
|
||||
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
|
||||
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
|
||||
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
|
||||
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadError(t *testing.T) {
|
||||
// Test loading non-existent file
|
||||
_, err := Load("non-existent-file.yaml")
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for non-existent file")
|
||||
}
|
||||
|
||||
// Test loading invalid YAML
|
||||
tmpDir := t.TempDir()
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
|
||||
t.Fatalf("Failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(invalidPath)
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for invalid YAML")
|
||||
}
|
||||
}
|
|
@ -1,312 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AuthHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
Auth: config.AuthConfig{
|
||||
Registration: struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Message string `yaml:"message"`
|
||||
}{
|
||||
Enabled: true,
|
||||
Message: "Registration is disabled",
|
||||
},
|
||||
},
|
||||
}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestAuthHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthHandlerTestSuite))
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestRegister() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
request RegisterRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
registration bool
|
||||
}{
|
||||
{
|
||||
name: "成功注册",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "注册功能已禁用",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedError: "Registration is disabled",
|
||||
registration: false,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "密码太短",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "short",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "无效的角色",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "invalid-role",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
|
||||
registration: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置注册功能状态
|
||||
s.handler.cfg.Auth.Registration.Enabled = tc.registration
|
||||
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Register(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestLogin() {
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request LoginRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功登录",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的用户名",
|
||||
request: LoginRequest{
|
||||
Username: "te",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
request: LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "nonexistent").
|
||||
Return(nil, fmt.Errorf("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "密码错误",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "获取用户角色失败",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return(nil, fmt.Errorf("failed to get roles"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get user roles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Login(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,481 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Custom assertion function for comparing categories
|
||||
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
|
||||
if expected == nil && actual == nil {
|
||||
return true
|
||||
}
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Fail(t, "One category is nil while the other is not")
|
||||
}
|
||||
|
||||
// Compare only relevant fields, ignoring time fields
|
||||
return assert.Equal(t, expected.ID, actual.ID) &&
|
||||
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
|
||||
}
|
||||
|
||||
// Custom assertion function for comparing category slices
|
||||
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
|
||||
if len(expected) != len(actual) {
|
||||
return assert.Fail(t, "Category slice lengths do not match")
|
||||
}
|
||||
|
||||
for i := range expected {
|
||||
if !assertCategoryEqual(t, expected[i], actual[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type CategoryHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestCategoryHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(CategoryHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListCategories
|
||||
func (s *CategoryHandlerTestSuite) TestListCategories() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("en")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("zh")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetCategory
|
||||
func (s *CategoryHandlerTestSuite) TestGetCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
langCode: "en",
|
||||
slug: "test-category",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
|
||||
Return(&ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Not Found",
|
||||
langCode: "en",
|
||||
slug: "non-existent",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
|
||||
Return(nil, types.ErrNotFound)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddCategoryContent
|
||||
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
|
||||
var description = "Test Description"
|
||||
testCases := []struct {
|
||||
name string
|
||||
categoryID string
|
||||
requestBody interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(&ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
categoryID: "1",
|
||||
requestBody: "invalid json",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Invalid Category ID",
|
||||
categoryID: "invalid",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
var body []byte
|
||||
var err error
|
||||
if str, ok := tc.requestBody.(string); ok {
|
||||
body = []byte(str)
|
||||
} else {
|
||||
body, err = json.Marshal(tc.requestBody)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.CategoryContent
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreateCategory
|
||||
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功创建分类",
|
||||
setupMock: func() {
|
||||
category := &ent.Category{
|
||||
ID: 1,
|
||||
}
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(category, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建分类失败",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(nil, errors.New("failed to create category"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to create category",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.CreateCategory(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
s.Equal(1, response.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,456 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type ContributorHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestContributorHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(ContributorHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestListContributors() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return([]*ent.Contributor{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Jane Smith",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []gin.H{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list contributors"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestGetContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid ID",
|
||||
id: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: name,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"name": "", // Empty name is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(&ent.ContributorSocialLink{
|
||||
Type: "github",
|
||||
Name: name,
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"type": "github",
|
||||
"name": "johndoe",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid contributor ID",
|
||||
id: "invalid",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
id: "1",
|
||||
body: map[string]interface{}{
|
||||
"type": "", // Empty type is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add contributor social link"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,532 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DailyHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestDailyHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(DailyHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestListDailies() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
categoryID := 1
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list dailies"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/dailies"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestGetDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestCreateDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"id": "daily1",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dailyID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(&ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
dailyID: "daily1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add daily content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -25,8 +25,9 @@ func NewHandler(cfg *config.Config, service service.Service) *Handler {
|
|||
}
|
||||
|
||||
// RegisterRoutes registers all the routes
|
||||
func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
||||
api := r.Group("/api/v1")
|
||||
func (h *Handler) RegisterRoutes(router *gin.Engine) {
|
||||
// API routes
|
||||
api := router.Group("/api/v1")
|
||||
{
|
||||
// Auth routes
|
||||
auth := api.Group("/auth")
|
||||
|
@ -93,6 +94,9 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
|||
media.DELETE("/:id", h.DeleteMedia)
|
||||
}
|
||||
}
|
||||
|
||||
// Public media files
|
||||
router.GET("/media/:year/:month/:filename", h.GetMediaFile)
|
||||
}
|
||||
|
||||
// Category handlers
|
||||
|
@ -186,10 +190,12 @@ func (h *Handler) ListPosts(c *gin.Context) {
|
|||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
var categoryID *int
|
||||
if catIDStr := c.Query("category_id"); catIDStr != "" {
|
||||
if id, err := strconv.Atoi(catIDStr); err == nil {
|
||||
categoryID = &id
|
||||
var categoryIDs []int
|
||||
if catIDsStr := c.QueryArray("category_ids"); len(catIDsStr) > 0 {
|
||||
for _, idStr := range catIDsStr {
|
||||
if id, err := strconv.Atoi(idStr); err == nil {
|
||||
categoryIDs = append(categoryIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -207,7 +213,7 @@ func (h *Handler) ListPosts(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset)
|
||||
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryIDs, limit, offset)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list posts")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"})
|
||||
|
@ -244,10 +250,10 @@ func (h *Handler) GetPost(c *gin.Context) {
|
|||
contents := make([]gin.H, 0, len(post.Edges.Contents))
|
||||
for _, content := range post.Edges.Contents {
|
||||
contents = append(contents, gin.H{
|
||||
"language_code": content.LanguageCode,
|
||||
"title": content.Title,
|
||||
"language_code": content.LanguageCode,
|
||||
"title": content.Title,
|
||||
"content_markdown": content.ContentMarkdown,
|
||||
"summary": content.Summary,
|
||||
"summary": content.Summary,
|
||||
})
|
||||
}
|
||||
response["edges"].(gin.H)["contents"] = contents
|
||||
|
@ -256,7 +262,22 @@ func (h *Handler) GetPost(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (h *Handler) CreatePost(c *gin.Context) {
|
||||
post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status
|
||||
var req struct {
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft published archived"`
|
||||
CategoryIDs []int `json:"category_ids" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
status := req.Status
|
||||
if status == "" {
|
||||
status = "draft" // Default to draft status
|
||||
}
|
||||
|
||||
post, err := h.service.CreatePost(c.Request.Context(), status, req.CategoryIDs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create post")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"})
|
||||
|
@ -268,7 +289,8 @@ func (h *Handler) CreatePost(c *gin.Context) {
|
|||
"id": post.ID,
|
||||
"status": post.Status,
|
||||
"edges": gin.H{
|
||||
"contents": []interface{}{},
|
||||
"contents": []interface{}{},
|
||||
"categories": []interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -276,7 +298,7 @@ func (h *Handler) CreatePost(c *gin.Context) {
|
|||
}
|
||||
|
||||
type AddPostContentRequest struct {
|
||||
LanguageCode string `json:"language_code" binding:"required"`
|
||||
LanguageCode string `json:"language_code" binding:"required"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
ContentMarkdown string `json:"content_markdown" binding:"required"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
|
@ -308,10 +330,10 @@ func (h *Handler) AddPostContent(c *gin.Context) {
|
|||
"title": content.Title,
|
||||
"content_markdown": content.ContentMarkdown,
|
||||
"language_code": content.LanguageCode,
|
||||
"summary": content.Summary,
|
||||
"summary": content.Summary,
|
||||
"meta_keywords": content.MetaKeywords,
|
||||
"meta_description": content.MetaDescription,
|
||||
"edges": gin.H{},
|
||||
"edges": gin.H{},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -517,11 +539,3 @@ func (h *Handler) AddDailyContent(c *gin.Context) {
|
|||
|
||||
c.JSON(http.StatusCreated, content)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func stringPtr(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStringPtr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: strPtr(""),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "non-empty string",
|
||||
input: strPtr("test"),
|
||||
expected: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := stringPtr(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create string pointer
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
|
@ -5,9 +5,11 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Media handlers
|
||||
|
@ -33,17 +35,29 @@ func (h *Handler) ListMedia(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, media)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": media,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) UploadMedia(c *gin.Context) {
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert user ID to int
|
||||
userID, err := strconv.Atoi(userIDStr.(string))
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("user_id", fmt.Sprintf("%v", userIDStr)).
|
||||
Msg("Failed to convert user ID to int")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get file from form
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
|
@ -51,38 +65,107 @@ func (h *Handler) UploadMedia(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// 文件大小限制
|
||||
if file.Size > 10*1024*1024 { // 10MB
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"})
|
||||
// 获取文件类型和扩展名
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if contentType == "" {
|
||||
// 如果 Content-Type 为空,尝试从文件扩展名判断
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
contentType = "image/jpeg"
|
||||
case ".png":
|
||||
contentType = "image/png"
|
||||
case ".gif":
|
||||
contentType = "image/gif"
|
||||
case ".webp":
|
||||
contentType = "image/webp"
|
||||
case ".mp4":
|
||||
contentType = "video/mp4"
|
||||
case ".webm":
|
||||
contentType = "video/webm"
|
||||
case ".mp3":
|
||||
contentType = "audio/mpeg"
|
||||
case ".ogg":
|
||||
contentType = "audio/ogg"
|
||||
case ".wav":
|
||||
contentType = "audio/wav"
|
||||
case ".pdf":
|
||||
contentType = "application/pdf"
|
||||
case ".doc":
|
||||
contentType = "application/msword"
|
||||
case ".docx":
|
||||
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 Content-Type 确定文件类型和限制
|
||||
var maxSize int64
|
||||
var allowedTypes []string
|
||||
var fileType string
|
||||
|
||||
limits := h.cfg.Storage.Upload.Limits
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "image/"):
|
||||
maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Image.AllowedTypes
|
||||
fileType = "image"
|
||||
case strings.HasPrefix(contentType, "video/"):
|
||||
maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Video.AllowedTypes
|
||||
fileType = "video"
|
||||
case strings.HasPrefix(contentType, "audio/"):
|
||||
maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Audio.AllowedTypes
|
||||
fileType = "audio"
|
||||
case strings.HasPrefix(contentType, "application/"):
|
||||
maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Document.AllowedTypes
|
||||
fileType = "document"
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Unsupported file type",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 文件类型限制
|
||||
allowedTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"video/mp4": true,
|
||||
"video/webm": true,
|
||||
"audio/mpeg": true,
|
||||
"audio/ogg": true,
|
||||
"application/pdf": true,
|
||||
// 检查文件类型是否允许
|
||||
typeAllowed := false
|
||||
for _, allowed := range allowedTypes {
|
||||
if contentType == allowed {
|
||||
typeAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
if _, ok := allowedTypes[contentType]; !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"})
|
||||
if !typeAllowed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件大小
|
||||
if file.Size > maxSize {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Upload file
|
||||
media, err := h.service.Upload(c.Request.Context(), file, userID.(int))
|
||||
media, err := h.service.Upload(c.Request.Context(), file, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to upload media")
|
||||
log.Error().Err(err).
|
||||
Str("filename", file.Filename).
|
||||
Str("content_type", contentType).
|
||||
Int("user_id", userID).
|
||||
Msg("Failed to upload media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, media)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"data": media,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) GetMedia(c *gin.Context) {
|
||||
|
@ -101,7 +184,7 @@ func (h *Handler) GetMedia(c *gin.Context) {
|
|||
}
|
||||
|
||||
// Get file content
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), id)
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), media.StorageID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get media file")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
|
||||
|
@ -122,16 +205,18 @@ func (h *Handler) GetMedia(c *gin.Context) {
|
|||
}
|
||||
|
||||
func (h *Handler) GetMediaFile(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
|
||||
return
|
||||
}
|
||||
year := c.Param("year")
|
||||
month := c.Param("month")
|
||||
filename := c.Param("filename")
|
||||
|
||||
// Get file content
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), id)
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), filename) // 直接使用完整的文件名
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get media file")
|
||||
log.Error().Err(err).
|
||||
Str("year", year).
|
||||
Str("month", month).
|
||||
Str("filename", filename).
|
||||
Msg("Failed to get media file")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
|
||||
return
|
||||
}
|
||||
|
@ -151,20 +236,33 @@ func (h *Handler) GetMediaFile(c *gin.Context) {
|
|||
|
||||
func (h *Handler) DeleteMedia(c *gin.Context) {
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert user ID to int
|
||||
userID, err := strconv.Atoi(userIDStr.(string))
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("user_id", fmt.Sprintf("%v", userIDStr)).
|
||||
Msg("Failed to convert user ID to int")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete media")
|
||||
if err := h.service.DeleteMedia(c.Request.Context(), id, userID); err != nil {
|
||||
log.Error().Err(err).
|
||||
Int("media_id", id).
|
||||
Int("user_id", userID).
|
||||
Msg("Failed to delete media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,524 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"net/textproto"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type MediaHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestMediaHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestListMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
query string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功列出媒体",
|
||||
query: "?limit=10&offset=0",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "使用默认限制和偏移",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "列出媒体失败",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return(nil, errors.New("failed to list media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to list media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.ListMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response []*ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestUploadMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功上传媒体",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
expectedFile := &multipart.FileHeader{
|
||||
Filename: "test.jpg",
|
||||
Size: int64(len("test content")),
|
||||
Header: textproto.MIMEHeader{
|
||||
"Content-Type": []string{"image/jpeg"},
|
||||
},
|
||||
}
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
|
||||
s.Equal(expectedFile.Filename, f.Filename)
|
||||
s.Equal(expectedFile.Size, f.Size)
|
||||
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
|
||||
return &ent.Media{ID: 1}, nil
|
||||
})
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", nil)
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "上传失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to upload"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to upload media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, err := tc.setupRequest()
|
||||
s.NoError(err)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.UploadMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
|
||||
Size: 11,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "获取媒体元数据失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to get media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media",
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.GetMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal("11", w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
s.Equal("test content", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody []byte
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体文件",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
fileContent := "test file content"
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
|
||||
Name: "test.jpg",
|
||||
Size: int64(len(fileContent)),
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte("test file content"),
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup
|
||||
req, err := tc.setupRequest()
|
||||
s.Require().NoError(err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Test
|
||||
s.handler.GetMediaFile(c)
|
||||
|
||||
// Verify
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
s.Equal(tc.expectedBody, w.Body.Bytes())
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功删除媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(nil)
|
||||
},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
mediaID: "1",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "删除媒体失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(errors.New("failed to delete"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to delete media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.DeleteMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,624 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PostHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestPostHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(PostHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListPosts
|
||||
func (s *PostHandlerTestSuite) TestListPosts() {
|
||||
categoryID := 1
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
langCode: "en",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
langCode: "en",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
langCode: "en",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/posts"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Post
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetPost
|
||||
func (s *PostHandlerTestSuite) TestGetPost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "zh", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "zh",
|
||||
"title": "测试帖子",
|
||||
"content_markdown": "测试内容",
|
||||
"summary": "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/posts/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreatePost
|
||||
func (s *PostHandlerTestSuite) TestCreatePost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "draft",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "draft",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddPostContent
|
||||
func (s *PostHandlerTestSuite) TestAddPostContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
postID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(&ent.PostContent{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
Edges: ent.PostContentEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
"meta_keywords": "test,keywords",
|
||||
"meta_description": "Test meta description",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid post ID",
|
||||
postID: "invalid",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid post ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
postID: "1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add post content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
// 设置测试临时目录
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
setupRequest func(*http.Request)
|
||||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||||
}{
|
||||
{
|
||||
name: "Console logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "File logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: false,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both console and file logging",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With authenticated user",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
assert.Contains(t, logOutput, "test-user")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 捕获标准输出
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建访问日志中间件
|
||||
middleware, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 添加测试路由
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 如果是测试认证用户的情况,设置用户ID
|
||||
if tc.name == "With authenticated user" {
|
||||
c.Set("user_id", "test-user")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tc.setupRequest != nil {
|
||||
tc.setupRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 恢复标准输出并获取输出内容
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
os.Stdout = oldStdout
|
||||
|
||||
// 验证输出
|
||||
if tc.validateOutput != nil {
|
||||
tc.validateOutput(t, rec, buf.String())
|
||||
}
|
||||
|
||||
// 关闭日志文件
|
||||
if tc.config.EnableFile {
|
||||
// 调用中间件函数来关闭日志文件
|
||||
middleware(nil)
|
||||
// 等待一小段时间确保文件完全关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid log level",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Level: "invalid_level",
|
||||
},
|
||||
expectedError: false, // 应该使用默认的 info 级别
|
||||
},
|
||||
{
|
||||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tss-rocks-be/internal/service"
|
||||
)
|
||||
|
||||
func createTestToken(secret string, claims jwt.MapClaims) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to sign token: %v", err))
|
||||
}
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
jwtSecret := "test-secret"
|
||||
tokenBlacklist := service.NewTokenBlacklist()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(*http.Request)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
checkUserData bool
|
||||
expectedUserID string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "No Authorization header",
|
||||
setupAuth: func(req *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header is required"},
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization format",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "InvalidFormat")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
{
|
||||
name: "Valid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"admin", "editor"},
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
checkUserData: true,
|
||||
expectedUserID: "123",
|
||||
expectedRoles: []string{"admin", "editor"},
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"user"},
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
// 设置日志级别为 debug
|
||||
gin.SetMode(gin.DebugMode)
|
||||
c.Next()
|
||||
}, AuthMiddleware(jwtSecret, tokenBlacklist))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if tc.checkUserData {
|
||||
userID, exists := c.Get("user_id")
|
||||
assert.True(t, exists, "user_id should exist in context")
|
||||
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
|
||||
|
||||
roles, exists := c.Get("user_roles")
|
||||
assert.True(t, exists, "user_roles should exist in context")
|
||||
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tc.setupAuth(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
allowedRoles []string
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "No user roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置用户角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "User roles not found"},
|
||||
},
|
||||
{
|
||||
name: "Invalid roles type",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", 123) // 设置错误类型的角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: map[string]string{"error": "Invalid user roles type"},
|
||||
},
|
||||
{
|
||||
name: "Insufficient permissions",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
||||
},
|
||||
{
|
||||
name: "Allowed role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"admin"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "One of multiple allowed roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user", "editor"})
|
||||
},
|
||||
allowedRoles: []string{"admin", "editor", "moderator"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加角色中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
tc.setupContext(c)
|
||||
c.Next()
|
||||
}, RoleMiddleware(tc.allowedRoles...))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
checkHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "Normal GET request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加 CORS 中间件
|
||||
router.Use(CORS())
|
||||
|
||||
// 添加测试路由
|
||||
router.Any("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(tc.method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证状态码
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.checkHeaders {
|
||||
// 验证 CORS 头部
|
||||
headers := rec.Header()
|
||||
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,207 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.RateLimitConfig
|
||||
setupTest func(*gin.Engine)
|
||||
runTest func(*testing.T, *gin.Engine)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IP rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1, // 每秒1个请求
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 第一个请求应该成功
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 第二个请求应该被限制
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests from this IP", response["error"])
|
||||
|
||||
// 等待限流器重置
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// 第三个请求应该成功
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Route rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
||||
IPBurst: 10,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/limited": {
|
||||
Rate: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/limited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
router.GET("/unlimited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 测试限流路由
|
||||
req := httptest.NewRequest("GET", "/limited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests for this route", response["error"])
|
||||
|
||||
// 测试未限流路由
|
||||
req = httptest.NewRequest("GET", "/unlimited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// IP1 的请求
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.3:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 的请求应该不受 IP1 的限制影响
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.4:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req2)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加限流中间件
|
||||
router.Use(RateLimit(tc.config))
|
||||
|
||||
// 设置测试路由
|
||||
tc.setupTest(router)
|
||||
|
||||
// 运行测试
|
||||
tc.runTest(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
config := &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
}
|
||||
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
// 添加一些IP限流器
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
rl.getLimiter(ip)
|
||||
}
|
||||
|
||||
// 验证IP限流器已创建
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, len(ips), len(rl.ips))
|
||||
rl.mu.RUnlock()
|
||||
|
||||
// 修改一些IP的最后访问时间为1小时前
|
||||
rl.mu.Lock()
|
||||
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 手动触发清理
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 验证过期的IP限流器已被删除
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, 1, len(rl.ips))
|
||||
_, exists := rl.ips["192.168.1.3"]
|
||||
assert.True(t, exists)
|
||||
rl.mu.RUnlock()
|
||||
}
|
|
@ -3,146 +3,120 @@ package middleware
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"tss-rocks-be/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
|
||||
)
|
||||
|
||||
// ValidateUpload 创建文件上传验证中间件
|
||||
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
|
||||
// ValidateUpload 验证上传的文件
|
||||
func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否是multipart/form-data请求
|
||||
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Content-Type must be multipart/form-data",
|
||||
})
|
||||
// Get file from form
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析multipart表单
|
||||
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Failed to parse form: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 获取文件类型和扩展名
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil || form.File == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "No file uploaded",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有上传的文件
|
||||
for _, files := range form.File {
|
||||
for _, file := range files {
|
||||
// 检查文件大小
|
||||
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if !contains(cfg.AllowedExtensions, ext) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File extension %s is not allowed", ext),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to open file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 读取文件头部用于MIME类型检测
|
||||
header := make([]byte, maxHeaderBytes)
|
||||
n, err := src.Read(header)
|
||||
if err != nil && err != io.EOF {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// 检测MIME类型
|
||||
contentType := http.DetectContentType(header)
|
||||
if !contains(cfg.AllowedTypes, contentType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File type %s is not allowed", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件指针重置到开始位置
|
||||
_, err = src.Seek(0, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件内容读入缓冲区
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = io.Copy(buf, src)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将验证过的文件内容和类型保存到上下文中
|
||||
c.Set("validated_file_"+file.Filename, buf)
|
||||
c.Set("validated_content_type_"+file.Filename, contentType)
|
||||
// 如果 Content-Type 为空,尝试从文件扩展名判断
|
||||
if contentType == "" {
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
contentType = "image/jpeg"
|
||||
case ".png":
|
||||
contentType = "image/png"
|
||||
case ".gif":
|
||||
contentType = "image/gif"
|
||||
case ".webp":
|
||||
contentType = "image/webp"
|
||||
case ".mp4":
|
||||
contentType = "video/mp4"
|
||||
case ".webm":
|
||||
contentType = "video/webm"
|
||||
case ".mp3":
|
||||
contentType = "audio/mpeg"
|
||||
case ".ogg":
|
||||
contentType = "audio/ogg"
|
||||
case ".wav":
|
||||
contentType = "audio/wav"
|
||||
case ".pdf":
|
||||
contentType = "application/pdf"
|
||||
case ".doc":
|
||||
contentType = "application/msword"
|
||||
case ".docx":
|
||||
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 Content-Type 确定文件类型和限制
|
||||
var maxSize int64
|
||||
var allowedTypes []string
|
||||
var fileType string
|
||||
|
||||
limits := cfg.Limits
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "image/"):
|
||||
maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Image.AllowedTypes
|
||||
fileType = "image"
|
||||
case strings.HasPrefix(contentType, "video/"):
|
||||
maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Video.AllowedTypes
|
||||
fileType = "video"
|
||||
case strings.HasPrefix(contentType, "audio/"):
|
||||
maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Audio.AllowedTypes
|
||||
fileType = "audio"
|
||||
case strings.HasPrefix(contentType, "application/"):
|
||||
maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Document.AllowedTypes
|
||||
fileType = "document"
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Unsupported file type: %s", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件类型是否允许
|
||||
typeAllowed := false
|
||||
for _, allowed := range allowedTypes {
|
||||
if contentType == allowed {
|
||||
typeAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !typeAllowed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件大小
|
||||
if file.Size > maxSize {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含指定的字符串
|
||||
func contains(slice []string, str string) bool {
|
||||
for _, s := range slice {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetValidatedFile 从上下文中获取验证过的文件内容
|
||||
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
|
||||
file, exists := c.Get("validated_file_" + filename)
|
||||
|
|
|
@ -1,262 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, bytes.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/upload", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.UploadConfig
|
||||
filename string
|
||||
content []byte
|
||||
setupRequest func(*testing.T) *http.Request
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Valid image upload",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5, // 5MB
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.jpg",
|
||||
content: []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
|
||||
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invalid file extension",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.txt",
|
||||
content: []byte("test content"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File extension .txt is not allowed",
|
||||
},
|
||||
{
|
||||
name: "File too large",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 1, // 1MB
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "large.jpg",
|
||||
content: make([]byte, 2<<20), // 2MB
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File large.jpg exceeds maximum size of 1 MB",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "fake.jpg",
|
||||
content: []byte("not a real image"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File type text/plain; charset=utf-8 is not allowed",
|
||||
},
|
||||
{
|
||||
name: "Missing file",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Failed to parse form",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type header",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest("POST", "/upload", nil)
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Content-Type must be multipart/form-data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if tt.setupRequest != nil {
|
||||
req = tt.setupRequest(t)
|
||||
} else {
|
||||
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
middleware := ValidateUpload(tt.config)
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
if tt.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(w.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response["error"], tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidatedFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
filename string
|
||||
expectedFound bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Get existing file",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 创建测试文件内容
|
||||
content := []byte("test content")
|
||||
buf := bytes.NewBuffer(content)
|
||||
|
||||
// 设置验证过的文件和内容类型
|
||||
c.Set("validated_file_test.txt", buf)
|
||||
c.Set("validated_content_type_test.txt", "text/plain")
|
||||
},
|
||||
filename: "test.txt",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "File not found",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置任何文件
|
||||
},
|
||||
filename: "nonexistent.txt",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setupContext != nil {
|
||||
tt.setupContext(c)
|
||||
}
|
||||
|
||||
buffer, contentType, found := GetValidatedFile(c, tt.filename)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
if tt.expectedFound {
|
||||
assert.NotNil(t, buffer)
|
||||
assert.NotEmpty(t, contentType)
|
||||
} else {
|
||||
assert.Nil(t, buffer)
|
||||
assert.Empty(t, contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "String found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "String not found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
str: "a",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.str)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent/enttest"
|
||||
"tss-rocks-be/ent/role"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestInitializeRBAC(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test initialization
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Verify roles were created
|
||||
for roleName := range DefaultRoles {
|
||||
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Role %s was not created: %v", roleName, err)
|
||||
}
|
||||
|
||||
// Verify permissions for each role
|
||||
perms, err := r.QueryPermissions().All(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
|
||||
}
|
||||
|
||||
expectedPerms := DefaultRoles[roleName]
|
||||
permCount := 0
|
||||
for _, actions := range expectedPerms {
|
||||
permCount += len(actions)
|
||||
}
|
||||
|
||||
if len(perms) != permCount {
|
||||
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignRoleToUser(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize RBAC
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user
|
||||
user, err := client.User.Create().
|
||||
SetEmail("test@example.com").
|
||||
SetUsername("testuser").
|
||||
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Test assigning role to user
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "editor")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign role to user: %v", err)
|
||||
}
|
||||
|
||||
// Verify role assignment
|
||||
assignedRoles, err := user.QueryRoles().All(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query user roles: %v", err)
|
||||
}
|
||||
|
||||
if len(assignedRoles) != 1 {
|
||||
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
|
||||
}
|
||||
|
||||
if assignedRoles[0].Name != "editor" {
|
||||
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
|
||||
}
|
||||
|
||||
// Test assigning non-existent role
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when assigning non-existent role, got nil")
|
||||
}
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitDatabase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
driver string
|
||||
dsn string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "success with sqlite3",
|
||||
driver: "sqlite3",
|
||||
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
{
|
||||
name: "invalid driver",
|
||||
driver: "invalid_driver",
|
||||
dsn: "file:ent?mode=memory",
|
||||
wantErr: true,
|
||||
errContains: "unsupported driver",
|
||||
},
|
||||
{
|
||||
name: "invalid dsn",
|
||||
driver: "sqlite3",
|
||||
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
|
||||
wantErr: true,
|
||||
errContains: "foreign_keys pragma is off",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, client)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 测试数据库连接是否正常工作
|
||||
err = client.Schema.Create(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tss-rocks-be/internal/config"
|
||||
)
|
||||
|
||||
func TestNewEntClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
}{
|
||||
{
|
||||
name: "default sqlite3 config",
|
||||
cfg: &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Driver: "sqlite3",
|
||||
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := NewEntClient(tt.cfg)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 验证客户端是否可以正常工作
|
||||
err := client.Schema.Create(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,220 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/types"
|
||||
"tss-rocks-be/ent/enttest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
Upload: types.UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/v1/upload": {Rate: 10, Burst: 20},
|
||||
},
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "testdata/access.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 3,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 测试服务器初始化
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
assert.NotNil(t, s.router)
|
||||
assert.NotNil(t, s.handler)
|
||||
assert.Equal(t, cfg, s.config)
|
||||
}
|
||||
|
||||
func TestNew_StorageError(t *testing.T) {
|
||||
// 创建一个无效的存储配置
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{
|
||||
Type: "invalid_type", // 使用无效的存储类型
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
s, err := New(cfg, client)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
assert.Contains(t, err.Error(), "failed to initialize storage")
|
||||
}
|
||||
|
||||
func TestServer_StartAndShutdown(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 0, // 使用随机端口
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 初始化服务器
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 在 goroutine 中启动服务器
|
||||
go func() {
|
||||
err := s.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试关闭服务器
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = s.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_StartError(t *testing.T) {
|
||||
// 创建一个配置,使用已经被占用的端口来触发错误
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8899, // 使用固定端口以便测试
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 创建第一个服务器实例
|
||||
s1, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 启动第一个服务器
|
||||
go func() {
|
||||
err := s1.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 尝试在同一端口启动第二个服务器,应该会失败
|
||||
s2, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s2.Start()
|
||||
assert.Error(t, err)
|
||||
|
||||
// 清理
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 关闭第一个服务器
|
||||
err = s1.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查第一个服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 关闭第二个服务器
|
||||
err = s2.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_ShutdownWithNilServer(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.Shutdown(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -1,11 +1,14 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -26,6 +29,8 @@ import (
|
|||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"github.com/chai2010/webp"
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -419,21 +424,71 @@ func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, us
|
|||
// Open the uploaded file
|
||||
src, err := openFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 获取文件类型和扩展名
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if contentType == "" {
|
||||
// 如果 Content-Type 为空,尝试从文件扩展名判断
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
contentType = "image/jpeg"
|
||||
case ".png":
|
||||
contentType = "image/png"
|
||||
case ".gif":
|
||||
contentType = "image/gif"
|
||||
case ".webp":
|
||||
contentType = "image/webp"
|
||||
case ".mp4":
|
||||
contentType = "video/mp4"
|
||||
case ".webm":
|
||||
contentType = "video/webm"
|
||||
case ".mp3":
|
||||
contentType = "audio/mpeg"
|
||||
case ".ogg":
|
||||
contentType = "audio/ogg"
|
||||
case ".wav":
|
||||
contentType = "audio/wav"
|
||||
case ".pdf":
|
||||
contentType = "application/pdf"
|
||||
case ".doc":
|
||||
contentType = "application/msword"
|
||||
case ".docx":
|
||||
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是图片,检查是否需要转换为 WebP
|
||||
var fileToSave multipart.File = src
|
||||
var finalContentType = contentType
|
||||
if strings.HasPrefix(contentType, "image/") && contentType != "image/webp" {
|
||||
// 转换为 WebP
|
||||
webpFile, err := convertToWebP(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert image to WebP: %v", err)
|
||||
}
|
||||
fileToSave = webpFile
|
||||
finalContentType = "image/webp"
|
||||
ext = ".webp"
|
||||
}
|
||||
|
||||
// 生成带扩展名的存储文件名
|
||||
storageFilename := uuid.New().String() + ext
|
||||
|
||||
// Save the file to storage
|
||||
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
|
||||
fileInfo, err := s.storage.Save(ctx, storageFilename, finalContentType, fileToSave)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to save file: %v", err)
|
||||
}
|
||||
|
||||
// Create media record
|
||||
return s.client.Media.Create().
|
||||
SetStorageID(fileInfo.ID).
|
||||
SetOriginalName(file.Filename).
|
||||
SetMimeType(fileInfo.ContentType).
|
||||
SetMimeType(finalContentType).
|
||||
SetSize(fileInfo.Size).
|
||||
SetURL(fileInfo.URL).
|
||||
SetCreatedBy(strconv.Itoa(userID)).
|
||||
|
@ -444,13 +499,8 @@ func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error)
|
|||
return s.client.Media.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
media, err := s.GetMedia(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return s.storage.Get(ctx, media.StorageID)
|
||||
func (s *serviceImpl) GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
return s.storage.Get(ctx, storageID)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
|
||||
|
@ -476,7 +526,7 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error
|
|||
}
|
||||
|
||||
// Post operations
|
||||
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
|
||||
func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) {
|
||||
var postStatus post.Status
|
||||
switch status {
|
||||
case "draft":
|
||||
|
@ -492,10 +542,25 @@ func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post,
|
|||
// Generate a random slug
|
||||
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
|
||||
|
||||
return s.client.Post.Create().
|
||||
// Create post with categories
|
||||
postCreate := s.client.Post.Create().
|
||||
SetStatus(postStatus).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
SetSlug(slug)
|
||||
|
||||
// Add categories if provided
|
||||
if len(categoryIDs) > 0 {
|
||||
categories := make([]*ent.Category, 0, len(categoryIDs))
|
||||
for _, id := range categoryIDs {
|
||||
category, err := s.client.Category.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get category %d: %w", id, err)
|
||||
}
|
||||
categories = append(categories, category)
|
||||
}
|
||||
postCreate.AddCategories(categories...)
|
||||
}
|
||||
|
||||
return postCreate.Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
|
||||
|
@ -574,7 +639,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
|
|||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
WithCategories().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get posts: %w", err)
|
||||
|
@ -590,7 +655,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
|
|||
return posts[0], nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
|
||||
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
|
@ -610,8 +675,8 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
|
|||
Where(post.StatusEQ(post.StatusPublished))
|
||||
|
||||
// Add category filter if provided
|
||||
if categoryID != nil {
|
||||
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
|
||||
if len(categoryIDs) > 0 {
|
||||
query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...)))
|
||||
}
|
||||
|
||||
// Get unique post IDs
|
||||
|
@ -638,7 +703,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
|
|||
}
|
||||
|
||||
// If no category filter is applied, only take the latest 5 posts
|
||||
if categoryID == nil && len(postIDs) > 5 {
|
||||
if len(categoryIDs) == 0 && len(postIDs) > 5 {
|
||||
postIDs = postIDs[:5]
|
||||
}
|
||||
|
||||
|
@ -666,7 +731,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
|
|||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
WithCategories().
|
||||
Order(ent.Desc(post.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
|
@ -1050,3 +1115,47 @@ func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID
|
|||
|
||||
return s.client.Daily.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// convertToWebP 将图片转换为 WebP 格式
|
||||
func convertToWebP(src multipart.File) (multipart.File, error) {
|
||||
// 读取原始图片
|
||||
img, err := imaging.Decode(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode image: %v", err)
|
||||
}
|
||||
|
||||
// 创建一个新的缓冲区来存储 WebP 图片
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// 将图片编码为 WebP 格式
|
||||
// 设置较高的质量以保持图片质量
|
||||
err = webp.Encode(buf, img, &webp.Options{
|
||||
Lossless: false,
|
||||
Quality: 90,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode image to WebP: %v", err)
|
||||
}
|
||||
|
||||
// 创建一个新的临时文件来存储转换后的图片
|
||||
tmpFile, err := os.CreateTemp("", "webp-*.webp")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
// 写入转换后的数据
|
||||
if _, err := io.Copy(tmpFile, buf); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpFile.Name())
|
||||
return nil, fmt.Errorf("failed to write WebP data: %v", err)
|
||||
}
|
||||
|
||||
// 将文件指针移回开始位置
|
||||
if _, err := tmpFile.Seek(0, 0); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpFile.Name())
|
||||
return nil, fmt.Errorf("failed to seek file: %v", err)
|
||||
}
|
||||
|
||||
return tmpFile, nil
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,332 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/internal/storage/mock"
|
||||
"tss-rocks-be/internal/testutil"
|
||||
|
||||
"bou.ke/monkey"
|
||||
)
|
||||
|
||||
type MediaServiceTestSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *ent.Client
|
||||
storage *mock.MockStorage
|
||||
ctrl *gomock.Controller
|
||||
svc MediaService
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testutil.NewTestClient()
|
||||
require.NotNil(s.T(), s.client)
|
||||
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storage = mock.NewMockStorage(s.ctrl)
|
||||
s.svc = NewMediaService(s.client, s.storage)
|
||||
|
||||
// 清理数据库
|
||||
_, err := s.client.Media.Delete().Exec(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
s.client.Close()
|
||||
}
|
||||
|
||||
func TestMediaServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaServiceTestSuite))
|
||||
}
|
||||
|
||||
type mockFileHeader struct {
|
||||
filename string
|
||||
contentType string
|
||||
size int64
|
||||
content []byte
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Open() (multipart.File, error) {
|
||||
return newMockMultipartFile(h.content), nil
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Filename() string {
|
||||
return h.filename
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Size() int64 {
|
||||
return h.size
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Header() textproto.MIMEHeader {
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Type", h.contentType)
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
|
||||
header := &multipart.FileHeader{
|
||||
Filename: filename,
|
||||
Header: make(textproto.MIMEHeader),
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
header.Header.Set("Content-Type", contentType)
|
||||
|
||||
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
|
||||
return newMockMultipartFile(content), nil
|
||||
})
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestUpload() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
contentType string
|
||||
content []byte
|
||||
setupMock func()
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Upload text file",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||||
content, err := io.ReadAll(reader)
|
||||
s.Require().NoError(err)
|
||||
s.Equal([]byte("test content"), content)
|
||||
return &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: int64(len(content)),
|
||||
}, nil
|
||||
})
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename",
|
||||
filename: "../test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {},
|
||||
wantErr: true,
|
||||
errMsg: "invalid filename",
|
||||
},
|
||||
{
|
||||
name: "Storage error",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
Return(nil, fmt.Errorf("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "storage error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create test file
|
||||
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
|
||||
|
||||
// Add debug output
|
||||
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
|
||||
|
||||
// Test upload
|
||||
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
|
||||
|
||||
// Add debug output
|
||||
if err != nil {
|
||||
s.T().Logf("Upload error: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), tc.errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(media)
|
||||
s.Equal(tc.filename, media.OriginalName)
|
||||
s.Equal(tc.contentType, media.MimeType)
|
||||
s.Equal(int64(len(tc.content)), media.Size)
|
||||
s.Equal("1", media.CreatedBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGet() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test get existing media
|
||||
result, err := s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(media.ID, result.ID)
|
||||
s.Equal(media.OriginalName, result.OriginalName)
|
||||
|
||||
// Test get non-existing media
|
||||
_, err = s.svc.Get(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "media not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestDelete() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test delete by unauthorized user
|
||||
err = s.svc.Delete(s.ctx, media.ID, 2)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "unauthorized")
|
||||
|
||||
// Test delete by owner
|
||||
s.storage.EXPECT().
|
||||
Delete(gomock.Any(), "test-id").
|
||||
Return(nil)
|
||||
err = s.svc.Delete(s.ctx, media.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify media is deleted
|
||||
_, err = s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestList() {
|
||||
// Create test media
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := s.client.Media.Create().
|
||||
SetStorageID(fmt.Sprintf("test-id-%d", i)).
|
||||
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test list with limit and offset
|
||||
media, err := s.svc.List(s.ctx, 3, 1)
|
||||
s.Require().NoError(err)
|
||||
s.Len(media, 3)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGetFile() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Mock storage.Get
|
||||
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
|
||||
mockFileInfo := &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: 12,
|
||||
}
|
||||
s.storage.EXPECT().
|
||||
Get(gomock.Any(), "test-id").
|
||||
Return(mockReader, mockFileInfo, nil)
|
||||
|
||||
// Test get file
|
||||
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(reader)
|
||||
s.Equal(mockFileInfo, info)
|
||||
|
||||
// Test get non-existing file
|
||||
_, _, err = s.svc.GetFile(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestIsValidFilename() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Valid filename",
|
||||
filename: "test.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ../",
|
||||
filename: "../test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ./",
|
||||
filename: "./test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with backslash",
|
||||
filename: "test\\file.txt",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
got := isValidFilename(tc.filename)
|
||||
s.Equal(tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock
|
|
@ -1,7 +1,5 @@
|
|||
package service
|
||||
|
||||
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
@ -34,16 +32,16 @@ type Service interface {
|
|||
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
|
||||
|
||||
// Post operations
|
||||
CreatePost(ctx context.Context, status string) (*ent.Post, error)
|
||||
CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error)
|
||||
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
|
||||
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
|
||||
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
|
||||
ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error)
|
||||
|
||||
// Media operations
|
||||
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
|
||||
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
|
||||
GetMedia(ctx context.Context, id int) (*ent.Media, error)
|
||||
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
|
||||
GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error)
|
||||
DeleteMedia(ctx context.Context, id int, userID int) error
|
||||
|
||||
// Contributor operations
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
|
@ -44,6 +46,24 @@ func (s *LocalStorage) generateID() (string, error) {
|
|||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) generateFilePath(id string, ext string, createTime time.Time) string {
|
||||
// Create year/month directory structure
|
||||
year := createTime.Format("2006")
|
||||
month := createTime.Format("01")
|
||||
|
||||
// If id already has an extension, don't add ext
|
||||
if filepath.Ext(id) != "" {
|
||||
return filepath.Join(year, month, id)
|
||||
}
|
||||
|
||||
// Otherwise, add the extension if provided
|
||||
filename := id
|
||||
if ext != "" {
|
||||
filename = id + ext
|
||||
}
|
||||
return filepath.Join(year, month, filename)
|
||||
}
|
||||
|
||||
func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error {
|
||||
metaPath := filepath.Join(s.metaDir, id+".meta")
|
||||
file, err := os.Create(metaPath)
|
||||
|
@ -89,11 +109,64 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
|
|||
return nil, fmt.Errorf("failed to generate file ID: %w", err)
|
||||
}
|
||||
|
||||
// Create the file path
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
// Get file extension from original name or content type
|
||||
ext := filepath.Ext(name)
|
||||
if ext == "" {
|
||||
// If no extension in name, try to get it from content type
|
||||
switch contentType {
|
||||
case "image/jpeg":
|
||||
ext = ".jpg"
|
||||
case "image/png":
|
||||
ext = ".png"
|
||||
case "image/gif":
|
||||
ext = ".gif"
|
||||
case "image/webp":
|
||||
ext = ".webp"
|
||||
case "image/svg+xml":
|
||||
ext = ".svg"
|
||||
case "video/mp4":
|
||||
ext = ".mp4"
|
||||
case "video/webm":
|
||||
ext = ".webm"
|
||||
case "audio/mpeg":
|
||||
ext = ".mp3"
|
||||
case "audio/ogg":
|
||||
ext = ".ogg"
|
||||
case "audio/wav":
|
||||
ext = ".wav"
|
||||
case "application/pdf":
|
||||
ext = ".pdf"
|
||||
case "application/msword":
|
||||
ext = ".doc"
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
ext = ".docx"
|
||||
case "application/vnd.ms-excel":
|
||||
ext = ".xls"
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||
ext = ".xlsx"
|
||||
case "application/zip":
|
||||
ext = ".zip"
|
||||
case "application/x-rar-compressed":
|
||||
ext = ".rar"
|
||||
case "text/plain":
|
||||
ext = ".txt"
|
||||
case "text/csv":
|
||||
ext = ".csv"
|
||||
}
|
||||
}
|
||||
|
||||
// Create the file path with year/month structure
|
||||
now := time.Now()
|
||||
relPath := s.generateFilePath(id, ext, now)
|
||||
fullPath := filepath.Join(s.rootDir, relPath)
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Create the file
|
||||
file, err := os.Create(filePath)
|
||||
file, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
|
@ -102,52 +175,73 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
|
|||
// Copy the content
|
||||
size, err := io.Copy(file, reader)
|
||||
if err != nil {
|
||||
// Clean up the file if there's an error
|
||||
os.Remove(filePath)
|
||||
os.Remove(fullPath) // Clean up on error
|
||||
return nil, fmt.Errorf("failed to write file content: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
// Save metadata
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: size,
|
||||
ContentType: contentType,
|
||||
Size: size,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
URL: fmt.Sprintf("/api/media/file/%s", id),
|
||||
URL: fmt.Sprintf("/media/%s/%s/%s", now.Format("2006"), now.Format("01"), filepath.Base(relPath)),
|
||||
}
|
||||
|
||||
// Save metadata
|
||||
if err := s.saveMetadata(id, info); err != nil {
|
||||
os.Remove(filePath)
|
||||
return nil, err
|
||||
os.Remove(fullPath) // Clean up on error
|
||||
return nil, fmt.Errorf("failed to save metadata: %w", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
// 从 id 中提取文件扩展名和基础 ID
|
||||
ext := filepath.Ext(id)
|
||||
baseID := strings.TrimSuffix(id, ext)
|
||||
|
||||
// Open the file
|
||||
// 获取文件的创建时间(从元数据或当前时间)
|
||||
metaPath := filepath.Join(s.metaDir, baseID+".meta")
|
||||
var createTime time.Time
|
||||
if stat, err := os.Stat(metaPath); err == nil {
|
||||
createTime = stat.ModTime()
|
||||
} else {
|
||||
createTime = time.Now() // 如果找不到元数据,使用当前时间
|
||||
}
|
||||
|
||||
// 生成完整的文件路径
|
||||
year := createTime.Format("2006")
|
||||
month := createTime.Format("01")
|
||||
filePath := filepath.Join(s.rootDir, year, month, id) // 直接使用完整的 id(包含扩展名)
|
||||
|
||||
// 调试日志
|
||||
log.Debug().
|
||||
Str("id", id).
|
||||
Str("baseID", baseID).
|
||||
Str("ext", ext).
|
||||
Str("filePath", filePath).
|
||||
Time("createTime", createTime).
|
||||
Msg("Attempting to get file")
|
||||
|
||||
// 打开文件
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil, fmt.Errorf("file not found: %s", id)
|
||||
return nil, nil, fmt.Errorf("file not found: %s (path: %s)", id, filePath)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
// Get file info
|
||||
// 获取文件信息
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, nil, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
// Load metadata
|
||||
name, contentType, err := s.loadMetadata(id)
|
||||
// 加载元数据
|
||||
name, contentType, err := s.loadMetadata(baseID)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, nil, err
|
||||
|
@ -158,27 +252,44 @@ func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *File
|
|||
Name: name,
|
||||
Size: stat.Size(),
|
||||
ContentType: contentType,
|
||||
CreatedAt: stat.ModTime(),
|
||||
CreatedAt: createTime,
|
||||
UpdatedAt: stat.ModTime(),
|
||||
URL: fmt.Sprintf("/api/media/file/%s", id),
|
||||
URL: fmt.Sprintf("/media/%s/%s/%s", year, month, id),
|
||||
}
|
||||
|
||||
return file, info, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Delete(ctx context.Context, id string) error {
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("file not found: %s", id)
|
||||
}
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
// 从 id 中提取文件扩展名
|
||||
ext := filepath.Ext(id)
|
||||
baseID := strings.TrimSuffix(id, ext)
|
||||
|
||||
// 获取文件的创建时间(从元数据或当前时间)
|
||||
metaPath := filepath.Join(s.metaDir, baseID+".meta")
|
||||
var createTime time.Time
|
||||
if stat, err := os.Stat(metaPath); err == nil {
|
||||
createTime = stat.ModTime()
|
||||
} else {
|
||||
createTime = time.Now() // 如果找不到元数据,使用当前时间
|
||||
}
|
||||
|
||||
// Remove metadata
|
||||
metaPath := filepath.Join(s.metaDir, id+".meta")
|
||||
if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove metadata: %w", err)
|
||||
// 生成完整的文件路径
|
||||
relPath := s.generateFilePath(baseID, ext, createTime)
|
||||
filePath := filepath.Join(s.rootDir, relPath)
|
||||
|
||||
// 删除文件
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除元数据
|
||||
if err := os.Remove(metaPath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -1,154 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLocalStorage(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "storage_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a new LocalStorage instance
|
||||
storage, err := NewLocalStorage(tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Save and Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Save the file
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, int64(len(content)), fileInfo.Size)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
assert.False(t, fileInfo.CreatedAt.IsZero())
|
||||
|
||||
// Get the file
|
||||
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, fileInfo.ID, info.ID)
|
||||
assert.Equal(t, fileInfo.Name, info.Name)
|
||||
assert.Equal(t, fileInfo.Size, info.Size)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
// Clear the directory first
|
||||
dirEntries, err := os.ReadDir(tempDir)
|
||||
require.NoError(t, err)
|
||||
for _, entry := range dirEntries {
|
||||
if entry.Name() != ".meta" {
|
||||
os.Remove(filepath.Join(tempDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
// Save multiple files
|
||||
testFiles := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{"test1.txt", "content1"},
|
||||
{"test2.txt", "content2"},
|
||||
{"other.txt", "content3"},
|
||||
}
|
||||
|
||||
for _, f := range testFiles {
|
||||
reader := bytes.NewReader([]byte(f.content))
|
||||
_, err := storage.Save(ctx, f.name, "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// List all files
|
||||
allFiles, err := storage.List(ctx, "", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, allFiles, 3)
|
||||
|
||||
// List files with prefix
|
||||
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, filesWithPrefix, 2)
|
||||
for _, f := range filesWithPrefix {
|
||||
assert.True(t, strings.HasPrefix(f.Name, "test"))
|
||||
}
|
||||
|
||||
// Test pagination
|
||||
pagedFiles, err := storage.List(ctx, "", 2, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pagedFiles, 2)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check if file exists
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Check non-existent file
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the file
|
||||
err = storage.Delete(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file is deleted
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid operations", func(t *testing.T) {
|
||||
// Try to get non-existent file
|
||||
_, _, err := storage.Get(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
|
||||
// Try to save file with nil reader
|
||||
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "reader cannot be nil")
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
})
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -48,14 +49,25 @@ func (s *S3Storage) generateID() (string, error) {
|
|||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) getObjectURL(id string) string {
|
||||
func (s *S3Storage) generateObjectKey(id string, ext string, createTime time.Time) string {
|
||||
// Create year/month structure
|
||||
year := createTime.Format("2006")
|
||||
month := createTime.Format("01")
|
||||
filename := id
|
||||
if ext != "" {
|
||||
filename = id + ext
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", year, month, filename)
|
||||
}
|
||||
|
||||
func (s *S3Storage) getObjectURL(key string) string {
|
||||
if s.customURL != "" {
|
||||
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id)
|
||||
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), key)
|
||||
}
|
||||
if s.proxyS3 {
|
||||
return fmt.Sprintf("/api/media/file/%s", id)
|
||||
return fmt.Sprintf("/media/%s", key)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id)
|
||||
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, key)
|
||||
}
|
||||
|
||||
func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
|
||||
|
@ -65,10 +77,60 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
|
|||
return nil, fmt.Errorf("failed to generate file ID: %w", err)
|
||||
}
|
||||
|
||||
// Get file extension from original name or content type
|
||||
ext := filepath.Ext(name)
|
||||
if ext == "" {
|
||||
// If no extension in name, try to get it from content type
|
||||
switch contentType {
|
||||
case "image/jpeg":
|
||||
ext = ".jpg"
|
||||
case "image/png":
|
||||
ext = ".png"
|
||||
case "image/gif":
|
||||
ext = ".gif"
|
||||
case "image/webp":
|
||||
ext = ".webp"
|
||||
case "image/svg+xml":
|
||||
ext = ".svg"
|
||||
case "video/mp4":
|
||||
ext = ".mp4"
|
||||
case "video/webm":
|
||||
ext = ".webm"
|
||||
case "audio/mpeg":
|
||||
ext = ".mp3"
|
||||
case "audio/ogg":
|
||||
ext = ".ogg"
|
||||
case "audio/wav":
|
||||
ext = ".wav"
|
||||
case "application/pdf":
|
||||
ext = ".pdf"
|
||||
case "application/msword":
|
||||
ext = ".doc"
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
ext = ".docx"
|
||||
case "application/vnd.ms-excel":
|
||||
ext = ".xls"
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||
ext = ".xlsx"
|
||||
case "application/zip":
|
||||
ext = ".zip"
|
||||
case "application/x-rar-compressed":
|
||||
ext = ".rar"
|
||||
case "text/plain":
|
||||
ext = ".txt"
|
||||
case "text/csv":
|
||||
ext = ".csv"
|
||||
}
|
||||
}
|
||||
|
||||
// Create the object key with year/month structure
|
||||
now := time.Now()
|
||||
key := s.generateObjectKey(id, ext, now)
|
||||
|
||||
// Check if the file exists
|
||||
_, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("file already exists with ID: %s", id)
|
||||
|
@ -82,37 +144,50 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
|
|||
// Upload the file
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
Key: aws.String(key),
|
||||
Body: reader,
|
||||
ContentType: aws.String(contentType),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": name,
|
||||
"x-amz-meta-created-at": now.Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: 0, // Size is not available until after upload
|
||||
ContentType: contentType,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
URL: s.getObjectURL(id),
|
||||
URL: s.getObjectURL(key),
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
|
||||
// Get the object from S3
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
// Try to find the file with different extensions
|
||||
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
|
||||
|
||||
var result *s3.GetObjectOutput
|
||||
var err error
|
||||
var key string
|
||||
|
||||
for _, ext := range exts {
|
||||
key = s.generateObjectKey(id, ext, time.Now())
|
||||
result, err = s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get file from S3: %w", err)
|
||||
}
|
||||
|
@ -122,111 +197,117 @@ func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInf
|
|||
Name: result.Metadata["x-amz-meta-original-name"],
|
||||
Size: aws.ToInt64(result.ContentLength),
|
||||
ContentType: aws.ToString(result.ContentType),
|
||||
CreatedAt: aws.ToTime(result.LastModified),
|
||||
UpdatedAt: aws.ToTime(result.LastModified),
|
||||
URL: s.getObjectURL(id),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
URL: s.getObjectURL(key),
|
||||
}
|
||||
|
||||
return result.Body, info, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from S3: %w", err)
|
||||
// Try to find and delete the file with different extensions
|
||||
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
|
||||
|
||||
var lastErr error
|
||||
for _, ext := range exts {
|
||||
key := s.generateObjectKey(id, ext, time.Now())
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil
|
||||
return fmt.Errorf("failed to delete file from S3: %w", lastErr)
|
||||
}
|
||||
|
||||
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
|
||||
// Try to find the file with different extensions
|
||||
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
|
||||
|
||||
for _, ext := range exts {
|
||||
key := s.generateObjectKey(id, ext, time.Now())
|
||||
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
|
||||
var files []*FileInfo
|
||||
var continuationToken *string
|
||||
count := 0
|
||||
skip := offset
|
||||
|
||||
// Skip objects for offset
|
||||
for i := 0; i < offset/1000; i++ {
|
||||
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
|
||||
for {
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(prefix),
|
||||
ContinuationToken: continuationToken,
|
||||
MaxKeys: aws.Int32(1000),
|
||||
})
|
||||
MaxKeys: aws.Int32(100), // Fetch in batches of 100
|
||||
}
|
||||
|
||||
result, err := s.client.ListObjectsV2(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files from S3: %w", err)
|
||||
}
|
||||
if !aws.ToBool(output.IsTruncated) {
|
||||
return files, nil
|
||||
}
|
||||
continuationToken = output.NextContinuationToken
|
||||
}
|
||||
|
||||
// Get the actual objects
|
||||
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(prefix),
|
||||
ContinuationToken: continuationToken,
|
||||
MaxKeys: aws.Int32(int32(limit)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files from S3: %w", err)
|
||||
}
|
||||
|
||||
for _, obj := range output.Contents {
|
||||
// Get the object metadata
|
||||
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: obj.Key,
|
||||
})
|
||||
|
||||
var contentType string
|
||||
var originalName string
|
||||
|
||||
if err != nil {
|
||||
var noSuchKey *types.NoSuchKey
|
||||
if errors.As(err, &noSuchKey) {
|
||||
// If the object doesn't exist (which shouldn't happen normally),
|
||||
// we'll still include it in the list but with empty metadata
|
||||
contentType = ""
|
||||
originalName = aws.ToString(obj.Key)
|
||||
} else {
|
||||
// Process each object
|
||||
for _, obj := range result.Contents {
|
||||
if skip > 0 {
|
||||
skip--
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
contentType = aws.ToString(head.ContentType)
|
||||
originalName = head.Metadata["x-amz-meta-original-name"]
|
||||
if originalName == "" {
|
||||
originalName = aws.ToString(obj.Key)
|
||||
|
||||
// Get object metadata
|
||||
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: obj.Key,
|
||||
})
|
||||
if err != nil {
|
||||
continue // Skip files we can't get metadata for
|
||||
}
|
||||
|
||||
// Extract the ID from the key (remove extension if present)
|
||||
id := aws.ToString(obj.Key)
|
||||
if ext := filepath.Ext(id); ext != "" {
|
||||
id = id[:len(id)-len(ext)]
|
||||
}
|
||||
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: head.Metadata["x-amz-meta-original-name"],
|
||||
Size: aws.ToInt64(obj.Size),
|
||||
ContentType: aws.ToString(head.ContentType),
|
||||
CreatedAt: aws.ToTime(obj.LastModified),
|
||||
UpdatedAt: aws.ToTime(obj.LastModified),
|
||||
URL: s.getObjectURL(aws.ToString(obj.Key)),
|
||||
}
|
||||
files = append(files, info)
|
||||
count++
|
||||
|
||||
if count >= limit {
|
||||
return files, nil
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, &FileInfo{
|
||||
ID: aws.ToString(obj.Key),
|
||||
Name: originalName,
|
||||
Size: aws.ToInt64(obj.Size),
|
||||
ContentType: contentType,
|
||||
CreatedAt: aws.ToTime(obj.LastModified),
|
||||
UpdatedAt: aws.ToTime(obj.LastModified),
|
||||
URL: s.getObjectURL(aws.ToString(obj.Key)),
|
||||
})
|
||||
if !aws.ToBool(result.IsTruncated) {
|
||||
break
|
||||
}
|
||||
continuationToken = result.NextContinuationToken
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
|
||||
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
var nsk *types.NoSuchKey
|
||||
if ok := errors.As(err, &nsk); ok {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
|
|
@ -1,211 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockS3Client is a mock implementation of the S3 client interface
|
||||
type MockS3Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func TestS3Storage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockClient := new(MockS3Client)
|
||||
storage := NewS3Storage(mockClient, "test-bucket", "", false)
|
||||
|
||||
t.Run("Save", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Mock HeadObject to return NotFound error
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
})
|
||||
|
||||
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.ContentType) == "text/plain"
|
||||
})).Return(&s3.PutObjectOutput{}, nil)
|
||||
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.GetObjectOutput{
|
||||
Body: io.NopCloser(bytes.NewReader(content)),
|
||||
ContentType: aws.String("text/plain"),
|
||||
ContentLength: aws.Int64(int64(len(content))),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
}, nil)
|
||||
|
||||
readCloser, info, err := storage.Get(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, "test-id", info.ID)
|
||||
assert.Equal(t, int64(len(content)), info.Size)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Prefix) == "test" &&
|
||||
aws.ToInt32(input.MaxKeys) == 10
|
||||
})).Return(&s3.ListObjectsV2Output{
|
||||
Contents: []types.Object{
|
||||
{
|
||||
Key: aws.String("test1"),
|
||||
Size: aws.Int64(100),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
{
|
||||
Key: aws.String("test2"),
|
||||
Size: aws.Int64(200),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
// Mock HeadObject for both files
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test1"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test1.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test2"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test2.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
files, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, files, 2)
|
||||
assert.Equal(t, "test1", files[0].ID)
|
||||
assert.Equal(t, int64(100), files[0].Size)
|
||||
assert.Equal(t, "test1.txt", files[0].Name)
|
||||
assert.Equal(t, "text/plain", files[0].ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.DeleteObjectOutput{}, nil)
|
||||
|
||||
err := storage.Delete(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
// Mock HeadObject for existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.HeadObjectOutput{}, nil).Once()
|
||||
|
||||
exists, err := storage.Exists(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Mock HeadObject for non-existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "non-existent"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
}).Once()
|
||||
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom URL", func(t *testing.T) {
|
||||
customStorage := &S3Storage{
|
||||
client: mockClient,
|
||||
bucket: "test-bucket",
|
||||
customURL: "https://custom.domain",
|
||||
proxyS3: true,
|
||||
}
|
||||
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
|
||||
})
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
package storage
|
||||
|
||||
//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRateLimitConfig(t *testing.T) {
|
||||
config := RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/test": {
|
||||
Rate: 50,
|
||||
Burst: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if config.IPRate != 100 {
|
||||
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
|
||||
}
|
||||
if config.IPBurst != 200 {
|
||||
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
|
||||
}
|
||||
|
||||
route := config.RouteRates["/api/test"]
|
||||
if route.Rate != 50 {
|
||||
t.Errorf("Expected route rate 50, got %d", route.Rate)
|
||||
}
|
||||
if route.Burst != 100 {
|
||||
t.Errorf("Expected route burst 100, got %d", route.Burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogConfig(t *testing.T) {
|
||||
config := AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "/var/log/app.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 5,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
}
|
||||
|
||||
if !config.EnableConsole {
|
||||
t.Error("Expected EnableConsole to be true")
|
||||
}
|
||||
if !config.EnableFile {
|
||||
t.Error("Expected EnableFile to be true")
|
||||
}
|
||||
if config.FilePath != "/var/log/app.log" {
|
||||
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
|
||||
}
|
||||
if config.Format != "json" {
|
||||
t.Errorf("Expected Format 'json', got '%s'", config.Format)
|
||||
}
|
||||
if config.Level != "info" {
|
||||
t.Errorf("Expected Level 'info', got '%s'", config.Level)
|
||||
}
|
||||
|
||||
rotation := config.Rotation
|
||||
if rotation.MaxSize != 100 {
|
||||
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
|
||||
}
|
||||
if rotation.MaxAge != 7 {
|
||||
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
|
||||
}
|
||||
if rotation.MaxBackups != 5 {
|
||||
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
|
||||
}
|
||||
if !rotation.Compress {
|
||||
t.Error("Expected Compress to be true")
|
||||
}
|
||||
if !rotation.LocalTime {
|
||||
t.Error("Expected LocalTime to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadConfig(t *testing.T) {
|
||||
config := UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
}
|
||||
|
||||
if config.MaxSize != 10 {
|
||||
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
|
||||
}
|
||||
if len(config.AllowedTypes) != 2 {
|
||||
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
|
||||
}
|
||||
if config.AllowedTypes[0] != "image/jpeg" {
|
||||
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
|
||||
}
|
||||
if len(config.AllowedExtensions) != 2 {
|
||||
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
|
||||
}
|
||||
if config.AllowedExtensions[0] != ".jpg" {
|
||||
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFileInfo(t *testing.T) {
|
||||
fileInfo := FileInfo{
|
||||
Size: 1024,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}
|
||||
|
||||
if fileInfo.Size != 1024 {
|
||||
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
|
||||
}
|
||||
if fileInfo.Name != "test.jpg" {
|
||||
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
|
||||
}
|
||||
if fileInfo.ContentType != "image/jpeg" {
|
||||
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCategory(t *testing.T) {
|
||||
description := "Test Description"
|
||||
category := Category{
|
||||
ID: 1,
|
||||
Name: "Test Category",
|
||||
Slug: "test-category",
|
||||
Description: &description,
|
||||
}
|
||||
|
||||
if category.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", category.ID)
|
||||
}
|
||||
if category.Name != "Test Category" {
|
||||
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
|
||||
}
|
||||
if category.Slug != "test-category" {
|
||||
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
|
||||
}
|
||||
if *category.Description != description {
|
||||
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPost(t *testing.T) {
|
||||
metaKeywords := "test,blog"
|
||||
metaDesc := "Test Description"
|
||||
post := Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
Slug: "test-post",
|
||||
ContentMarkdown: "# Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: &metaKeywords,
|
||||
MetaDescription: &metaDesc,
|
||||
}
|
||||
|
||||
if post.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", post.ID)
|
||||
}
|
||||
if post.Title != "Test Post" {
|
||||
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
|
||||
}
|
||||
if post.Slug != "test-post" {
|
||||
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
|
||||
}
|
||||
if *post.MetaKeywords != metaKeywords {
|
||||
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaily(t *testing.T) {
|
||||
daily := Daily{
|
||||
ID: "2025-02-12",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image.jpg",
|
||||
Quote: "Test Quote",
|
||||
}
|
||||
|
||||
if daily.ID != "2025-02-12" {
|
||||
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
|
||||
}
|
||||
if daily.CategoryID != 1 {
|
||||
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
|
||||
}
|
||||
if daily.ImageURL != "https://example.com/image.jpg" {
|
||||
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
|
||||
}
|
||||
if daily.Quote != "Test Quote" {
|
||||
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary test config file
|
||||
testConfig := `
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost:5432/db
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
jwt:
|
||||
secret: test-secret
|
||||
expiration: 24h
|
||||
logging:
|
||||
level: debug
|
||||
format: console
|
||||
`
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
// Test successful config loading
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"database.driver", cfg.Database.Driver, "postgres"},
|
||||
{"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"},
|
||||
{"server.port", cfg.Server.Port, 8080},
|
||||
{"server.host", cfg.Server.Host, "localhost"},
|
||||
{"jwt.secret", cfg.JWT.Secret, "test-secret"},
|
||||
{"jwt.expiration", cfg.JWT.Expiration, "24h"},
|
||||
{"logging.level", cfg.Logging.Level, "debug"},
|
||||
{"logging.format", cfg.Logging.Format, "console"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test loading non-existent file
|
||||
_, err = Load("non-existent.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading non-existent file, got nil")
|
||||
}
|
||||
|
||||
// Test loading invalid YAML
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||
if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create invalid config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(invalidPath)
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading invalid YAML, got nil")
|
||||
}
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package imageutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsImageFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{"JPEG", "image/jpeg", true},
|
||||
{"PNG", "image/png", true},
|
||||
{"GIF", "image/gif", true},
|
||||
{"WebP", "image/webp", true},
|
||||
{"Invalid", "image/invalid", false},
|
||||
{"Empty", "", false},
|
||||
{"Text", "text/plain", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsImageFormat(tt.contentType); got != tt.want {
|
||||
t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultOptions(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
|
||||
if !opts.Lossless {
|
||||
t.Error("DefaultOptions().Lossless = false, want true")
|
||||
}
|
||||
if opts.Quality != 90 {
|
||||
t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality)
|
||||
}
|
||||
if opts.Compression != 4 {
|
||||
t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessImage(t *testing.T) {
|
||||
// Create a test image
|
||||
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
for y := 0; y < 100; y++ {
|
||||
for x := 0; x < 100; x++ {
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("Failed to create test PNG: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts ProcessOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Default options",
|
||||
opts: DefaultOptions(),
|
||||
},
|
||||
{
|
||||
name: "Custom quality",
|
||||
opts: ProcessOptions{
|
||||
Lossless: false,
|
||||
Quality: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := bytes.NewReader(buf.Bytes())
|
||||
result, err := ProcessImage(reader, tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(result) == 0 {
|
||||
t.Error("ProcessImage() returned empty result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test with invalid input
|
||||
_, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions())
|
||||
if err == nil {
|
||||
t.Error("ProcessImage() with invalid input should return error")
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tss-rocks-be/internal/config"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *config.Config
|
||||
expectedLevel zerolog.Level
|
||||
}{
|
||||
{
|
||||
name: "Debug level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.DebugLevel,
|
||||
},
|
||||
{
|
||||
name: "Info level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "Error level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "error",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.ErrorLevel,
|
||||
},
|
||||
{
|
||||
name: "Invalid level defaults to Info",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "invalid",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.InfoLevel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Setup(tt.config)
|
||||
if zerolog.GlobalLevel() != tt.expectedLevel {
|
||||
t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogger(t *testing.T) {
|
||||
logger := GetLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetLogger() returned nil")
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"categories": {
|
||||
"man": "Man",
|
||||
"man": "Human",
|
||||
"machine": "Machine",
|
||||
"earth": "Earth",
|
||||
"space": "Space",
|
||||
|
@ -29,8 +29,6 @@
|
|||
"edit": "Edit",
|
||||
"delete": "Delete",
|
||||
"upload": "Upload",
|
||||
"save": "Save",
|
||||
"saving": "Saving...",
|
||||
"status": "Status",
|
||||
"actions": "Actions",
|
||||
"published": "Published",
|
||||
|
@ -48,12 +46,10 @@
|
|||
"joinDate": "Join Date",
|
||||
"username": "Username",
|
||||
"logout": "Logout",
|
||||
"language": "Language",
|
||||
"theme": {
|
||||
"light": "Light Mode",
|
||||
"dark": "Dark Mode",
|
||||
"system": "System"
|
||||
}
|
||||
"save": "Save",
|
||||
"saving": "Saving...",
|
||||
"noData": "No data available",
|
||||
"unsavedChanges": "You have unsaved changes. Are you sure you want to leave?"
|
||||
},
|
||||
"dashboard": {
|
||||
"totalPosts": "Total Posts",
|
||||
|
@ -61,13 +57,33 @@
|
|||
"totalUsers": "Total Users",
|
||||
"totalContributors": "Total Contributors"
|
||||
},
|
||||
"posts": {
|
||||
"title": "Title",
|
||||
"categories": "Categories",
|
||||
"createdAt": "Created At",
|
||||
"status": "Status",
|
||||
"noTitle": "No Title",
|
||||
"deleteConfirm": "Are you sure you want to delete this post?",
|
||||
"create": "Create Post",
|
||||
"edit": "Edit Post",
|
||||
"slug": "Slug",
|
||||
"content": "Content",
|
||||
"summary": "Summary",
|
||||
"metaKeywords": "Meta Keywords",
|
||||
"metaDescription": "Meta Description",
|
||||
"selectCategories": "Select Categories",
|
||||
"saving": "Saving...",
|
||||
"publishing": "Publishing...",
|
||||
"saveDraft": "Save Draft",
|
||||
"publish": "Publish"
|
||||
},
|
||||
"login": {
|
||||
"title": "Admin Login",
|
||||
"username": "Username",
|
||||
"password": "Password",
|
||||
"remember": "Remember me",
|
||||
"submit": "Sign in",
|
||||
"loading": "Signing in...",
|
||||
"submit": "Login",
|
||||
"loading": "Logging in...",
|
||||
"error": {
|
||||
"failed": "Login failed",
|
||||
"retry": "Login failed, please try again later"
|
||||
|
@ -76,7 +92,7 @@
|
|||
"nav": {
|
||||
"dashboard": "Dashboard",
|
||||
"posts": "Posts",
|
||||
"daily": "Daily Quotes",
|
||||
"daily": "Daily",
|
||||
"medias": "Media",
|
||||
"categories": "Categories",
|
||||
"users": "Users",
|
||||
|
@ -110,5 +126,45 @@
|
|||
"passwordMismatch": "Passwords do not match",
|
||||
"passwordTooShort": "Password must be at least 8 characters long"
|
||||
}
|
||||
},
|
||||
"editor": {
|
||||
"heading1": "Heading 1",
|
||||
"heading2": "Heading 2",
|
||||
"heading3": "Heading 3",
|
||||
"bold": "Bold",
|
||||
"italic": "Italic",
|
||||
"orderedList": "Ordered List",
|
||||
"unorderedList": "Unordered List",
|
||||
"quote": "Quote",
|
||||
"link": "Link",
|
||||
"image": "Image",
|
||||
"inlineCode": "Inline Code",
|
||||
"codeBlock": "Code Block",
|
||||
"table": "Table",
|
||||
"togglePreview": "Toggle Preview",
|
||||
"toggleFullscreen": "Toggle Fullscreen",
|
||||
"selectLanguage": "Select Language",
|
||||
"plainText": "Plain Text",
|
||||
"insertTable": "Insert Table",
|
||||
"addRowAbove": "Add Row Above",
|
||||
"addRowBelow": "Add Row Below",
|
||||
"addColumnLeft": "Add Column Left",
|
||||
"addColumnRight": "Add Column Right",
|
||||
"deleteRow": "Delete Row",
|
||||
"deleteColumn": "Delete Column",
|
||||
"deleteTable": "Delete Table",
|
||||
"dragAndDrop": "Drop images here",
|
||||
"dropToUpload": "Drop files here to upload",
|
||||
"uploading": "Uploading...",
|
||||
"uploadProgress": "Upload progress: {{progress}}%",
|
||||
"uploadError": "Failed to upload image: {{error}}",
|
||||
"uploadSuccess": "Image uploaded successfully",
|
||||
"codeBlockShortcut": "Ctrl+Shift+K",
|
||||
"boldShortcut": "Ctrl+B",
|
||||
"italicShortcut": "Ctrl+I",
|
||||
"linkShortcut": "Ctrl+K",
|
||||
"heading1Shortcut": "Shift+1",
|
||||
"heading2Shortcut": "Shift+2",
|
||||
"heading3Shortcut": "Shift+3"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,9 @@
|
|||
"username": "用户名",
|
||||
"logout": "退出登录",
|
||||
"save": "保存",
|
||||
"saving": "保存中..."
|
||||
"saving": "保存中...",
|
||||
"noData": "暂无数据",
|
||||
"unsavedChanges": "你有未保存的更改,确定要离开吗?"
|
||||
},
|
||||
"dashboard": {
|
||||
"totalPosts": "文章总数",
|
||||
|
@ -55,6 +57,26 @@
|
|||
"totalUsers": "用户总数",
|
||||
"totalContributors": "贡献者总数"
|
||||
},
|
||||
"posts": {
|
||||
"title": "标题",
|
||||
"categories": "分类",
|
||||
"createdAt": "创建时间",
|
||||
"status": "状态",
|
||||
"noTitle": "无标题",
|
||||
"deleteConfirm": "确定要删除这篇文章吗?",
|
||||
"create": "创建文章",
|
||||
"edit": "编辑文章",
|
||||
"slug": "文章链接",
|
||||
"content": "内容",
|
||||
"summary": "摘要",
|
||||
"metaKeywords": "关键词",
|
||||
"metaDescription": "描述",
|
||||
"selectCategories": "选择分类",
|
||||
"saving": "保存中...",
|
||||
"publishing": "发布中...",
|
||||
"saveDraft": "保存草稿",
|
||||
"publish": "发布文章"
|
||||
},
|
||||
"login": {
|
||||
"title": "管理员登录",
|
||||
"username": "用户名",
|
||||
|
@ -104,5 +126,45 @@
|
|||
"passwordMismatch": "两次输入的密码不一致",
|
||||
"passwordTooShort": "密码长度不能少于8个字符"
|
||||
}
|
||||
},
|
||||
"editor": {
|
||||
"heading1": "一级标题",
|
||||
"heading2": "二级标题",
|
||||
"heading3": "三级标题",
|
||||
"bold": "粗体",
|
||||
"italic": "斜体",
|
||||
"orderedList": "有序列表",
|
||||
"unorderedList": "无序列表",
|
||||
"quote": "引用",
|
||||
"link": "链接",
|
||||
"image": "图片",
|
||||
"inlineCode": "行内代码",
|
||||
"codeBlock": "代码块",
|
||||
"table": "表格",
|
||||
"togglePreview": "切换预览",
|
||||
"toggleFullscreen": "切换全屏",
|
||||
"selectLanguage": "选择语言",
|
||||
"plainText": "纯文本",
|
||||
"insertTable": "插入表格",
|
||||
"addRowAbove": "在上方插入行",
|
||||
"addRowBelow": "在下方插入行",
|
||||
"addColumnLeft": "在左侧插入列",
|
||||
"addColumnRight": "在右侧插入列",
|
||||
"deleteRow": "删除行",
|
||||
"deleteColumn": "删除列",
|
||||
"deleteTable": "删除表格",
|
||||
"dragAndDrop": "拖放图片到此处",
|
||||
"dropToUpload": "拖放文件到此处上传",
|
||||
"uploading": "上传中...",
|
||||
"uploadProgress": "上传进度:{{progress}}%",
|
||||
"uploadError": "图片上传失败:{{error}}",
|
||||
"uploadSuccess": "图片上传成功",
|
||||
"codeBlockShortcut": "Ctrl+Shift+K",
|
||||
"boldShortcut": "Ctrl+B",
|
||||
"italicShortcut": "Ctrl+I",
|
||||
"linkShortcut": "Ctrl+K",
|
||||
"heading1Shortcut": "Shift+1",
|
||||
"heading2Shortcut": "Shift+2",
|
||||
"heading3Shortcut": "Shift+3"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@
|
|||
"lastLogin": "最後登入",
|
||||
"joinDate": "加入時間",
|
||||
"username": "用戶名",
|
||||
"logout": "退出登錄",
|
||||
"logout": "退出登入",
|
||||
"language": "語言",
|
||||
"theme": {
|
||||
"light": "淺色模式",
|
||||
|
@ -53,7 +53,9 @@
|
|||
"system": "跟隨系統"
|
||||
},
|
||||
"save": "保存",
|
||||
"saving": "保存中..."
|
||||
"saving": "保存中...",
|
||||
"noData": "暫無數據",
|
||||
"unsavedChanges": "你有未保存的更改,確定要離開嗎?"
|
||||
},
|
||||
"nav": {
|
||||
"dashboard": "儀表板",
|
||||
|
@ -83,15 +85,15 @@
|
|||
"uploadDate": "上傳日期"
|
||||
},
|
||||
"login": {
|
||||
"title": "管理員登錄",
|
||||
"title": "管理員登入",
|
||||
"username": "用戶名",
|
||||
"password": "密碼",
|
||||
"remember": "記住我",
|
||||
"submit": "登錄",
|
||||
"loading": "登錄中...",
|
||||
"submit": "登入",
|
||||
"loading": "登入中...",
|
||||
"error": {
|
||||
"failed": "登錄失敗",
|
||||
"retry": "登錄失敗,請稍後重試"
|
||||
"failed": "登入失敗",
|
||||
"retry": "登入失敗,請稍後重試"
|
||||
}
|
||||
},
|
||||
"roles": {
|
||||
|
@ -110,5 +112,45 @@
|
|||
"passwordMismatch": "兩次輸入的密碼不一致",
|
||||
"passwordTooShort": "密碼長度不能少於8個字符"
|
||||
}
|
||||
},
|
||||
"editor": {
|
||||
"heading1": "一級標題",
|
||||
"heading2": "二級標題",
|
||||
"heading3": "三級標題",
|
||||
"bold": "粗體",
|
||||
"italic": "斜體",
|
||||
"orderedList": "有序列表",
|
||||
"unorderedList": "無序列表",
|
||||
"quote": "引用",
|
||||
"link": "連結",
|
||||
"image": "圖片",
|
||||
"inlineCode": "行內程式碼",
|
||||
"codeBlock": "程式碼區塊",
|
||||
"table": "表格",
|
||||
"togglePreview": "切換預覽",
|
||||
"toggleFullscreen": "切換全螢幕",
|
||||
"selectLanguage": "選擇語言",
|
||||
"plainText": "純文字",
|
||||
"insertTable": "插入表格",
|
||||
"addRowAbove": "在上方插入列",
|
||||
"addRowBelow": "在下方插入列",
|
||||
"addColumnLeft": "在左側插入欄",
|
||||
"addColumnRight": "在右側插入欄",
|
||||
"deleteRow": "刪除列",
|
||||
"deleteColumn": "刪除欄",
|
||||
"deleteTable": "刪除表格",
|
||||
"dragAndDrop": "拖放圖片到此處",
|
||||
"dropToUpload": "拖放檔案到此處上傳",
|
||||
"uploading": "上傳中...",
|
||||
"uploadProgress": "上傳進度:{{progress}}%",
|
||||
"uploadError": "圖片上傳失敗:{{error}}",
|
||||
"uploadSuccess": "圖片上傳成功",
|
||||
"codeBlockShortcut": "Ctrl+Shift+K",
|
||||
"boldShortcut": "Ctrl+B",
|
||||
"italicShortcut": "Ctrl+I",
|
||||
"linkShortcut": "Ctrl+K",
|
||||
"heading1Shortcut": "Shift+1",
|
||||
"heading2Shortcut": "Shift+2",
|
||||
"heading3Shortcut": "Shift+3"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,14 @@
|
|||
"dependencies": {
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@tss-rocks/api": "workspace:*",
|
||||
"@types/axios": "^0.14.4",
|
||||
"@types/classnames": "^2.3.4",
|
||||
"@types/highlight.js": "^10.1.0",
|
||||
"@types/markdown-it": "^14.1.2",
|
||||
"axios": "^1.7.9",
|
||||
"classnames": "^2.5.1",
|
||||
"dayjs": "^1.11.13",
|
||||
"highlight.js": "^11.11.1",
|
||||
"i18next": "^24.2.2",
|
||||
"i18next-browser-languagedetector": "^8.0.4",
|
||||
"lucide-react": "^0.474.0",
|
||||
|
@ -21,7 +28,13 @@
|
|||
"react-dom": "^19.0.0",
|
||||
"react-i18next": "^15.4.0",
|
||||
"react-icons": "^5.4.0",
|
||||
"react-router-dom": "^7.1.5"
|
||||
"react-markdown": "^10.0.0",
|
||||
"react-router-dom": "^7.1.5",
|
||||
"react-toastify": "^11.0.3",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-raw": "^7.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.9.1",
|
||||
|
|
|
@ -4,16 +4,21 @@ import { AuthProvider } from './contexts/AuthContext';
|
|||
import { UserProvider } from './contexts/UserContext';
|
||||
import router from './router';
|
||||
import LoadingSpinner from './components/LoadingSpinner';
|
||||
import { ToastContainer } from 'react-toastify';
|
||||
import 'react-toastify/dist/ReactToastify.css';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<AuthProvider>
|
||||
<UserProvider>
|
||||
<RouterProvider router={router} />
|
||||
</UserProvider>
|
||||
</AuthProvider>
|
||||
</Suspense>
|
||||
<>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<AuthProvider>
|
||||
<UserProvider>
|
||||
<RouterProvider router={router} />
|
||||
</UserProvider>
|
||||
</AuthProvider>
|
||||
</Suspense>
|
||||
<ToastContainer />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ export default function LoadingSpinner({ fullScreen = false }: LoadingSpinnerPro
|
|||
|
||||
const content = (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<div className="w-10 h-10 border-4 border-indigo-200 dark:border-indigo-900 border-t-indigo-500 dark:border-t-indigo-400 rounded-full animate-spin" />
|
||||
<div className="text-slate-600 dark:text-slate-300 text-sm font-medium">
|
||||
<div className="w-10 h-10 border-4 border-gray-200 dark:border-gray-700 border-t-gray-900 dark:border-t-gray-200 rounded-full animate-spin" />
|
||||
<div className="text-gray-900 dark:text-gray-200 text-sm font-medium">
|
||||
{t('admin.common.loading')}
|
||||
</div>
|
||||
</div>
|
||||
|
|
958
frontend/src/components/admin/MarkdownEditor.tsx
Normal file
958
frontend/src/components/admin/MarkdownEditor.tsx
Normal file
|
@ -0,0 +1,958 @@
|
|||
import { FC, useState, useRef, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import rehypeRaw from 'rehype-raw';
|
||||
import rehypeSanitize from 'rehype-sanitize';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import 'highlight.js/styles/github-dark.css';
|
||||
import classNames from 'classnames';
|
||||
import axios from 'axios';
|
||||
import { useToast } from '@/hooks/useToast';
|
||||
import {
|
||||
RiH1,
|
||||
RiH2,
|
||||
RiH3,
|
||||
RiBold,
|
||||
RiItalic,
|
||||
RiListOrdered,
|
||||
RiListUnordered,
|
||||
RiLink,
|
||||
RiImage2Line,
|
||||
RiCodeLine,
|
||||
RiCodeBoxLine,
|
||||
RiTableLine,
|
||||
RiDoubleQuotesL,
|
||||
RiEyeLine,
|
||||
RiEyeOffLine,
|
||||
RiFullscreenLine,
|
||||
} from 'react-icons/ri';
|
||||
|
||||
interface MarkdownEditorProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
type CodeBlockProps = {
|
||||
inline?: boolean;
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
} & React.HTMLAttributes<HTMLElement>;
|
||||
|
||||
const MarkdownEditor: FC<MarkdownEditorProps> = ({
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const { showToast } = useToast();
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const tableMenuRef = useRef<HTMLDivElement>(null);
|
||||
const langSelectorRef = useRef<HTMLDivElement>(null);
|
||||
const [showTableMenu, setShowTableMenu] = useState(false);
|
||||
const [showLangSelector, setShowLangSelector] = useState(false);
|
||||
const [isPreview, setIsPreview] = useState(false);
|
||||
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||
const [isDraggingOver, setIsDraggingOver] = useState(false);
|
||||
|
||||
// 检查当前行是否为空
|
||||
const isCurrentLineEmpty = (textarea: HTMLTextAreaElement): boolean => {
|
||||
const text = textarea.value;
|
||||
const lines = text.split('\n');
|
||||
const currentLine = getCurrentLine(textarea);
|
||||
return !lines[currentLine].trim();
|
||||
};
|
||||
|
||||
// 获取当前行号
|
||||
const getCurrentLine = (textarea: HTMLTextAreaElement): number => {
|
||||
const text = textarea.value.substring(0, textarea.selectionStart);
|
||||
return text.split('\n').length - 1;
|
||||
};
|
||||
|
||||
// 在当前位置插入文本
|
||||
const insertText = (textarea: HTMLTextAreaElement, text: string) => {
|
||||
const start = textarea.selectionStart;
|
||||
const end = textarea.selectionEnd;
|
||||
const before = textarea.value.substring(0, start);
|
||||
const after = textarea.value.substring(end);
|
||||
|
||||
const needNewLine = !isCurrentLineEmpty(textarea) &&
|
||||
(text.startsWith('#') || text.startsWith('>'));
|
||||
|
||||
const newText = needNewLine ? `\n${text}` : text;
|
||||
const newValue = before + newText + after;
|
||||
|
||||
onChange(newValue);
|
||||
|
||||
textarea.value = newValue;
|
||||
const newCursorPos = start + newText.length;
|
||||
textarea.selectionStart = newCursorPos;
|
||||
textarea.selectionEnd = newCursorPos;
|
||||
textarea.focus();
|
||||
};
|
||||
|
||||
const insertTable = (rows: number, cols: number) => {
|
||||
if (!textareaRef.current) return;
|
||||
|
||||
const headers = Array(cols).fill('header').join(' | ');
|
||||
const separators = Array(cols).fill('---').join(' | ');
|
||||
const cells = Array(cols).fill('content').join(' | ');
|
||||
const rows_content = Array(rows).fill(cells).join('\n| ');
|
||||
|
||||
insertText(
|
||||
textareaRef.current,
|
||||
`\n| ${headers} |\n| ${separators} |\n| ${rows_content} |\n\n`
|
||||
);
|
||||
setShowTableMenu(false);
|
||||
};
|
||||
|
||||
const findTableAtCursor = (): { table: string; start: number; end: number; rows: string[][]; alignments: string[] } | null => {
|
||||
if (!textareaRef.current) return null;
|
||||
|
||||
const text = textareaRef.current.value;
|
||||
const cursorPos = textareaRef.current.selectionStart;
|
||||
|
||||
const tableRegex = /\|[^\n]+\|[\s]*\n\|[- |:]+\|[\s]*\n(\|[^\n]+\|[\s]*\n?)+/g;
|
||||
let match;
|
||||
while ((match = tableRegex.exec(text)) !== null) {
|
||||
const start = match.index;
|
||||
const end = start + match[0].length;
|
||||
if (cursorPos >= start && cursorPos <= end) {
|
||||
const rows = match[0]
|
||||
.trim()
|
||||
.split('\n')
|
||||
.map(row =>
|
||||
row
|
||||
.trim()
|
||||
.replace(/^\||\|$/g, '')
|
||||
.split('|')
|
||||
.map(cell => cell.trim())
|
||||
);
|
||||
|
||||
const alignments = rows[1].map(cell => {
|
||||
if (cell.startsWith(':') && cell.endsWith(':')) return 'center';
|
||||
if (cell.endsWith(':')) return 'right';
|
||||
return 'left';
|
||||
});
|
||||
|
||||
return {
|
||||
table: match[0],
|
||||
start,
|
||||
end,
|
||||
rows: [rows[0], ...rows.slice(2)],
|
||||
alignments,
|
||||
};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const generateTableText = (rows: string[][], alignments: string[]): string => {
|
||||
const colWidths = rows[0].map((_, colIndex) =>
|
||||
Math.max(...rows.map(row => row[colIndex]?.length || 0))
|
||||
);
|
||||
|
||||
const separator = alignments.map((align, i) => {
|
||||
const width = Math.max(3, colWidths[i]);
|
||||
switch (align) {
|
||||
case 'center':
|
||||
return ':' + '-'.repeat(width - 2) + ':';
|
||||
case 'right':
|
||||
return '-'.repeat(width - 1) + ':';
|
||||
default:
|
||||
return '-'.repeat(width);
|
||||
}
|
||||
});
|
||||
|
||||
const tableRows = [
|
||||
rows[0],
|
||||
separator,
|
||||
...rows.slice(1)
|
||||
].map(row =>
|
||||
'| ' + row.map((cell, i) => cell.padEnd(colWidths[i])).join(' | ') + ' |'
|
||||
);
|
||||
|
||||
return tableRows.join('\n') + '\n';
|
||||
};
|
||||
|
||||
const updateTable = (pos: { table: string; start: number; end: number; rows: string[][]; alignments: string[] }, newRows?: string[][]) => {
|
||||
if (!textareaRef.current) return;
|
||||
|
||||
const text = textareaRef.current.value;
|
||||
const newTable = generateTableText(newRows || pos.rows, pos.alignments);
|
||||
const newText = text.substring(0, pos.start) + newTable + text.substring(pos.end);
|
||||
onChange(newText);
|
||||
};
|
||||
|
||||
const tableActions = [
|
||||
{
|
||||
label: t('editor.insertTable'),
|
||||
action: () => insertTable(2, 2),
|
||||
},
|
||||
{
|
||||
label: t('editor.addRowAbove'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
let rowIndex = 0;
|
||||
let currentPos = tablePos.start;
|
||||
|
||||
for (let i = 0; i < tablePos.rows.length; i++) {
|
||||
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
|
||||
if (cursorPos <= currentPos + rowText.length) {
|
||||
rowIndex = i;
|
||||
break;
|
||||
}
|
||||
currentPos += rowText.length;
|
||||
}
|
||||
|
||||
const newRows = [...tablePos.rows];
|
||||
newRows.splice(rowIndex, 0, Array(tablePos.rows[0].length).fill(''));
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.addRowBelow'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
let rowIndex = 0;
|
||||
let currentPos = tablePos.start;
|
||||
|
||||
for (let i = 0; i < tablePos.rows.length; i++) {
|
||||
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
|
||||
if (cursorPos <= currentPos + rowText.length) {
|
||||
rowIndex = i;
|
||||
break;
|
||||
}
|
||||
currentPos += rowText.length;
|
||||
}
|
||||
|
||||
const newRows = [...tablePos.rows];
|
||||
newRows.splice(rowIndex + 1, 0, Array(tablePos.rows[0].length).fill(''));
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.addColumnLeft'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
const line = textareaRef.current?.value
|
||||
.substring(tablePos.start, cursorPos)
|
||||
.split('\n')
|
||||
.pop() || '';
|
||||
const colIndex = (line.match(/\|/g) || []).length - 1;
|
||||
|
||||
const newRows = tablePos.rows.map(row => {
|
||||
const newRow = [...row];
|
||||
newRow.splice(colIndex, 0, '');
|
||||
return newRow;
|
||||
});
|
||||
|
||||
const newAlignments = [...tablePos.alignments];
|
||||
newAlignments.splice(colIndex, 0, 'left');
|
||||
|
||||
tablePos.alignments = newAlignments;
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.addColumnRight'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
const line = textareaRef.current?.value
|
||||
.substring(tablePos.start, cursorPos)
|
||||
.split('\n')
|
||||
.pop() || '';
|
||||
const colIndex = (line.match(/\|/g) || []).length - 1;
|
||||
|
||||
const newRows = tablePos.rows.map(row => {
|
||||
const newRow = [...row];
|
||||
newRow.splice(colIndex + 1, 0, '');
|
||||
return newRow;
|
||||
});
|
||||
|
||||
const newAlignments = [...tablePos.alignments];
|
||||
newAlignments.splice(colIndex + 1, 0, 'left');
|
||||
|
||||
tablePos.alignments = newAlignments;
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.deleteRow'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
let rowIndex = 0;
|
||||
let currentPos = tablePos.start;
|
||||
|
||||
for (let i = 0; i < tablePos.rows.length; i++) {
|
||||
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
|
||||
if (cursorPos <= currentPos + rowText.length) {
|
||||
rowIndex = i;
|
||||
break;
|
||||
}
|
||||
currentPos += rowText.length;
|
||||
}
|
||||
|
||||
const newRows = [...tablePos.rows];
|
||||
newRows.splice(rowIndex, 1);
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.deleteColumn'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const cursorPos = textareaRef.current?.selectionStart || 0;
|
||||
const line = textareaRef.current?.value
|
||||
.substring(tablePos.start, cursorPos)
|
||||
.split('\n')
|
||||
.pop() || '';
|
||||
const colIndex = (line.match(/\|/g) || []).length - 1;
|
||||
|
||||
const newRows = tablePos.rows.map(row => {
|
||||
const newRow = [...row];
|
||||
newRow.splice(colIndex, 1);
|
||||
return newRow;
|
||||
});
|
||||
|
||||
const newAlignments = [...tablePos.alignments];
|
||||
newAlignments.splice(colIndex, 1);
|
||||
|
||||
tablePos.alignments = newAlignments;
|
||||
updateTable(tablePos, newRows);
|
||||
},
|
||||
},
|
||||
{
|
||||
label: t('editor.deleteTable'),
|
||||
action: () => {
|
||||
const tablePos = findTableAtCursor();
|
||||
if (!tablePos) return;
|
||||
|
||||
const text = textareaRef.current?.value || '';
|
||||
const newText = text.substring(0, tablePos.start) + text.substring(tablePos.end);
|
||||
onChange(newText);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const toolbarButtons = [
|
||||
{
|
||||
icon: RiH1,
|
||||
label: t('editor.heading1'),
|
||||
shortcut: t('editor.heading1Shortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '# ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiH2,
|
||||
label: t('editor.heading2'),
|
||||
shortcut: t('editor.heading2Shortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '## ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiH3,
|
||||
label: t('editor.heading3'),
|
||||
shortcut: t('editor.heading3Shortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '### ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiBold,
|
||||
label: t('editor.bold'),
|
||||
shortcut: t('editor.boldShortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
const text = textarea.value.substring(
|
||||
textarea.selectionStart,
|
||||
textarea.selectionEnd
|
||||
);
|
||||
insertText(textarea, `**${text || t('editor.bold')}**`);
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiItalic,
|
||||
label: t('editor.italic'),
|
||||
shortcut: t('editor.italicShortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
const text = textarea.value.substring(
|
||||
textarea.selectionStart,
|
||||
textarea.selectionEnd
|
||||
);
|
||||
insertText(textarea, `_${text || t('editor.italic')}_`);
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiListUnordered,
|
||||
label: t('editor.unorderedList'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '- ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiListOrdered,
|
||||
label: t('editor.orderedList'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '1. ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiDoubleQuotesL,
|
||||
label: t('editor.quote'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
insertText(textarea, '> ');
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiLink,
|
||||
label: t('editor.link'),
|
||||
shortcut: t('editor.linkShortcut'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
const text = textarea.value.substring(
|
||||
textarea.selectionStart,
|
||||
textarea.selectionEnd
|
||||
);
|
||||
insertText(textarea, `[${text || t('editor.link')}](url)`);
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiImage2Line,
|
||||
label: t('editor.image'),
|
||||
action: () => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = 'image/*';
|
||||
input.onchange = (e) => {
|
||||
const file = (e.target as HTMLInputElement).files?.[0];
|
||||
if (file) {
|
||||
handleFileUpload(file);
|
||||
}
|
||||
};
|
||||
input.click();
|
||||
},
|
||||
},
|
||||
{
|
||||
icon: RiCodeLine,
|
||||
label: t('editor.inlineCode'),
|
||||
action: (textarea: HTMLTextAreaElement) => {
|
||||
const text = textarea.value.substring(
|
||||
textarea.selectionStart,
|
||||
textarea.selectionEnd
|
||||
);
|
||||
insertText(textarea, `\`${text || t('editor.inlineCode')}\``);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const commonLanguages = [
|
||||
{ value: 'javascript', label: 'JavaScript' },
|
||||
{ value: 'typescript', label: 'TypeScript' },
|
||||
{ value: 'jsx', label: 'JSX' },
|
||||
{ value: 'tsx', label: 'TSX' },
|
||||
{ value: 'css', label: 'CSS' },
|
||||
{ value: 'html', label: 'HTML' },
|
||||
{ value: 'json', label: 'JSON' },
|
||||
{ value: 'markdown', label: 'Markdown' },
|
||||
{ value: 'python', label: 'Python' },
|
||||
{ value: 'java', label: 'Java' },
|
||||
{ value: 'c', label: 'C' },
|
||||
{ value: 'cpp', label: 'C++' },
|
||||
{ value: 'csharp', label: 'C#' },
|
||||
{ value: 'go', label: 'Go' },
|
||||
{ value: 'rust', label: 'Rust' },
|
||||
{ value: 'php', label: 'PHP' },
|
||||
{ value: 'ruby', label: 'Ruby' },
|
||||
{ value: 'swift', label: 'Swift' },
|
||||
{ value: 'kotlin', label: 'Kotlin' },
|
||||
{ value: 'sql', label: 'SQL' },
|
||||
{ value: 'shell', label: 'Shell' },
|
||||
{ value: 'yaml', label: 'YAML' },
|
||||
{ value: 'xml', label: 'XML' },
|
||||
];
|
||||
|
||||
const insertCodeBlock = (language?: string) => {
|
||||
if (!textareaRef.current) return;
|
||||
const lang = language || '';
|
||||
insertText(textareaRef.current, `\n\`\`\`${lang}\n\n\`\`\`\n`);
|
||||
const cursorPos = textareaRef.current.selectionStart - 4;
|
||||
textareaRef.current.setSelectionRange(cursorPos, cursorPos);
|
||||
setShowLangSelector(false);
|
||||
};
|
||||
|
||||
// 处理文件上传
|
||||
const handleFileUpload = async (file: File) => {
|
||||
if (!file.type.startsWith('image/')) {
|
||||
showToast('error', t('editor.uploadError', { error: 'Not an image file' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/v1/media', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'Authorization': `Bearer ${localStorage.getItem('token')}`,
|
||||
},
|
||||
withCredentials: true,
|
||||
});
|
||||
|
||||
const imageUrl = response.data.data.url;
|
||||
const imageMarkdown = ``;
|
||||
|
||||
if (textareaRef.current) {
|
||||
const start = textareaRef.current.selectionStart;
|
||||
const end = textareaRef.current.selectionEnd;
|
||||
const before = value.substring(0, start);
|
||||
const after = value.substring(end);
|
||||
const newValue = before + imageMarkdown + after;
|
||||
onChange(newValue);
|
||||
showToast('success', t('editor.uploadSuccess'));
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Upload error:', error);
|
||||
const errorMessage = error.response?.data?.error?.message || error.message;
|
||||
showToast('error', t('editor.uploadError', { error: errorMessage }));
|
||||
}
|
||||
};
|
||||
|
||||
// 处理拖放事件
|
||||
const handleDragOver = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDraggingOver(true);
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDraggingOver(false);
|
||||
};
|
||||
|
||||
const handleDrop = async (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDraggingOver(false);
|
||||
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
for (const file of files) {
|
||||
await handleFileUpload(file);
|
||||
}
|
||||
};
|
||||
|
||||
// 处理粘贴事件
|
||||
const handlePaste = async (e: React.ClipboardEvent) => {
|
||||
const items = Array.from(e.clipboardData.items);
|
||||
for (const item of items) {
|
||||
if (item.type.startsWith('image/')) {
|
||||
e.preventDefault();
|
||||
const file = item.getAsFile();
|
||||
if (file) {
|
||||
await handleFileUpload(file);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
const key = e.key.toLowerCase();
|
||||
switch (key) {
|
||||
case 'b': // 粗体
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `**${text}**`);
|
||||
}
|
||||
break;
|
||||
case 'i': // 斜体
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `_${text}_`);
|
||||
}
|
||||
break;
|
||||
case 'k': // 链接
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `[${text}](url)`);
|
||||
}
|
||||
break;
|
||||
case '1': // 标题 1
|
||||
if (e.shiftKey) {
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `# ${text}`);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case '2': // 标题 2
|
||||
if (e.shiftKey) {
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `## ${text}`);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case '3': // 标题 3
|
||||
if (e.shiftKey) {
|
||||
e.preventDefault();
|
||||
if (textareaRef.current) {
|
||||
const text = textareaRef.current.value.substring(
|
||||
textareaRef.current.selectionStart,
|
||||
textareaRef.current.selectionEnd
|
||||
);
|
||||
insertText(textareaRef.current, `### ${text}`);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'e': // 预览
|
||||
e.preventDefault();
|
||||
setIsPreview(!isPreview);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
tableMenuRef.current &&
|
||||
!tableMenuRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setShowTableMenu(false);
|
||||
}
|
||||
if (
|
||||
langSelectorRef.current &&
|
||||
!langSelectorRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setShowLangSelector(false);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames(
|
||||
'flex flex-col border border-slate-300 dark:border-slate-600 rounded-lg',
|
||||
{ 'fixed inset-0 z-50 bg-white dark:bg-slate-900': isFullscreen }
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-1 p-2 border-b border-slate-300 dark:border-slate-600 bg-slate-50 dark:bg-slate-800">
|
||||
{/* 左侧编辑工具栏 */}
|
||||
<div className="flex-1 flex items-center gap-1 border-r border-slate-300 dark:border-slate-600 pr-2">
|
||||
{toolbarButtons.map((button, index) => {
|
||||
return (
|
||||
<button
|
||||
key={index}
|
||||
onClick={() => textareaRef.current && button.action(textareaRef.current)}
|
||||
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
|
||||
title={`${button.label}${button.shortcut ? ` (${button.shortcut})` : ''}`}
|
||||
>
|
||||
<button.icon className="text-lg" />
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* 表格按钮 */}
|
||||
<div className="relative">
|
||||
<button
|
||||
onClick={() => setShowTableMenu(!showTableMenu)}
|
||||
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
|
||||
title={t('editor.table')}
|
||||
>
|
||||
<RiTableLine className="text-lg" />
|
||||
</button>
|
||||
{showTableMenu && (
|
||||
<div
|
||||
ref={tableMenuRef}
|
||||
className="absolute bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-lg shadow-lg z-50"
|
||||
style={{
|
||||
width: '200px',
|
||||
top: '100%',
|
||||
left: '50%',
|
||||
transform: 'translateX(-50%)',
|
||||
marginTop: '0.5rem',
|
||||
}}
|
||||
>
|
||||
{tableActions.map((action, index) => (
|
||||
<button
|
||||
key={index}
|
||||
onClick={() => {
|
||||
action.action();
|
||||
setShowTableMenu(false);
|
||||
}}
|
||||
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
|
||||
>
|
||||
{t(action.label)}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 代码块按钮 */}
|
||||
<div className="relative">
|
||||
<button
|
||||
onClick={() => setShowLangSelector(!showLangSelector)}
|
||||
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
|
||||
title={`${t('editor.codeBlock')} (${t('editor.codeBlockShortcut')})`}
|
||||
>
|
||||
<RiCodeBoxLine className="text-lg" />
|
||||
</button>
|
||||
{showLangSelector && (
|
||||
<div
|
||||
ref={langSelectorRef}
|
||||
className="absolute bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-lg shadow-lg z-50"
|
||||
style={{
|
||||
width: '200px',
|
||||
top: '100%',
|
||||
left: '50%',
|
||||
transform: 'translateX(-50%)',
|
||||
marginTop: '0.5rem',
|
||||
}}
|
||||
>
|
||||
<div className="px-4 py-1 text-sm font-medium text-slate-900 dark:text-white border-b border-slate-200 dark:border-slate-700">
|
||||
{t('editor.selectLanguage')}
|
||||
</div>
|
||||
<div className="max-h-64 overflow-y-auto">
|
||||
<button
|
||||
onClick={() => {
|
||||
insertCodeBlock();
|
||||
setShowLangSelector(false);
|
||||
}}
|
||||
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
|
||||
>
|
||||
{t('editor.plainText')}
|
||||
</button>
|
||||
{commonLanguages.map(({ value, label }) => (
|
||||
<button
|
||||
key={value}
|
||||
onClick={() => {
|
||||
insertCodeBlock(value);
|
||||
setShowLangSelector(false);
|
||||
}}
|
||||
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 右侧预览和全屏按钮 */}
|
||||
<div className="flex items-center gap-1 pl-2">
|
||||
<button
|
||||
onClick={() => setIsPreview(!isPreview)}
|
||||
className={classNames(
|
||||
'p-1.5 rounded',
|
||||
isPreview
|
||||
? 'text-slate-900 dark:text-white bg-slate-200 dark:bg-slate-700'
|
||||
: 'text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700'
|
||||
)}
|
||||
title={t('editor.togglePreview')}
|
||||
>
|
||||
{isPreview ? (
|
||||
<RiEyeOffLine className="text-lg" />
|
||||
) : (
|
||||
<RiEyeLine className="text-lg" />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setIsFullscreen(!isFullscreen)}
|
||||
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
|
||||
title={t('editor.toggleFullscreen')}
|
||||
>
|
||||
<RiFullscreenLine className="text-lg" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 flex">
|
||||
<div
|
||||
className={classNames(
|
||||
'flex-1 relative',
|
||||
{ hidden: isPreview },
|
||||
{ 'before:absolute before:inset-0 before:bg-slate-900/10 before:z-10 before:pointer-events-none': isDraggingOver }
|
||||
)}
|
||||
>
|
||||
{isDraggingOver && (
|
||||
<div className="absolute inset-0 flex items-center justify-center z-20 pointer-events-none">
|
||||
<div className="bg-white dark:bg-slate-800 rounded-lg shadow-lg px-4 py-2">
|
||||
{t('editor.dragAndDrop')}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
className="w-full h-full p-4 bg-white dark:bg-slate-900 text-slate-900 dark:text-white resize-vertical focus:outline-none"
|
||||
style={{ minHeight: '200px' }}
|
||||
placeholder={placeholder}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDrop}
|
||||
onPaste={handlePaste}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
</div>
|
||||
{isPreview && (
|
||||
<div className="flex-1 overflow-auto">
|
||||
<div className="prose dark:prose-invert max-w-none p-4">
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[rehypeRaw, rehypeSanitize, rehypeHighlight]}
|
||||
components={{
|
||||
code: ({ inline, className, children, ...props }) => {
|
||||
const match = /language-(\w+)/.exec(className || '');
|
||||
const lang = match ? match[1] : '';
|
||||
if (inline) {
|
||||
return (
|
||||
<code className="bg-slate-100 dark:bg-slate-800 text-slate-900 dark:text-slate-100 px-1 py-0.5 rounded" {...props}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="not-prose">
|
||||
<pre className={classNames(
|
||||
'bg-slate-100 dark:bg-slate-800 text-slate-900 dark:text-slate-100 p-4 rounded-lg overflow-x-auto',
|
||||
lang && `language-${lang}`
|
||||
)}>
|
||||
<code className={lang ? `language-${lang}` : ''} {...props}>
|
||||
{children}
|
||||
</code>
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
pre: ({ children, ...props }) => children,
|
||||
p: ({ children, ...props }) => (
|
||||
<p className="text-slate-900 dark:text-slate-100 mb-4" {...props}>{children}</p>
|
||||
),
|
||||
h1: ({ children, ...props }) => (
|
||||
<h1 className="text-slate-900 dark:text-slate-100 text-4xl font-bold mt-6 mb-4" {...props}>{children}</h1>
|
||||
),
|
||||
h2: ({ children, ...props }) => (
|
||||
<h2 className="text-slate-900 dark:text-slate-100 text-3xl font-bold mt-5 mb-3" {...props}>{children}</h2>
|
||||
),
|
||||
h3: ({ children, ...props }) => (
|
||||
<h3 className="text-slate-900 dark:text-slate-100 text-2xl font-bold mt-4 mb-3" {...props}>{children}</h3>
|
||||
),
|
||||
h4: ({ children, ...props }) => (
|
||||
<h4 className="text-slate-900 dark:text-slate-100 text-xl font-bold mt-4 mb-2" {...props}>{children}</h4>
|
||||
),
|
||||
h5: ({ children, ...props }) => (
|
||||
<h5 className="text-slate-900 dark:text-slate-100 text-lg font-bold mt-3 mb-2" {...props}>{children}</h5>
|
||||
),
|
||||
h6: ({ children, ...props }) => (
|
||||
<h6 className="text-slate-900 dark:text-slate-100 text-base font-bold mt-3 mb-2" {...props}>{children}</h6>
|
||||
),
|
||||
a: ({ href, children, ...props }) => (
|
||||
<a href={href} className="text-blue-600 dark:text-blue-400 hover:underline" {...props}>{children}</a>
|
||||
),
|
||||
ul: ({ children, ...props }) => (
|
||||
<ul className="text-slate-900 dark:text-slate-100 list-disc pl-5 mb-4 space-y-1" {...props}>{children}</ul>
|
||||
),
|
||||
ol: ({ children, ...props }) => (
|
||||
<ol className="text-slate-900 dark:text-slate-100 list-decimal pl-5 mb-4 space-y-1" {...props}>{children}</ol>
|
||||
),
|
||||
li: ({ children, ...props }) => (
|
||||
<li className="text-slate-900 dark:text-slate-100" {...props}>{children}</li>
|
||||
),
|
||||
blockquote: ({ children, ...props }) => (
|
||||
<blockquote className="border-l-4 border-slate-300 dark:border-slate-600 pl-4 italic text-slate-700 dark:text-slate-300 my-4" {...props}>{children}</blockquote>
|
||||
),
|
||||
table: ({ children, ...props }) => (
|
||||
<div className="overflow-x-auto mb-4">
|
||||
<table className="min-w-full divide-y divide-slate-300 dark:divide-slate-600" {...props}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
thead: ({ children, ...props }) => (
|
||||
<thead className="bg-slate-50 dark:bg-slate-800" {...props}>{children}</thead>
|
||||
),
|
||||
tbody: ({ children, ...props }) => (
|
||||
<tbody className="divide-y divide-slate-200 dark:divide-slate-700" {...props}>{children}</tbody>
|
||||
),
|
||||
tr: ({ children, ...props }) => (
|
||||
<tr {...props}>{children}</tr>
|
||||
),
|
||||
th: ({ children, ...props }) => (
|
||||
<th className="px-3 py-2 text-left text-sm font-semibold text-slate-900 dark:text-slate-100" {...props}>{children}</th>
|
||||
),
|
||||
td: ({ children, ...props }) => (
|
||||
<td className="px-3 py-2 text-sm text-slate-900 dark:text-slate-100" {...props}>{children}</td>
|
||||
),
|
||||
img: ({ src, alt, ...props }) => (
|
||||
<img src={src} alt={alt} className="max-w-full h-auto rounded-lg my-4" {...props} />
|
||||
),
|
||||
hr: (props) => (
|
||||
<hr className="border-t border-slate-300 dark:border-slate-600 my-8" {...props} />
|
||||
),
|
||||
}}
|
||||
>
|
||||
{value}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="p-2 text-xs text-slate-500 dark:text-slate-400 border-t border-slate-300 dark:border-slate-600">
|
||||
{t('editor.dropToUpload')}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default MarkdownEditor;
|
26
frontend/src/contexts/PageTitleContext.tsx
Normal file
26
frontend/src/contexts/PageTitleContext.tsx
Normal file
|
@ -0,0 +1,26 @@
|
|||
import { createContext, useContext, FC, ReactNode, useState } from 'react';
|
||||
|
||||
interface PageTitleContextType {
|
||||
title: string;
|
||||
setTitle: (title: string) => void;
|
||||
}
|
||||
|
||||
const PageTitleContext = createContext<PageTitleContextType | undefined>(undefined);
|
||||
|
||||
export const PageTitleProvider: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
const [title, setTitle] = useState('');
|
||||
|
||||
return (
|
||||
<PageTitleContext.Provider value={{ title, setTitle }}>
|
||||
{children}
|
||||
</PageTitleContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const usePageTitle = () => {
|
||||
const context = useContext(PageTitleContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('usePageTitle must be used within a PageTitleProvider');
|
||||
}
|
||||
return context;
|
||||
};
|
30
frontend/src/hooks/useToast.ts
Normal file
30
frontend/src/hooks/useToast.ts
Normal file
|
@ -0,0 +1,30 @@
|
|||
import { useCallback } from 'react';
|
||||
import { toast, ToastOptions } from 'react-toastify';
|
||||
|
||||
type ToastType = 'success' | 'error' | 'info' | 'warning';
|
||||
|
||||
const darkModeToastStyle: ToastOptions = {
|
||||
theme: 'dark',
|
||||
style: {
|
||||
background: '#1e293b', // slate-800
|
||||
color: '#f1f5f9', // slate-100
|
||||
},
|
||||
};
|
||||
|
||||
export const useToast = () => {
|
||||
const isDarkMode = document.documentElement.classList.contains('dark');
|
||||
|
||||
const showToast = useCallback((type: ToastType, message: string) => {
|
||||
toast[type](message, {
|
||||
position: 'bottom-right',
|
||||
autoClose: 3000,
|
||||
hideProgressBar: false,
|
||||
closeOnClick: true,
|
||||
pauseOnHover: true,
|
||||
draggable: true,
|
||||
...(isDarkMode ? darkModeToastStyle : {}),
|
||||
});
|
||||
}, [isDarkMode]);
|
||||
|
||||
return { showToast };
|
||||
};
|
|
@ -19,6 +19,7 @@ import { useTheme } from '../../../hooks/useTheme';
|
|||
import { Suspense } from 'react';
|
||||
import LoadingSpinner from '../../../components/LoadingSpinner';
|
||||
import { useUser } from '../../../contexts/UserContext';
|
||||
import { PageTitleProvider, usePageTitle } from '../../../contexts/PageTitleContext';
|
||||
|
||||
interface AdminLayoutProps {}
|
||||
|
||||
|
@ -56,22 +57,58 @@ const languageMap: LanguageMap = {
|
|||
'zh-Hant': 'en'
|
||||
};
|
||||
|
||||
const AdminLayout: FC<AdminLayoutProps> = () => {
|
||||
const AdminLayoutContent: FC = () => {
|
||||
const { t, i18n } = useTranslation();
|
||||
const location = useLocation();
|
||||
const navigate = useNavigate();
|
||||
const { theme, setTheme } = useTheme();
|
||||
const { user, loading, error } = useUser();
|
||||
const { user, loading, error, fetchUser } = useUser();
|
||||
const { title } = usePageTitle();
|
||||
|
||||
useEffect(() => {
|
||||
console.log('AdminLayout user:', user);
|
||||
console.log('AdminLayout loading:', loading);
|
||||
console.log('AdminLayout error:', error);
|
||||
}, [user, loading, error]);
|
||||
// 如果没有 token,重定向到登录页
|
||||
if (!localStorage.getItem('token')) {
|
||||
navigate('/admin/login');
|
||||
return;
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
localStorage.removeItem('token');
|
||||
navigate('/admin/login');
|
||||
// 如果没有用户信息且没有在加载中,尝试获取用户信息
|
||||
if (!user && !loading && !error) {
|
||||
fetchUser();
|
||||
}
|
||||
|
||||
// 如果获取用户信息出错,可能是 token 过期,重定向到登录页
|
||||
if (error) {
|
||||
localStorage.removeItem('token');
|
||||
navigate('/admin/login');
|
||||
}
|
||||
}, [user, loading, error, navigate, fetchUser]);
|
||||
|
||||
const handleLogout = async () => {
|
||||
try {
|
||||
// 调用后端登出接口
|
||||
const token = localStorage.getItem('token');
|
||||
if (token) {
|
||||
await fetch('/api/v1/auth/logout', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Logout error:', err);
|
||||
} finally {
|
||||
// 清除所有认证相关的存储数据
|
||||
localStorage.removeItem('token');
|
||||
localStorage.removeItem('username');
|
||||
|
||||
// 重置用户状态
|
||||
await fetchUser();
|
||||
|
||||
// 重定向到登录页
|
||||
navigate('/admin/login');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -161,7 +198,7 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
|
|||
<div className="h-16 px-8 flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-2xl font-bold text-slate-800 dark:text-white">
|
||||
{t(menuItems.find(item => item.path === location.pathname)?.label || 'admin.nav.dashboard')}
|
||||
{title || t(menuItems.find(item => item.path === location.pathname)?.label || 'admin.nav.dashboard')}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
|
@ -178,7 +215,7 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
|
|||
</header>
|
||||
<div className="flex-1 p-6">
|
||||
<div className="h-full bg-white dark:bg-slate-800 rounded-lg shadow-sm border border-slate-200/60 dark:border-slate-700/60">
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Outlet />
|
||||
</Suspense>
|
||||
</div>
|
||||
|
@ -189,4 +226,12 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const AdminLayout: FC<AdminLayoutProps> = () => {
|
||||
return (
|
||||
<PageTitleProvider>
|
||||
<AdminLayoutContent />
|
||||
</PageTitleProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default AdminLayout;
|
||||
|
|
|
@ -4,6 +4,7 @@ import { FiUser, FiLock, FiSun, FiMoon, FiMonitor, FiGlobe } from 'react-icons/f
|
|||
import { useTranslation } from 'react-i18next';
|
||||
import { useTheme } from '../../hooks/useTheme';
|
||||
import { Menu } from '@headlessui/react';
|
||||
import { useUser } from '../../contexts/UserContext';
|
||||
|
||||
interface LoginFormData {
|
||||
username: string;
|
||||
|
@ -27,6 +28,7 @@ export default function Login() {
|
|||
const navigate = useNavigate();
|
||||
const { t, i18n } = useTranslation();
|
||||
const { theme, setTheme } = useTheme();
|
||||
const { fetchUser } = useUser();
|
||||
const [formData, setFormData] = useState<LoginFormData>({
|
||||
username: '',
|
||||
password: '',
|
||||
|
@ -66,6 +68,9 @@ export default function Login() {
|
|||
localStorage.removeItem('username');
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
await fetchUser();
|
||||
|
||||
// 跳转到管理面板
|
||||
navigate('/admin');
|
||||
} catch (err) {
|
||||
|
|
296
frontend/src/pages/admin/posts/PostEditor.tsx
Normal file
296
frontend/src/pages/admin/posts/PostEditor.tsx
Normal file
|
@ -0,0 +1,296 @@
|
|||
import { useState, useEffect, FC } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useNavigate, useParams, useBlocker } from 'react-router-dom';
|
||||
import LoadingSpinner from '../../../components/LoadingSpinner';
|
||||
import { usePageTitle } from '../../../contexts/PageTitleContext';
|
||||
import MarkdownEditor from '../../../components/admin/MarkdownEditor';
|
||||
import classNames from 'classnames';
|
||||
|
||||
interface PostContent {
|
||||
language_code: 'en' | 'zh-Hans' | 'zh-Hant';
|
||||
title: string;
|
||||
content_markdown: string;
|
||||
summary?: string;
|
||||
meta_keywords?: string;
|
||||
meta_description?: string;
|
||||
}
|
||||
|
||||
interface Category {
|
||||
id: number;
|
||||
contents: Array<{
|
||||
language_code: string;
|
||||
name: string;
|
||||
slug: string;
|
||||
description?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id?: number;
|
||||
slug: string;
|
||||
status: 'draft' | 'published';
|
||||
contents: PostContent[];
|
||||
categories: Category[];
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
}
|
||||
|
||||
const LANGUAGES = [
|
||||
{ code: 'en', label: 'English' },
|
||||
{ code: 'zh-Hans', label: '简体中文' },
|
||||
{ code: 'zh-Hant', label: '繁體中文' }
|
||||
] as const;
|
||||
|
||||
const PostEditor: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const { postId } = useParams<{ postId: string }>();
|
||||
const isEditing = !!postId;
|
||||
const { setTitle } = usePageTitle();
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [isDirty, setIsDirty] = useState(false);
|
||||
const [activeTab, setActiveTab] = useState<'en' | 'zh-Hans' | 'zh-Hant'>('en');
|
||||
const [post, setPost] = useState<Post>({
|
||||
slug: '',
|
||||
status: 'draft',
|
||||
contents: [
|
||||
{
|
||||
language_code: 'en',
|
||||
title: '',
|
||||
content_markdown: ''
|
||||
},
|
||||
{
|
||||
language_code: 'zh-Hans',
|
||||
title: '',
|
||||
content_markdown: ''
|
||||
},
|
||||
{
|
||||
language_code: 'zh-Hant',
|
||||
title: '',
|
||||
content_markdown: ''
|
||||
}
|
||||
],
|
||||
categories: []
|
||||
});
|
||||
|
||||
const blocker = useBlocker(
|
||||
({ currentLocation, nextLocation }) =>
|
||||
isDirty && currentLocation.pathname !== nextLocation.pathname
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (blocker.state === 'blocked') {
|
||||
const confirmed = window.confirm(t('admin.common.unsavedChanges'));
|
||||
if (confirmed) {
|
||||
blocker.proceed();
|
||||
} else {
|
||||
blocker.reset();
|
||||
}
|
||||
}
|
||||
}, [blocker, t]);
|
||||
|
||||
useEffect(() => {
|
||||
setTitle(isEditing ? t('admin.posts.edit') : t('admin.posts.create'));
|
||||
}, [isEditing, t, setTitle]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isEditing) {
|
||||
fetchPost();
|
||||
}
|
||||
}, [postId]);
|
||||
|
||||
const fetchPost = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch(`/api/v1/posts/${postId}`);
|
||||
if (!response.ok) throw new Error('Failed to fetch post');
|
||||
const data = await response.json();
|
||||
setPost(data.data);
|
||||
setIsDirty(false);
|
||||
} catch (error) {
|
||||
console.error('Error fetching post:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (publish: boolean = false) => {
|
||||
try {
|
||||
setSaving(true);
|
||||
const method = isEditing ? 'PUT' : 'POST';
|
||||
const url = isEditing ? `/api/v1/posts/${postId}` : '/api/v1/posts';
|
||||
|
||||
const response = await fetch(url, {
|
||||
method,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...post,
|
||||
status: publish ? 'published' : 'draft'
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save post');
|
||||
|
||||
setIsDirty(false);
|
||||
navigate('/admin/posts');
|
||||
} catch (error) {
|
||||
console.error('Error saving post:', error);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const activeContent = post.contents.find(c => c.language_code === activeTab)!;
|
||||
const updateContent = (updates: Partial<PostContent>) => {
|
||||
setIsDirty(true);
|
||||
const newContents = post.contents.map(content =>
|
||||
content.language_code === activeTab
|
||||
? { ...content, ...updates }
|
||||
: content
|
||||
);
|
||||
setPost(prev => ({ ...prev, contents: newContents }));
|
||||
};
|
||||
|
||||
const handleSlugChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setIsDirty(true);
|
||||
setPost(prev => ({ ...prev, slug: e.target.value }));
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return <LoadingSpinner />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6">
|
||||
<div className="space-y-6">
|
||||
{/* Slug */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.slug')}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={post.slug}
|
||||
onChange={handleSlugChange}
|
||||
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Language Tabs */}
|
||||
<div className="border-b border-slate-200 dark:border-slate-700">
|
||||
<nav className="-mb-px flex space-x-8" aria-label="Language">
|
||||
{LANGUAGES.map(({ code, label }) => (
|
||||
<button
|
||||
key={code}
|
||||
onClick={() => setActiveTab(code)}
|
||||
className={`whitespace-nowrap py-4 px-1 border-b-2 font-medium text-sm ${
|
||||
activeTab === code
|
||||
? 'border-slate-900 dark:border-white text-slate-900 dark:text-white'
|
||||
: 'border-transparent text-slate-500 dark:text-slate-400 hover:text-slate-700 dark:hover:text-slate-300 hover:border-slate-300 dark:hover:border-slate-600'
|
||||
}`}
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
))}
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
{/* Content Fields */}
|
||||
<div className="space-y-4">
|
||||
{/* Title */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.title')}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={activeContent.title}
|
||||
onChange={e => updateContent({ title: e.target.value })}
|
||||
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.content')}
|
||||
</label>
|
||||
<MarkdownEditor
|
||||
value={activeContent.content_markdown}
|
||||
onChange={(value) => updateContent({ content_markdown: value })}
|
||||
placeholder={t('admin.posts.content')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Summary */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.summary')}
|
||||
</label>
|
||||
<textarea
|
||||
value={activeContent.summary || ''}
|
||||
onChange={e => updateContent({ summary: e.target.value })}
|
||||
rows={3}
|
||||
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Meta Keywords */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.metaKeywords')}
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={activeContent.meta_keywords || ''}
|
||||
onChange={e => updateContent({ meta_keywords: e.target.value })}
|
||||
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Meta Description */}
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
|
||||
{t('admin.posts.metaDescription')}
|
||||
</label>
|
||||
<textarea
|
||||
value={activeContent.meta_description || ''}
|
||||
onChange={e => updateContent({ meta_description: e.target.value })}
|
||||
rows={2}
|
||||
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Categories */}
|
||||
{/* TODO: Add category selection */}
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex justify-end space-x-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSubmit(false)}
|
||||
disabled={saving}
|
||||
className="px-4 py-2 text-sm font-medium text-slate-700 dark:text-slate-300 bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-md hover:bg-slate-50 dark:hover:bg-slate-700 focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400 disabled:opacity-50"
|
||||
>
|
||||
{saving ? t('admin.posts.saving') : t('admin.posts.saveDraft')}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSubmit(true)}
|
||||
disabled={saving}
|
||||
className="px-4 py-2 text-sm font-medium text-white bg-slate-900 dark:bg-slate-700 rounded-md hover:bg-slate-800 dark:hover:bg-slate-600 focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400 disabled:opacity-50"
|
||||
>
|
||||
{saving ? t('admin.posts.publishing') : t('admin.posts.publish')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default PostEditor;
|
|
@ -1,73 +1,183 @@
|
|||
import { useState } from 'react';
|
||||
import { useState, useEffect, FC, ReactNode } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import Table from '../../../components/admin/Table';
|
||||
import TableActions from '../../../components/admin/TableActions';
|
||||
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
|
||||
import { RiSearchLine } from 'react-icons/ri';
|
||||
import dayjs from 'dayjs';
|
||||
import relativeTime from 'dayjs/plugin/relativeTime';
|
||||
import 'dayjs/locale/zh-cn';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { usePageTitle } from '../../../contexts/PageTitleContext';
|
||||
|
||||
// 初始化 dayjs 插件
|
||||
dayjs.extend(relativeTime);
|
||||
dayjs.locale('zh-cn');
|
||||
|
||||
interface PostContent {
|
||||
language_code: 'en' | 'zh-Hans' | 'zh-Hant';
|
||||
title: string;
|
||||
content_markdown: string;
|
||||
summary?: string;
|
||||
meta_keywords?: string;
|
||||
meta_description?: string;
|
||||
}
|
||||
|
||||
interface Category {
|
||||
id: number;
|
||||
contents: Array<{
|
||||
language_code: string;
|
||||
name: string;
|
||||
slug: string;
|
||||
description?: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: string;
|
||||
title: string;
|
||||
category: string;
|
||||
publishDate: string;
|
||||
id: number;
|
||||
slug: string;
|
||||
status: 'draft' | 'published';
|
||||
contents: PostContent[];
|
||||
categories: Category[];
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
type TablePost = {
|
||||
id: number;
|
||||
slug: string;
|
||||
status: 'draft' | 'published';
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
displayTitle: string;
|
||||
displayCategories: string;
|
||||
};
|
||||
|
||||
const PostsManagement: FC = () => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [posts, setPosts] = useState<TablePost[]>([]);
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const { setTitle } = usePageTitle();
|
||||
|
||||
// 这里后续会通过 API 获取数据
|
||||
const posts: Post[] = [
|
||||
{
|
||||
id: '1',
|
||||
title: '示例文章标题',
|
||||
category: '示例分类',
|
||||
publishDate: '2024-02-20',
|
||||
status: 'published',
|
||||
},
|
||||
];
|
||||
useEffect(() => {
|
||||
setTitle(t('admin.nav.posts'));
|
||||
return () => setTitle('');
|
||||
}, [setTitle, t]);
|
||||
|
||||
const handleEdit = (post: Post) => {
|
||||
console.log('Edit post:', post);
|
||||
useEffect(() => {
|
||||
fetchPosts();
|
||||
}, []);
|
||||
|
||||
const fetchPosts = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch('/api/v1/posts?sort=-created_at');
|
||||
if (!response.ok) throw new Error('Failed to fetch posts');
|
||||
const data = await response.json();
|
||||
|
||||
// 处理数据,添加显示字段
|
||||
const processedPosts = data.data.map((post: Post): TablePost => ({
|
||||
id: post.id,
|
||||
slug: post.slug,
|
||||
status: post.status,
|
||||
created_at: post.created_at,
|
||||
updated_at: post.updated_at,
|
||||
displayTitle: getPostTitle(post),
|
||||
displayCategories: getCategoryNames(post)
|
||||
}));
|
||||
|
||||
setPosts(processedPosts);
|
||||
} catch (error) {
|
||||
console.error('Error fetching posts:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = (post: Post) => {
|
||||
console.log('Delete post:', post);
|
||||
const getPostTitle = (post: Post) => {
|
||||
// Try to get English title first
|
||||
const englishContent = post.contents.find(c => c.language_code === 'en');
|
||||
if (englishContent?.title) return englishContent.title;
|
||||
|
||||
// Fallback to the first available title
|
||||
const firstContent = post.contents[0];
|
||||
return firstContent?.title || t('admin.posts.noTitle');
|
||||
};
|
||||
|
||||
const getCategoryNames = (post: Post) => {
|
||||
return post.categories
|
||||
.map(category => {
|
||||
const englishContent = category.contents.find(c => c.language_code === 'en');
|
||||
if (englishContent?.name) return englishContent.name;
|
||||
return category.contents[0]?.name || '';
|
||||
})
|
||||
.filter(Boolean)
|
||||
.join(', ');
|
||||
};
|
||||
|
||||
const handleEdit = async (post: TablePost) => {
|
||||
navigate(post.slug);
|
||||
};
|
||||
|
||||
const handleDelete = async (post: TablePost) => {
|
||||
if (!window.confirm(t('admin.posts.deleteConfirm'))) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/v1/posts/${post.slug}`, {
|
||||
method: 'DELETE',
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to delete post');
|
||||
|
||||
await fetchPosts();
|
||||
} catch (error) {
|
||||
console.error('Error deleting post:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const columns = [
|
||||
{
|
||||
key: 'title' as keyof Post,
|
||||
key: 'displayTitle' as keyof TablePost,
|
||||
title: t('admin.posts.title'),
|
||||
render: (value: string): ReactNode => value
|
||||
},
|
||||
{
|
||||
key: 'category' as keyof Post,
|
||||
title: t('admin.posts.category'),
|
||||
key: 'displayCategories' as keyof TablePost,
|
||||
title: t('admin.posts.categories'),
|
||||
render: (value: string): ReactNode => value
|
||||
},
|
||||
{
|
||||
key: 'publishDate' as keyof Post,
|
||||
title: t('admin.posts.publishDate'),
|
||||
},
|
||||
{
|
||||
key: 'status' as keyof Post,
|
||||
title: t('admin.posts.status'),
|
||||
render: (value: Post['status']) => (
|
||||
<span
|
||||
className={`inline-block px-2 py-1 text-xs font-medium rounded-full ${
|
||||
value === 'published'
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
|
||||
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
|
||||
}`}
|
||||
>
|
||||
{t(`admin.common.${value}`)}
|
||||
key: 'created_at' as keyof TablePost,
|
||||
title: t('admin.posts.createdAt'),
|
||||
render: (value: string): ReactNode => (
|
||||
<span title={dayjs(value).format('YYYY-MM-DD HH:mm:ss')}>
|
||||
{dayjs(value).fromNow()}
|
||||
</span>
|
||||
),
|
||||
)
|
||||
},
|
||||
{
|
||||
key: 'status' as keyof TablePost,
|
||||
title: t('admin.posts.status'),
|
||||
render: (value: string | number): ReactNode => {
|
||||
const status = value as TablePost['status'];
|
||||
return (
|
||||
<span
|
||||
className={`inline-block px-2 py-1 text-xs font-medium rounded-full ${
|
||||
status === 'published'
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
|
||||
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
|
||||
}`}
|
||||
>
|
||||
{t(`admin.common.${status}`)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
const handleCreate = () => {
|
||||
// TODO: 实现创建文章的逻辑
|
||||
navigate('new');
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -89,59 +199,13 @@ const PostsManagement: FC = () => {
|
|||
</div>
|
||||
|
||||
<div>
|
||||
<Table<Post>
|
||||
<Table<TablePost>
|
||||
columns={columns}
|
||||
data={posts}
|
||||
loading={loading}
|
||||
onEdit={handleEdit}
|
||||
onDelete={handleDelete}
|
||||
>
|
||||
{({ data, onEdit, onDelete }) => (
|
||||
<table className="w-full">
|
||||
<thead>
|
||||
<tr className="border-b border-slate-200 dark:border-slate-700">
|
||||
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tl-md">{t('admin.common.title')}</th>
|
||||
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.category')}</th>
|
||||
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.publishDate')}</th>
|
||||
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.status')}</th>
|
||||
<th className="text-right py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tr-md">{t('admin.common.actions')}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((post) => (
|
||||
<tr key={post.id} className="border-b border-slate-200 dark:border-slate-700 last:border-0">
|
||||
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.title}</td>
|
||||
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.category}</td>
|
||||
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.publishDate}</td>
|
||||
<td className="py-4 px-6">
|
||||
<span className={`px-2 py-1 text-sm font-medium rounded-full ${
|
||||
post.status === 'published'
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
|
||||
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
|
||||
}`}>
|
||||
{t(`admin.common.${post.status}`)}
|
||||
</span>
|
||||
</td>
|
||||
<td className="py-4 px-6 text-right">
|
||||
<button
|
||||
className="text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-slate-200 px-2 py-1 rounded-md transition-colors"
|
||||
onClick={() => onEdit(post)}
|
||||
>
|
||||
{t('admin.common.edit')}
|
||||
</button>
|
||||
<button
|
||||
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
|
||||
onClick={() => onDelete(post)}
|
||||
>
|
||||
{t('admin.common.delete')}
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
)}
|
||||
</Table>
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
@ -14,6 +14,7 @@ const Footer = lazy(() => import('./components/Footer'));
|
|||
// 管理页面组件
|
||||
const Dashboard = lazy(() => import('./pages/admin/dashboard/Dashboard'));
|
||||
const PostsManagement = lazy(() => import('./pages/admin/posts/PostsManagement'));
|
||||
const PostEditor = lazy(() => import('./pages/admin/posts/PostEditor'));
|
||||
const DailyManagement = lazy(() => import('./pages/admin/daily/DailyManagement'));
|
||||
const MediasManagement = lazy(() => import('./pages/admin/medias/MediasManagement'));
|
||||
const CategoriesManagement = lazy(() => import('./pages/admin/categories/CategoriesManagement'));
|
||||
|
@ -56,7 +57,7 @@ const LoginRoute = () => {
|
|||
}
|
||||
|
||||
return (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Login />
|
||||
</Suspense>
|
||||
);
|
||||
|
@ -65,7 +66,7 @@ const LoginRoute = () => {
|
|||
// 页面布局组件
|
||||
const PageLayout = () => (
|
||||
<div className="flex flex-col min-h-screen bg-white dark:bg-neutral-900 text-gray-900 dark:text-gray-100">
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Header />
|
||||
<div className="w-[95%] mx-auto">
|
||||
<div className="border-t-2 border-gray-900 dark:border-gray-100 w-full mb-2" />
|
||||
|
@ -82,7 +83,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: '/',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<PageLayout />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -90,7 +91,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
index: true,
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Home />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -98,7 +99,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'daily',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Daily />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -106,7 +107,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'posts/:articleId',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Article />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -133,23 +134,44 @@ const router = createBrowserRouter([
|
|||
{
|
||||
index: true,
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Dashboard />
|
||||
</Suspense>
|
||||
),
|
||||
},
|
||||
{
|
||||
path: 'posts',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<PostsManagement />
|
||||
</Suspense>
|
||||
),
|
||||
children: [
|
||||
{
|
||||
index: true,
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<PostsManagement />
|
||||
</Suspense>
|
||||
),
|
||||
},
|
||||
{
|
||||
path: 'new',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<PostEditor />
|
||||
</Suspense>
|
||||
),
|
||||
},
|
||||
{
|
||||
path: ':postId',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<PostEditor />
|
||||
</Suspense>
|
||||
),
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
path: 'daily',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<DailyManagement />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -157,7 +179,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'medias',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<MediasManagement />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -165,7 +187,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'categories',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<CategoriesManagement />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -173,7 +195,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'users',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<UsersManagement />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -181,7 +203,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'contributors',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<ContributorsManagement />
|
||||
</Suspense>
|
||||
),
|
||||
|
@ -189,7 +211,7 @@ const router = createBrowserRouter([
|
|||
{
|
||||
path: 'settings',
|
||||
element: (
|
||||
<Suspense fallback={<LoadingSpinner />}>
|
||||
<Suspense fallback={<LoadingSpinner fullScreen />}>
|
||||
<Settings />
|
||||
</Suspense>
|
||||
),
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"files": [],
|
||||
"references": [
|
||||
{ "path": "./tsconfig.app.json" },
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
import path from 'path';
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
|
@ -8,6 +9,11 @@ export default defineConfig({
|
|||
react(),
|
||||
tailwindcss()
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, './src')
|
||||
}
|
||||
},
|
||||
optimizeDeps: {
|
||||
exclude: ['lucide-react'],
|
||||
},
|
||||
|
@ -18,6 +24,11 @@ export default defineConfig({
|
|||
changeOrigin: true,
|
||||
secure: false,
|
||||
},
|
||||
'/media': {
|
||||
target: 'http://localhost:8080',
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
1141
pnpm-lock.yaml
generated
1141
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue