Compare commits

...

9 commits

Author SHA1 Message Date
CDN
3e6181e578
[feature/backend] overall enhancement of image uploading
All checks were successful
Build Backend / Build Docker Image (push) Successful in 5m3s
2025-02-23 04:42:48 +08:00
CDN
6e1be3d513
[feature/frontend] markdown editor 2025-02-23 02:41:36 +08:00
CDN
086c9761a9
[feature/frontend] create posts (wip) 2025-02-22 03:46:57 +08:00
CDN
e86d8c1576
[chore/backend] remove mocks 2025-02-22 02:46:40 +08:00
CDN
be8bf22017
[feature/backend] add categories param in posts 2025-02-22 02:42:55 +08:00
CDN
958e3c2886
[bugfix/ci] fix Dockerfile build 2025-02-22 02:14:54 +08:00
CDN
1c9628124f
[chore/backend] remove all test for now 2025-02-22 02:11:27 +08:00
CDN
3d19ef05b3
[chore/frontend] call /auth/logout during logout 2025-02-22 01:23:13 +08:00
CDN
2c3e238e9a
[chore/frontend] load on the center of the screen 2025-02-22 00:58:00 +08:00
76 changed files with 4111 additions and 7531 deletions

View file

@ -103,6 +103,7 @@ Post:
- slug
- status
- contents
- categories
properties:
id:
type: integer
@ -117,6 +118,11 @@ Post:
type: array
items:
$ref: '#/PostContent'
categories:
type: array
items:
$ref: '#/Category'
description: 文章所属分类
created_at:
type: string
format: date-time

View file

@ -14,9 +14,7 @@ RUN adduser -u 1000 -D tss-rocks
USER tss-rocks
WORKDIR /app
# 复制二进制文件和配置
COPY --from=builder /app/tss-rocks-be .
COPY --from=builder /app/config/config.yaml ./config/
EXPOSE 8080
ENV GIN_MODE=release

View file

@ -21,10 +21,50 @@ auth:
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
storage:
driver: local
type: local # local or s3
local:
root: storage
base_url: http://localhost:8080/storage
root_dir: "./storage/media"
s3:
region: "us-east-1"
bucket: "your-bucket-name"
access_key_id: "your-access-key-id"
secret_access_key: "your-secret-access-key"
endpoint: "" # Optional, for MinIO or other S3-compatible services
custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
upload:
limits:
image:
max_size: 10 # MB
allowed_types:
- image/jpeg
- image/png
- image/gif
- image/webp
- image/svg+xml
video:
max_size: 500 # MB
allowed_types:
- video/mp4
- video/webm
audio:
max_size: 50 # MB
allowed_types:
- audio/mpeg
- audio/ogg
- audio/wav
document:
max_size: 20 # MB
allowed_types:
- application/pdf
- application/msword
- application/vnd.openxmlformats-officedocument.wordprocessingml.document
- application/vnd.ms-excel
- application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
- application/zip
- application/x-rar-compressed
- text/plain
- text/csv
logging:
level: debug

View file

@ -33,13 +33,11 @@ const (
ContentsInverseTable = "category_contents"
// ContentsColumn is the table column denoting the contents relation/edge.
ContentsColumn = "category_contents"
// PostsTable is the table that holds the posts relation/edge.
PostsTable = "posts"
// PostsTable is the table that holds the posts relation/edge. The primary key declared below.
PostsTable = "category_posts"
// PostsInverseTable is the table name for the Post entity.
// It exists in this package in order to avoid circular dependency with the "post" package.
PostsInverseTable = "posts"
// PostsColumn is the table column denoting the posts relation/edge.
PostsColumn = "category_posts"
// DailyItemsTable is the table that holds the daily_items relation/edge.
DailyItemsTable = "dailies"
// DailyItemsInverseTable is the table name for the Daily entity.
@ -56,6 +54,12 @@ var Columns = []string{
FieldUpdatedAt,
}
var (
// PostsPrimaryKey and PostsColumn2 are the table columns denoting the
// primary key for the posts relation (M2M).
PostsPrimaryKey = []string{"category_id", "post_id"}
)
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
@ -145,7 +149,7 @@ func newPostsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PostsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
)
}
func newDailyItemsStep() *sqlgraph.Step {

View file

@ -173,7 +173,7 @@ func HasPosts() predicate.Category {
return predicate.Category(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PostsTable, PostsColumn),
sqlgraph.Edge(sqlgraph.M2M, false, PostsTable, PostsPrimaryKey...),
)
sqlgraph.HasNeighbors(s, step)
})

View file

@ -201,10 +201,10 @@ func (cc *CategoryCreate) createSpec() (*Category, *sqlgraph.CreateSpec) {
}
if nodes := cc.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),

View file

@ -101,7 +101,7 @@ func (cq *CategoryQuery) QueryPosts() *PostQuery {
step := sqlgraph.NewStep(
sqlgraph.From(category.Table, category.FieldID, selector),
sqlgraph.To(post.Table, post.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
)
fromU = sqlgraph.SetNeighbors(cq.driver.Dialect(), step)
return fromU, nil
@ -523,33 +523,63 @@ func (cq *CategoryQuery) loadContents(ctx context.Context, query *CategoryConten
return nil
}
func (cq *CategoryQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*Category, init func(*Category), assign func(*Category, *Post)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int]*Category)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int]*Category)
nids := make(map[int]map[*Category]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
byID[node.ID] = node
if init != nil {
init(nodes[i])
init(node)
}
}
query.withFKs = true
query.Where(predicate.Post(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(category.PostsColumn), fks...))
}))
neighbors, err := query.All(ctx)
query.Where(func(s *sql.Selector) {
joinT := sql.Table(category.PostsTable)
s.Join(joinT).On(s.C(post.FieldID), joinT.C(category.PostsPrimaryKey[1]))
s.Where(sql.InValues(joinT.C(category.PostsPrimaryKey[0]), edgeIDs...))
columns := s.SelectedColumns()
s.Select(joinT.C(category.PostsPrimaryKey[0]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
if err := query.prepareQuery(ctx); err != nil {
return err
}
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]any, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]any{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []any) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*Category]struct{}{byID[outValue]: {}}
return assign(columns[1:], values[1:])
}
nids[inValue][byID[outValue]] = struct{}{}
return nil
}
})
})
neighbors, err := withInterceptors[[]*Post](ctx, query, qr, query.inters)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.category_posts
if fk == nil {
return fmt.Errorf(`foreign-key "category_posts" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
nodes, ok := nids[n.ID]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "category_posts" returned %v for node %v`, *fk, n.ID)
return fmt.Errorf(`unexpected "posts" node returned %v`, n.ID)
}
for kn := range nodes {
assign(kn, n)
}
assign(node, n)
}
return nil
}

View file

@ -262,10 +262,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if cu.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@ -275,10 +275,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if nodes := cu.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cu.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@ -291,10 +291,10 @@ func (cu *CategoryUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
if nodes := cu.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@ -631,10 +631,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if cuo.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@ -644,10 +644,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if nodes := cuo.mutation.RemovedPostsIDs(); len(nodes) > 0 && !cuo.mutation.PostsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),
@ -660,10 +660,10 @@ func (cuo *CategoryUpdateOne) sqlSave(ctx context.Context) (_node *Category, err
}
if nodes := cuo.mutation.PostsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Rel: sqlgraph.M2M,
Inverse: false,
Table: category.PostsTable,
Columns: []string{category.PostsColumn},
Columns: category.PostsPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(post.FieldID, field.TypeInt),

View file

@ -464,7 +464,7 @@ func (c *CategoryClient) QueryPosts(ca *Category) *PostQuery {
step := sqlgraph.NewStep(
sqlgraph.From(category.Table, category.FieldID, id),
sqlgraph.To(post.Table, post.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, category.PostsTable, category.PostsColumn),
sqlgraph.Edge(sqlgraph.M2M, false, category.PostsTable, category.PostsPrimaryKey...),
)
fromV = sqlgraph.Neighbors(ca.driver.Dialect(), step)
return fromV, nil
@ -2207,15 +2207,15 @@ func (c *PostClient) QueryContributors(po *Post) *PostContributorQuery {
return query
}
// QueryCategory queries the category edge of a Post.
func (c *PostClient) QueryCategory(po *Post) *CategoryQuery {
// QueryCategories queries the categories edge of a Post.
func (c *PostClient) QueryCategories(po *Post) *CategoryQuery {
query := (&CategoryClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := po.ID
step := sqlgraph.NewStep(
sqlgraph.From(post.Table, post.FieldID, id),
sqlgraph.To(category.Table, category.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
)
fromV = sqlgraph.Neighbors(po.driver.Dialect(), step)
return fromV, nil

View file

@ -265,21 +265,12 @@ var (
{Name: "slug", Type: field.TypeString, Unique: true},
{Name: "created_at", Type: field.TypeTime},
{Name: "updated_at", Type: field.TypeTime},
{Name: "category_posts", Type: field.TypeInt, Nullable: true},
}
// PostsTable holds the schema information for the "posts" table.
PostsTable = &schema.Table{
Name: "posts",
Columns: PostsColumns,
PrimaryKey: []*schema.Column{PostsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "posts_categories_posts",
Columns: []*schema.Column{PostsColumns[5]},
RefColumns: []*schema.Column{CategoriesColumns[0]},
OnDelete: schema.SetNull,
},
},
}
// PostContentsColumns holds the columns for the "post_contents" table.
PostContentsColumns = []*schema.Column{
@ -387,6 +378,31 @@ var (
Columns: UsersColumns,
PrimaryKey: []*schema.Column{UsersColumns[0]},
}
// CategoryPostsColumns holds the columns for the "category_posts" table.
CategoryPostsColumns = []*schema.Column{
{Name: "category_id", Type: field.TypeInt},
{Name: "post_id", Type: field.TypeInt},
}
// CategoryPostsTable holds the schema information for the "category_posts" table.
CategoryPostsTable = &schema.Table{
Name: "category_posts",
Columns: CategoryPostsColumns,
PrimaryKey: []*schema.Column{CategoryPostsColumns[0], CategoryPostsColumns[1]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "category_posts_category_id",
Columns: []*schema.Column{CategoryPostsColumns[0]},
RefColumns: []*schema.Column{CategoriesColumns[0]},
OnDelete: schema.Cascade,
},
{
Symbol: "category_posts_post_id",
Columns: []*schema.Column{CategoryPostsColumns[1]},
RefColumns: []*schema.Column{PostsColumns[0]},
OnDelete: schema.Cascade,
},
},
}
// RolePermissionsColumns holds the columns for the "role_permissions" table.
RolePermissionsColumns = []*schema.Column{
{Name: "role_id", Type: field.TypeInt},
@ -455,6 +471,7 @@ var (
PostContributorsTable,
RolesTable,
UsersTable,
CategoryPostsTable,
RolePermissionsTable,
UserRolesTable,
}
@ -469,11 +486,12 @@ func init() {
DailyCategoryContentsTable.ForeignKeys[0].RefTable = DailyCategoriesTable
DailyContentsTable.ForeignKeys[0].RefTable = DailiesTable
MediaTable.ForeignKeys[0].RefTable = UsersTable
PostsTable.ForeignKeys[0].RefTable = CategoriesTable
PostContentsTable.ForeignKeys[0].RefTable = PostsTable
PostContributorsTable.ForeignKeys[0].RefTable = ContributorsTable
PostContributorsTable.ForeignKeys[1].RefTable = ContributorRolesTable
PostContributorsTable.ForeignKeys[2].RefTable = PostsTable
CategoryPostsTable.ForeignKeys[0].RefTable = CategoriesTable
CategoryPostsTable.ForeignKeys[1].RefTable = PostsTable
RolePermissionsTable.ForeignKeys[0].RefTable = RolesTable
RolePermissionsTable.ForeignKeys[1].RefTable = PermissionsTable
UserRolesTable.ForeignKeys[0].RefTable = UsersTable

View file

@ -6578,8 +6578,9 @@ type PostMutation struct {
contributors map[int]struct{}
removedcontributors map[int]struct{}
clearedcontributors bool
category *int
clearedcategory bool
categories map[int]struct{}
removedcategories map[int]struct{}
clearedcategories bool
done bool
oldValue func(context.Context) (*Post, error)
predicates []predicate.Post
@ -6935,43 +6936,58 @@ func (m *PostMutation) ResetContributors() {
m.removedcontributors = nil
}
// SetCategoryID sets the "category" edge to the Category entity by id.
func (m *PostMutation) SetCategoryID(id int) {
m.category = &id
// AddCategoryIDs adds the "categories" edge to the Category entity by ids.
func (m *PostMutation) AddCategoryIDs(ids ...int) {
if m.categories == nil {
m.categories = make(map[int]struct{})
}
for i := range ids {
m.categories[ids[i]] = struct{}{}
}
}
// ClearCategory clears the "category" edge to the Category entity.
func (m *PostMutation) ClearCategory() {
m.clearedcategory = true
// ClearCategories clears the "categories" edge to the Category entity.
func (m *PostMutation) ClearCategories() {
m.clearedcategories = true
}
// CategoryCleared reports if the "category" edge to the Category entity was cleared.
func (m *PostMutation) CategoryCleared() bool {
return m.clearedcategory
// CategoriesCleared reports if the "categories" edge to the Category entity was cleared.
func (m *PostMutation) CategoriesCleared() bool {
return m.clearedcategories
}
// CategoryID returns the "category" edge ID in the mutation.
func (m *PostMutation) CategoryID() (id int, exists bool) {
if m.category != nil {
return *m.category, true
// RemoveCategoryIDs removes the "categories" edge to the Category entity by IDs.
func (m *PostMutation) RemoveCategoryIDs(ids ...int) {
if m.removedcategories == nil {
m.removedcategories = make(map[int]struct{})
}
for i := range ids {
delete(m.categories, ids[i])
m.removedcategories[ids[i]] = struct{}{}
}
}
// RemovedCategories returns the removed IDs of the "categories" edge to the Category entity.
func (m *PostMutation) RemovedCategoriesIDs() (ids []int) {
for id := range m.removedcategories {
ids = append(ids, id)
}
return
}
// CategoryIDs returns the "category" edge IDs in the mutation.
// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
// CategoryID instead. It exists only for internal usage by the builders.
func (m *PostMutation) CategoryIDs() (ids []int) {
if id := m.category; id != nil {
ids = append(ids, *id)
// CategoriesIDs returns the "categories" edge IDs in the mutation.
func (m *PostMutation) CategoriesIDs() (ids []int) {
for id := range m.categories {
ids = append(ids, id)
}
return
}
// ResetCategory resets all changes to the "category" edge.
func (m *PostMutation) ResetCategory() {
m.category = nil
m.clearedcategory = false
// ResetCategories resets all changes to the "categories" edge.
func (m *PostMutation) ResetCategories() {
m.categories = nil
m.clearedcategories = false
m.removedcategories = nil
}
// Where appends a list predicates to the PostMutation builder.
@ -7165,8 +7181,8 @@ func (m *PostMutation) AddedEdges() []string {
if m.contributors != nil {
edges = append(edges, post.EdgeContributors)
}
if m.category != nil {
edges = append(edges, post.EdgeCategory)
if m.categories != nil {
edges = append(edges, post.EdgeCategories)
}
return edges
}
@ -7187,10 +7203,12 @@ func (m *PostMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
case post.EdgeCategory:
if id := m.category; id != nil {
return []ent.Value{*id}
case post.EdgeCategories:
ids := make([]ent.Value, 0, len(m.categories))
for id := range m.categories {
ids = append(ids, id)
}
return ids
}
return nil
}
@ -7204,6 +7222,9 @@ func (m *PostMutation) RemovedEdges() []string {
if m.removedcontributors != nil {
edges = append(edges, post.EdgeContributors)
}
if m.removedcategories != nil {
edges = append(edges, post.EdgeCategories)
}
return edges
}
@ -7223,6 +7244,12 @@ func (m *PostMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
case post.EdgeCategories:
ids := make([]ent.Value, 0, len(m.removedcategories))
for id := range m.removedcategories {
ids = append(ids, id)
}
return ids
}
return nil
}
@ -7236,8 +7263,8 @@ func (m *PostMutation) ClearedEdges() []string {
if m.clearedcontributors {
edges = append(edges, post.EdgeContributors)
}
if m.clearedcategory {
edges = append(edges, post.EdgeCategory)
if m.clearedcategories {
edges = append(edges, post.EdgeCategories)
}
return edges
}
@ -7250,8 +7277,8 @@ func (m *PostMutation) EdgeCleared(name string) bool {
return m.clearedcontents
case post.EdgeContributors:
return m.clearedcontributors
case post.EdgeCategory:
return m.clearedcategory
case post.EdgeCategories:
return m.clearedcategories
}
return false
}
@ -7260,9 +7287,6 @@ func (m *PostMutation) EdgeCleared(name string) bool {
// if that edge is not defined in the schema.
func (m *PostMutation) ClearEdge(name string) error {
switch name {
case post.EdgeCategory:
m.ClearCategory()
return nil
}
return fmt.Errorf("unknown Post unique edge %s", name)
}
@ -7277,8 +7301,8 @@ func (m *PostMutation) ResetEdge(name string) error {
case post.EdgeContributors:
m.ResetContributors()
return nil
case post.EdgeCategory:
m.ResetCategory()
case post.EdgeCategories:
m.ResetCategories()
return nil
}
return fmt.Errorf("unknown Post edge %s", name)

View file

@ -6,7 +6,6 @@ import (
"fmt"
"strings"
"time"
"tss-rocks-be/ent/category"
"tss-rocks-be/ent/post"
"entgo.io/ent"
@ -28,9 +27,8 @@ type Post struct {
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the PostQuery when eager-loading is set.
Edges PostEdges `json:"edges"`
category_posts *int
selectValues sql.SelectValues
Edges PostEdges `json:"edges"`
selectValues sql.SelectValues
}
// PostEdges holds the relations/edges for other nodes in the graph.
@ -39,8 +37,8 @@ type PostEdges struct {
Contents []*PostContent `json:"contents,omitempty"`
// Contributors holds the value of the contributors edge.
Contributors []*PostContributor `json:"contributors,omitempty"`
// Category holds the value of the category edge.
Category *Category `json:"category,omitempty"`
// Categories holds the value of the categories edge.
Categories []*Category `json:"categories,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [3]bool
@ -64,15 +62,13 @@ func (e PostEdges) ContributorsOrErr() ([]*PostContributor, error) {
return nil, &NotLoadedError{edge: "contributors"}
}
// CategoryOrErr returns the Category value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e PostEdges) CategoryOrErr() (*Category, error) {
if e.Category != nil {
return e.Category, nil
} else if e.loadedTypes[2] {
return nil, &NotFoundError{label: category.Label}
// CategoriesOrErr returns the Categories value or an error if the edge
// was not loaded in eager-loading.
func (e PostEdges) CategoriesOrErr() ([]*Category, error) {
if e.loadedTypes[2] {
return e.Categories, nil
}
return nil, &NotLoadedError{edge: "category"}
return nil, &NotLoadedError{edge: "categories"}
}
// scanValues returns the types for scanning values from sql.Rows.
@ -86,8 +82,6 @@ func (*Post) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullString)
case post.FieldCreatedAt, post.FieldUpdatedAt:
values[i] = new(sql.NullTime)
case post.ForeignKeys[0]: // category_posts
values[i] = new(sql.NullInt64)
default:
values[i] = new(sql.UnknownType)
}
@ -133,13 +127,6 @@ func (po *Post) assignValues(columns []string, values []any) error {
} else if value.Valid {
po.UpdatedAt = value.Time
}
case post.ForeignKeys[0]:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for edge-field category_posts", value)
} else if value.Valid {
po.category_posts = new(int)
*po.category_posts = int(value.Int64)
}
default:
po.selectValues.Set(columns[i], values[i])
}
@ -163,9 +150,9 @@ func (po *Post) QueryContributors() *PostContributorQuery {
return NewPostClient(po.config).QueryContributors(po)
}
// QueryCategory queries the "category" edge of the Post entity.
func (po *Post) QueryCategory() *CategoryQuery {
return NewPostClient(po.config).QueryCategory(po)
// QueryCategories queries the "categories" edge of the Post entity.
func (po *Post) QueryCategories() *CategoryQuery {
return NewPostClient(po.config).QueryCategories(po)
}
// Update returns a builder for updating this Post.

View file

@ -27,8 +27,8 @@ const (
EdgeContents = "contents"
// EdgeContributors holds the string denoting the contributors edge name in mutations.
EdgeContributors = "contributors"
// EdgeCategory holds the string denoting the category edge name in mutations.
EdgeCategory = "category"
// EdgeCategories holds the string denoting the categories edge name in mutations.
EdgeCategories = "categories"
// Table holds the table name of the post in the database.
Table = "posts"
// ContentsTable is the table that holds the contents relation/edge.
@ -45,13 +45,11 @@ const (
ContributorsInverseTable = "post_contributors"
// ContributorsColumn is the table column denoting the contributors relation/edge.
ContributorsColumn = "post_contributors"
// CategoryTable is the table that holds the category relation/edge.
CategoryTable = "posts"
// CategoryInverseTable is the table name for the Category entity.
// CategoriesTable is the table that holds the categories relation/edge. The primary key declared below.
CategoriesTable = "category_posts"
// CategoriesInverseTable is the table name for the Category entity.
// It exists in this package in order to avoid circular dependency with the "category" package.
CategoryInverseTable = "categories"
// CategoryColumn is the table column denoting the category relation/edge.
CategoryColumn = "category_posts"
CategoriesInverseTable = "categories"
)
// Columns holds all SQL columns for post fields.
@ -63,11 +61,11 @@ var Columns = []string{
FieldUpdatedAt,
}
// ForeignKeys holds the SQL foreign-keys that are owned by the "posts"
// table and are not defined as standalone fields in the schema.
var ForeignKeys = []string{
"category_posts",
}
var (
// CategoriesPrimaryKey and CategoriesColumn2 are the table columns denoting the
// primary key for the categories relation (M2M).
CategoriesPrimaryKey = []string{"category_id", "post_id"}
)
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
@ -76,11 +74,6 @@ func ValidColumn(column string) bool {
return true
}
}
for i := range ForeignKeys {
if column == ForeignKeys[i] {
return true
}
}
return false
}
@ -178,10 +171,17 @@ func ByContributors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByCategoryField orders the results by category field.
func ByCategoryField(field string, opts ...sql.OrderTermOption) OrderOption {
// ByCategoriesCount orders the results by categories count.
func ByCategoriesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newCategoryStep(), sql.OrderByField(field, opts...))
sqlgraph.OrderByNeighborsCount(s, newCategoriesStep(), opts...)
}
}
// ByCategories orders the results by categories terms.
func ByCategories(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newCategoriesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newContentsStep() *sqlgraph.Step {
@ -198,10 +198,10 @@ func newContributorsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, ContributorsTable, ContributorsColumn),
)
}
func newCategoryStep() *sqlgraph.Step {
func newCategoriesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(CategoryInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
sqlgraph.To(CategoriesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
)
}

View file

@ -281,21 +281,21 @@ func HasContributorsWith(preds ...predicate.PostContributor) predicate.Post {
})
}
// HasCategory applies the HasEdge predicate on the "category" edge.
func HasCategory() predicate.Post {
// HasCategories applies the HasEdge predicate on the "categories" edge.
func HasCategories() predicate.Post {
return predicate.Post(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, CategoryTable, CategoryColumn),
sqlgraph.Edge(sqlgraph.M2M, true, CategoriesTable, CategoriesPrimaryKey...),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasCategoryWith applies the HasEdge predicate on the "category" edge with a given conditions (other predicates).
func HasCategoryWith(preds ...predicate.Category) predicate.Post {
// HasCategoriesWith applies the HasEdge predicate on the "categories" edge with a given conditions (other predicates).
func HasCategoriesWith(preds ...predicate.Category) predicate.Post {
return predicate.Post(func(s *sql.Selector) {
step := newCategoryStep()
step := newCategoriesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)

View file

@ -101,23 +101,19 @@ func (pc *PostCreate) AddContributors(p ...*PostContributor) *PostCreate {
return pc.AddContributorIDs(ids...)
}
// SetCategoryID sets the "category" edge to the Category entity by ID.
func (pc *PostCreate) SetCategoryID(id int) *PostCreate {
pc.mutation.SetCategoryID(id)
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
func (pc *PostCreate) AddCategoryIDs(ids ...int) *PostCreate {
pc.mutation.AddCategoryIDs(ids...)
return pc
}
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
func (pc *PostCreate) SetNillableCategoryID(id *int) *PostCreate {
if id != nil {
pc = pc.SetCategoryID(*id)
// AddCategories adds the "categories" edges to the Category entity.
func (pc *PostCreate) AddCategories(c ...*Category) *PostCreate {
ids := make([]int, len(c))
for i := range c {
ids[i] = c[i].ID
}
return pc
}
// SetCategory sets the "category" edge to the Category entity.
func (pc *PostCreate) SetCategory(c *Category) *PostCreate {
return pc.SetCategoryID(c.ID)
return pc.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@ -267,12 +263,12 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := pc.mutation.CategoryIDs(); len(nodes) > 0 {
if nodes := pc.mutation.CategoriesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoryTable,
Columns: []string{post.CategoryColumn},
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@ -281,7 +277,6 @@ func (pc *PostCreate) createSpec() (*Post, *sqlgraph.CreateSpec) {
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.category_posts = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec

View file

@ -28,8 +28,7 @@ type PostQuery struct {
predicates []predicate.Post
withContents *PostContentQuery
withContributors *PostContributorQuery
withCategory *CategoryQuery
withFKs bool
withCategories *CategoryQuery
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -110,8 +109,8 @@ func (pq *PostQuery) QueryContributors() *PostContributorQuery {
return query
}
// QueryCategory chains the current query on the "category" edge.
func (pq *PostQuery) QueryCategory() *CategoryQuery {
// QueryCategories chains the current query on the "categories" edge.
func (pq *PostQuery) QueryCategories() *CategoryQuery {
query := (&CategoryClient{config: pq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := pq.prepareQuery(ctx); err != nil {
@ -124,7 +123,7 @@ func (pq *PostQuery) QueryCategory() *CategoryQuery {
step := sqlgraph.NewStep(
sqlgraph.From(post.Table, post.FieldID, selector),
sqlgraph.To(category.Table, category.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, post.CategoryTable, post.CategoryColumn),
sqlgraph.Edge(sqlgraph.M2M, true, post.CategoriesTable, post.CategoriesPrimaryKey...),
)
fromU = sqlgraph.SetNeighbors(pq.driver.Dialect(), step)
return fromU, nil
@ -326,7 +325,7 @@ func (pq *PostQuery) Clone() *PostQuery {
predicates: append([]predicate.Post{}, pq.predicates...),
withContents: pq.withContents.Clone(),
withContributors: pq.withContributors.Clone(),
withCategory: pq.withCategory.Clone(),
withCategories: pq.withCategories.Clone(),
// clone intermediate query.
sql: pq.sql.Clone(),
path: pq.path,
@ -355,14 +354,14 @@ func (pq *PostQuery) WithContributors(opts ...func(*PostContributorQuery)) *Post
return pq
}
// WithCategory tells the query-builder to eager-load the nodes that are connected to
// the "category" edge. The optional arguments are used to configure the query builder of the edge.
func (pq *PostQuery) WithCategory(opts ...func(*CategoryQuery)) *PostQuery {
// WithCategories tells the query-builder to eager-load the nodes that are connected to
// the "categories" edge. The optional arguments are used to configure the query builder of the edge.
func (pq *PostQuery) WithCategories(opts ...func(*CategoryQuery)) *PostQuery {
query := (&CategoryClient{config: pq.config}).Query()
for _, opt := range opts {
opt(query)
}
pq.withCategory = query
pq.withCategories = query
return pq
}
@ -443,20 +442,13 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error {
func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) {
var (
nodes = []*Post{}
withFKs = pq.withFKs
_spec = pq.querySpec()
loadedTypes = [3]bool{
pq.withContents != nil,
pq.withContributors != nil,
pq.withCategory != nil,
pq.withCategories != nil,
}
)
if pq.withCategory != nil {
withFKs = true
}
if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, post.ForeignKeys...)
}
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*Post).scanValues(nil, columns)
}
@ -489,9 +481,10 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e
return nil, err
}
}
if query := pq.withCategory; query != nil {
if err := pq.loadCategory(ctx, query, nodes, nil,
func(n *Post, e *Category) { n.Edges.Category = e }); err != nil {
if query := pq.withCategories; query != nil {
if err := pq.loadCategories(ctx, query, nodes,
func(n *Post) { n.Edges.Categories = []*Category{} },
func(n *Post, e *Category) { n.Edges.Categories = append(n.Edges.Categories, e) }); err != nil {
return nil, err
}
}
@ -560,34 +553,63 @@ func (pq *PostQuery) loadContributors(ctx context.Context, query *PostContributo
}
return nil
}
func (pq *PostQuery) loadCategory(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
ids := make([]int, 0, len(nodes))
nodeids := make(map[int][]*Post)
for i := range nodes {
if nodes[i].category_posts == nil {
continue
func (pq *PostQuery) loadCategories(ctx context.Context, query *CategoryQuery, nodes []*Post, init func(*Post), assign func(*Post, *Category)) error {
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[int]*Post)
nids := make(map[int]map[*Post]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
byID[node.ID] = node
if init != nil {
init(node)
}
fk := *nodes[i].category_posts
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
query.Where(func(s *sql.Selector) {
joinT := sql.Table(post.CategoriesTable)
s.Join(joinT).On(s.C(category.FieldID), joinT.C(post.CategoriesPrimaryKey[0]))
s.Where(sql.InValues(joinT.C(post.CategoriesPrimaryKey[1]), edgeIDs...))
columns := s.SelectedColumns()
s.Select(joinT.C(post.CategoriesPrimaryKey[1]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
if err := query.prepareQuery(ctx); err != nil {
return err
}
query.Where(category.IDIn(ids...))
neighbors, err := query.All(ctx)
qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
spec.ScanValues = func(columns []string) ([]any, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]any{new(sql.NullInt64)}, values...), nil
}
spec.Assign = func(columns []string, values []any) error {
outValue := int(values[0].(*sql.NullInt64).Int64)
inValue := int(values[1].(*sql.NullInt64).Int64)
if nids[inValue] == nil {
nids[inValue] = map[*Post]struct{}{byID[outValue]: {}}
return assign(columns[1:], values[1:])
}
nids[inValue][byID[outValue]] = struct{}{}
return nil
}
})
})
neighbors, err := withInterceptors[[]*Category](ctx, query, qr, query.inters)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
nodes, ok := nids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "category_posts" returned %v`, n.ID)
return fmt.Errorf(`unexpected "categories" node returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
for kn := range nodes {
assign(kn, n)
}
}
return nil

View file

@ -109,23 +109,19 @@ func (pu *PostUpdate) AddContributors(p ...*PostContributor) *PostUpdate {
return pu.AddContributorIDs(ids...)
}
// SetCategoryID sets the "category" edge to the Category entity by ID.
func (pu *PostUpdate) SetCategoryID(id int) *PostUpdate {
pu.mutation.SetCategoryID(id)
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
func (pu *PostUpdate) AddCategoryIDs(ids ...int) *PostUpdate {
pu.mutation.AddCategoryIDs(ids...)
return pu
}
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
func (pu *PostUpdate) SetNillableCategoryID(id *int) *PostUpdate {
if id != nil {
pu = pu.SetCategoryID(*id)
// AddCategories adds the "categories" edges to the Category entity.
func (pu *PostUpdate) AddCategories(c ...*Category) *PostUpdate {
ids := make([]int, len(c))
for i := range c {
ids[i] = c[i].ID
}
return pu
}
// SetCategory sets the "category" edge to the Category entity.
func (pu *PostUpdate) SetCategory(c *Category) *PostUpdate {
return pu.SetCategoryID(c.ID)
return pu.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@ -175,12 +171,27 @@ func (pu *PostUpdate) RemoveContributors(p ...*PostContributor) *PostUpdate {
return pu.RemoveContributorIDs(ids...)
}
// ClearCategory clears the "category" edge to the Category entity.
func (pu *PostUpdate) ClearCategory() *PostUpdate {
pu.mutation.ClearCategory()
// ClearCategories clears all "categories" edges to the Category entity.
func (pu *PostUpdate) ClearCategories() *PostUpdate {
pu.mutation.ClearCategories()
return pu
}
// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
func (pu *PostUpdate) RemoveCategoryIDs(ids ...int) *PostUpdate {
pu.mutation.RemoveCategoryIDs(ids...)
return pu
}
// RemoveCategories removes "categories" edges to Category entities.
func (pu *PostUpdate) RemoveCategories(c ...*Category) *PostUpdate {
ids := make([]int, len(c))
for i := range c {
ids[i] = c[i].ID
}
return pu.RemoveCategoryIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (pu *PostUpdate) Save(ctx context.Context) (int, error) {
pu.defaults()
@ -346,12 +357,12 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if pu.mutation.CategoryCleared() {
if pu.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoryTable,
Columns: []string{post.CategoryColumn},
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@ -359,12 +370,28 @@ func (pu *PostUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := pu.mutation.CategoryIDs(); len(nodes) > 0 {
if nodes := pu.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !pu.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoryTable,
Columns: []string{post.CategoryColumn},
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := pu.mutation.CategoriesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@ -473,23 +500,19 @@ func (puo *PostUpdateOne) AddContributors(p ...*PostContributor) *PostUpdateOne
return puo.AddContributorIDs(ids...)
}
// SetCategoryID sets the "category" edge to the Category entity by ID.
func (puo *PostUpdateOne) SetCategoryID(id int) *PostUpdateOne {
puo.mutation.SetCategoryID(id)
// AddCategoryIDs adds the "categories" edge to the Category entity by IDs.
func (puo *PostUpdateOne) AddCategoryIDs(ids ...int) *PostUpdateOne {
puo.mutation.AddCategoryIDs(ids...)
return puo
}
// SetNillableCategoryID sets the "category" edge to the Category entity by ID if the given value is not nil.
func (puo *PostUpdateOne) SetNillableCategoryID(id *int) *PostUpdateOne {
if id != nil {
puo = puo.SetCategoryID(*id)
// AddCategories adds the "categories" edges to the Category entity.
func (puo *PostUpdateOne) AddCategories(c ...*Category) *PostUpdateOne {
ids := make([]int, len(c))
for i := range c {
ids[i] = c[i].ID
}
return puo
}
// SetCategory sets the "category" edge to the Category entity.
func (puo *PostUpdateOne) SetCategory(c *Category) *PostUpdateOne {
return puo.SetCategoryID(c.ID)
return puo.AddCategoryIDs(ids...)
}
// Mutation returns the PostMutation object of the builder.
@ -539,12 +562,27 @@ func (puo *PostUpdateOne) RemoveContributors(p ...*PostContributor) *PostUpdateO
return puo.RemoveContributorIDs(ids...)
}
// ClearCategory clears the "category" edge to the Category entity.
func (puo *PostUpdateOne) ClearCategory() *PostUpdateOne {
puo.mutation.ClearCategory()
// ClearCategories clears all "categories" edges to the Category entity.
func (puo *PostUpdateOne) ClearCategories() *PostUpdateOne {
puo.mutation.ClearCategories()
return puo
}
// RemoveCategoryIDs removes the "categories" edge to Category entities by IDs.
func (puo *PostUpdateOne) RemoveCategoryIDs(ids ...int) *PostUpdateOne {
puo.mutation.RemoveCategoryIDs(ids...)
return puo
}
// RemoveCategories removes "categories" edges to Category entities.
func (puo *PostUpdateOne) RemoveCategories(c ...*Category) *PostUpdateOne {
ids := make([]int, len(c))
for i := range c {
ids[i] = c[i].ID
}
return puo.RemoveCategoryIDs(ids...)
}
// Where appends a list predicates to the PostUpdate builder.
func (puo *PostUpdateOne) Where(ps ...predicate.Post) *PostUpdateOne {
puo.mutation.Where(ps...)
@ -740,12 +778,12 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if puo.mutation.CategoryCleared() {
if puo.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoryTable,
Columns: []string{post.CategoryColumn},
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
@ -753,12 +791,28 @@ func (puo *PostUpdateOne) sqlSave(ctx context.Context) (_node *Post, err error)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := puo.mutation.CategoryIDs(); len(nodes) > 0 {
if nodes := puo.mutation.RemovedCategoriesIDs(); len(nodes) > 0 && !puo.mutation.CategoriesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoryTable,
Columns: []string{post.CategoryColumn},
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := puo.mutation.CategoriesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2M,
Inverse: true,
Table: post.CategoriesTable,
Columns: post.CategoriesPrimaryKey,
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(category.FieldID, field.TypeInt),

View file

@ -16,11 +16,14 @@ type Media struct {
func (Media) Fields() []ent.Field {
return []ent.Field{
field.String("storage_id").
StorageKey("storage_id").
NotEmpty().
Unique(),
field.String("original_name").
StorageKey("original_name").
NotEmpty(),
field.String("mime_type").
StorageKey("mime_type").
NotEmpty(),
field.Int64("size").
Positive(),

View file

@ -34,8 +34,7 @@ func (Post) Edges() []ent.Edge {
return []ent.Edge{
edge.To("contents", PostContent.Type),
edge.To("contributors", PostContributor.Type),
edge.From("category", Category.Type).
Ref("posts").
Unique(),
edge.From("categories", Category.Type).
Ref("posts"),
}
}

View file

@ -3,13 +3,13 @@ module tss-rocks-be
go 1.23.6
require (
bou.ke/monkey v1.0.2
entgo.io/ent v0.14.1
github.com/aws/aws-sdk-go-v2 v1.36.1
github.com/aws/aws-sdk-go-v2/config v1.29.6
github.com/aws/aws-sdk-go-v2/credentials v1.17.59
github.com/aws/aws-sdk-go-v2/service/s3 v1.76.1
github.com/chai2010/webp v1.1.1
github.com/disintegration/imaging v1.6.2
github.com/gin-gonic/gin v1.10.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
@ -17,7 +17,6 @@ require (
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
go.uber.org/mock v0.5.0
golang.org/x/crypto v0.33.0
golang.org/x/time v0.10.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
@ -68,12 +67,12 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/zclconf/go-cty v1.16.2 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
golang.org/x/arch v0.14.0 // indirect
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/sync v0.11.0 // indirect

View file

@ -1,7 +1,5 @@
ariga.io/atlas v0.31.0 h1:Nw6/Jdc7OpZfiy6oh/dJAYPp5XxGYvMTWLOUutwWjeY=
ariga.io/atlas v0.31.0/go.mod h1:J3chwsQAgjDF6Ostz7JmJJRTCbtqIupUbVR/gqZrMiA=
bou.ke/monkey v1.0.2 h1:kWcnsrCNUatbxncxR/ThdYqbytgOIArtYWqcQLQzKLI=
bou.ke/monkey v1.0.2/go.mod h1:OqickVX3tNx6t33n1xvtTtu85YN5s6cKwVug+oHMaIA=
entgo.io/ent v0.14.1 h1:fUERL506Pqr92EPHJqr8EYxbPioflJo6PudkrEA8a/s=
entgo.io/ent v0.14.1/go.mod h1:MH6XLG0KXpkcDQhKiHfANZSzR55TJyPL5IGNpI8wpco=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
@ -63,6 +61,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E=
@ -110,8 +110,6 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
@ -121,8 +119,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -139,7 +135,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -159,12 +154,13 @@ github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940 h1:4r45xpDWB6
github.com/zclconf/go-cty-debug v0.0.0-20240509010212-0d6042c53940/go.mod h1:CmBdvvj3nqzfzJ6nTCIwDTPZ56aVGvDrmztiO5g3qrM=
github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWBdig0=
github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4=
golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
@ -176,10 +172,13 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=

View file

@ -1,27 +0,0 @@
package auth
import (
"context"
"testing"
)
func TestUserIDKey(t *testing.T) {
// Test that the UserIDKey constant is defined correctly
if UserIDKey != "user_id" {
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
}
// Test context with user ID
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
value := ctx.Value(UserIDKey)
if value != "test-user-123" {
t.Errorf("Context value = %v, want %v", value, "test-user-123")
}
// Test context without user ID
emptyCtx := context.Background()
emptyValue := emptyCtx.Value(UserIDKey)
if emptyValue != nil {
t.Errorf("Empty context value = %v, want nil", emptyValue)
}
}

View file

@ -49,7 +49,7 @@ type StorageConfig struct {
Type string `yaml:"type"`
Local LocalStorage `yaml:"local"`
S3 S3Storage `yaml:"s3"`
Upload types.UploadConfig `yaml:"upload"`
Upload UploadConfig `yaml:"upload"`
}
type LocalStorage struct {
@ -66,6 +66,27 @@ type S3Storage struct {
ProxyS3 bool `yaml:"proxy_s3"`
}
type UploadConfig struct {
Limits struct {
Image struct {
MaxSize int `yaml:"max_size"`
AllowedTypes []string `yaml:"allowed_types"`
} `yaml:"image"`
Video struct {
MaxSize int `yaml:"max_size"`
AllowedTypes []string `yaml:"allowed_types"`
} `yaml:"video"`
Audio struct {
MaxSize int `yaml:"max_size"`
AllowedTypes []string `yaml:"allowed_types"`
} `yaml:"audio"`
Document struct {
MaxSize int `yaml:"max_size"`
AllowedTypes []string `yaml:"allowed_types"`
} `yaml:"document"`
} `yaml:"limits"`
}
// Load loads configuration from a YAML file
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)

View file

@ -1,85 +0,0 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
// Create a temporary test config file
content := []byte(`
database:
driver: postgres
dsn: postgres://user:pass@localhost:5432/dbname
server:
port: 8080
host: localhost
jwt:
secret: test-secret
expiration: 24h
storage:
type: local
local:
root_dir: /tmp/storage
upload:
max_size: 10485760
logging:
level: info
format: json
`)
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, content, 0644); err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Test loading config
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
want interface{}
errorMsg string
}{
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.want {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
}
})
}
}
func TestLoadError(t *testing.T) {
// Test loading non-existent file
_, err := Load("non-existent-file.yaml")
if err == nil {
t.Error("Load() error = nil, want error for non-existent file")
}
// Test loading invalid YAML
tmpDir := t.TempDir()
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
t.Fatalf("Failed to write invalid config: %v", err)
}
_, err = Load(invalidPath)
if err == nil {
t.Error("Load() error = nil, want error for invalid YAML")
}
}

View file

@ -1,312 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
)
type AuthHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *AuthHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
Auth: config.AuthConfig{
Registration: struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
}{
Enabled: true,
Message: "Registration is disabled",
},
},
}, s.service)
s.router = gin.New()
}
func (s *AuthHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite))
}
type ErrorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct {
name string
request RegisterRequest
setupMock func()
expectedStatus int
expectedError string
registration bool
}{
{
name: "成功注册",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {
s.service.EXPECT().
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(&ent.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusCreated,
registration: true,
},
{
name: "注册功能已禁用",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusForbidden,
expectedError: "Registration is disabled",
registration: false,
},
{
name: "无效的邮箱格式",
request: RegisterRequest{
Username: "testuser",
Email: "invalid-email",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
registration: true,
},
{
name: "密码太短",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "short",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
registration: true,
},
{
name: "无效的角色",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "invalid-role",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
registration: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置注册功能状态
s.handler.cfg.Auth.Registration.Enabled = tc.registration
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Register(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}
func (s *AuthHandlerTestSuite) TestLogin() {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
testCases := []struct {
name string
request LoginRequest
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功登录",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的用户名",
request: LoginRequest{
Username: "te",
Password: "password123",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
},
{
name: "用户不存在",
request: LoginRequest{
Username: "nonexistent",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "密码错误",
request: LoginRequest{
Username: "testuser",
Password: "wrongpassword",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "获取用户角色失败",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return(nil, fmt.Errorf("failed to get roles"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Login(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}

View file

@ -1,481 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
// Custom assertion function for comparing categories
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
if expected == nil && actual == nil {
return true
}
if expected == nil || actual == nil {
return assert.Fail(t, "One category is nil while the other is not")
}
// Compare only relevant fields, ignoring time fields
return assert.Equal(t, expected.ID, actual.ID) &&
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
}
// Custom assertion function for comparing category slices
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
if len(expected) != len(actual) {
return assert.Fail(t, "Category slice lengths do not match")
}
for i := range expected {
if !assertCategoryEqual(t, expected[i], actual[i]) {
return false
}
}
return true
}
type CategoryHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *CategoryHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *CategoryHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestCategoryHandlerSuite(t *testing.T) {
suite.Run(t, new(CategoryHandlerTestSuite))
}
// Test cases for ListCategories
func (s *CategoryHandlerTestSuite) TestListCategories() {
testCases := []struct {
name string
langCode string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("en")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("zh")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
}
})
}
}
// Test cases for GetCategory
func (s *CategoryHandlerTestSuite) TestGetCategory() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
langCode: "en",
slug: "test-category",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
Return(&ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
{
name: "Not Found",
langCode: "en",
slug: "non-existent",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
Return(nil, types.ErrNotFound)
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
}
})
}
}
// Test cases for AddCategoryContent
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
var description = "Test Description"
testCases := []struct {
name string
categoryID string
requestBody interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(&ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
},
},
{
name: "Invalid JSON",
categoryID: "1",
requestBody: "invalid json",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid Category ID",
categoryID: "invalid",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Service Error",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
var body []byte
var err error
if str, ok := tc.requestBody.(string); ok {
body = []byte(str)
} else {
body, err = json.Marshal(tc.requestBody)
s.NoError(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.CategoryContent
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, &response)
}
})
}
}
// Test cases for CreateCategory
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功创建分类",
setupMock: func() {
category := &ent.Category{
ID: 1,
}
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(category, nil)
},
expectedStatus: http.StatusCreated,
},
{
name: "创建分类失败",
setupMock: func() {
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(nil, errors.New("failed to create category"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to create category",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.CreateCategory(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
s.Equal(1, response.ID)
}
})
}
}

View file

@ -1,456 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
type ContributorHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *ContributorHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *ContributorHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestContributorHandlerSuite(t *testing.T) {
suite.Run(t, new(ContributorHandlerTestSuite))
}
func (s *ContributorHandlerTestSuite) TestListContributors() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return([]*ent.Contributor{
{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
{
ID: 2,
Name: "Jane Smith",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []gin.H{
{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
{
"id": 2,
"name": "Jane Smith",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list contributors"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestGetContributor() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(&ent.Contributor{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
},
{
name: "Invalid ID",
id: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Service error",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(&ent.Contributor{
ID: 1,
Name: name,
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"name": "", // Empty name is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
testCases := []struct {
name string
id string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(&ent.ContributorSocialLink{
Type: "github",
Name: name,
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"type": "github",
"name": "johndoe",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
{
name: "Invalid contributor ID",
id: "invalid",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Invalid request body",
id: "1",
body: map[string]interface{}{
"type": "", // Empty type is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
},
{
name: "Service error",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add contributor social link"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -1,532 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type DailyHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *DailyHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *DailyHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestDailyHandlerSuite(t *testing.T) {
suite.Run(t, new(DailyHandlerTestSuite))
}
func (s *DailyHandlerTestSuite) TestListDailies() {
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
},
},
{
name: "Success with category filter",
categoryID: "1",
setupMock: func() {
categoryID := 1
s.service.EXPECT().
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with pagination",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
},
},
{
name: "Service Error",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list dailies"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/dailies"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestGetDaily() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
{
name: "Service error",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestCreateDaily() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"id": "daily1",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
testCases := []struct {
name string
dailyID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(&ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
{
name: "Invalid request body",
dailyID: "daily1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
},
{
name: "Service error",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add daily content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -25,8 +25,9 @@ func NewHandler(cfg *config.Config, service service.Service) *Handler {
}
// RegisterRoutes registers all the routes
func (h *Handler) RegisterRoutes(r *gin.Engine) {
api := r.Group("/api/v1")
func (h *Handler) RegisterRoutes(router *gin.Engine) {
// API routes
api := router.Group("/api/v1")
{
// Auth routes
auth := api.Group("/auth")
@ -93,6 +94,9 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
media.DELETE("/:id", h.DeleteMedia)
}
}
// Public media files
router.GET("/media/:year/:month/:filename", h.GetMediaFile)
}
// Category handlers
@ -186,10 +190,12 @@ func (h *Handler) ListPosts(c *gin.Context) {
langCode = "en" // Default to English
}
var categoryID *int
if catIDStr := c.Query("category_id"); catIDStr != "" {
if id, err := strconv.Atoi(catIDStr); err == nil {
categoryID = &id
var categoryIDs []int
if catIDsStr := c.QueryArray("category_ids"); len(catIDsStr) > 0 {
for _, idStr := range catIDsStr {
if id, err := strconv.Atoi(idStr); err == nil {
categoryIDs = append(categoryIDs, id)
}
}
}
@ -207,7 +213,7 @@ func (h *Handler) ListPosts(c *gin.Context) {
}
}
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset)
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryIDs, limit, offset)
if err != nil {
log.Error().Err(err).Msg("Failed to list posts")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"})
@ -244,10 +250,10 @@ func (h *Handler) GetPost(c *gin.Context) {
contents := make([]gin.H, 0, len(post.Edges.Contents))
for _, content := range post.Edges.Contents {
contents = append(contents, gin.H{
"language_code": content.LanguageCode,
"title": content.Title,
"language_code": content.LanguageCode,
"title": content.Title,
"content_markdown": content.ContentMarkdown,
"summary": content.Summary,
"summary": content.Summary,
})
}
response["edges"].(gin.H)["contents"] = contents
@ -256,7 +262,22 @@ func (h *Handler) GetPost(c *gin.Context) {
}
func (h *Handler) CreatePost(c *gin.Context) {
post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status
var req struct {
Status string `json:"status" binding:"omitempty,oneof=draft published archived"`
CategoryIDs []int `json:"category_ids" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
status := req.Status
if status == "" {
status = "draft" // Default to draft status
}
post, err := h.service.CreatePost(c.Request.Context(), status, req.CategoryIDs)
if err != nil {
log.Error().Err(err).Msg("Failed to create post")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"})
@ -268,7 +289,8 @@ func (h *Handler) CreatePost(c *gin.Context) {
"id": post.ID,
"status": post.Status,
"edges": gin.H{
"contents": []interface{}{},
"contents": []interface{}{},
"categories": []interface{}{},
},
}
@ -276,7 +298,7 @@ func (h *Handler) CreatePost(c *gin.Context) {
}
type AddPostContentRequest struct {
LanguageCode string `json:"language_code" binding:"required"`
LanguageCode string `json:"language_code" binding:"required"`
Title string `json:"title" binding:"required"`
ContentMarkdown string `json:"content_markdown" binding:"required"`
Summary string `json:"summary" binding:"required"`
@ -308,10 +330,10 @@ func (h *Handler) AddPostContent(c *gin.Context) {
"title": content.Title,
"content_markdown": content.ContentMarkdown,
"language_code": content.LanguageCode,
"summary": content.Summary,
"summary": content.Summary,
"meta_keywords": content.MetaKeywords,
"meta_description": content.MetaDescription,
"edges": gin.H{},
"edges": gin.H{},
})
}
@ -517,11 +539,3 @@ func (h *Handler) AddDailyContent(c *gin.Context) {
c.JSON(http.StatusCreated, content)
}
// Helper functions
func stringPtr(s *string) string {
if s == nil {
return ""
}
return *s
}

View file

@ -1,43 +0,0 @@
package handler
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStringPtr(t *testing.T) {
testCases := []struct {
name string
input *string
expected string
}{
{
name: "nil pointer",
input: nil,
expected: "",
},
{
name: "empty string",
input: strPtr(""),
expected: "",
},
{
name: "non-empty string",
input: strPtr("test"),
expected: "test",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := stringPtr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// Helper function to create string pointer
func strPtr(s string) *string {
return &s
}

View file

@ -5,9 +5,11 @@ import (
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"path/filepath"
)
// Media handlers
@ -33,17 +35,29 @@ func (h *Handler) ListMedia(c *gin.Context) {
return
}
c.JSON(http.StatusOK, media)
c.JSON(http.StatusOK, gin.H{
"data": media,
})
}
func (h *Handler) UploadMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
userID, exists := c.Get("user_id")
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Convert user ID to int
userID, err := strconv.Atoi(userIDStr.(string))
if err != nil {
log.Error().Err(err).
Str("user_id", fmt.Sprintf("%v", userIDStr)).
Msg("Failed to convert user ID to int")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}
// Get file from form
file, err := c.FormFile("file")
if err != nil {
@ -51,38 +65,107 @@ func (h *Handler) UploadMedia(c *gin.Context) {
return
}
// 文件大小限制
if file.Size > 10*1024*1024 { // 10MB
c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"})
// 获取文件类型和扩展名
contentType := file.Header.Get("Content-Type")
ext := strings.ToLower(filepath.Ext(file.Filename))
if contentType == "" {
// 如果 Content-Type 为空,尝试从文件扩展名判断
switch ext {
case ".jpg", ".jpeg":
contentType = "image/jpeg"
case ".png":
contentType = "image/png"
case ".gif":
contentType = "image/gif"
case ".webp":
contentType = "image/webp"
case ".mp4":
contentType = "video/mp4"
case ".webm":
contentType = "video/webm"
case ".mp3":
contentType = "audio/mpeg"
case ".ogg":
contentType = "audio/ogg"
case ".wav":
contentType = "audio/wav"
case ".pdf":
contentType = "application/pdf"
case ".doc":
contentType = "application/msword"
case ".docx":
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
}
// 根据 Content-Type 确定文件类型和限制
var maxSize int64
var allowedTypes []string
var fileType string
limits := h.cfg.Storage.Upload.Limits
switch {
case strings.HasPrefix(contentType, "image/"):
maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
allowedTypes = limits.Image.AllowedTypes
fileType = "image"
case strings.HasPrefix(contentType, "video/"):
maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
allowedTypes = limits.Video.AllowedTypes
fileType = "video"
case strings.HasPrefix(contentType, "audio/"):
maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
allowedTypes = limits.Audio.AllowedTypes
fileType = "audio"
case strings.HasPrefix(contentType, "application/"):
maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
allowedTypes = limits.Document.AllowedTypes
fileType = "document"
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": "Unsupported file type",
})
return
}
// 文件类型限制
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"video/mp4": true,
"video/webm": true,
"audio/mpeg": true,
"audio/ogg": true,
"application/pdf": true,
// 检查文件类型是否允许
typeAllowed := false
for _, allowed := range allowedTypes {
if contentType == allowed {
typeAllowed = true
break
}
}
contentType := file.Header.Get("Content-Type")
if _, ok := allowedTypes[contentType]; !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"})
if !typeAllowed {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
})
return
}
// 检查文件大小
if file.Size > maxSize {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
})
return
}
// Upload file
media, err := h.service.Upload(c.Request.Context(), file, userID.(int))
media, err := h.service.Upload(c.Request.Context(), file, userID)
if err != nil {
log.Error().Err(err).Msg("Failed to upload media")
log.Error().Err(err).
Str("filename", file.Filename).
Str("content_type", contentType).
Int("user_id", userID).
Msg("Failed to upload media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"})
return
}
c.JSON(http.StatusCreated, media)
c.JSON(http.StatusCreated, gin.H{
"data": media,
})
}
func (h *Handler) GetMedia(c *gin.Context) {
@ -101,7 +184,7 @@ func (h *Handler) GetMedia(c *gin.Context) {
}
// Get file content
reader, info, err := h.service.GetFile(c.Request.Context(), id)
reader, info, err := h.service.GetFile(c.Request.Context(), media.StorageID)
if err != nil {
log.Error().Err(err).Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
@ -122,16 +205,18 @@ func (h *Handler) GetMedia(c *gin.Context) {
}
func (h *Handler) GetMediaFile(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
year := c.Param("year")
month := c.Param("month")
filename := c.Param("filename")
// Get file content
reader, info, err := h.service.GetFile(c.Request.Context(), id)
reader, info, err := h.service.GetFile(c.Request.Context(), filename) // 直接使用完整的文件名
if err != nil {
log.Error().Err(err).Msg("Failed to get media file")
log.Error().Err(err).
Str("year", year).
Str("month", month).
Str("filename", filename).
Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
return
}
@ -151,20 +236,33 @@ func (h *Handler) GetMediaFile(c *gin.Context) {
func (h *Handler) DeleteMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
userID, exists := c.Get("user_id")
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Convert user ID to int
userID, err := strconv.Atoi(userIDStr.(string))
if err != nil {
log.Error().Err(err).
Str("user_id", fmt.Sprintf("%v", userIDStr)).
Msg("Failed to convert user ID to int")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil {
log.Error().Err(err).Msg("Failed to delete media")
if err := h.service.DeleteMedia(c.Request.Context(), id, userID); err != nil {
log.Error().Err(err).
Int("media_id", id).
Int("user_id", userID).
Msg("Failed to delete media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"})
return
}

View file

@ -1,524 +0,0 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/storage"
"net/textproto"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)
type MediaHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *MediaHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{}, s.service)
s.router = gin.New()
}
func (s *MediaHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestMediaHandlerSuite(t *testing.T) {
suite.Run(t, new(MediaHandlerTestSuite))
}
func (s *MediaHandlerTestSuite) TestListMedia() {
testCases := []struct {
name string
query string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功列出媒体",
query: "?limit=10&offset=0",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "使用默认限制和偏移",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "列出媒体失败",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return(nil, errors.New("failed to list media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to list media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.ListMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response []*ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestUploadMedia() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功上传媒体",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
expectedFile := &multipart.FileHeader{
Filename: "test.jpg",
Size: int64(len("test content")),
Header: textproto.MIMEHeader{
"Content-Type": []string{"image/jpeg"},
},
}
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
s.Equal(expectedFile.Filename, f.Filename)
s.Equal(expectedFile.Size, f.Size)
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
return &ent.Media{ID: 1}, nil
})
},
expectedStatus: http.StatusCreated,
},
{
name: "未授权",
setupRequest: func() (*http.Request, error) {
req := httptest.NewRequest(http.MethodPost, "/media", nil)
return req, nil
},
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "上传失败",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
Return(nil, errors.New("failed to upload"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to upload media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, err := tc.setupRequest()
s.NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.UploadMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功获取媒体",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
Size: 11,
Name: "test.jpg",
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "获取媒体元数据失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(nil, errors.New("failed to get media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media",
},
{
name: "获取媒体文件失败",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media file",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 执行请求
s.handler.GetMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal("11", w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
s.Equal("test content", w.Body.String())
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedBody []byte
}{
{
name: "成功获取媒体文件",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
fileContent := "test file content"
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
Name: "test.jpg",
Size: int64(len(fileContent)),
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []byte("test file content"),
},
{
name: "无效的媒体ID",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "获取媒体文件失败",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup
req, err := tc.setupRequest()
s.Require().NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// Setup mock
tc.setupMock()
// Test
s.handler.GetMediaFile(c)
// Verify
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
s.Equal(tc.expectedBody, w.Body.Bytes())
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
}
})
}
}
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功删除媒体",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(nil)
},
expectedStatus: http.StatusNoContent,
},
{
name: "未授权",
mediaID: "1",
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "删除媒体失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(errors.New("failed to delete"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to delete media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.DeleteMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
}
})
}
}

View file

@ -1,624 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type PostHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *PostHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *PostHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestPostHandlerSuite(t *testing.T) {
suite.Run(t, new(PostHandlerTestSuite))
}
// Test cases for ListPosts
func (s *PostHandlerTestSuite) TestListPosts() {
categoryID := 1
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
},
},
{
name: "Success with category filter",
langCode: "en",
categoryID: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with pagination",
langCode: "en",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
},
},
{
name: "Service Error",
langCode: "en",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/posts"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Post
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, response)
}
})
}
}
// Test cases for GetPost
func (s *PostHandlerTestSuite) TestGetPost() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "zh", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "zh",
"title": "测试帖子",
"content_markdown": "测试内容",
"summary": "测试摘要",
},
},
},
},
},
{
name: "Service error",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/posts/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for CreatePost
func (s *PostHandlerTestSuite) TestCreatePost() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(&ent.Post{
ID: 1,
Status: "draft",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"status": "draft",
"edges": gin.H{
"contents": []gin.H{},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for AddPostContent
func (s *PostHandlerTestSuite) TestAddPostContent() {
testCases := []struct {
name string
postID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(&ent.PostContent{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
Edges: ent.PostContentEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
"meta_keywords": "test,keywords",
"meta_description": "Test meta description",
"edges": gin.H{},
},
},
{
name: "Invalid post ID",
postID: "invalid",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid post ID"},
},
{
name: "Invalid request body",
postID: "1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
},
{
name: "Service error",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add post content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -1,238 +0,0 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View file

@ -1,227 +0,0 @@
package middleware
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/service"
)
func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
panic(fmt.Sprintf("Failed to sign token: %v", err))
}
return signedToken
}
func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret"
tokenBlacklist := service.NewTokenBlacklist()
testCases := []struct {
name string
setupAuth func(*http.Request)
expectedStatus int
expectedBody map[string]string
checkUserData bool
expectedUserID string
expectedRoles []string
}{
{
name: "No Authorization header",
setupAuth: func(req *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header is required"},
},
{
name: "Invalid Authorization format",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidFormat")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
},
{
name: "Invalid token",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer invalid.token.here")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
{
name: "Valid token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"admin", "editor"},
"exp": time.Now().Add(time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusOK,
checkUserData: true,
expectedUserID: "123",
expectedRoles: []string{"admin", "editor"},
},
{
name: "Expired token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"user"},
"exp": time.Now().Add(-time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加认证中间件
router.Use(func(c *gin.Context) {
// 设置日志级别为 debug
gin.SetMode(gin.DebugMode)
c.Next()
}, AuthMiddleware(jwtSecret, tokenBlacklist))
// 测试路由
router.GET("/test", func(c *gin.Context) {
if tc.checkUserData {
userID, exists := c.Get("user_id")
assert.True(t, exists, "user_id should exist in context")
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
roles, exists := c.Get("user_roles")
assert.True(t, exists, "user_roles should exist in context")
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
}
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
tc.setupAuth(req)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}
func TestRoleMiddleware(t *testing.T) {
testCases := []struct {
name string
setupContext func(*gin.Context)
allowedRoles []string
expectedStatus int
expectedBody map[string]string
}{
{
name: "No user roles",
setupContext: func(c *gin.Context) {
// 不设置用户角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User roles not found"},
},
{
name: "Invalid roles type",
setupContext: func(c *gin.Context) {
c.Set("user_roles", 123) // 设置错误类型的角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user roles type"},
},
{
name: "Insufficient permissions",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden,
expectedBody: map[string]string{"error": "Insufficient permissions"},
},
{
name: "Allowed role",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"admin"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK,
},
{
name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user", "editor"})
},
allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加角色中间件
router.Use(func(c *gin.Context) {
tc.setupContext(c)
c.Next()
}, RoleMiddleware(tc.allowedRoles...))
// 测试路由
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}

View file

@ -1,76 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCORS(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
checkHeaders bool
}{
{
name: "Normal GET request",
method: "GET",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
{
name: "OPTIONS request",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
checkHeaders: true,
},
{
name: "POST request",
method: "POST",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加 CORS 中间件
router.Use(CORS())
// 添加测试路由
router.Any("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest(tc.method, "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证状态码
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.checkHeaders {
// 验证 CORS 头部
headers := rec.Header()
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
}
})
}
}

View file

@ -1,207 +0,0 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}

View file

@ -3,146 +3,120 @@ package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"tss-rocks-be/internal/config"
"github.com/gin-gonic/gin"
"tss-rocks-be/internal/types"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
)
// ValidateUpload 创建文件上传验证中间件
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
// ValidateUpload 验证上传的文件
func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查是否是multipart/form-data请求
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Content-Type must be multipart/form-data",
})
// Get file from form
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
c.Abort()
return
}
// 解析multipart表单
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Failed to parse form: %v", err),
})
c.Abort()
return
}
// 获取文件类型和扩展名
contentType := file.Header.Get("Content-Type")
ext := strings.ToLower(filepath.Ext(file.Filename))
form := c.Request.MultipartForm
if form == nil || form.File == nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "No file uploaded",
})
c.Abort()
return
}
// 遍历所有上传的文件
for _, files := range form.File {
for _, file := range files {
// 检查文件大小
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
})
c.Abort()
return
}
// 检查文件扩展名
ext := strings.ToLower(filepath.Ext(file.Filename))
if !contains(cfg.AllowedExtensions, ext) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File extension %s is not allowed", ext),
})
c.Abort()
return
}
// 打开文件
src, err := file.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to open file: %v", err),
})
c.Abort()
return
}
defer src.Close()
// 读取文件头部用于MIME类型检测
header := make([]byte, maxHeaderBytes)
n, err := src.Read(header)
if err != nil && err != io.EOF {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
header = header[:n]
// 检测MIME类型
contentType := http.DetectContentType(header)
if !contains(cfg.AllowedTypes, contentType) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File type %s is not allowed", contentType),
})
c.Abort()
return
}
// 将文件指针重置到开始位置
_, err = src.Seek(0, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将文件内容读入缓冲区
buf := &bytes.Buffer{}
_, err = io.Copy(buf, src)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将验证过的文件内容和类型保存到上下文中
c.Set("validated_file_"+file.Filename, buf)
c.Set("validated_content_type_"+file.Filename, contentType)
// 如果 Content-Type 为空,尝试从文件扩展名判断
if contentType == "" {
switch ext {
case ".jpg", ".jpeg":
contentType = "image/jpeg"
case ".png":
contentType = "image/png"
case ".gif":
contentType = "image/gif"
case ".webp":
contentType = "image/webp"
case ".mp4":
contentType = "video/mp4"
case ".webm":
contentType = "video/webm"
case ".mp3":
contentType = "audio/mpeg"
case ".ogg":
contentType = "audio/ogg"
case ".wav":
contentType = "audio/wav"
case ".pdf":
contentType = "application/pdf"
case ".doc":
contentType = "application/msword"
case ".docx":
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
}
// 根据 Content-Type 确定文件类型和限制
var maxSize int64
var allowedTypes []string
var fileType string
limits := cfg.Limits
switch {
case strings.HasPrefix(contentType, "image/"):
maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
allowedTypes = limits.Image.AllowedTypes
fileType = "image"
case strings.HasPrefix(contentType, "video/"):
maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
allowedTypes = limits.Video.AllowedTypes
fileType = "video"
case strings.HasPrefix(contentType, "audio/"):
maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
allowedTypes = limits.Audio.AllowedTypes
fileType = "audio"
case strings.HasPrefix(contentType, "application/"):
maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
allowedTypes = limits.Document.AllowedTypes
fileType = "document"
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Unsupported file type: %s", contentType),
})
c.Abort()
return
}
// 检查文件类型是否允许
typeAllowed := false
for _, allowed := range allowedTypes {
if contentType == allowed {
typeAllowed = true
break
}
}
if !typeAllowed {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
})
c.Abort()
return
}
// 检查文件大小
if file.Size > maxSize {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
})
c.Abort()
return
}
c.Next()
}
}
// contains 检查切片中是否包含指定的字符串
func contains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
// GetValidatedFile 从上下文中获取验证过的文件内容
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
file, exists := c.Get("validated_file_" + filename)

View file

@ -1,262 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}

View file

@ -1,99 +0,0 @@
package rbac
import (
"context"
"testing"
"tss-rocks-be/ent/enttest"
"tss-rocks-be/ent/role"
_ "github.com/mattn/go-sqlite3"
)
func TestInitializeRBAC(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Test initialization
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Verify roles were created
for roleName := range DefaultRoles {
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
if err != nil {
t.Errorf("Role %s was not created: %v", roleName, err)
}
// Verify permissions for each role
perms, err := r.QueryPermissions().All(ctx)
if err != nil {
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
}
expectedPerms := DefaultRoles[roleName]
permCount := 0
for _, actions := range expectedPerms {
permCount += len(actions)
}
if len(perms) != permCount {
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
}
}
}
func TestAssignRoleToUser(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Initialize RBAC
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Create a test user
user, err := client.User.Create().
SetEmail("test@example.com").
SetUsername("testuser").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx)
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
// Test assigning role to user
err = AssignRoleToUser(ctx, client, user.ID, "editor")
if err != nil {
t.Fatalf("Failed to assign role to user: %v", err)
}
// Verify role assignment
assignedRoles, err := user.QueryRoles().All(ctx)
if err != nil {
t.Fatalf("Failed to query user roles: %v", err)
}
if len(assignedRoles) != 1 {
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
}
if assignedRoles[0].Name != "editor" {
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
}
// Test assigning non-existent role
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
if err == nil {
t.Error("Expected error when assigning non-existent role, got nil")
}
}

View file

@ -1,64 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitDatabase(t *testing.T) {
tests := []struct {
name string
driver string
dsn string
wantErr bool
errContains string
}{
{
name: "success with sqlite3",
driver: "sqlite3",
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
},
{
name: "invalid driver",
driver: "invalid_driver",
dsn: "file:ent?mode=memory",
wantErr: true,
errContains: "unsupported driver",
},
{
name: "invalid dsn",
driver: "sqlite3",
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
wantErr: true,
errContains: "foreign_keys pragma is off",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, client)
} else {
require.NoError(t, err)
assert.NotNil(t, client)
// 测试数据库连接是否正常工作
err = client.Schema.Create(ctx)
assert.NoError(t, err)
// 清理
client.Close()
}
})
}
}

View file

@ -1,40 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"tss-rocks-be/internal/config"
)
func TestNewEntClient(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
}{
{
name: "default sqlite3 config",
cfg: &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewEntClient(tt.cfg)
assert.NotNil(t, client)
// 验证客户端是否可以正常工作
err := client.Schema.Create(context.Background())
assert.NoError(t, err)
// 清理
client.Close()
})
}
}

View file

@ -1,220 +0,0 @@
package server
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/types"
"tss-rocks-be/ent/enttest"
)
func TestNew(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8080,
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
Upload: types.UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/v1/upload": {Rate: 10, Burst: 20},
},
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "testdata/access.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 3,
Compress: true,
LocalTime: true,
},
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 测试服务器初始化
s, err := New(cfg, client)
require.NoError(t, err)
assert.NotNil(t, s)
assert.NotNil(t, s.router)
assert.NotNil(t, s.handler)
assert.Equal(t, cfg, s.config)
}
func TestNew_StorageError(t *testing.T) {
// 创建一个无效的存储配置
cfg := &config.Config{
Storage: config.StorageConfig{
Type: "invalid_type", // 使用无效的存储类型
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
s, err := New(cfg, client)
assert.Error(t, err)
assert.Nil(t, s)
assert.Contains(t, err.Error(), "failed to initialize storage")
}
func TestServer_StartAndShutdown(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 0, // 使用随机端口
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
Format: "json",
Level: "info",
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 初始化服务器
s, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 在 goroutine 中启动服务器
go func() {
err := s.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 测试关闭服务器
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = s.Shutdown(ctx)
assert.NoError(t, err)
// 检查服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
}
func TestServer_StartError(t *testing.T) {
// 创建一个配置,使用已经被占用的端口来触发错误
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8899, // 使用固定端口以便测试
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 创建第一个服务器实例
s1, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 启动第一个服务器
go func() {
err := s1.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 尝试在同一端口启动第二个服务器,应该会失败
s2, err := New(cfg, client)
require.NoError(t, err)
err = s2.Start()
assert.Error(t, err)
// 清理
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 关闭第一个服务器
err = s1.Shutdown(ctx)
assert.NoError(t, err)
// 检查第一个服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
// 关闭第二个服务器
err = s2.Shutdown(ctx)
assert.NoError(t, err)
}
func TestServer_ShutdownWithNilServer(t *testing.T) {
s := &Server{}
err := s.Shutdown(context.Background())
assert.NoError(t, err)
}

View file

@ -1,11 +1,14 @@
package service
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
@ -26,6 +29,8 @@ import (
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/chai2010/webp"
"github.com/disintegration/imaging"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@ -419,21 +424,71 @@ func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, us
// Open the uploaded file
src, err := openFile(file)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to open file: %v", err)
}
defer src.Close()
// 获取文件类型和扩展名
contentType := file.Header.Get("Content-Type")
ext := strings.ToLower(filepath.Ext(file.Filename))
if contentType == "" {
// 如果 Content-Type 为空,尝试从文件扩展名判断
switch ext {
case ".jpg", ".jpeg":
contentType = "image/jpeg"
case ".png":
contentType = "image/png"
case ".gif":
contentType = "image/gif"
case ".webp":
contentType = "image/webp"
case ".mp4":
contentType = "video/mp4"
case ".webm":
contentType = "video/webm"
case ".mp3":
contentType = "audio/mpeg"
case ".ogg":
contentType = "audio/ogg"
case ".wav":
contentType = "audio/wav"
case ".pdf":
contentType = "application/pdf"
case ".doc":
contentType = "application/msword"
case ".docx":
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
}
// 如果是图片,检查是否需要转换为 WebP
var fileToSave multipart.File = src
var finalContentType = contentType
if strings.HasPrefix(contentType, "image/") && contentType != "image/webp" {
// 转换为 WebP
webpFile, err := convertToWebP(src)
if err != nil {
return nil, fmt.Errorf("failed to convert image to WebP: %v", err)
}
fileToSave = webpFile
finalContentType = "image/webp"
ext = ".webp"
}
// 生成带扩展名的存储文件名
storageFilename := uuid.New().String() + ext
// Save the file to storage
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
fileInfo, err := s.storage.Save(ctx, storageFilename, finalContentType, fileToSave)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to save file: %v", err)
}
// Create media record
return s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(file.Filename).
SetMimeType(fileInfo.ContentType).
SetMimeType(finalContentType).
SetSize(fileInfo.Size).
SetURL(fileInfo.URL).
SetCreatedBy(strconv.Itoa(userID)).
@ -444,13 +499,8 @@ func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error)
return s.client.Media.Get(ctx, id)
}
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.GetMedia(ctx, id)
if err != nil {
return nil, nil, err
}
return s.storage.Get(ctx, media.StorageID)
func (s *serviceImpl) GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error) {
return s.storage.Get(ctx, storageID)
}
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
@ -476,7 +526,7 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error
}
// Post operations
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
func (s *serviceImpl) CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error) {
var postStatus post.Status
switch status {
case "draft":
@ -492,10 +542,25 @@ func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post,
// Generate a random slug
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
return s.client.Post.Create().
// Create post with categories
postCreate := s.client.Post.Create().
SetStatus(postStatus).
SetSlug(slug).
Save(ctx)
SetSlug(slug)
// Add categories if provided
if len(categoryIDs) > 0 {
categories := make([]*ent.Category, 0, len(categoryIDs))
for _, id := range categoryIDs {
category, err := s.client.Category.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get category %d: %w", id, err)
}
categories = append(categories, category)
}
postCreate.AddCategories(categories...)
}
return postCreate.Save(ctx)
}
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
@ -574,7 +639,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
WithCategories().
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
@ -590,7 +655,7 @@ func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string)
return posts[0], nil
}
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
@ -610,8 +675,8 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
Where(post.StatusEQ(post.StatusPublished))
// Add category filter if provided
if categoryID != nil {
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
if len(categoryIDs) > 0 {
query = query.Where(post.HasCategoriesWith(category.IDIn(categoryIDs...)))
}
// Get unique post IDs
@ -638,7 +703,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
}
// If no category filter is applied, only take the latest 5 posts
if categoryID == nil && len(postIDs) > 5 {
if len(categoryIDs) == 0 && len(postIDs) > 5 {
postIDs = postIDs[:5]
}
@ -666,7 +731,7 @@ func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
WithCategories().
Order(ent.Desc(post.FieldCreatedAt)).
All(ctx)
if err != nil {
@ -1050,3 +1115,47 @@ func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID
return s.client.Daily.DeleteOneID(id).Exec(ctx)
}
// convertToWebP 将图片转换为 WebP 格式
func convertToWebP(src multipart.File) (multipart.File, error) {
// 读取原始图片
img, err := imaging.Decode(src)
if err != nil {
return nil, fmt.Errorf("failed to decode image: %v", err)
}
// 创建一个新的缓冲区来存储 WebP 图片
buf := new(bytes.Buffer)
// 将图片编码为 WebP 格式
// 设置较高的质量以保持图片质量
err = webp.Encode(buf, img, &webp.Options{
Lossless: false,
Quality: 90,
})
if err != nil {
return nil, fmt.Errorf("failed to encode image to WebP: %v", err)
}
// 创建一个新的临时文件来存储转换后的图片
tmpFile, err := os.CreateTemp("", "webp-*.webp")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %v", err)
}
// 写入转换后的数据
if _, err := io.Copy(tmpFile, buf); err != nil {
tmpFile.Close()
os.Remove(tmpFile.Name())
return nil, fmt.Errorf("failed to write WebP data: %v", err)
}
// 将文件指针移回开始位置
if _, err := tmpFile.Seek(0, 0); err != nil {
tmpFile.Close()
os.Remove(tmpFile.Name())
return nil, fmt.Errorf("failed to seek file: %v", err)
}
return tmpFile, nil
}

File diff suppressed because it is too large Load diff

View file

@ -1,332 +0,0 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/textproto"
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/storage/mock"
"tss-rocks-be/internal/testutil"
"bou.ke/monkey"
)
type MediaServiceTestSuite struct {
suite.Suite
ctx context.Context
client *ent.Client
storage *mock.MockStorage
ctrl *gomock.Controller
svc MediaService
}
func (s *MediaServiceTestSuite) SetupTest() {
s.ctx = context.Background()
s.client = testutil.NewTestClient()
require.NotNil(s.T(), s.client)
s.ctrl = gomock.NewController(s.T())
s.storage = mock.NewMockStorage(s.ctrl)
s.svc = NewMediaService(s.client, s.storage)
// 清理数据库
_, err := s.client.Media.Delete().Exec(s.ctx)
require.NoError(s.T(), err)
}
func (s *MediaServiceTestSuite) TearDownTest() {
s.ctrl.Finish()
s.client.Close()
}
func TestMediaServiceSuite(t *testing.T) {
suite.Run(t, new(MediaServiceTestSuite))
}
type mockFileHeader struct {
filename string
contentType string
size int64
content []byte
}
func (h *mockFileHeader) Open() (multipart.File, error) {
return newMockMultipartFile(h.content), nil
}
func (h *mockFileHeader) Filename() string {
return h.filename
}
func (h *mockFileHeader) Size() int64 {
return h.size
}
func (h *mockFileHeader) Header() textproto.MIMEHeader {
header := make(textproto.MIMEHeader)
header.Set("Content-Type", h.contentType)
return header
}
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
header := &multipart.FileHeader{
Filename: filename,
Header: make(textproto.MIMEHeader),
Size: int64(len(content)),
}
header.Header.Set("Content-Type", contentType)
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
return newMockMultipartFile(content), nil
})
return header
}
func (s *MediaServiceTestSuite) TestUpload() {
testCases := []struct {
name string
filename string
contentType string
content []byte
setupMock func()
wantErr bool
errMsg string
}{
{
name: "Upload text file",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
content, err := io.ReadAll(reader)
s.Require().NoError(err)
s.Equal([]byte("test content"), content)
return &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: int64(len(content)),
}, nil
})
},
wantErr: false,
},
{
name: "Invalid filename",
filename: "../test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {},
wantErr: true,
errMsg: "invalid filename",
},
{
name: "Storage error",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
Return(nil, fmt.Errorf("storage error"))
},
wantErr: true,
errMsg: "storage error",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create test file
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
// Add debug output
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
// Test upload
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
// Add debug output
if err != nil {
s.T().Logf("Upload error: %v", err)
}
if tc.wantErr {
s.Require().Error(err)
s.Contains(err.Error(), tc.errMsg)
return
}
s.Require().NoError(err)
s.NotNil(media)
s.Equal(tc.filename, media.OriginalName)
s.Equal(tc.contentType, media.MimeType)
s.Equal(int64(len(tc.content)), media.Size)
s.Equal("1", media.CreatedBy)
})
}
}
func (s *MediaServiceTestSuite) TestGet() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test get existing media
result, err := s.svc.Get(s.ctx, media.ID)
s.Require().NoError(err)
s.Equal(media.ID, result.ID)
s.Equal(media.OriginalName, result.OriginalName)
// Test get non-existing media
_, err = s.svc.Get(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "media not found")
}
func (s *MediaServiceTestSuite) TestDelete() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test delete by unauthorized user
err = s.svc.Delete(s.ctx, media.ID, 2)
s.Require().Error(err)
s.Contains(err.Error(), "unauthorized")
// Test delete by owner
s.storage.EXPECT().
Delete(gomock.Any(), "test-id").
Return(nil)
err = s.svc.Delete(s.ctx, media.ID, 1)
s.Require().NoError(err)
// Verify media is deleted
_, err = s.svc.Get(s.ctx, media.ID)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestList() {
// Create test media
for i := 0; i < 5; i++ {
_, err := s.client.Media.Create().
SetStorageID(fmt.Sprintf("test-id-%d", i)).
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
SetMimeType("text/plain").
SetSize(12).
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
}
// Test list with limit and offset
media, err := s.svc.List(s.ctx, 3, 1)
s.Require().NoError(err)
s.Len(media, 3)
}
func (s *MediaServiceTestSuite) TestGetFile() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Mock storage.Get
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
mockFileInfo := &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: 12,
}
s.storage.EXPECT().
Get(gomock.Any(), "test-id").
Return(mockReader, mockFileInfo, nil)
// Test get file
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
s.Require().NoError(err)
s.NotNil(reader)
s.Equal(mockFileInfo, info)
// Test get non-existing file
_, _, err = s.svc.GetFile(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestIsValidFilename() {
testCases := []struct {
name string
filename string
want bool
}{
{
name: "Valid filename",
filename: "test.txt",
want: true,
},
{
name: "Invalid filename with ../",
filename: "../test.txt",
want: false,
},
{
name: "Invalid filename with ./",
filename: "./test.txt",
want: false,
},
{
name: "Invalid filename with backslash",
filename: "test\\file.txt",
want: false,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
got := isValidFilename(tc.filename)
s.Equal(tc.want, got)
})
}
}

View file

@ -1,3 +0,0 @@
package mock
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock

View file

@ -1,7 +1,5 @@
package service
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
import (
"context"
"io"
@ -34,16 +32,16 @@ type Service interface {
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
// Post operations
CreatePost(ctx context.Context, status string) (*ent.Post, error)
CreatePost(ctx context.Context, status string, categoryIDs []int) (*ent.Post, error)
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryIDs []int, limit, offset int) ([]*ent.Post, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
GetFile(ctx context.Context, storageID string) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// Contributor operations

View file

@ -11,6 +11,8 @@ import (
"path/filepath"
"strings"
"time"
"github.com/rs/zerolog/log"
)
type LocalStorage struct {
@ -44,6 +46,24 @@ func (s *LocalStorage) generateID() (string, error) {
return hex.EncodeToString(bytes), nil
}
func (s *LocalStorage) generateFilePath(id string, ext string, createTime time.Time) string {
// Create year/month directory structure
year := createTime.Format("2006")
month := createTime.Format("01")
// If id already has an extension, don't add ext
if filepath.Ext(id) != "" {
return filepath.Join(year, month, id)
}
// Otherwise, add the extension if provided
filename := id
if ext != "" {
filename = id + ext
}
return filepath.Join(year, month, filename)
}
func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error {
metaPath := filepath.Join(s.metaDir, id+".meta")
file, err := os.Create(metaPath)
@ -89,11 +109,64 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
// Create the file path
filePath := filepath.Join(s.rootDir, id)
// Get file extension from original name or content type
ext := filepath.Ext(name)
if ext == "" {
// If no extension in name, try to get it from content type
switch contentType {
case "image/jpeg":
ext = ".jpg"
case "image/png":
ext = ".png"
case "image/gif":
ext = ".gif"
case "image/webp":
ext = ".webp"
case "image/svg+xml":
ext = ".svg"
case "video/mp4":
ext = ".mp4"
case "video/webm":
ext = ".webm"
case "audio/mpeg":
ext = ".mp3"
case "audio/ogg":
ext = ".ogg"
case "audio/wav":
ext = ".wav"
case "application/pdf":
ext = ".pdf"
case "application/msword":
ext = ".doc"
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
ext = ".docx"
case "application/vnd.ms-excel":
ext = ".xls"
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
ext = ".xlsx"
case "application/zip":
ext = ".zip"
case "application/x-rar-compressed":
ext = ".rar"
case "text/plain":
ext = ".txt"
case "text/csv":
ext = ".csv"
}
}
// Create the file path with year/month structure
now := time.Now()
relPath := s.generateFilePath(id, ext, now)
fullPath := filepath.Join(s.rootDir, relPath)
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
// Create the file
file, err := os.Create(filePath)
file, err := os.Create(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
}
@ -102,52 +175,73 @@ func (s *LocalStorage) Save(ctx context.Context, name string, contentType string
// Copy the content
size, err := io.Copy(file, reader)
if err != nil {
// Clean up the file if there's an error
os.Remove(filePath)
os.Remove(fullPath) // Clean up on error
return nil, fmt.Errorf("failed to write file content: %w", err)
}
now := time.Now()
// Save metadata
info := &FileInfo{
ID: id,
Name: name,
Size: size,
ContentType: contentType,
Size: size,
CreatedAt: now,
UpdatedAt: now,
URL: fmt.Sprintf("/api/media/file/%s", id),
URL: fmt.Sprintf("/media/%s/%s/%s", now.Format("2006"), now.Format("01"), filepath.Base(relPath)),
}
// Save metadata
if err := s.saveMetadata(id, info); err != nil {
os.Remove(filePath)
return nil, err
os.Remove(fullPath) // Clean up on error
return nil, fmt.Errorf("failed to save metadata: %w", err)
}
return info, nil
}
func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
filePath := filepath.Join(s.rootDir, id)
// 从 id 中提取文件扩展名和基础 ID
ext := filepath.Ext(id)
baseID := strings.TrimSuffix(id, ext)
// Open the file
// 获取文件的创建时间(从元数据或当前时间)
metaPath := filepath.Join(s.metaDir, baseID+".meta")
var createTime time.Time
if stat, err := os.Stat(metaPath); err == nil {
createTime = stat.ModTime()
} else {
createTime = time.Now() // 如果找不到元数据,使用当前时间
}
// 生成完整的文件路径
year := createTime.Format("2006")
month := createTime.Format("01")
filePath := filepath.Join(s.rootDir, year, month, id) // 直接使用完整的 id包含扩展名
// 调试日志
log.Debug().
Str("id", id).
Str("baseID", baseID).
Str("ext", ext).
Str("filePath", filePath).
Time("createTime", createTime).
Msg("Attempting to get file")
// 打开文件
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil, fmt.Errorf("file not found: %s", id)
return nil, nil, fmt.Errorf("file not found: %s (path: %s)", id, filePath)
}
return nil, nil, fmt.Errorf("failed to open file: %w", err)
}
// Get file info
// 获取文件信息
stat, err := file.Stat()
if err != nil {
file.Close()
return nil, nil, fmt.Errorf("failed to get file info: %w", err)
}
// Load metadata
name, contentType, err := s.loadMetadata(id)
// 加载元数据
name, contentType, err := s.loadMetadata(baseID)
if err != nil {
file.Close()
return nil, nil, err
@ -158,27 +252,44 @@ func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *File
Name: name,
Size: stat.Size(),
ContentType: contentType,
CreatedAt: stat.ModTime(),
CreatedAt: createTime,
UpdatedAt: stat.ModTime(),
URL: fmt.Sprintf("/api/media/file/%s", id),
URL: fmt.Sprintf("/media/%s/%s/%s", year, month, id),
}
return file, info, nil
}
func (s *LocalStorage) Delete(ctx context.Context, id string) error {
filePath := filepath.Join(s.rootDir, id)
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("file not found: %s", id)
}
return fmt.Errorf("failed to delete file: %w", err)
// 从 id 中提取文件扩展名
ext := filepath.Ext(id)
baseID := strings.TrimSuffix(id, ext)
// 获取文件的创建时间(从元数据或当前时间)
metaPath := filepath.Join(s.metaDir, baseID+".meta")
var createTime time.Time
if stat, err := os.Stat(metaPath); err == nil {
createTime = stat.ModTime()
} else {
createTime = time.Now() // 如果找不到元数据,使用当前时间
}
// Remove metadata
metaPath := filepath.Join(s.metaDir, id+".meta")
if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove metadata: %w", err)
// 生成完整的文件路径
relPath := s.generateFilePath(baseID, ext, createTime)
filePath := filepath.Join(s.rootDir, relPath)
// 删除文件
if err := os.Remove(filePath); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to delete file: %w", err)
}
}
// 删除元数据
if err := os.Remove(metaPath); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to delete metadata: %w", err)
}
}
return nil

View file

@ -1,154 +0,0 @@
package storage
import (
"bytes"
"context"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLocalStorage(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "storage_test_*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a new LocalStorage instance
storage, err := NewLocalStorage(tempDir)
require.NoError(t, err)
ctx := context.Background()
t.Run("Save and Get", func(t *testing.T) {
content := []byte("test content")
reader := bytes.NewReader(content)
// Save the file
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, int64(len(content)), fileInfo.Size)
assert.Equal(t, "text/plain", fileInfo.ContentType)
assert.False(t, fileInfo.CreatedAt.IsZero())
// Get the file
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, fileInfo.ID, info.ID)
assert.Equal(t, fileInfo.Name, info.Name)
assert.Equal(t, fileInfo.Size, info.Size)
})
t.Run("List", func(t *testing.T) {
// Clear the directory first
dirEntries, err := os.ReadDir(tempDir)
require.NoError(t, err)
for _, entry := range dirEntries {
if entry.Name() != ".meta" {
os.Remove(filepath.Join(tempDir, entry.Name()))
}
}
// Save multiple files
testFiles := []struct {
name string
content string
}{
{"test1.txt", "content1"},
{"test2.txt", "content2"},
{"other.txt", "content3"},
}
for _, f := range testFiles {
reader := bytes.NewReader([]byte(f.content))
_, err := storage.Save(ctx, f.name, "text/plain", reader)
require.NoError(t, err)
}
// List all files
allFiles, err := storage.List(ctx, "", 10, 0)
require.NoError(t, err)
assert.Len(t, allFiles, 3)
// List files with prefix
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, filesWithPrefix, 2)
for _, f := range filesWithPrefix {
assert.True(t, strings.HasPrefix(f.Name, "test"))
}
// Test pagination
pagedFiles, err := storage.List(ctx, "", 2, 1)
require.NoError(t, err)
assert.Len(t, pagedFiles, 2)
})
t.Run("Exists", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
require.NoError(t, err)
// Check if file exists
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.True(t, exists)
// Check non-existent file
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
require.NoError(t, err)
// Delete the file
err = storage.Delete(ctx, fileInfo.ID)
require.NoError(t, err)
// Verify file is deleted
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.False(t, exists)
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
})
t.Run("Invalid operations", func(t *testing.T) {
// Try to get non-existent file
_, _, err := storage.Get(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
// Try to save file with nil reader
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "reader cannot be nil")
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
})
}

View file

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"path/filepath"
"strings"
"time"
@ -48,14 +49,25 @@ func (s *S3Storage) generateID() (string, error) {
return hex.EncodeToString(bytes), nil
}
func (s *S3Storage) getObjectURL(id string) string {
func (s *S3Storage) generateObjectKey(id string, ext string, createTime time.Time) string {
// Create year/month structure
year := createTime.Format("2006")
month := createTime.Format("01")
filename := id
if ext != "" {
filename = id + ext
}
return fmt.Sprintf("%s/%s/%s", year, month, filename)
}
func (s *S3Storage) getObjectURL(key string) string {
if s.customURL != "" {
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id)
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), key)
}
if s.proxyS3 {
return fmt.Sprintf("/api/media/file/%s", id)
return fmt.Sprintf("/media/%s", key)
}
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id)
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, key)
}
func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
@ -65,10 +77,60 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
// Get file extension from original name or content type
ext := filepath.Ext(name)
if ext == "" {
// If no extension in name, try to get it from content type
switch contentType {
case "image/jpeg":
ext = ".jpg"
case "image/png":
ext = ".png"
case "image/gif":
ext = ".gif"
case "image/webp":
ext = ".webp"
case "image/svg+xml":
ext = ".svg"
case "video/mp4":
ext = ".mp4"
case "video/webm":
ext = ".webm"
case "audio/mpeg":
ext = ".mp3"
case "audio/ogg":
ext = ".ogg"
case "audio/wav":
ext = ".wav"
case "application/pdf":
ext = ".pdf"
case "application/msword":
ext = ".doc"
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
ext = ".docx"
case "application/vnd.ms-excel":
ext = ".xls"
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
ext = ".xlsx"
case "application/zip":
ext = ".zip"
case "application/x-rar-compressed":
ext = ".rar"
case "text/plain":
ext = ".txt"
case "text/csv":
ext = ".csv"
}
}
// Create the object key with year/month structure
now := time.Now()
key := s.generateObjectKey(id, ext, now)
// Check if the file exists
_, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
Key: aws.String(key),
})
if err == nil {
return nil, fmt.Errorf("file already exists with ID: %s", id)
@ -82,37 +144,50 @@ func (s *S3Storage) Save(ctx context.Context, name string, contentType string, r
// Upload the file
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
Key: aws.String(key),
Body: reader,
ContentType: aws.String(contentType),
Metadata: map[string]string{
"x-amz-meta-original-name": name,
"x-amz-meta-created-at": now.Format(time.RFC3339),
},
})
if err != nil {
return nil, fmt.Errorf("failed to upload file: %w", err)
}
now := time.Now()
info := &FileInfo{
ID: id,
Name: name,
Size: 0, // Size is not available until after upload
ContentType: contentType,
CreatedAt: now,
UpdatedAt: now,
URL: s.getObjectURL(id),
URL: s.getObjectURL(key),
}
return info, nil
}
func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
// Get the object from S3
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
// Try to find the file with different extensions
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
var result *s3.GetObjectOutput
var err error
var key string
for _, ext := range exts {
key = s.generateObjectKey(id, ext, time.Now())
result, err = s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err == nil {
break
}
}
if err != nil {
return nil, nil, fmt.Errorf("failed to get file from S3: %w", err)
}
@ -122,111 +197,117 @@ func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInf
Name: result.Metadata["x-amz-meta-original-name"],
Size: aws.ToInt64(result.ContentLength),
ContentType: aws.ToString(result.ContentType),
CreatedAt: aws.ToTime(result.LastModified),
UpdatedAt: aws.ToTime(result.LastModified),
URL: s.getObjectURL(id),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
URL: s.getObjectURL(key),
}
return result.Body, info, nil
}
func (s *S3Storage) Delete(ctx context.Context, id string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err != nil {
return fmt.Errorf("failed to delete file from S3: %w", err)
// Try to find and delete the file with different extensions
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
var lastErr error
for _, ext := range exts {
key := s.generateObjectKey(id, ext, time.Now())
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err == nil {
return nil
}
lastErr = err
}
return nil
return fmt.Errorf("failed to delete file from S3: %w", lastErr)
}
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
// Try to find the file with different extensions
exts := []string{"", ".jpg", ".png", ".gif", ".webp", ".svg", ".mp4", ".webm", ".mp3", ".ogg", ".wav",
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".zip", ".rar", ".txt", ".csv"}
for _, ext := range exts {
key := s.generateObjectKey(id, ext, time.Now())
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(key),
})
if err == nil {
return true, nil
}
}
return false, nil
}
func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
var files []*FileInfo
var continuationToken *string
count := 0
skip := offset
// Skip objects for offset
for i := 0; i < offset/1000; i++ {
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
for {
input := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
MaxKeys: aws.Int32(1000),
})
MaxKeys: aws.Int32(100), // Fetch in batches of 100
}
result, err := s.client.ListObjectsV2(ctx, input)
if err != nil {
return nil, fmt.Errorf("failed to list files from S3: %w", err)
}
if !aws.ToBool(output.IsTruncated) {
return files, nil
}
continuationToken = output.NextContinuationToken
}
// Get the actual objects
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
MaxKeys: aws.Int32(int32(limit)),
})
if err != nil {
return nil, fmt.Errorf("failed to list files from S3: %w", err)
}
for _, obj := range output.Contents {
// Get the object metadata
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: obj.Key,
})
var contentType string
var originalName string
if err != nil {
var noSuchKey *types.NoSuchKey
if errors.As(err, &noSuchKey) {
// If the object doesn't exist (which shouldn't happen normally),
// we'll still include it in the list but with empty metadata
contentType = ""
originalName = aws.ToString(obj.Key)
} else {
// Process each object
for _, obj := range result.Contents {
if skip > 0 {
skip--
continue
}
} else {
contentType = aws.ToString(head.ContentType)
originalName = head.Metadata["x-amz-meta-original-name"]
if originalName == "" {
originalName = aws.ToString(obj.Key)
// Get object metadata
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: obj.Key,
})
if err != nil {
continue // Skip files we can't get metadata for
}
// Extract the ID from the key (remove extension if present)
id := aws.ToString(obj.Key)
if ext := filepath.Ext(id); ext != "" {
id = id[:len(id)-len(ext)]
}
info := &FileInfo{
ID: id,
Name: head.Metadata["x-amz-meta-original-name"],
Size: aws.ToInt64(obj.Size),
ContentType: aws.ToString(head.ContentType),
CreatedAt: aws.ToTime(obj.LastModified),
UpdatedAt: aws.ToTime(obj.LastModified),
URL: s.getObjectURL(aws.ToString(obj.Key)),
}
files = append(files, info)
count++
if count >= limit {
return files, nil
}
}
files = append(files, &FileInfo{
ID: aws.ToString(obj.Key),
Name: originalName,
Size: aws.ToInt64(obj.Size),
ContentType: contentType,
CreatedAt: aws.ToTime(obj.LastModified),
UpdatedAt: aws.ToTime(obj.LastModified),
URL: s.getObjectURL(aws.ToString(obj.Key)),
})
if !aws.ToBool(result.IsTruncated) {
break
}
continuationToken = result.NextContinuationToken
}
return files, nil
}
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err != nil {
var nsk *types.NoSuchKey
if ok := errors.As(err, &nsk); ok {
return false, nil
}
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
}
return true, nil
}

View file

@ -1,211 +0,0 @@
package storage
import (
"bytes"
"context"
"io"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockS3Client is a mock implementation of the S3 client interface
type MockS3Client struct {
mock.Mock
}
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
}
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
}
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
}
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
}
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
}
func TestS3Storage(t *testing.T) {
ctx := context.Background()
mockClient := new(MockS3Client)
storage := NewS3Storage(mockClient, "test-bucket", "", false)
t.Run("Save", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
content := []byte("test content")
reader := bytes.NewReader(content)
// Mock HeadObject to return NotFound error
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
})
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.ContentType) == "text/plain"
})).Return(&s3.PutObjectOutput{}, nil)
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, "text/plain", fileInfo.ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Get", func(t *testing.T) {
content := []byte("test content")
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.GetObjectOutput{
Body: io.NopCloser(bytes.NewReader(content)),
ContentType: aws.String("text/plain"),
ContentLength: aws.Int64(int64(len(content))),
LastModified: aws.Time(time.Now()),
}, nil)
readCloser, info, err := storage.Get(ctx, "test-id")
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, "test-id", info.ID)
assert.Equal(t, int64(len(content)), info.Size)
mockClient.AssertExpectations(t)
})
t.Run("List", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Prefix) == "test" &&
aws.ToInt32(input.MaxKeys) == 10
})).Return(&s3.ListObjectsV2Output{
Contents: []types.Object{
{
Key: aws.String("test1"),
Size: aws.Int64(100),
LastModified: aws.Time(time.Now()),
},
{
Key: aws.String("test2"),
Size: aws.Int64(200),
LastModified: aws.Time(time.Now()),
},
},
}, nil)
// Mock HeadObject for both files
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test1"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test1.txt",
},
}, nil).Once()
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test2"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test2.txt",
},
}, nil).Once()
files, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, files, 2)
assert.Equal(t, "test1", files[0].ID)
assert.Equal(t, int64(100), files[0].Size)
assert.Equal(t, "test1.txt", files[0].Name)
assert.Equal(t, "text/plain", files[0].ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Delete", func(t *testing.T) {
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.DeleteObjectOutput{}, nil)
err := storage.Delete(ctx, "test-id")
require.NoError(t, err)
mockClient.AssertExpectations(t)
})
t.Run("Exists", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
// Mock HeadObject for existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.HeadObjectOutput{}, nil).Once()
exists, err := storage.Exists(ctx, "test-id")
require.NoError(t, err)
assert.True(t, exists)
// Mock HeadObject for non-existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "non-existent"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
}).Once()
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
mockClient.AssertExpectations(t)
})
t.Run("Custom URL", func(t *testing.T) {
customStorage := &S3Storage{
client: mockClient,
bucket: "test-bucket",
customURL: "https://custom.domain",
proxyS3: true,
}
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
})
}

View file

@ -1,7 +1,5 @@
package storage
//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock
import (
"context"
"io"

View file

@ -1,116 +0,0 @@
package types
import (
"testing"
)
func TestRateLimitConfig(t *testing.T) {
config := RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/test": {
Rate: 50,
Burst: 100,
},
},
}
if config.IPRate != 100 {
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
}
if config.IPBurst != 200 {
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
}
route := config.RouteRates["/api/test"]
if route.Rate != 50 {
t.Errorf("Expected route rate 50, got %d", route.Rate)
}
if route.Burst != 100 {
t.Errorf("Expected route burst 100, got %d", route.Burst)
}
}
func TestAccessLogConfig(t *testing.T) {
config := AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "/var/log/app.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 5,
Compress: true,
LocalTime: true,
},
}
if !config.EnableConsole {
t.Error("Expected EnableConsole to be true")
}
if !config.EnableFile {
t.Error("Expected EnableFile to be true")
}
if config.FilePath != "/var/log/app.log" {
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
}
if config.Format != "json" {
t.Errorf("Expected Format 'json', got '%s'", config.Format)
}
if config.Level != "info" {
t.Errorf("Expected Level 'info', got '%s'", config.Level)
}
rotation := config.Rotation
if rotation.MaxSize != 100 {
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
}
if rotation.MaxAge != 7 {
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
}
if rotation.MaxBackups != 5 {
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
}
if !rotation.Compress {
t.Error("Expected Compress to be true")
}
if !rotation.LocalTime {
t.Error("Expected LocalTime to be true")
}
}
func TestUploadConfig(t *testing.T) {
config := UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
}
if config.MaxSize != 10 {
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
}
if len(config.AllowedTypes) != 2 {
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
}
if config.AllowedTypes[0] != "image/jpeg" {
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
}
if len(config.AllowedExtensions) != 2 {
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
}
if config.AllowedExtensions[0] != ".jpg" {
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
}
}

View file

@ -1,21 +0,0 @@
package types
import "testing"
func TestFileInfo(t *testing.T) {
fileInfo := FileInfo{
Size: 1024,
Name: "test.jpg",
ContentType: "image/jpeg",
}
if fileInfo.Size != 1024 {
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
}
if fileInfo.Name != "test.jpg" {
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
}
if fileInfo.ContentType != "image/jpeg" {
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
}
}

View file

@ -1,77 +0,0 @@
package types
import (
"testing"
)
func TestCategory(t *testing.T) {
description := "Test Description"
category := Category{
ID: 1,
Name: "Test Category",
Slug: "test-category",
Description: &description,
}
if category.ID != 1 {
t.Errorf("Expected ID 1, got %d", category.ID)
}
if category.Name != "Test Category" {
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
}
if category.Slug != "test-category" {
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
}
if *category.Description != description {
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
}
}
func TestPost(t *testing.T) {
metaKeywords := "test,blog"
metaDesc := "Test Description"
post := Post{
ID: 1,
Title: "Test Post",
Slug: "test-post",
ContentMarkdown: "# Test Content",
Summary: "Test Summary",
MetaKeywords: &metaKeywords,
MetaDescription: &metaDesc,
}
if post.ID != 1 {
t.Errorf("Expected ID 1, got %d", post.ID)
}
if post.Title != "Test Post" {
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
}
if post.Slug != "test-post" {
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
}
if *post.MetaKeywords != metaKeywords {
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
}
}
func TestDaily(t *testing.T) {
daily := Daily{
ID: "2025-02-12",
CategoryID: 1,
ImageURL: "https://example.com/image.jpg",
Quote: "Test Quote",
}
if daily.ID != "2025-02-12" {
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
}
if daily.CategoryID != 1 {
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
}
if daily.ImageURL != "https://example.com/image.jpg" {
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
}
if daily.Quote != "Test Quote" {
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
}
}

View file

@ -1,77 +0,0 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
// Create a temporary test config file
testConfig := `
database:
driver: postgres
dsn: postgres://user:pass@localhost:5432/db
server:
port: 8080
host: localhost
jwt:
secret: test-secret
expiration: 24h
logging:
level: debug
format: console
`
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// Test successful config loading
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"database.driver", cfg.Database.Driver, "postgres"},
{"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"},
{"server.port", cfg.Server.Port, 8080},
{"server.host", cfg.Server.Host, "localhost"},
{"jwt.secret", cfg.JWT.Secret, "test-secret"},
{"jwt.expiration", cfg.JWT.Expiration, "24h"},
{"logging.level", cfg.Logging.Level, "debug"},
{"logging.format", cfg.Logging.Format, "console"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
// Test loading non-existent file
_, err = Load("non-existent.yaml")
if err == nil {
t.Error("Expected error when loading non-existent file, got nil")
}
// Test loading invalid YAML
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil {
t.Fatalf("Failed to create invalid config file: %v", err)
}
_, err = Load(invalidPath)
if err == nil {
t.Error("Expected error when loading invalid YAML, got nil")
}
}

View file

@ -1,100 +0,0 @@
package imageutil
import (
"bytes"
"image"
"image/color"
"image/png"
"testing"
)
func TestIsImageFormat(t *testing.T) {
tests := []struct {
name string
contentType string
want bool
}{
{"JPEG", "image/jpeg", true},
{"PNG", "image/png", true},
{"GIF", "image/gif", true},
{"WebP", "image/webp", true},
{"Invalid", "image/invalid", false},
{"Empty", "", false},
{"Text", "text/plain", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsImageFormat(tt.contentType); got != tt.want {
t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want)
}
})
}
}
func TestDefaultOptions(t *testing.T) {
opts := DefaultOptions()
if !opts.Lossless {
t.Error("DefaultOptions().Lossless = false, want true")
}
if opts.Quality != 90 {
t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality)
}
if opts.Compression != 4 {
t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression)
}
}
func TestProcessImage(t *testing.T) {
// Create a test image
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
for y := 0; y < 100; y++ {
for x := 0; x < 100; x++ {
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
}
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("Failed to create test PNG: %v", err)
}
tests := []struct {
name string
opts ProcessOptions
wantErr bool
}{
{
name: "Default options",
opts: DefaultOptions(),
},
{
name: "Custom quality",
opts: ProcessOptions{
Lossless: false,
Quality: 75,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := bytes.NewReader(buf.Bytes())
result, err := ProcessImage(reader, tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(result) == 0 {
t.Error("ProcessImage() returned empty result")
}
})
}
// Test with invalid input
_, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions())
if err == nil {
t.Error("ProcessImage() with invalid input should return error")
}
}

View file

@ -1,85 +0,0 @@
package logger
import (
"testing"
"tss-rocks-be/internal/config"
"github.com/rs/zerolog"
)
func TestSetup(t *testing.T) {
tests := []struct {
name string
config *config.Config
expectedLevel zerolog.Level
}{
{
name: "Debug level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "debug",
Format: "json",
},
},
expectedLevel: zerolog.DebugLevel,
},
{
name: "Info level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "info",
Format: "json",
},
},
expectedLevel: zerolog.InfoLevel,
},
{
name: "Error level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "error",
Format: "json",
},
},
expectedLevel: zerolog.ErrorLevel,
},
{
name: "Invalid level defaults to Info",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "invalid",
Format: "json",
},
},
expectedLevel: zerolog.InfoLevel,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Setup(tt.config)
if zerolog.GlobalLevel() != tt.expectedLevel {
t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel)
}
})
}
}
func TestGetLogger(t *testing.T) {
logger := GetLogger()
if logger == nil {
t.Error("GetLogger() returned nil")
}
}

View file

@ -1,6 +1,6 @@
{
"categories": {
"man": "Man",
"man": "Human",
"machine": "Machine",
"earth": "Earth",
"space": "Space",
@ -29,8 +29,6 @@
"edit": "Edit",
"delete": "Delete",
"upload": "Upload",
"save": "Save",
"saving": "Saving...",
"status": "Status",
"actions": "Actions",
"published": "Published",
@ -48,12 +46,10 @@
"joinDate": "Join Date",
"username": "Username",
"logout": "Logout",
"language": "Language",
"theme": {
"light": "Light Mode",
"dark": "Dark Mode",
"system": "System"
}
"save": "Save",
"saving": "Saving...",
"noData": "No data available",
"unsavedChanges": "You have unsaved changes. Are you sure you want to leave?"
},
"dashboard": {
"totalPosts": "Total Posts",
@ -61,13 +57,33 @@
"totalUsers": "Total Users",
"totalContributors": "Total Contributors"
},
"posts": {
"title": "Title",
"categories": "Categories",
"createdAt": "Created At",
"status": "Status",
"noTitle": "No Title",
"deleteConfirm": "Are you sure you want to delete this post?",
"create": "Create Post",
"edit": "Edit Post",
"slug": "Slug",
"content": "Content",
"summary": "Summary",
"metaKeywords": "Meta Keywords",
"metaDescription": "Meta Description",
"selectCategories": "Select Categories",
"saving": "Saving...",
"publishing": "Publishing...",
"saveDraft": "Save Draft",
"publish": "Publish"
},
"login": {
"title": "Admin Login",
"username": "Username",
"password": "Password",
"remember": "Remember me",
"submit": "Sign in",
"loading": "Signing in...",
"submit": "Login",
"loading": "Logging in...",
"error": {
"failed": "Login failed",
"retry": "Login failed, please try again later"
@ -76,7 +92,7 @@
"nav": {
"dashboard": "Dashboard",
"posts": "Posts",
"daily": "Daily Quotes",
"daily": "Daily",
"medias": "Media",
"categories": "Categories",
"users": "Users",
@ -110,5 +126,45 @@
"passwordMismatch": "Passwords do not match",
"passwordTooShort": "Password must be at least 8 characters long"
}
},
"editor": {
"heading1": "Heading 1",
"heading2": "Heading 2",
"heading3": "Heading 3",
"bold": "Bold",
"italic": "Italic",
"orderedList": "Ordered List",
"unorderedList": "Unordered List",
"quote": "Quote",
"link": "Link",
"image": "Image",
"inlineCode": "Inline Code",
"codeBlock": "Code Block",
"table": "Table",
"togglePreview": "Toggle Preview",
"toggleFullscreen": "Toggle Fullscreen",
"selectLanguage": "Select Language",
"plainText": "Plain Text",
"insertTable": "Insert Table",
"addRowAbove": "Add Row Above",
"addRowBelow": "Add Row Below",
"addColumnLeft": "Add Column Left",
"addColumnRight": "Add Column Right",
"deleteRow": "Delete Row",
"deleteColumn": "Delete Column",
"deleteTable": "Delete Table",
"dragAndDrop": "Drop images here",
"dropToUpload": "Drop files here to upload",
"uploading": "Uploading...",
"uploadProgress": "Upload progress: {{progress}}%",
"uploadError": "Failed to upload image: {{error}}",
"uploadSuccess": "Image uploaded successfully",
"codeBlockShortcut": "Ctrl+Shift+K",
"boldShortcut": "Ctrl+B",
"italicShortcut": "Ctrl+I",
"linkShortcut": "Ctrl+K",
"heading1Shortcut": "Shift+1",
"heading2Shortcut": "Shift+2",
"heading3Shortcut": "Shift+3"
}
}

View file

@ -47,7 +47,9 @@
"username": "用户名",
"logout": "退出登录",
"save": "保存",
"saving": "保存中..."
"saving": "保存中...",
"noData": "暂无数据",
"unsavedChanges": "你有未保存的更改,确定要离开吗?"
},
"dashboard": {
"totalPosts": "文章总数",
@ -55,6 +57,26 @@
"totalUsers": "用户总数",
"totalContributors": "贡献者总数"
},
"posts": {
"title": "标题",
"categories": "分类",
"createdAt": "创建时间",
"status": "状态",
"noTitle": "无标题",
"deleteConfirm": "确定要删除这篇文章吗?",
"create": "创建文章",
"edit": "编辑文章",
"slug": "文章链接",
"content": "内容",
"summary": "摘要",
"metaKeywords": "关键词",
"metaDescription": "描述",
"selectCategories": "选择分类",
"saving": "保存中...",
"publishing": "发布中...",
"saveDraft": "保存草稿",
"publish": "发布文章"
},
"login": {
"title": "管理员登录",
"username": "用户名",
@ -104,5 +126,45 @@
"passwordMismatch": "两次输入的密码不一致",
"passwordTooShort": "密码长度不能少于8个字符"
}
},
"editor": {
"heading1": "一级标题",
"heading2": "二级标题",
"heading3": "三级标题",
"bold": "粗体",
"italic": "斜体",
"orderedList": "有序列表",
"unorderedList": "无序列表",
"quote": "引用",
"link": "链接",
"image": "图片",
"inlineCode": "行内代码",
"codeBlock": "代码块",
"table": "表格",
"togglePreview": "切换预览",
"toggleFullscreen": "切换全屏",
"selectLanguage": "选择语言",
"plainText": "纯文本",
"insertTable": "插入表格",
"addRowAbove": "在上方插入行",
"addRowBelow": "在下方插入行",
"addColumnLeft": "在左侧插入列",
"addColumnRight": "在右侧插入列",
"deleteRow": "删除行",
"deleteColumn": "删除列",
"deleteTable": "删除表格",
"dragAndDrop": "拖放图片到此处",
"dropToUpload": "拖放文件到此处上传",
"uploading": "上传中...",
"uploadProgress": "上传进度:{{progress}}%",
"uploadError": "图片上传失败:{{error}}",
"uploadSuccess": "图片上传成功",
"codeBlockShortcut": "Ctrl+Shift+K",
"boldShortcut": "Ctrl+B",
"italicShortcut": "Ctrl+I",
"linkShortcut": "Ctrl+K",
"heading1Shortcut": "Shift+1",
"heading2Shortcut": "Shift+2",
"heading3Shortcut": "Shift+3"
}
}

View file

@ -45,7 +45,7 @@
"lastLogin": "最後登入",
"joinDate": "加入時間",
"username": "用戶名",
"logout": "退出登",
"logout": "退出登",
"language": "語言",
"theme": {
"light": "淺色模式",
@ -53,7 +53,9 @@
"system": "跟隨系統"
},
"save": "保存",
"saving": "保存中..."
"saving": "保存中...",
"noData": "暫無數據",
"unsavedChanges": "你有未保存的更改,確定要離開嗎?"
},
"nav": {
"dashboard": "儀表板",
@ -83,15 +85,15 @@
"uploadDate": "上傳日期"
},
"login": {
"title": "管理員登",
"title": "管理員登",
"username": "用戶名",
"password": "密碼",
"remember": "記住我",
"submit": "登",
"loading": "登中...",
"submit": "登",
"loading": "登中...",
"error": {
"failed": "登失敗",
"retry": "登失敗,請稍後重試"
"failed": "登失敗",
"retry": "登失敗,請稍後重試"
}
},
"roles": {
@ -110,5 +112,45 @@
"passwordMismatch": "兩次輸入的密碼不一致",
"passwordTooShort": "密碼長度不能少於8個字符"
}
},
"editor": {
"heading1": "一級標題",
"heading2": "二級標題",
"heading3": "三級標題",
"bold": "粗體",
"italic": "斜體",
"orderedList": "有序列表",
"unorderedList": "無序列表",
"quote": "引用",
"link": "連結",
"image": "圖片",
"inlineCode": "行內程式碼",
"codeBlock": "程式碼區塊",
"table": "表格",
"togglePreview": "切換預覽",
"toggleFullscreen": "切換全螢幕",
"selectLanguage": "選擇語言",
"plainText": "純文字",
"insertTable": "插入表格",
"addRowAbove": "在上方插入列",
"addRowBelow": "在下方插入列",
"addColumnLeft": "在左側插入欄",
"addColumnRight": "在右側插入欄",
"deleteRow": "刪除列",
"deleteColumn": "刪除欄",
"deleteTable": "刪除表格",
"dragAndDrop": "拖放圖片到此處",
"dropToUpload": "拖放檔案到此處上傳",
"uploading": "上傳中...",
"uploadProgress": "上傳進度:{{progress}}%",
"uploadError": "圖片上傳失敗:{{error}}",
"uploadSuccess": "圖片上傳成功",
"codeBlockShortcut": "Ctrl+Shift+K",
"boldShortcut": "Ctrl+B",
"italicShortcut": "Ctrl+I",
"linkShortcut": "Ctrl+K",
"heading1Shortcut": "Shift+1",
"heading2Shortcut": "Shift+2",
"heading3Shortcut": "Shift+3"
}
}

View file

@ -12,7 +12,14 @@
"dependencies": {
"@headlessui/react": "^2.2.0",
"@tss-rocks/api": "workspace:*",
"@types/axios": "^0.14.4",
"@types/classnames": "^2.3.4",
"@types/highlight.js": "^10.1.0",
"@types/markdown-it": "^14.1.2",
"axios": "^1.7.9",
"classnames": "^2.5.1",
"dayjs": "^1.11.13",
"highlight.js": "^11.11.1",
"i18next": "^24.2.2",
"i18next-browser-languagedetector": "^8.0.4",
"lucide-react": "^0.474.0",
@ -21,7 +28,13 @@
"react-dom": "^19.0.0",
"react-i18next": "^15.4.0",
"react-icons": "^5.4.0",
"react-router-dom": "^7.1.5"
"react-markdown": "^10.0.0",
"react-router-dom": "^7.1.5",
"react-toastify": "^11.0.3",
"rehype-highlight": "^7.0.2",
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-gfm": "^4.0.1"
},
"devDependencies": {
"@eslint/js": "^9.9.1",

View file

@ -4,16 +4,21 @@ import { AuthProvider } from './contexts/AuthContext';
import { UserProvider } from './contexts/UserContext';
import router from './router';
import LoadingSpinner from './components/LoadingSpinner';
import { ToastContainer } from 'react-toastify';
import 'react-toastify/dist/ReactToastify.css';
function App() {
return (
<Suspense fallback={<LoadingSpinner fullScreen />}>
<AuthProvider>
<UserProvider>
<RouterProvider router={router} />
</UserProvider>
</AuthProvider>
</Suspense>
<>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<AuthProvider>
<UserProvider>
<RouterProvider router={router} />
</UserProvider>
</AuthProvider>
</Suspense>
<ToastContainer />
</>
);
}

View file

@ -9,8 +9,8 @@ export default function LoadingSpinner({ fullScreen = false }: LoadingSpinnerPro
const content = (
<div className="flex flex-col items-center gap-3">
<div className="w-10 h-10 border-4 border-indigo-200 dark:border-indigo-900 border-t-indigo-500 dark:border-t-indigo-400 rounded-full animate-spin" />
<div className="text-slate-600 dark:text-slate-300 text-sm font-medium">
<div className="w-10 h-10 border-4 border-gray-200 dark:border-gray-700 border-t-gray-900 dark:border-t-gray-200 rounded-full animate-spin" />
<div className="text-gray-900 dark:text-gray-200 text-sm font-medium">
{t('admin.common.loading')}
</div>
</div>

View file

@ -0,0 +1,958 @@
import { FC, useState, useRef, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown';
import remarkGfm from 'remark-gfm';
import rehypeRaw from 'rehype-raw';
import rehypeSanitize from 'rehype-sanitize';
import rehypeHighlight from 'rehype-highlight';
import 'highlight.js/styles/github-dark.css';
import classNames from 'classnames';
import axios from 'axios';
import { useToast } from '@/hooks/useToast';
import {
RiH1,
RiH2,
RiH3,
RiBold,
RiItalic,
RiListOrdered,
RiListUnordered,
RiLink,
RiImage2Line,
RiCodeLine,
RiCodeBoxLine,
RiTableLine,
RiDoubleQuotesL,
RiEyeLine,
RiEyeOffLine,
RiFullscreenLine,
} from 'react-icons/ri';
interface MarkdownEditorProps {
value: string;
onChange: (value: string) => void;
placeholder?: string;
}
type CodeBlockProps = {
inline?: boolean;
className?: string;
children: React.ReactNode;
} & React.HTMLAttributes<HTMLElement>;
const MarkdownEditor: FC<MarkdownEditorProps> = ({
value,
onChange,
placeholder,
}) => {
const { t } = useTranslation();
const { showToast } = useToast();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const tableMenuRef = useRef<HTMLDivElement>(null);
const langSelectorRef = useRef<HTMLDivElement>(null);
const [showTableMenu, setShowTableMenu] = useState(false);
const [showLangSelector, setShowLangSelector] = useState(false);
const [isPreview, setIsPreview] = useState(false);
const [isFullscreen, setIsFullscreen] = useState(false);
const [isDraggingOver, setIsDraggingOver] = useState(false);
// 检查当前行是否为空
const isCurrentLineEmpty = (textarea: HTMLTextAreaElement): boolean => {
const text = textarea.value;
const lines = text.split('\n');
const currentLine = getCurrentLine(textarea);
return !lines[currentLine].trim();
};
// 获取当前行号
const getCurrentLine = (textarea: HTMLTextAreaElement): number => {
const text = textarea.value.substring(0, textarea.selectionStart);
return text.split('\n').length - 1;
};
// 在当前位置插入文本
const insertText = (textarea: HTMLTextAreaElement, text: string) => {
const start = textarea.selectionStart;
const end = textarea.selectionEnd;
const before = textarea.value.substring(0, start);
const after = textarea.value.substring(end);
const needNewLine = !isCurrentLineEmpty(textarea) &&
(text.startsWith('#') || text.startsWith('>'));
const newText = needNewLine ? `\n${text}` : text;
const newValue = before + newText + after;
onChange(newValue);
textarea.value = newValue;
const newCursorPos = start + newText.length;
textarea.selectionStart = newCursorPos;
textarea.selectionEnd = newCursorPos;
textarea.focus();
};
const insertTable = (rows: number, cols: number) => {
if (!textareaRef.current) return;
const headers = Array(cols).fill('header').join(' | ');
const separators = Array(cols).fill('---').join(' | ');
const cells = Array(cols).fill('content').join(' | ');
const rows_content = Array(rows).fill(cells).join('\n| ');
insertText(
textareaRef.current,
`\n| ${headers} |\n| ${separators} |\n| ${rows_content} |\n\n`
);
setShowTableMenu(false);
};
const findTableAtCursor = (): { table: string; start: number; end: number; rows: string[][]; alignments: string[] } | null => {
if (!textareaRef.current) return null;
const text = textareaRef.current.value;
const cursorPos = textareaRef.current.selectionStart;
const tableRegex = /\|[^\n]+\|[\s]*\n\|[- |:]+\|[\s]*\n(\|[^\n]+\|[\s]*\n?)+/g;
let match;
while ((match = tableRegex.exec(text)) !== null) {
const start = match.index;
const end = start + match[0].length;
if (cursorPos >= start && cursorPos <= end) {
const rows = match[0]
.trim()
.split('\n')
.map(row =>
row
.trim()
.replace(/^\||\|$/g, '')
.split('|')
.map(cell => cell.trim())
);
const alignments = rows[1].map(cell => {
if (cell.startsWith(':') && cell.endsWith(':')) return 'center';
if (cell.endsWith(':')) return 'right';
return 'left';
});
return {
table: match[0],
start,
end,
rows: [rows[0], ...rows.slice(2)],
alignments,
};
}
}
return null;
};
const generateTableText = (rows: string[][], alignments: string[]): string => {
const colWidths = rows[0].map((_, colIndex) =>
Math.max(...rows.map(row => row[colIndex]?.length || 0))
);
const separator = alignments.map((align, i) => {
const width = Math.max(3, colWidths[i]);
switch (align) {
case 'center':
return ':' + '-'.repeat(width - 2) + ':';
case 'right':
return '-'.repeat(width - 1) + ':';
default:
return '-'.repeat(width);
}
});
const tableRows = [
rows[0],
separator,
...rows.slice(1)
].map(row =>
'| ' + row.map((cell, i) => cell.padEnd(colWidths[i])).join(' | ') + ' |'
);
return tableRows.join('\n') + '\n';
};
const updateTable = (pos: { table: string; start: number; end: number; rows: string[][]; alignments: string[] }, newRows?: string[][]) => {
if (!textareaRef.current) return;
const text = textareaRef.current.value;
const newTable = generateTableText(newRows || pos.rows, pos.alignments);
const newText = text.substring(0, pos.start) + newTable + text.substring(pos.end);
onChange(newText);
};
const tableActions = [
{
label: t('editor.insertTable'),
action: () => insertTable(2, 2),
},
{
label: t('editor.addRowAbove'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
let rowIndex = 0;
let currentPos = tablePos.start;
for (let i = 0; i < tablePos.rows.length; i++) {
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
if (cursorPos <= currentPos + rowText.length) {
rowIndex = i;
break;
}
currentPos += rowText.length;
}
const newRows = [...tablePos.rows];
newRows.splice(rowIndex, 0, Array(tablePos.rows[0].length).fill(''));
updateTable(tablePos, newRows);
},
},
{
label: t('editor.addRowBelow'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
let rowIndex = 0;
let currentPos = tablePos.start;
for (let i = 0; i < tablePos.rows.length; i++) {
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
if (cursorPos <= currentPos + rowText.length) {
rowIndex = i;
break;
}
currentPos += rowText.length;
}
const newRows = [...tablePos.rows];
newRows.splice(rowIndex + 1, 0, Array(tablePos.rows[0].length).fill(''));
updateTable(tablePos, newRows);
},
},
{
label: t('editor.addColumnLeft'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
const line = textareaRef.current?.value
.substring(tablePos.start, cursorPos)
.split('\n')
.pop() || '';
const colIndex = (line.match(/\|/g) || []).length - 1;
const newRows = tablePos.rows.map(row => {
const newRow = [...row];
newRow.splice(colIndex, 0, '');
return newRow;
});
const newAlignments = [...tablePos.alignments];
newAlignments.splice(colIndex, 0, 'left');
tablePos.alignments = newAlignments;
updateTable(tablePos, newRows);
},
},
{
label: t('editor.addColumnRight'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
const line = textareaRef.current?.value
.substring(tablePos.start, cursorPos)
.split('\n')
.pop() || '';
const colIndex = (line.match(/\|/g) || []).length - 1;
const newRows = tablePos.rows.map(row => {
const newRow = [...row];
newRow.splice(colIndex + 1, 0, '');
return newRow;
});
const newAlignments = [...tablePos.alignments];
newAlignments.splice(colIndex + 1, 0, 'left');
tablePos.alignments = newAlignments;
updateTable(tablePos, newRows);
},
},
{
label: t('editor.deleteRow'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
let rowIndex = 0;
let currentPos = tablePos.start;
for (let i = 0; i < tablePos.rows.length; i++) {
const rowText = generateTableText([tablePos.rows[i]], tablePos.alignments);
if (cursorPos <= currentPos + rowText.length) {
rowIndex = i;
break;
}
currentPos += rowText.length;
}
const newRows = [...tablePos.rows];
newRows.splice(rowIndex, 1);
updateTable(tablePos, newRows);
},
},
{
label: t('editor.deleteColumn'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const cursorPos = textareaRef.current?.selectionStart || 0;
const line = textareaRef.current?.value
.substring(tablePos.start, cursorPos)
.split('\n')
.pop() || '';
const colIndex = (line.match(/\|/g) || []).length - 1;
const newRows = tablePos.rows.map(row => {
const newRow = [...row];
newRow.splice(colIndex, 1);
return newRow;
});
const newAlignments = [...tablePos.alignments];
newAlignments.splice(colIndex, 1);
tablePos.alignments = newAlignments;
updateTable(tablePos, newRows);
},
},
{
label: t('editor.deleteTable'),
action: () => {
const tablePos = findTableAtCursor();
if (!tablePos) return;
const text = textareaRef.current?.value || '';
const newText = text.substring(0, tablePos.start) + text.substring(tablePos.end);
onChange(newText);
},
},
];
const toolbarButtons = [
{
icon: RiH1,
label: t('editor.heading1'),
shortcut: t('editor.heading1Shortcut'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '# ');
},
},
{
icon: RiH2,
label: t('editor.heading2'),
shortcut: t('editor.heading2Shortcut'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '## ');
},
},
{
icon: RiH3,
label: t('editor.heading3'),
shortcut: t('editor.heading3Shortcut'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '### ');
},
},
{
icon: RiBold,
label: t('editor.bold'),
shortcut: t('editor.boldShortcut'),
action: (textarea: HTMLTextAreaElement) => {
const text = textarea.value.substring(
textarea.selectionStart,
textarea.selectionEnd
);
insertText(textarea, `**${text || t('editor.bold')}**`);
},
},
{
icon: RiItalic,
label: t('editor.italic'),
shortcut: t('editor.italicShortcut'),
action: (textarea: HTMLTextAreaElement) => {
const text = textarea.value.substring(
textarea.selectionStart,
textarea.selectionEnd
);
insertText(textarea, `_${text || t('editor.italic')}_`);
},
},
{
icon: RiListUnordered,
label: t('editor.unorderedList'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '- ');
},
},
{
icon: RiListOrdered,
label: t('editor.orderedList'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '1. ');
},
},
{
icon: RiDoubleQuotesL,
label: t('editor.quote'),
action: (textarea: HTMLTextAreaElement) => {
insertText(textarea, '> ');
},
},
{
icon: RiLink,
label: t('editor.link'),
shortcut: t('editor.linkShortcut'),
action: (textarea: HTMLTextAreaElement) => {
const text = textarea.value.substring(
textarea.selectionStart,
textarea.selectionEnd
);
insertText(textarea, `[${text || t('editor.link')}](url)`);
},
},
{
icon: RiImage2Line,
label: t('editor.image'),
action: () => {
const input = document.createElement('input');
input.type = 'file';
input.accept = 'image/*';
input.onchange = (e) => {
const file = (e.target as HTMLInputElement).files?.[0];
if (file) {
handleFileUpload(file);
}
};
input.click();
},
},
{
icon: RiCodeLine,
label: t('editor.inlineCode'),
action: (textarea: HTMLTextAreaElement) => {
const text = textarea.value.substring(
textarea.selectionStart,
textarea.selectionEnd
);
insertText(textarea, `\`${text || t('editor.inlineCode')}\``);
},
},
];
const commonLanguages = [
{ value: 'javascript', label: 'JavaScript' },
{ value: 'typescript', label: 'TypeScript' },
{ value: 'jsx', label: 'JSX' },
{ value: 'tsx', label: 'TSX' },
{ value: 'css', label: 'CSS' },
{ value: 'html', label: 'HTML' },
{ value: 'json', label: 'JSON' },
{ value: 'markdown', label: 'Markdown' },
{ value: 'python', label: 'Python' },
{ value: 'java', label: 'Java' },
{ value: 'c', label: 'C' },
{ value: 'cpp', label: 'C++' },
{ value: 'csharp', label: 'C#' },
{ value: 'go', label: 'Go' },
{ value: 'rust', label: 'Rust' },
{ value: 'php', label: 'PHP' },
{ value: 'ruby', label: 'Ruby' },
{ value: 'swift', label: 'Swift' },
{ value: 'kotlin', label: 'Kotlin' },
{ value: 'sql', label: 'SQL' },
{ value: 'shell', label: 'Shell' },
{ value: 'yaml', label: 'YAML' },
{ value: 'xml', label: 'XML' },
];
const insertCodeBlock = (language?: string) => {
if (!textareaRef.current) return;
const lang = language || '';
insertText(textareaRef.current, `\n\`\`\`${lang}\n\n\`\`\`\n`);
const cursorPos = textareaRef.current.selectionStart - 4;
textareaRef.current.setSelectionRange(cursorPos, cursorPos);
setShowLangSelector(false);
};
// 处理文件上传
const handleFileUpload = async (file: File) => {
if (!file.type.startsWith('image/')) {
showToast('error', t('editor.uploadError', { error: 'Not an image file' }));
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/v1/media', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': `Bearer ${localStorage.getItem('token')}`,
},
withCredentials: true,
});
const imageUrl = response.data.data.url;
const imageMarkdown = `![${file.name}](${imageUrl})`;
if (textareaRef.current) {
const start = textareaRef.current.selectionStart;
const end = textareaRef.current.selectionEnd;
const before = value.substring(0, start);
const after = value.substring(end);
const newValue = before + imageMarkdown + after;
onChange(newValue);
showToast('success', t('editor.uploadSuccess'));
}
} catch (error: any) {
console.error('Upload error:', error);
const errorMessage = error.response?.data?.error?.message || error.message;
showToast('error', t('editor.uploadError', { error: errorMessage }));
}
};
// 处理拖放事件
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setIsDraggingOver(true);
};
const handleDragLeave = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setIsDraggingOver(false);
};
const handleDrop = async (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setIsDraggingOver(false);
const files = Array.from(e.dataTransfer.files);
for (const file of files) {
await handleFileUpload(file);
}
};
// 处理粘贴事件
const handlePaste = async (e: React.ClipboardEvent) => {
const items = Array.from(e.clipboardData.items);
for (const item of items) {
if (item.type.startsWith('image/')) {
e.preventDefault();
const file = item.getAsFile();
if (file) {
await handleFileUpload(file);
}
}
}
};
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.ctrlKey || e.metaKey) {
const key = e.key.toLowerCase();
switch (key) {
case 'b': // 粗体
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `**${text}**`);
}
break;
case 'i': // 斜体
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `_${text}_`);
}
break;
case 'k': // 链接
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `[${text}](url)`);
}
break;
case '1': // 标题 1
if (e.shiftKey) {
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `# ${text}`);
}
}
break;
case '2': // 标题 2
if (e.shiftKey) {
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `## ${text}`);
}
}
break;
case '3': // 标题 3
if (e.shiftKey) {
e.preventDefault();
if (textareaRef.current) {
const text = textareaRef.current.value.substring(
textareaRef.current.selectionStart,
textareaRef.current.selectionEnd
);
insertText(textareaRef.current, `### ${text}`);
}
}
break;
case 'e': // 预览
e.preventDefault();
setIsPreview(!isPreview);
break;
}
}
};
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (
tableMenuRef.current &&
!tableMenuRef.current.contains(event.target as Node)
) {
setShowTableMenu(false);
}
if (
langSelectorRef.current &&
!langSelectorRef.current.contains(event.target as Node)
) {
setShowLangSelector(false);
}
};
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}, []);
return (
<div
className={classNames(
'flex flex-col border border-slate-300 dark:border-slate-600 rounded-lg',
{ 'fixed inset-0 z-50 bg-white dark:bg-slate-900': isFullscreen }
)}
>
<div className="flex items-center gap-1 p-2 border-b border-slate-300 dark:border-slate-600 bg-slate-50 dark:bg-slate-800">
{/* 左侧编辑工具栏 */}
<div className="flex-1 flex items-center gap-1 border-r border-slate-300 dark:border-slate-600 pr-2">
{toolbarButtons.map((button, index) => {
return (
<button
key={index}
onClick={() => textareaRef.current && button.action(textareaRef.current)}
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
title={`${button.label}${button.shortcut ? ` (${button.shortcut})` : ''}`}
>
<button.icon className="text-lg" />
</button>
);
})}
{/* 表格按钮 */}
<div className="relative">
<button
onClick={() => setShowTableMenu(!showTableMenu)}
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
title={t('editor.table')}
>
<RiTableLine className="text-lg" />
</button>
{showTableMenu && (
<div
ref={tableMenuRef}
className="absolute bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-lg shadow-lg z-50"
style={{
width: '200px',
top: '100%',
left: '50%',
transform: 'translateX(-50%)',
marginTop: '0.5rem',
}}
>
{tableActions.map((action, index) => (
<button
key={index}
onClick={() => {
action.action();
setShowTableMenu(false);
}}
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
>
{t(action.label)}
</button>
))}
</div>
)}
</div>
{/* 代码块按钮 */}
<div className="relative">
<button
onClick={() => setShowLangSelector(!showLangSelector)}
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
title={`${t('editor.codeBlock')} (${t('editor.codeBlockShortcut')})`}
>
<RiCodeBoxLine className="text-lg" />
</button>
{showLangSelector && (
<div
ref={langSelectorRef}
className="absolute bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-lg shadow-lg z-50"
style={{
width: '200px',
top: '100%',
left: '50%',
transform: 'translateX(-50%)',
marginTop: '0.5rem',
}}
>
<div className="px-4 py-1 text-sm font-medium text-slate-900 dark:text-white border-b border-slate-200 dark:border-slate-700">
{t('editor.selectLanguage')}
</div>
<div className="max-h-64 overflow-y-auto">
<button
onClick={() => {
insertCodeBlock();
setShowLangSelector(false);
}}
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
>
{t('editor.plainText')}
</button>
{commonLanguages.map(({ value, label }) => (
<button
key={value}
onClick={() => {
insertCodeBlock(value);
setShowLangSelector(false);
}}
className="w-full px-4 py-2 text-left text-sm text-slate-700 dark:text-slate-300 hover:bg-slate-100 dark:hover:bg-slate-700"
>
{label}
</button>
))}
</div>
</div>
)}
</div>
</div>
{/* 右侧预览和全屏按钮 */}
<div className="flex items-center gap-1 pl-2">
<button
onClick={() => setIsPreview(!isPreview)}
className={classNames(
'p-1.5 rounded',
isPreview
? 'text-slate-900 dark:text-white bg-slate-200 dark:bg-slate-700'
: 'text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700'
)}
title={t('editor.togglePreview')}
>
{isPreview ? (
<RiEyeOffLine className="text-lg" />
) : (
<RiEyeLine className="text-lg" />
)}
</button>
<button
onClick={() => setIsFullscreen(!isFullscreen)}
className="p-1.5 text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-white hover:bg-slate-200 dark:hover:bg-slate-700 rounded"
title={t('editor.toggleFullscreen')}
>
<RiFullscreenLine className="text-lg" />
</button>
</div>
</div>
<div className="flex-1 flex">
<div
className={classNames(
'flex-1 relative',
{ hidden: isPreview },
{ 'before:absolute before:inset-0 before:bg-slate-900/10 before:z-10 before:pointer-events-none': isDraggingOver }
)}
>
{isDraggingOver && (
<div className="absolute inset-0 flex items-center justify-center z-20 pointer-events-none">
<div className="bg-white dark:bg-slate-800 rounded-lg shadow-lg px-4 py-2">
{t('editor.dragAndDrop')}
</div>
</div>
)}
<textarea
ref={textareaRef}
value={value}
onChange={(e) => onChange(e.target.value)}
className="w-full h-full p-4 bg-white dark:bg-slate-900 text-slate-900 dark:text-white resize-vertical focus:outline-none"
style={{ minHeight: '200px' }}
placeholder={placeholder}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDrop}
onPaste={handlePaste}
onKeyDown={handleKeyDown}
/>
</div>
{isPreview && (
<div className="flex-1 overflow-auto">
<div className="prose dark:prose-invert max-w-none p-4">
<ReactMarkdown
remarkPlugins={[remarkGfm]}
rehypePlugins={[rehypeRaw, rehypeSanitize, rehypeHighlight]}
components={{
code: ({ inline, className, children, ...props }) => {
const match = /language-(\w+)/.exec(className || '');
const lang = match ? match[1] : '';
if (inline) {
return (
<code className="bg-slate-100 dark:bg-slate-800 text-slate-900 dark:text-slate-100 px-1 py-0.5 rounded" {...props}>
{children}
</code>
);
}
return (
<div className="not-prose">
<pre className={classNames(
'bg-slate-100 dark:bg-slate-800 text-slate-900 dark:text-slate-100 p-4 rounded-lg overflow-x-auto',
lang && `language-${lang}`
)}>
<code className={lang ? `language-${lang}` : ''} {...props}>
{children}
</code>
</pre>
</div>
);
},
pre: ({ children, ...props }) => children,
p: ({ children, ...props }) => (
<p className="text-slate-900 dark:text-slate-100 mb-4" {...props}>{children}</p>
),
h1: ({ children, ...props }) => (
<h1 className="text-slate-900 dark:text-slate-100 text-4xl font-bold mt-6 mb-4" {...props}>{children}</h1>
),
h2: ({ children, ...props }) => (
<h2 className="text-slate-900 dark:text-slate-100 text-3xl font-bold mt-5 mb-3" {...props}>{children}</h2>
),
h3: ({ children, ...props }) => (
<h3 className="text-slate-900 dark:text-slate-100 text-2xl font-bold mt-4 mb-3" {...props}>{children}</h3>
),
h4: ({ children, ...props }) => (
<h4 className="text-slate-900 dark:text-slate-100 text-xl font-bold mt-4 mb-2" {...props}>{children}</h4>
),
h5: ({ children, ...props }) => (
<h5 className="text-slate-900 dark:text-slate-100 text-lg font-bold mt-3 mb-2" {...props}>{children}</h5>
),
h6: ({ children, ...props }) => (
<h6 className="text-slate-900 dark:text-slate-100 text-base font-bold mt-3 mb-2" {...props}>{children}</h6>
),
a: ({ href, children, ...props }) => (
<a href={href} className="text-blue-600 dark:text-blue-400 hover:underline" {...props}>{children}</a>
),
ul: ({ children, ...props }) => (
<ul className="text-slate-900 dark:text-slate-100 list-disc pl-5 mb-4 space-y-1" {...props}>{children}</ul>
),
ol: ({ children, ...props }) => (
<ol className="text-slate-900 dark:text-slate-100 list-decimal pl-5 mb-4 space-y-1" {...props}>{children}</ol>
),
li: ({ children, ...props }) => (
<li className="text-slate-900 dark:text-slate-100" {...props}>{children}</li>
),
blockquote: ({ children, ...props }) => (
<blockquote className="border-l-4 border-slate-300 dark:border-slate-600 pl-4 italic text-slate-700 dark:text-slate-300 my-4" {...props}>{children}</blockquote>
),
table: ({ children, ...props }) => (
<div className="overflow-x-auto mb-4">
<table className="min-w-full divide-y divide-slate-300 dark:divide-slate-600" {...props}>
{children}
</table>
</div>
),
thead: ({ children, ...props }) => (
<thead className="bg-slate-50 dark:bg-slate-800" {...props}>{children}</thead>
),
tbody: ({ children, ...props }) => (
<tbody className="divide-y divide-slate-200 dark:divide-slate-700" {...props}>{children}</tbody>
),
tr: ({ children, ...props }) => (
<tr {...props}>{children}</tr>
),
th: ({ children, ...props }) => (
<th className="px-3 py-2 text-left text-sm font-semibold text-slate-900 dark:text-slate-100" {...props}>{children}</th>
),
td: ({ children, ...props }) => (
<td className="px-3 py-2 text-sm text-slate-900 dark:text-slate-100" {...props}>{children}</td>
),
img: ({ src, alt, ...props }) => (
<img src={src} alt={alt} className="max-w-full h-auto rounded-lg my-4" {...props} />
),
hr: (props) => (
<hr className="border-t border-slate-300 dark:border-slate-600 my-8" {...props} />
),
}}
>
{value}
</ReactMarkdown>
</div>
</div>
)}
</div>
<div className="p-2 text-xs text-slate-500 dark:text-slate-400 border-t border-slate-300 dark:border-slate-600">
{t('editor.dropToUpload')}
</div>
</div>
);
};
export default MarkdownEditor;

View file

@ -0,0 +1,26 @@
import { createContext, useContext, FC, ReactNode, useState } from 'react';
interface PageTitleContextType {
title: string;
setTitle: (title: string) => void;
}
const PageTitleContext = createContext<PageTitleContextType | undefined>(undefined);
export const PageTitleProvider: FC<{ children: ReactNode }> = ({ children }) => {
const [title, setTitle] = useState('');
return (
<PageTitleContext.Provider value={{ title, setTitle }}>
{children}
</PageTitleContext.Provider>
);
};
export const usePageTitle = () => {
const context = useContext(PageTitleContext);
if (context === undefined) {
throw new Error('usePageTitle must be used within a PageTitleProvider');
}
return context;
};

View file

@ -0,0 +1,30 @@
import { useCallback } from 'react';
import { toast, ToastOptions } from 'react-toastify';
type ToastType = 'success' | 'error' | 'info' | 'warning';
const darkModeToastStyle: ToastOptions = {
theme: 'dark',
style: {
background: '#1e293b', // slate-800
color: '#f1f5f9', // slate-100
},
};
export const useToast = () => {
const isDarkMode = document.documentElement.classList.contains('dark');
const showToast = useCallback((type: ToastType, message: string) => {
toast[type](message, {
position: 'bottom-right',
autoClose: 3000,
hideProgressBar: false,
closeOnClick: true,
pauseOnHover: true,
draggable: true,
...(isDarkMode ? darkModeToastStyle : {}),
});
}, [isDarkMode]);
return { showToast };
};

View file

@ -19,6 +19,7 @@ import { useTheme } from '../../../hooks/useTheme';
import { Suspense } from 'react';
import LoadingSpinner from '../../../components/LoadingSpinner';
import { useUser } from '../../../contexts/UserContext';
import { PageTitleProvider, usePageTitle } from '../../../contexts/PageTitleContext';
interface AdminLayoutProps {}
@ -56,22 +57,58 @@ const languageMap: LanguageMap = {
'zh-Hant': 'en'
};
const AdminLayout: FC<AdminLayoutProps> = () => {
const AdminLayoutContent: FC = () => {
const { t, i18n } = useTranslation();
const location = useLocation();
const navigate = useNavigate();
const { theme, setTheme } = useTheme();
const { user, loading, error } = useUser();
const { user, loading, error, fetchUser } = useUser();
const { title } = usePageTitle();
useEffect(() => {
console.log('AdminLayout user:', user);
console.log('AdminLayout loading:', loading);
console.log('AdminLayout error:', error);
}, [user, loading, error]);
// 如果没有 token重定向到登录页
if (!localStorage.getItem('token')) {
navigate('/admin/login');
return;
}
const handleLogout = () => {
localStorage.removeItem('token');
navigate('/admin/login');
// 如果没有用户信息且没有在加载中,尝试获取用户信息
if (!user && !loading && !error) {
fetchUser();
}
// 如果获取用户信息出错,可能是 token 过期,重定向到登录页
if (error) {
localStorage.removeItem('token');
navigate('/admin/login');
}
}, [user, loading, error, navigate, fetchUser]);
const handleLogout = async () => {
try {
// 调用后端登出接口
const token = localStorage.getItem('token');
if (token) {
await fetch('/api/v1/auth/logout', {
method: 'POST',
headers: {
'Authorization': `Bearer ${token}`,
},
});
}
} catch (err) {
console.error('Logout error:', err);
} finally {
// 清除所有认证相关的存储数据
localStorage.removeItem('token');
localStorage.removeItem('username');
// 重置用户状态
await fetchUser();
// 重定向到登录页
navigate('/admin/login');
}
};
return (
@ -161,7 +198,7 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
<div className="h-16 px-8 flex items-center justify-between">
<div>
<h2 className="text-2xl font-bold text-slate-800 dark:text-white">
{t(menuItems.find(item => item.path === location.pathname)?.label || 'admin.nav.dashboard')}
{title || t(menuItems.find(item => item.path === location.pathname)?.label || 'admin.nav.dashboard')}
</h2>
</div>
<div className="flex items-center gap-3">
@ -178,7 +215,7 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
</header>
<div className="flex-1 p-6">
<div className="h-full bg-white dark:bg-slate-800 rounded-lg shadow-sm border border-slate-200/60 dark:border-slate-700/60">
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Outlet />
</Suspense>
</div>
@ -189,4 +226,12 @@ const AdminLayout: FC<AdminLayoutProps> = () => {
);
};
const AdminLayout: FC<AdminLayoutProps> = () => {
return (
<PageTitleProvider>
<AdminLayoutContent />
</PageTitleProvider>
);
};
export default AdminLayout;

View file

@ -4,6 +4,7 @@ import { FiUser, FiLock, FiSun, FiMoon, FiMonitor, FiGlobe } from 'react-icons/f
import { useTranslation } from 'react-i18next';
import { useTheme } from '../../hooks/useTheme';
import { Menu } from '@headlessui/react';
import { useUser } from '../../contexts/UserContext';
interface LoginFormData {
username: string;
@ -27,6 +28,7 @@ export default function Login() {
const navigate = useNavigate();
const { t, i18n } = useTranslation();
const { theme, setTheme } = useTheme();
const { fetchUser } = useUser();
const [formData, setFormData] = useState<LoginFormData>({
username: '',
password: '',
@ -66,6 +68,9 @@ export default function Login() {
localStorage.removeItem('username');
}
// 获取用户信息
await fetchUser();
// 跳转到管理面板
navigate('/admin');
} catch (err) {

View file

@ -0,0 +1,296 @@
import { useState, useEffect, FC } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate, useParams, useBlocker } from 'react-router-dom';
import LoadingSpinner from '../../../components/LoadingSpinner';
import { usePageTitle } from '../../../contexts/PageTitleContext';
import MarkdownEditor from '../../../components/admin/MarkdownEditor';
import classNames from 'classnames';
interface PostContent {
language_code: 'en' | 'zh-Hans' | 'zh-Hant';
title: string;
content_markdown: string;
summary?: string;
meta_keywords?: string;
meta_description?: string;
}
interface Category {
id: number;
contents: Array<{
language_code: string;
name: string;
slug: string;
description?: string;
}>;
}
interface Post {
id?: number;
slug: string;
status: 'draft' | 'published';
contents: PostContent[];
categories: Category[];
created_at?: string;
updated_at?: string;
}
const LANGUAGES = [
{ code: 'en', label: 'English' },
{ code: 'zh-Hans', label: '简体中文' },
{ code: 'zh-Hant', label: '繁體中文' }
] as const;
const PostEditor: FC = () => {
const { t } = useTranslation();
const navigate = useNavigate();
const { postId } = useParams<{ postId: string }>();
const isEditing = !!postId;
const { setTitle } = usePageTitle();
const [loading, setLoading] = useState(false);
const [saving, setSaving] = useState(false);
const [isDirty, setIsDirty] = useState(false);
const [activeTab, setActiveTab] = useState<'en' | 'zh-Hans' | 'zh-Hant'>('en');
const [post, setPost] = useState<Post>({
slug: '',
status: 'draft',
contents: [
{
language_code: 'en',
title: '',
content_markdown: ''
},
{
language_code: 'zh-Hans',
title: '',
content_markdown: ''
},
{
language_code: 'zh-Hant',
title: '',
content_markdown: ''
}
],
categories: []
});
const blocker = useBlocker(
({ currentLocation, nextLocation }) =>
isDirty && currentLocation.pathname !== nextLocation.pathname
);
useEffect(() => {
if (blocker.state === 'blocked') {
const confirmed = window.confirm(t('admin.common.unsavedChanges'));
if (confirmed) {
blocker.proceed();
} else {
blocker.reset();
}
}
}, [blocker, t]);
useEffect(() => {
setTitle(isEditing ? t('admin.posts.edit') : t('admin.posts.create'));
}, [isEditing, t, setTitle]);
useEffect(() => {
if (isEditing) {
fetchPost();
}
}, [postId]);
const fetchPost = async () => {
try {
setLoading(true);
const response = await fetch(`/api/v1/posts/${postId}`);
if (!response.ok) throw new Error('Failed to fetch post');
const data = await response.json();
setPost(data.data);
setIsDirty(false);
} catch (error) {
console.error('Error fetching post:', error);
} finally {
setLoading(false);
}
};
const handleSubmit = async (publish: boolean = false) => {
try {
setSaving(true);
const method = isEditing ? 'PUT' : 'POST';
const url = isEditing ? `/api/v1/posts/${postId}` : '/api/v1/posts';
const response = await fetch(url, {
method,
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
...post,
status: publish ? 'published' : 'draft'
}),
});
if (!response.ok) throw new Error('Failed to save post');
setIsDirty(false);
navigate('/admin/posts');
} catch (error) {
console.error('Error saving post:', error);
} finally {
setSaving(false);
}
};
const activeContent = post.contents.find(c => c.language_code === activeTab)!;
const updateContent = (updates: Partial<PostContent>) => {
setIsDirty(true);
const newContents = post.contents.map(content =>
content.language_code === activeTab
? { ...content, ...updates }
: content
);
setPost(prev => ({ ...prev, contents: newContents }));
};
const handleSlugChange = (e: React.ChangeEvent<HTMLInputElement>) => {
setIsDirty(true);
setPost(prev => ({ ...prev, slug: e.target.value }));
};
if (loading) {
return <LoadingSpinner />;
}
return (
<div className="p-6">
<div className="space-y-6">
{/* Slug */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.slug')}
</label>
<input
type="text"
value={post.slug}
onChange={handleSlugChange}
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
/>
</div>
{/* Language Tabs */}
<div className="border-b border-slate-200 dark:border-slate-700">
<nav className="-mb-px flex space-x-8" aria-label="Language">
{LANGUAGES.map(({ code, label }) => (
<button
key={code}
onClick={() => setActiveTab(code)}
className={`whitespace-nowrap py-4 px-1 border-b-2 font-medium text-sm ${
activeTab === code
? 'border-slate-900 dark:border-white text-slate-900 dark:text-white'
: 'border-transparent text-slate-500 dark:text-slate-400 hover:text-slate-700 dark:hover:text-slate-300 hover:border-slate-300 dark:hover:border-slate-600'
}`}
>
{label}
</button>
))}
</nav>
</div>
{/* Content Fields */}
<div className="space-y-4">
{/* Title */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.title')}
</label>
<input
type="text"
value={activeContent.title}
onChange={e => updateContent({ title: e.target.value })}
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
/>
</div>
{/* Content */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.content')}
</label>
<MarkdownEditor
value={activeContent.content_markdown}
onChange={(value) => updateContent({ content_markdown: value })}
placeholder={t('admin.posts.content')}
/>
</div>
{/* Summary */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.summary')}
</label>
<textarea
value={activeContent.summary || ''}
onChange={e => updateContent({ summary: e.target.value })}
rows={3}
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
/>
</div>
{/* Meta Keywords */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.metaKeywords')}
</label>
<input
type="text"
value={activeContent.meta_keywords || ''}
onChange={e => updateContent({ meta_keywords: e.target.value })}
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
/>
</div>
{/* Meta Description */}
<div>
<label className="block text-sm font-medium mb-2 text-slate-900 dark:text-white">
{t('admin.posts.metaDescription')}
</label>
<textarea
value={activeContent.meta_description || ''}
onChange={e => updateContent({ meta_description: e.target.value })}
rows={2}
className="w-full px-3 py-2 bg-white dark:bg-slate-800 text-slate-900 dark:text-white border border-slate-300 dark:border-slate-600 rounded-md focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400"
/>
</div>
</div>
{/* Categories */}
{/* TODO: Add category selection */}
{/* Actions */}
<div className="flex justify-end space-x-4">
<button
type="button"
onClick={() => handleSubmit(false)}
disabled={saving}
className="px-4 py-2 text-sm font-medium text-slate-700 dark:text-slate-300 bg-white dark:bg-slate-800 border border-slate-300 dark:border-slate-600 rounded-md hover:bg-slate-50 dark:hover:bg-slate-700 focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400 disabled:opacity-50"
>
{saving ? t('admin.posts.saving') : t('admin.posts.saveDraft')}
</button>
<button
type="button"
onClick={() => handleSubmit(true)}
disabled={saving}
className="px-4 py-2 text-sm font-medium text-white bg-slate-900 dark:bg-slate-700 rounded-md hover:bg-slate-800 dark:hover:bg-slate-600 focus:outline-none focus:ring-2 focus:ring-slate-500 dark:focus:ring-slate-400 disabled:opacity-50"
>
{saving ? t('admin.posts.publishing') : t('admin.posts.publish')}
</button>
</div>
</div>
</div>
);
};
export default PostEditor;

View file

@ -1,73 +1,183 @@
import { useState } from 'react';
import { useState, useEffect, FC, ReactNode } from 'react';
import { useTranslation } from 'react-i18next';
import Table from '../../../components/admin/Table';
import TableActions from '../../../components/admin/TableActions';
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
import { RiSearchLine } from 'react-icons/ri';
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
import 'dayjs/locale/zh-cn';
import { useNavigate } from 'react-router-dom';
import { usePageTitle } from '../../../contexts/PageTitleContext';
// 初始化 dayjs 插件
dayjs.extend(relativeTime);
dayjs.locale('zh-cn');
interface PostContent {
language_code: 'en' | 'zh-Hans' | 'zh-Hant';
title: string;
content_markdown: string;
summary?: string;
meta_keywords?: string;
meta_description?: string;
}
interface Category {
id: number;
contents: Array<{
language_code: string;
name: string;
slug: string;
description?: string;
}>;
}
interface Post {
id: string;
title: string;
category: string;
publishDate: string;
id: number;
slug: string;
status: 'draft' | 'published';
contents: PostContent[];
categories: Category[];
created_at: string;
updated_at: string;
}
type TablePost = {
id: number;
slug: string;
status: 'draft' | 'published';
created_at: string;
updated_at: string;
displayTitle: string;
displayCategories: string;
};
const PostsManagement: FC = () => {
const [searchTerm, setSearchTerm] = useState('');
const [loading, setLoading] = useState(false);
const [posts, setPosts] = useState<TablePost[]>([]);
const { t } = useTranslation();
const navigate = useNavigate();
const { setTitle } = usePageTitle();
// 这里后续会通过 API 获取数据
const posts: Post[] = [
{
id: '1',
title: '示例文章标题',
category: '示例分类',
publishDate: '2024-02-20',
status: 'published',
},
];
useEffect(() => {
setTitle(t('admin.nav.posts'));
return () => setTitle('');
}, [setTitle, t]);
const handleEdit = (post: Post) => {
console.log('Edit post:', post);
useEffect(() => {
fetchPosts();
}, []);
const fetchPosts = async () => {
try {
setLoading(true);
const response = await fetch('/api/v1/posts?sort=-created_at');
if (!response.ok) throw new Error('Failed to fetch posts');
const data = await response.json();
// 处理数据,添加显示字段
const processedPosts = data.data.map((post: Post): TablePost => ({
id: post.id,
slug: post.slug,
status: post.status,
created_at: post.created_at,
updated_at: post.updated_at,
displayTitle: getPostTitle(post),
displayCategories: getCategoryNames(post)
}));
setPosts(processedPosts);
} catch (error) {
console.error('Error fetching posts:', error);
} finally {
setLoading(false);
}
};
const handleDelete = (post: Post) => {
console.log('Delete post:', post);
const getPostTitle = (post: Post) => {
// Try to get English title first
const englishContent = post.contents.find(c => c.language_code === 'en');
if (englishContent?.title) return englishContent.title;
// Fallback to the first available title
const firstContent = post.contents[0];
return firstContent?.title || t('admin.posts.noTitle');
};
const getCategoryNames = (post: Post) => {
return post.categories
.map(category => {
const englishContent = category.contents.find(c => c.language_code === 'en');
if (englishContent?.name) return englishContent.name;
return category.contents[0]?.name || '';
})
.filter(Boolean)
.join(', ');
};
const handleEdit = async (post: TablePost) => {
navigate(post.slug);
};
const handleDelete = async (post: TablePost) => {
if (!window.confirm(t('admin.posts.deleteConfirm'))) return;
try {
const response = await fetch(`/api/v1/posts/${post.slug}`, {
method: 'DELETE',
});
if (!response.ok) throw new Error('Failed to delete post');
await fetchPosts();
} catch (error) {
console.error('Error deleting post:', error);
}
};
const columns = [
{
key: 'title' as keyof Post,
key: 'displayTitle' as keyof TablePost,
title: t('admin.posts.title'),
render: (value: string): ReactNode => value
},
{
key: 'category' as keyof Post,
title: t('admin.posts.category'),
key: 'displayCategories' as keyof TablePost,
title: t('admin.posts.categories'),
render: (value: string): ReactNode => value
},
{
key: 'publishDate' as keyof Post,
title: t('admin.posts.publishDate'),
},
{
key: 'status' as keyof Post,
title: t('admin.posts.status'),
render: (value: Post['status']) => (
<span
className={`inline-block px-2 py-1 text-xs font-medium rounded-full ${
value === 'published'
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
}`}
>
{t(`admin.common.${value}`)}
key: 'created_at' as keyof TablePost,
title: t('admin.posts.createdAt'),
render: (value: string): ReactNode => (
<span title={dayjs(value).format('YYYY-MM-DD HH:mm:ss')}>
{dayjs(value).fromNow()}
</span>
),
)
},
{
key: 'status' as keyof TablePost,
title: t('admin.posts.status'),
render: (value: string | number): ReactNode => {
const status = value as TablePost['status'];
return (
<span
className={`inline-block px-2 py-1 text-xs font-medium rounded-full ${
status === 'published'
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
}`}
>
{t(`admin.common.${status}`)}
</span>
);
}
}
];
const handleCreate = () => {
// TODO: 实现创建文章的逻辑
navigate('new');
};
return (
@ -89,59 +199,13 @@ const PostsManagement: FC = () => {
</div>
<div>
<Table<Post>
<Table<TablePost>
columns={columns}
data={posts}
loading={loading}
onEdit={handleEdit}
onDelete={handleDelete}
>
{({ data, onEdit, onDelete }) => (
<table className="w-full">
<thead>
<tr className="border-b border-slate-200 dark:border-slate-700">
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tl-md">{t('admin.common.title')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.category')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.publishDate')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.status')}</th>
<th className="text-right py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tr-md">{t('admin.common.actions')}</th>
</tr>
</thead>
<tbody>
{data.map((post) => (
<tr key={post.id} className="border-b border-slate-200 dark:border-slate-700 last:border-0">
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.title}</td>
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.category}</td>
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.publishDate}</td>
<td className="py-4 px-6">
<span className={`px-2 py-1 text-sm font-medium rounded-full ${
post.status === 'published'
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
: 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-300'
}`}>
{t(`admin.common.${post.status}`)}
</span>
</td>
<td className="py-4 px-6 text-right">
<button
className="text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-slate-200 px-2 py-1 rounded-md transition-colors"
onClick={() => onEdit(post)}
>
{t('admin.common.edit')}
</button>
<button
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
onClick={() => onDelete(post)}
>
{t('admin.common.delete')}
</button>
</td>
</tr>
))}
</tbody>
</table>
)}
</Table>
/>
</div>
</div>
);

View file

@ -14,6 +14,7 @@ const Footer = lazy(() => import('./components/Footer'));
// 管理页面组件
const Dashboard = lazy(() => import('./pages/admin/dashboard/Dashboard'));
const PostsManagement = lazy(() => import('./pages/admin/posts/PostsManagement'));
const PostEditor = lazy(() => import('./pages/admin/posts/PostEditor'));
const DailyManagement = lazy(() => import('./pages/admin/daily/DailyManagement'));
const MediasManagement = lazy(() => import('./pages/admin/medias/MediasManagement'));
const CategoriesManagement = lazy(() => import('./pages/admin/categories/CategoriesManagement'));
@ -56,7 +57,7 @@ const LoginRoute = () => {
}
return (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Login />
</Suspense>
);
@ -65,7 +66,7 @@ const LoginRoute = () => {
// 页面布局组件
const PageLayout = () => (
<div className="flex flex-col min-h-screen bg-white dark:bg-neutral-900 text-gray-900 dark:text-gray-100">
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Header />
<div className="w-[95%] mx-auto">
<div className="border-t-2 border-gray-900 dark:border-gray-100 w-full mb-2" />
@ -82,7 +83,7 @@ const router = createBrowserRouter([
{
path: '/',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<PageLayout />
</Suspense>
),
@ -90,7 +91,7 @@ const router = createBrowserRouter([
{
index: true,
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Home />
</Suspense>
),
@ -98,7 +99,7 @@ const router = createBrowserRouter([
{
path: 'daily',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Daily />
</Suspense>
),
@ -106,7 +107,7 @@ const router = createBrowserRouter([
{
path: 'posts/:articleId',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Article />
</Suspense>
),
@ -133,23 +134,44 @@ const router = createBrowserRouter([
{
index: true,
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Dashboard />
</Suspense>
),
},
{
path: 'posts',
element: (
<Suspense fallback={<LoadingSpinner />}>
<PostsManagement />
</Suspense>
),
children: [
{
index: true,
element: (
<Suspense fallback={<LoadingSpinner fullScreen />}>
<PostsManagement />
</Suspense>
),
},
{
path: 'new',
element: (
<Suspense fallback={<LoadingSpinner fullScreen />}>
<PostEditor />
</Suspense>
),
},
{
path: ':postId',
element: (
<Suspense fallback={<LoadingSpinner fullScreen />}>
<PostEditor />
</Suspense>
),
},
],
},
{
path: 'daily',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<DailyManagement />
</Suspense>
),
@ -157,7 +179,7 @@ const router = createBrowserRouter([
{
path: 'medias',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<MediasManagement />
</Suspense>
),
@ -165,7 +187,7 @@ const router = createBrowserRouter([
{
path: 'categories',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<CategoriesManagement />
</Suspense>
),
@ -173,7 +195,7 @@ const router = createBrowserRouter([
{
path: 'users',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<UsersManagement />
</Suspense>
),
@ -181,7 +203,7 @@ const router = createBrowserRouter([
{
path: 'contributors',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<ContributorsManagement />
</Suspense>
),
@ -189,7 +211,7 @@ const router = createBrowserRouter([
{
path: 'settings',
element: (
<Suspense fallback={<LoadingSpinner />}>
<Suspense fallback={<LoadingSpinner fullScreen />}>
<Settings />
</Suspense>
),

View file

@ -1,4 +1,9 @@
{
"compilerOptions": {
"paths": {
"@/*": ["./src/*"]
}
},
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },

View file

@ -1,6 +1,7 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
import tailwindcss from "@tailwindcss/vite";
import path from 'path';
// https://vitejs.dev/config/
export default defineConfig({
@ -8,6 +9,11 @@ export default defineConfig({
react(),
tailwindcss()
],
resolve: {
alias: {
'@': path.resolve(__dirname, './src')
}
},
optimizeDeps: {
exclude: ['lucide-react'],
},
@ -18,6 +24,11 @@ export default defineConfig({
changeOrigin: true,
secure: false,
},
'/media': {
target: 'http://localhost:8080',
changeOrigin: true,
secure: false,
},
},
},
});

1141
pnpm-lock.yaml generated

File diff suppressed because it is too large Load diff