[feature] migrate to monorepo
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s

This commit is contained in:
CDN 2025-02-21 00:49:20 +08:00
commit 05ddc1f783
Signed by: CDN
GPG key ID: 0C656827F9F80080
267 changed files with 75165 additions and 0 deletions

View file

@ -0,0 +1,6 @@
package auth
// Constants for auth-related context keys
const (
UserIDKey = "user_id"
)

View file

@ -0,0 +1,27 @@
package auth
import (
"context"
"testing"
)
func TestUserIDKey(t *testing.T) {
// Test that the UserIDKey constant is defined correctly
if UserIDKey != "user_id" {
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
}
// Test context with user ID
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
value := ctx.Value(UserIDKey)
if value != "test-user-123" {
t.Errorf("Context value = %v, want %v", value, "test-user-123")
}
// Test context without user ID
emptyCtx := context.Background()
emptyValue := emptyCtx.Value(UserIDKey)
if emptyValue != nil {
t.Errorf("Empty context value = %v, want nil", emptyValue)
}
}

View file

@ -0,0 +1,74 @@
package config
import (
"os"
"gopkg.in/yaml.v3"
"tss-rocks-be/internal/types"
)
type Config struct {
Database DatabaseConfig `yaml:"database"`
Server ServerConfig `yaml:"server"`
JWT JWTConfig `yaml:"jwt"`
Storage StorageConfig `yaml:"storage"`
Logging LoggingConfig `yaml:"logging"`
RateLimit types.RateLimitConfig `yaml:"rate_limit"`
AccessLog types.AccessLogConfig `yaml:"access_log"`
}
type DatabaseConfig struct {
Driver string `yaml:"driver"`
DSN string `yaml:"dsn"`
}
type ServerConfig struct {
Port int `yaml:"port"`
Host string `yaml:"host"`
}
type JWTConfig struct {
Secret string `yaml:"secret"`
Expiration string `yaml:"expiration"`
}
type LoggingConfig struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}
type StorageConfig struct {
Type string `yaml:"type"`
Local LocalStorage `yaml:"local"`
S3 S3Storage `yaml:"s3"`
Upload types.UploadConfig `yaml:"upload"`
}
type LocalStorage struct {
RootDir string `yaml:"root_dir"`
}
type S3Storage struct {
Region string `yaml:"region"`
Bucket string `yaml:"bucket"`
AccessKeyID string `yaml:"access_key_id"`
SecretAccessKey string `yaml:"secret_access_key"`
Endpoint string `yaml:"endpoint"`
CustomURL string `yaml:"custom_url"`
ProxyS3 bool `yaml:"proxy_s3"`
}
// Load loads configuration from a YAML file
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}

View file

@ -0,0 +1,85 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
// Create a temporary test config file
content := []byte(`
database:
driver: postgres
dsn: postgres://user:pass@localhost:5432/dbname
server:
port: 8080
host: localhost
jwt:
secret: test-secret
expiration: 24h
storage:
type: local
local:
root_dir: /tmp/storage
upload:
max_size: 10485760
logging:
level: info
format: json
`)
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, content, 0644); err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Test loading config
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
want interface{}
errorMsg string
}{
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.want {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
}
})
}
}
func TestLoadError(t *testing.T) {
// Test loading non-existent file
_, err := Load("non-existent-file.yaml")
if err == nil {
t.Error("Load() error = nil, want error for non-existent file")
}
// Test loading invalid YAML
tmpDir := t.TempDir()
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
t.Fatalf("Failed to write invalid config: %v", err)
}
_, err = Load(invalidPath)
if err == nil {
t.Error("Load() error = nil, want error for invalid YAML")
}
}

View file

@ -0,0 +1,119 @@
package handler
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor contributor"`
}
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
type AuthResponse struct {
Token string `json:"token"`
}
func (h *Handler) Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role)
if err != nil {
log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// Get user roles
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil {
log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
return
}
// Extract role names for JWT
roleNames := make([]string, len(roles))
for i, r := range roles {
roleNames[i] = r.Name
}
// Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": user.ID,
"roles": roleNames,
"exp": time.Now().Add(24 * time.Hour).Unix(),
})
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil {
log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusCreated, AuthResponse{Token: tokenString})
}
func (h *Handler) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
return
}
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
return
}
// Get user roles
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil {
log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
return
}
// Extract role names for JWT
roleNames := make([]string, len(roles))
for i, r := range roles {
roleNames[i] = r.Name
}
// Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": user.ID,
"roles": roleNames,
"exp": time.Now().Add(24 * time.Hour).Unix(),
})
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil {
log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusOK, AuthResponse{Token: tokenString})
}

View file

@ -0,0 +1,276 @@
package handler
import (
"bytes"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)
type AuthHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *AuthHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}, s.service)
s.router = gin.New()
}
func (s *AuthHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite))
}
func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct {
name string
request RegisterRequest
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功注册",
request: RegisterRequest{
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT().
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor").
Return(user, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusCreated,
},
{
name: "无效的邮箱格式",
request: RegisterRequest{
Email: "invalid-email",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
},
{
name: "密码太短",
request: RegisterRequest{
Email: "test@example.com",
Password: "short",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
},
{
name: "无效的角色",
request: RegisterRequest{
Email: "test@example.com",
Password: "password123",
Role: "invalid-role",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Register(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response["error"], tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}
func (s *AuthHandlerTestSuite) TestLogin() {
testCases := []struct {
name string
request LoginRequest
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功登录",
request: LoginRequest{
Email: "test@example.com",
Password: "password123",
},
setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的邮箱格式",
request: LoginRequest{
Email: "invalid-email",
Password: "password123",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
},
{
name: "用户不存在",
request: LoginRequest{
Email: "nonexistent@example.com",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "nonexistent@example.com").
Return(nil, errors.New("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials",
},
{
name: "密码错误",
request: LoginRequest{
Email: "test@example.com",
Password: "wrong-password",
},
setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "wrong-password").
Return(false)
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials",
},
{
name: "获取用户角色失败",
request: LoginRequest{
Email: "test@example.com",
Password: "password123",
},
setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return(nil, errors.New("failed to get roles"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Login(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response["error"], tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}

View file

@ -0,0 +1,468 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
// Custom assertion function for comparing categories
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
if expected == nil && actual == nil {
return true
}
if expected == nil || actual == nil {
return assert.Fail(t, "One category is nil while the other is not")
}
// Compare only relevant fields, ignoring time fields
return assert.Equal(t, expected.ID, actual.ID) &&
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
}
// Custom assertion function for comparing category slices
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
if len(expected) != len(actual) {
return assert.Fail(t, "Category slice lengths do not match")
}
for i := range expected {
if !assertCategoryEqual(t, expected[i], actual[i]) {
return false
}
}
return true
}
type CategoryHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *CategoryHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
s.handler.RegisterRoutes(s.router)
}
func (s *CategoryHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestCategoryHandlerSuite(t *testing.T) {
suite.Run(t, new(CategoryHandlerTestSuite))
}
// Test cases for ListCategories
func (s *CategoryHandlerTestSuite) TestListCategories() {
testCases := []struct {
name string
langCode string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("en")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("zh")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
}
})
}
}
// Test cases for GetCategory
func (s *CategoryHandlerTestSuite) TestGetCategory() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
langCode: "en",
slug: "test-category",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
Return(&ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
{
name: "Not Found",
langCode: "en",
slug: "non-existent",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
Return(nil, types.ErrNotFound)
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
}
})
}
}
// Test cases for AddCategoryContent
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
var description = "Test Description"
testCases := []struct {
name string
categoryID string
requestBody interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(&ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
},
},
{
name: "Invalid JSON",
categoryID: "1",
requestBody: "invalid json",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid Category ID",
categoryID: "invalid",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Service Error",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
var body []byte
var err error
if str, ok := tc.requestBody.(string); ok {
body = []byte(str)
} else {
body, err = json.Marshal(tc.requestBody)
s.NoError(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.CategoryContent
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, &response)
}
})
}
}
// Test cases for CreateCategory
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功创建分类",
setupMock: func() {
category := &ent.Category{
ID: 1,
}
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(category, nil)
},
expectedStatus: http.StatusCreated,
},
{
name: "创建分类失败",
setupMock: func() {
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(nil, errors.New("failed to create category"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to create category",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.CreateCategory(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
s.Equal(1, response.ID)
}
})
}
}

View file

@ -0,0 +1,443 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
type ContributorHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *ContributorHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
s.handler.RegisterRoutes(s.router)
}
func (s *ContributorHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestContributorHandlerSuite(t *testing.T) {
suite.Run(t, new(ContributorHandlerTestSuite))
}
func (s *ContributorHandlerTestSuite) TestListContributors() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return([]*ent.Contributor{
{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
{
ID: 2,
Name: "Jane Smith",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []gin.H{
{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
{
"id": 2,
"name": "Jane Smith",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list contributors"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestGetContributor() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(&ent.Contributor{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
},
{
name: "Invalid ID",
id: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Service error",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(&ent.Contributor{
ID: 1,
Name: name,
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"name": "", // Empty name is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
testCases := []struct {
name string
id string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(&ent.ContributorSocialLink{
Type: "github",
Name: name,
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"type": "github",
"name": "johndoe",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
{
name: "Invalid contributor ID",
id: "invalid",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Invalid request body",
id: "1",
body: map[string]interface{}{
"type": "", // Empty type is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
},
{
name: "Service error",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add contributor social link"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -0,0 +1,519 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type DailyHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *DailyHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
s.handler.RegisterRoutes(s.router)
}
func (s *DailyHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestDailyHandlerSuite(t *testing.T) {
suite.Run(t, new(DailyHandlerTestSuite))
}
func (s *DailyHandlerTestSuite) TestListDailies() {
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
},
},
{
name: "Success with category filter",
categoryID: "1",
setupMock: func() {
categoryID := 1
s.service.EXPECT().
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with pagination",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
},
},
{
name: "Service Error",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list dailies"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/dailies"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestGetDaily() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
{
name: "Service error",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestCreateDaily() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"id": "daily1",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
testCases := []struct {
name string
dailyID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(&ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
{
name: "Invalid request body",
dailyID: "daily1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
},
{
name: "Service error",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add daily content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -0,0 +1,513 @@
package handler
import (
"net/http"
"strconv"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
type Handler struct {
cfg *config.Config
service service.Service
}
func NewHandler(cfg *config.Config, service service.Service) *Handler {
return &Handler{
cfg: cfg,
service: service,
}
}
// RegisterRoutes registers all the routes
func (h *Handler) RegisterRoutes(r *gin.Engine) {
api := r.Group("/api/v1")
{
// Auth routes
auth := api.Group("/auth")
{
auth.POST("/register", h.Register)
auth.POST("/login", h.Login)
}
// Category routes
categories := api.Group("/categories")
{
categories.GET("", h.ListCategories)
categories.GET("/:slug", h.GetCategory)
categories.POST("", h.CreateCategory)
categories.POST("/:id/contents", h.AddCategoryContent)
}
// Post routes
posts := api.Group("/posts")
{
posts.GET("", h.ListPosts)
posts.GET("/:slug", h.GetPost)
posts.POST("", h.CreatePost)
posts.POST("/:id/contents", h.AddPostContent)
}
// Contributor routes
contributors := api.Group("/contributors")
{
contributors.GET("", h.ListContributors)
contributors.GET("/:id", h.GetContributor)
contributors.POST("", h.CreateContributor)
contributors.POST("/:id/social-links", h.AddContributorSocialLink)
}
// Daily routes
dailies := api.Group("/dailies")
{
dailies.GET("", h.ListDailies)
dailies.GET("/:id", h.GetDaily)
dailies.POST("", h.CreateDaily)
dailies.POST("/:id/contents", h.AddDailyContent)
}
// Media routes
media := api.Group("/media")
{
media.GET("", h.ListMedia)
media.POST("", h.UploadMedia)
media.GET("/:id", h.GetMedia)
media.DELETE("/:id", h.DeleteMedia)
}
}
}
// Category handlers
func (h *Handler) ListCategories(c *gin.Context) {
langCode := c.Query("lang")
if langCode == "" {
langCode = "en" // Default to English
}
categories, err := h.service.ListCategories(c.Request.Context(), langCode)
if err != nil {
log.Error().Err(err).Msg("Failed to list categories")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list categories"})
return
}
c.JSON(http.StatusOK, categories)
}
func (h *Handler) GetCategory(c *gin.Context) {
langCode := c.Query("lang")
if langCode == "" {
langCode = "en" // Default to English
}
slug := c.Param("slug")
category, err := h.service.GetCategoryBySlug(c.Request.Context(), langCode, slug)
if err != nil {
if err == types.ErrNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "Category not found"})
return
}
log.Error().Err(err).Msg("Failed to get category")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get category"})
return
}
c.JSON(http.StatusOK, category)
}
func (h *Handler) CreateCategory(c *gin.Context) {
category, err := h.service.CreateCategory(c.Request.Context())
if err != nil {
log.Error().Err(err).Msg("Failed to create category")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create category"})
return
}
c.JSON(http.StatusCreated, category)
}
type AddCategoryContentRequest struct {
LanguageCode string `json:"language_code" binding:"required"`
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
Slug string `json:"slug" binding:"required"`
}
func (h *Handler) AddCategoryContent(c *gin.Context) {
var req AddCategoryContentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
categoryID, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid category ID"})
return
}
var description string
if req.Description != nil {
description = *req.Description
}
content, err := h.service.AddCategoryContent(c.Request.Context(), categoryID, req.LanguageCode, req.Name, description, req.Slug)
if err != nil {
log.Error().Err(err).Msg("Failed to add category content")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add category content"})
return
}
c.JSON(http.StatusCreated, content)
}
// Post handlers
func (h *Handler) ListPosts(c *gin.Context) {
langCode := c.Query("lang")
if langCode == "" {
langCode = "en" // Default to English
}
var categoryID *int
if catIDStr := c.Query("category_id"); catIDStr != "" {
if id, err := strconv.Atoi(catIDStr); err == nil {
categoryID = &id
}
}
limit := 10 // Default limit
if limitStr := c.Query("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offset := 0 // Default offset
if offsetStr := c.Query("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset)
if err != nil {
log.Error().Err(err).Msg("Failed to list posts")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"})
return
}
c.JSON(http.StatusOK, posts)
}
func (h *Handler) GetPost(c *gin.Context) {
langCode := c.Query("lang")
if langCode == "" {
langCode = "en" // Default to English
}
slug := c.Param("slug")
post, err := h.service.GetPostBySlug(c.Request.Context(), langCode, slug)
if err != nil {
log.Error().Err(err).Msg("Failed to get post")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get post"})
return
}
// Convert to a map to control the fields
response := gin.H{
"id": post.ID,
"status": post.Status,
"slug": post.Slug,
"edges": gin.H{
"contents": []gin.H{},
},
}
contents := make([]gin.H, 0, len(post.Edges.Contents))
for _, content := range post.Edges.Contents {
contents = append(contents, gin.H{
"language_code": content.LanguageCode,
"title": content.Title,
"content_markdown": content.ContentMarkdown,
"summary": content.Summary,
})
}
response["edges"].(gin.H)["contents"] = contents
c.JSON(http.StatusOK, response)
}
func (h *Handler) CreatePost(c *gin.Context) {
post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status
if err != nil {
log.Error().Err(err).Msg("Failed to create post")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"})
return
}
// Convert to a map to control the fields
response := gin.H{
"id": post.ID,
"status": post.Status,
"edges": gin.H{
"contents": []interface{}{},
},
}
c.JSON(http.StatusCreated, response)
}
type AddPostContentRequest struct {
LanguageCode string `json:"language_code" binding:"required"`
Title string `json:"title" binding:"required"`
ContentMarkdown string `json:"content_markdown" binding:"required"`
Summary string `json:"summary" binding:"required"`
MetaKeywords string `json:"meta_keywords"`
MetaDescription string `json:"meta_description"`
}
func (h *Handler) AddPostContent(c *gin.Context) {
var req AddPostContentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
postID, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid post ID"})
return
}
content, err := h.service.AddPostContent(c.Request.Context(), postID, req.LanguageCode, req.Title, req.ContentMarkdown, req.Summary, req.MetaKeywords, req.MetaDescription)
if err != nil {
log.Error().Err(err).Msg("Failed to add post content")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add post content"})
return
}
c.JSON(http.StatusCreated, gin.H{
"title": content.Title,
"content_markdown": content.ContentMarkdown,
"language_code": content.LanguageCode,
"summary": content.Summary,
"meta_keywords": content.MetaKeywords,
"meta_description": content.MetaDescription,
"edges": gin.H{},
})
}
// Contributor handlers
func (h *Handler) ListContributors(c *gin.Context) {
contributors, err := h.service.ListContributors(c.Request.Context())
if err != nil {
log.Error().Err(err).Msg("Failed to list contributors")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list contributors"})
return
}
response := make([]gin.H, len(contributors))
for i, contributor := range contributors {
socialLinks := make([]gin.H, len(contributor.Edges.SocialLinks))
for j, link := range contributor.Edges.SocialLinks {
socialLinks[j] = gin.H{
"type": link.Type,
"value": link.Value,
"edges": gin.H{},
}
}
response[i] = gin.H{
"id": contributor.ID,
"name": contributor.Name,
"created_at": contributor.CreatedAt,
"updated_at": contributor.UpdatedAt,
"edges": gin.H{
"social_links": socialLinks,
},
}
}
c.JSON(http.StatusOK, response)
}
func (h *Handler) GetContributor(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contributor ID"})
return
}
contributor, err := h.service.GetContributorByID(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get contributor")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get contributor"})
return
}
c.JSON(http.StatusOK, contributor)
}
type CreateContributorRequest struct {
Name string `json:"name" binding:"required"`
AvatarURL *string `json:"avatar_url"`
Bio *string `json:"bio"`
}
func (h *Handler) CreateContributor(c *gin.Context) {
var req CreateContributorRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
contributor, err := h.service.CreateContributor(c.Request.Context(), req.Name, req.AvatarURL, req.Bio)
if err != nil {
log.Error().Err(err).Msg("Failed to create contributor")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create contributor"})
return
}
c.JSON(http.StatusCreated, contributor)
}
type AddContributorSocialLinkRequest struct {
Type string `json:"type" binding:"required"`
Name *string `json:"name"`
Value string `json:"value" binding:"required"`
}
func (h *Handler) AddContributorSocialLink(c *gin.Context) {
var req AddContributorSocialLinkRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
contributorID, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contributor ID"})
return
}
name := ""
if req.Name != nil {
name = *req.Name
}
link, err := h.service.AddContributorSocialLink(c.Request.Context(), contributorID, req.Type, name, req.Value)
if err != nil {
log.Error().Err(err).Msg("Failed to add contributor social link")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add contributor social link"})
return
}
c.JSON(http.StatusCreated, link)
}
// Daily handlers
func (h *Handler) ListDailies(c *gin.Context) {
langCode := c.Query("lang")
if langCode == "" {
langCode = "en" // Default to English
}
var categoryID *int
if catIDStr := c.Query("category_id"); catIDStr != "" {
if id, err := strconv.Atoi(catIDStr); err == nil {
categoryID = &id
}
}
limit := 10 // Default limit
if limitStr := c.Query("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offset := 0 // Default offset
if offsetStr := c.Query("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
dailies, err := h.service.ListDailies(c.Request.Context(), langCode, categoryID, limit, offset)
if err != nil {
log.Error().Err(err).Msg("Failed to list dailies")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list dailies"})
return
}
c.JSON(http.StatusOK, dailies)
}
func (h *Handler) GetDaily(c *gin.Context) {
id := c.Param("id")
daily, err := h.service.GetDailyByID(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get daily")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get daily"})
return
}
c.JSON(http.StatusOK, daily)
}
type CreateDailyRequest struct {
ID string `json:"id" binding:"required"`
CategoryID int `json:"category_id" binding:"required"`
ImageURL string `json:"image_url" binding:"required"`
}
func (h *Handler) CreateDaily(c *gin.Context) {
var req CreateDailyRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
daily, err := h.service.CreateDaily(c.Request.Context(), req.ID, req.CategoryID, req.ImageURL)
if err != nil {
log.Error().Err(err).Msg("Failed to create daily")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create daily"})
return
}
c.JSON(http.StatusCreated, daily)
}
type AddDailyContentRequest struct {
LanguageCode string `json:"language_code" binding:"required"`
Quote string `json:"quote" binding:"required"`
}
func (h *Handler) AddDailyContent(c *gin.Context) {
var req AddDailyContentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
dailyID := c.Param("id")
content, err := h.service.AddDailyContent(c.Request.Context(), dailyID, req.LanguageCode, req.Quote)
if err != nil {
log.Error().Err(err).Msg("Failed to add daily content")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add daily content"})
return
}
c.JSON(http.StatusCreated, content)
}
// Helper functions
func stringPtr(s *string) string {
if s == nil {
return ""
}
return *s
}

View file

@ -0,0 +1,43 @@
package handler
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStringPtr(t *testing.T) {
testCases := []struct {
name string
input *string
expected string
}{
{
name: "nil pointer",
input: nil,
expected: "",
},
{
name: "empty string",
input: strPtr(""),
expected: "",
},
{
name: "non-empty string",
input: strPtr("test"),
expected: "test",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := stringPtr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// Helper function to create string pointer
func strPtr(s string) *string {
return &s
}

View file

@ -0,0 +1,173 @@
package handler
import (
"fmt"
"io"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
// Media handlers
func (h *Handler) ListMedia(c *gin.Context) {
limit := 10 // Default limit
if limitStr := c.Query("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offset := 0 // Default offset
if offsetStr := c.Query("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
media, err := h.service.ListMedia(c.Request.Context(), limit, offset)
if err != nil {
log.Error().Err(err).Msg("Failed to list media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list media"})
return
}
c.JSON(http.StatusOK, media)
}
func (h *Handler) UploadMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Get file from form
file, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
return
}
// 文件大小限制
if file.Size > 10*1024*1024 { // 10MB
c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"})
return
}
// 文件类型限制
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"video/mp4": true,
"video/webm": true,
"audio/mpeg": true,
"audio/ogg": true,
"application/pdf": true,
}
contentType := file.Header.Get("Content-Type")
if _, ok := allowedTypes[contentType]; !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"})
return
}
// Upload file
media, err := h.service.Upload(c.Request.Context(), file, userID.(int))
if err != nil {
log.Error().Err(err).Msg("Failed to upload media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"})
return
}
c.JSON(http.StatusCreated, media)
}
func (h *Handler) GetMedia(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
// Get media metadata
media, err := h.service.GetMedia(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media"})
return
}
// Get file content
reader, info, err := h.service.GetFile(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
return
}
defer reader.Close()
// Set response headers
c.Header("Content-Type", media.MimeType)
c.Header("Content-Length", fmt.Sprintf("%d", info.Size))
c.Header("Content-Disposition", fmt.Sprintf("inline; filename=%s", media.OriginalName))
// Stream the file
if _, err := io.Copy(c.Writer, reader); err != nil {
log.Error().Err(err).Msg("Failed to stream media file")
return
}
}
func (h *Handler) GetMediaFile(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
// Get file content
reader, info, err := h.service.GetFile(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get media file")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
return
}
defer reader.Close()
// Set response headers
c.Header("Content-Type", info.ContentType)
c.Header("Content-Length", fmt.Sprintf("%d", info.Size))
c.Header("Content-Disposition", fmt.Sprintf("inline; filename=%s", info.Name))
// Stream the file
if _, err := io.Copy(c.Writer, reader); err != nil {
log.Error().Err(err).Msg("Failed to stream media file")
return
}
}
func (h *Handler) DeleteMedia(c *gin.Context) {
// Get user ID from context (set by auth middleware)
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
return
}
if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil {
log.Error().Err(err).Msg("Failed to delete media")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"})
return
}
c.JSON(http.StatusNoContent, nil)
}

View file

@ -0,0 +1,524 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/storage"
"net/textproto"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)
type MediaHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *MediaHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{}, s.service)
s.router = gin.New()
}
func (s *MediaHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestMediaHandlerSuite(t *testing.T) {
suite.Run(t, new(MediaHandlerTestSuite))
}
func (s *MediaHandlerTestSuite) TestListMedia() {
testCases := []struct {
name string
query string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功列出媒体",
query: "?limit=10&offset=0",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "使用默认限制和偏移",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "列出媒体失败",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return(nil, errors.New("failed to list media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to list media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.ListMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response []*ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestUploadMedia() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功上传媒体",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
expectedFile := &multipart.FileHeader{
Filename: "test.jpg",
Size: int64(len("test content")),
Header: textproto.MIMEHeader{
"Content-Type": []string{"image/jpeg"},
},
}
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
s.Equal(expectedFile.Filename, f.Filename)
s.Equal(expectedFile.Size, f.Size)
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
return &ent.Media{ID: 1}, nil
})
},
expectedStatus: http.StatusCreated,
},
{
name: "未授权",
setupRequest: func() (*http.Request, error) {
req := httptest.NewRequest(http.MethodPost, "/media", nil)
return req, nil
},
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "上传失败",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
Return(nil, errors.New("failed to upload"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to upload media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, err := tc.setupRequest()
s.NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.UploadMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功获取媒体",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
Size: 11,
Name: "test.jpg",
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "获取媒体元数据失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(nil, errors.New("failed to get media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media",
},
{
name: "获取媒体文件失败",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media file",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 执行请求
s.handler.GetMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal("11", w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
s.Equal("test content", w.Body.String())
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedBody []byte
}{
{
name: "成功获取媒体文件",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
fileContent := "test file content"
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
Name: "test.jpg",
Size: int64(len(fileContent)),
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []byte("test file content"),
},
{
name: "无效的媒体ID",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "获取媒体文件失败",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup
req, err := tc.setupRequest()
s.Require().NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// Setup mock
tc.setupMock()
// Test
s.handler.GetMediaFile(c)
// Verify
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
s.Equal(tc.expectedBody, w.Body.Bytes())
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
}
})
}
}
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功删除媒体",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(nil)
},
expectedStatus: http.StatusNoContent,
},
{
name: "未授权",
mediaID: "1",
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "删除媒体失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(errors.New("failed to delete"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to delete media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.DeleteMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
}
})
}
}

View file

@ -0,0 +1,611 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type PostHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *PostHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
s.handler.RegisterRoutes(s.router)
}
func (s *PostHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestPostHandlerSuite(t *testing.T) {
suite.Run(t, new(PostHandlerTestSuite))
}
// Test cases for ListPosts
func (s *PostHandlerTestSuite) TestListPosts() {
categoryID := 1
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
},
},
{
name: "Success with category filter",
langCode: "en",
categoryID: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with pagination",
langCode: "en",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
},
},
{
name: "Service Error",
langCode: "en",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/posts"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Post
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, response)
}
})
}
}
// Test cases for GetPost
func (s *PostHandlerTestSuite) TestGetPost() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "zh", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "zh",
"title": "测试帖子",
"content_markdown": "测试内容",
"summary": "测试摘要",
},
},
},
},
},
{
name: "Service error",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/posts/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for CreatePost
func (s *PostHandlerTestSuite) TestCreatePost() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(&ent.Post{
ID: 1,
Status: "draft",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"status": "draft",
"edges": gin.H{
"contents": []gin.H{},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for AddPostContent
func (s *PostHandlerTestSuite) TestAddPostContent() {
testCases := []struct {
name string
postID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(&ent.PostContent{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
Edges: ent.PostContentEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
"meta_keywords": "test,keywords",
"meta_description": "Test meta description",
"edges": gin.H{},
},
},
{
name: "Invalid post ID",
postID: "invalid",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid post ID"},
},
{
name: "Invalid request body",
postID: "1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
},
{
name: "Service error",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add post content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -0,0 +1,192 @@
package middleware
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2"
"tss-rocks-be/internal/types"
)
// AccessLogConfig 访问日志配置
type AccessLogConfig struct {
// 是否启用控制台输出
EnableConsole bool `yaml:"enable_console"`
// 是否启用文件日志
EnableFile bool `yaml:"enable_file"`
// 日志文件路径
FilePath string `yaml:"file_path"`
// 日志格式 (json 或 text)
Format string `yaml:"format"`
// 日志级别
Level string `yaml:"level"`
// 日志轮转配置
Rotation struct {
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小MB
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
} `yaml:"rotation"`
}
// accessLogger 访问日志记录器
type accessLogger struct {
consoleLogger *zerolog.Logger
fileLogger *zerolog.Logger
logWriter *lumberjack.Logger
config *types.AccessLogConfig
}
// Close 关闭日志文件
func (l *accessLogger) Close() error {
if l.logWriter != nil {
return l.logWriter.Close()
}
return nil
}
// newAccessLogger 创建新的访问日志记录器
func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
var consoleLogger, fileLogger *zerolog.Logger
var logWriter *lumberjack.Logger
// 设置日志级别
level, err := zerolog.ParseLevel(config.Level)
if err != nil {
level = zerolog.InfoLevel
}
zerolog.SetGlobalLevel(level)
// 配置控制台日志
if config.EnableConsole {
logger := zerolog.New(os.Stdout).
With().
Timestamp().
Logger()
if config.Format == "text" {
logger = logger.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
}
consoleLogger = &logger
}
// 配置文件日志
if config.EnableFile {
// 确保日志目录存在
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
// 配置日志轮转
logWriter = &lumberjack.Logger{
Filename: config.FilePath,
MaxSize: config.Rotation.MaxSize, // MB
MaxAge: config.Rotation.MaxAge, // days
MaxBackups: config.Rotation.MaxBackups, // files
Compress: config.Rotation.Compress, // 是否压缩
LocalTime: config.Rotation.LocalTime, // 使用本地时间
}
logger := zerolog.New(logWriter).
With().
Timestamp().
Logger()
fileLogger = &logger
}
return &accessLogger{
consoleLogger: consoleLogger,
fileLogger: fileLogger,
logWriter: logWriter,
config: config,
}, nil
}
// logEvent 记录日志事件
func (l *accessLogger) logEvent(fields map[string]interface{}, msg string) {
if l.consoleLogger != nil {
event := l.consoleLogger.Info()
for k, v := range fields {
event = event.Interface(k, v)
}
event.Msg(msg)
}
if l.fileLogger != nil {
event := l.fileLogger.Info()
for k, v := range fields {
event = event.Interface(k, v)
}
event.Msg(msg)
}
}
// AccessLog 创建访问日志中间件
func AccessLog(config *types.AccessLogConfig) (gin.HandlerFunc, error) {
logger, err := newAccessLogger(config)
if err != nil {
return nil, err
}
return func(c *gin.Context) {
// 用于测试时关闭日志文件
if c == nil {
if err := logger.Close(); err != nil {
fmt.Printf("Error closing log file: %v\n", err)
}
return
}
start := time.Now()
requestID := uuid.New().String()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
// 设置请求ID到上下文
c.Set("request_id", requestID)
// 处理请求
c.Next()
// 计算处理时间
latency := time.Since(start)
// 获取用户ID如果已认证
var userID interface{}
if id, exists := c.Get("user_id"); exists {
userID = id
}
// 准备日志字段
fields := map[string]interface{}{
"request_id": requestID,
"method": c.Request.Method,
"path": path,
"query": query,
"ip": c.ClientIP(),
"user_agent": c.Request.UserAgent(),
"status": c.Writer.Status(),
"size": c.Writer.Size(),
"latency_ms": latency.Milliseconds(),
"component": "access_log",
}
if userID != nil {
fields["user_id"] = userID
}
// 如果有错误,添加到日志中
if len(c.Errors) > 0 {
fields["error"] = c.Errors.String()
}
// 记录日志
logger.logEvent(fields, fmt.Sprintf("%s %s", c.Request.Method, path))
}, nil
}

View file

@ -0,0 +1,238 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View file

@ -0,0 +1,82 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)
// AuthMiddleware creates a middleware for JWT authentication
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
c.Abort()
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
c.Abort()
return
}
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(jwtSecret), nil
})
if err != nil {
log.Error().Err(err).Msg("Failed to parse token")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
c.Set("user_id", claims["sub"])
c.Set("user_role", claims["role"])
c.Next()
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
}
}
// RoleMiddleware creates a middleware for role-based authorization
func RoleMiddleware(roles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
userRole, exists := c.Get("user_role")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"})
c.Abort()
return
}
roleStr, ok := userRole.(string)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"})
c.Abort()
return
}
for _, role := range roles {
if role == roleStr {
c.Next()
return
}
}
c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
c.Abort()
}
}

View file

@ -0,0 +1,217 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, _ := token.SignedString([]byte(secret))
return signedToken
}
func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret"
testCases := []struct {
name string
setupAuth func(*http.Request)
expectedStatus int
expectedBody map[string]string
checkUserData bool
expectedUserID string
expectedRole string
}{
{
name: "No Authorization header",
setupAuth: func(req *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header is required"},
},
{
name: "Invalid Authorization format",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidFormat")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
},
{
name: "Invalid token",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer invalid.token.here")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
{
name: "Valid token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "user123",
"role": "user",
"exp": time.Now().Add(time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusOK,
checkUserData: true,
expectedUserID: "user123",
expectedRole: "user",
},
{
name: "Expired token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "user123",
"role": "user",
"exp": time.Now().Add(-time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加认证中间件
router.Use(AuthMiddleware(jwtSecret))
// 测试路由
router.GET("/test", func(c *gin.Context) {
if tc.checkUserData {
userID, exists := c.Get("user_id")
assert.True(t, exists)
assert.Equal(t, tc.expectedUserID, userID)
role, exists := c.Get("user_role")
assert.True(t, exists)
assert.Equal(t, tc.expectedRole, role)
}
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
tc.setupAuth(req)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, tc.expectedBody, response)
}
})
}
}
func TestRoleMiddleware(t *testing.T) {
testCases := []struct {
name string
setupContext func(*gin.Context)
allowedRoles []string
expectedStatus int
expectedBody map[string]string
}{
{
name: "No user role",
setupContext: func(c *gin.Context) {
// 不设置用户角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User role not found"},
},
{
name: "Invalid role type",
setupContext: func(c *gin.Context) {
c.Set("user_role", 123) // 设置错误类型的角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user role type"},
},
{
name: "Insufficient permissions",
setupContext: func(c *gin.Context) {
c.Set("user_role", "user")
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden,
expectedBody: map[string]string{"error": "Insufficient permissions"},
},
{
name: "Allowed role",
setupContext: func(c *gin.Context) {
c.Set("user_role", "admin")
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK,
},
{
name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) {
c.Set("user_role", "editor")
},
allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加角色中间件
router.Use(func(c *gin.Context) {
tc.setupContext(c)
c.Next()
})
router.Use(RoleMiddleware(tc.allowedRoles...))
// 测试路由
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, tc.expectedBody, response)
}
})
}
}

View file

@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View file

@ -0,0 +1,76 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCORS(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
checkHeaders bool
}{
{
name: "Normal GET request",
method: "GET",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
{
name: "OPTIONS request",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
checkHeaders: true,
},
{
name: "POST request",
method: "POST",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加 CORS 中间件
router.Use(CORS())
// 添加测试路由
router.Any("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest(tc.method, "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证状态码
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.checkHeaders {
// 验证 CORS 头部
headers := rec.Header()
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
}
})
}
}

View file

@ -0,0 +1,107 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"tss-rocks-be/internal/types"
)
// ipLimiter IP限流器
type ipLimiter struct {
limiter *rate.Limiter
lastSeen time.Time
}
// rateLimiter 限流器管理器
type rateLimiter struct {
ips map[string]*ipLimiter
mu sync.RWMutex
config *types.RateLimitConfig
routes map[string]*rate.Limiter
}
// newRateLimiter 创建新的限流器
func newRateLimiter(config *types.RateLimitConfig) *rateLimiter {
// 初始化路由限流器
routes := make(map[string]*rate.Limiter)
for path, cfg := range config.RouteRates {
routes[path] = rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Burst)
}
rl := &rateLimiter{
ips: make(map[string]*ipLimiter),
config: config,
routes: routes,
}
// 启动清理过期IP限流器的goroutine
go rl.cleanupIPLimiters()
return rl
}
// cleanupIPLimiters 清理过期的IP限流器
func (rl *rateLimiter) cleanupIPLimiters() {
for {
time.Sleep(time.Hour) // 每小时清理一次
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
}
}
// getLimiter 获取IP限流器
func (rl *rateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.ips[ip]
if !exists {
limiter := rate.NewLimiter(rate.Limit(rl.config.IPRate), rl.config.IPBurst)
rl.ips[ip] = &ipLimiter{limiter: limiter, lastSeen: time.Now()}
return limiter
}
v.lastSeen = time.Now()
return v.limiter
}
// RateLimit 创建限流中间件
func RateLimit(config *types.RateLimitConfig) gin.HandlerFunc {
rl := newRateLimiter(config)
return func(c *gin.Context) {
// 检查路由限流
path := c.Request.URL.Path
if limiter, ok := rl.routes[path]; ok {
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "too many requests for this route",
})
c.Abort()
return
}
}
// 检查IP限流
limiter := rl.getLimiter(c.ClientIP())
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "too many requests from this IP",
})
c.Abort()
return
}
c.Next()
}
}

View file

@ -0,0 +1,207 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}

View file

@ -0,0 +1,110 @@
package middleware
import (
"context"
"fmt"
"net/http"
"tss-rocks-be/ent"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/auth"
"github.com/gin-gonic/gin"
)
// RequirePermission creates a middleware that checks if the user has the required permission
func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc {
return func(c *gin.Context) {
// Get user from context
userID, exists := c.Get(auth.UserIDKey)
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
})
return
}
// Get user with roles
user, err := client.User.Query().
Where(user.ID(userID.(int))).
WithRoles(func(q *ent.RoleQuery) {
q.WithPermissions()
}).
Only(context.Background())
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "User not found",
})
return
}
// Check if user has the required permission through any of their roles
hasPermission := false
for _, r := range user.Edges.Roles {
for _, p := range r.Edges.Permissions {
if p.Resource == resource && p.Action == action {
hasPermission = true
break
}
}
if hasPermission {
break
}
}
if !hasPermission {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("Missing required permission: %s:%s", resource, action),
})
return
}
c.Next()
}
}
// RequireRole creates a middleware that checks if the user has the required role
func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc {
return func(c *gin.Context) {
// Get user from context
userID, exists := c.Get(auth.UserIDKey)
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Unauthorized",
})
return
}
// Get user with roles
user, err := client.User.Query().
Where(user.ID(userID.(int))).
WithRoles().
Only(context.Background())
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "User not found",
})
return
}
// Check if user has the required role
hasRole := false
for _, r := range user.Edges.Roles {
if r.Name == roleName {
hasRole = true
break
}
}
if !hasRole {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("Required role: %s", roleName),
})
return
}
c.Next()
}
}

View file

@ -0,0 +1,159 @@
package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
"tss-rocks-be/internal/types"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
)
// ValidateUpload 创建文件上传验证中间件
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查是否是multipart/form-data请求
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Content-Type must be multipart/form-data",
})
c.Abort()
return
}
// 解析multipart表单
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("Failed to parse form: %v", err),
})
c.Abort()
return
}
form := c.Request.MultipartForm
if form == nil || form.File == nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "No file uploaded",
})
c.Abort()
return
}
// 遍历所有上传的文件
for _, files := range form.File {
for _, file := range files {
// 检查文件大小
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
})
c.Abort()
return
}
// 检查文件扩展名
ext := strings.ToLower(filepath.Ext(file.Filename))
if !contains(cfg.AllowedExtensions, ext) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File extension %s is not allowed", ext),
})
c.Abort()
return
}
// 打开文件
src, err := file.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to open file: %v", err),
})
c.Abort()
return
}
defer src.Close()
// 读取文件头部用于MIME类型检测
header := make([]byte, maxHeaderBytes)
n, err := src.Read(header)
if err != nil && err != io.EOF {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
header = header[:n]
// 检测MIME类型
contentType := http.DetectContentType(header)
if !contains(cfg.AllowedTypes, contentType) {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("File type %s is not allowed", contentType),
})
c.Abort()
return
}
// 将文件指针重置到开始位置
_, err = src.Seek(0, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将文件内容读入缓冲区
buf := &bytes.Buffer{}
_, err = io.Copy(buf, src)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to read file: %v", err),
})
c.Abort()
return
}
// 将验证过的文件内容和类型保存到上下文中
c.Set("validated_file_"+file.Filename, buf)
c.Set("validated_content_type_"+file.Filename, contentType)
}
}
c.Next()
}
}
// contains 检查切片中是否包含指定的字符串
func contains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
// GetValidatedFile 从上下文中获取验证过的文件内容
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
file, exists := c.Get("validated_file_" + filename)
if !exists {
return nil, "", false
}
contentType, exists := c.Get("validated_content_type_" + filename)
if !exists {
return nil, "", false
}
return file.(*bytes.Buffer), contentType.(string), true
}

View file

@ -0,0 +1,262 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}

View file

@ -0,0 +1,90 @@
package rbac
import (
"context"
"fmt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/role"
)
// DefaultPermissions defines the default permissions for each resource
var DefaultPermissions = map[string][]string{
"media": {"create", "read", "update", "delete", "list"},
"post": {"create", "read", "update", "delete", "list"},
"daily": {"create", "read", "update", "delete", "list"},
"user": {"create", "read", "update", "delete", "list"},
}
// DefaultRoles defines the default roles and their permissions
var DefaultRoles = map[string]map[string][]string{
"admin": DefaultPermissions,
"editor": {
"media": {"create", "read", "update", "list"},
"post": {"create", "read", "update", "list"},
"daily": {"create", "read", "update", "list"},
"user": {"read"},
},
"contributor": {
"media": {"read", "list"},
"post": {"read", "list"},
"daily": {"read", "list"},
},
}
// InitializeRBAC initializes the RBAC system with default roles and permissions
func InitializeRBAC(ctx context.Context, client *ent.Client) error {
// Create permissions
permissionMap := make(map[string]*ent.Permission)
for resource, actions := range DefaultPermissions {
for _, action := range actions {
permission, err := client.Permission.Create().
SetResource(resource).
SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
Save(ctx)
if err != nil {
return fmt.Errorf("failed creating permission: %w", err)
}
key := fmt.Sprintf("%s:%s", resource, action)
permissionMap[key] = permission
}
}
// Create roles with permissions
for roleName, permissions := range DefaultRoles {
roleCreate := client.Role.Create().
SetName(roleName).
SetDescription(fmt.Sprintf("Role for %s users", roleName))
// Add permissions to role
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
roleCreate.AddPermissions(permission)
}
}
}
if _, err := roleCreate.Save(ctx); err != nil {
return fmt.Errorf("failed creating role %s: %w", roleName, err)
}
}
return nil
}
// AssignRoleToUser assigns a role to a user
func AssignRoleToUser(ctx context.Context, client *ent.Client, userID int, roleName string) error {
role, err := client.Role.Query().
Where(role.Name(roleName)).
Only(ctx)
if err != nil {
return fmt.Errorf("failed querying role: %w", err)
}
return client.User.UpdateOneID(userID).
AddRoles(role).
Exec(ctx)
}

View file

@ -0,0 +1,98 @@
package rbac
import (
"context"
"testing"
"tss-rocks-be/ent/enttest"
"tss-rocks-be/ent/role"
_ "github.com/mattn/go-sqlite3"
)
func TestInitializeRBAC(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Test initialization
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Verify roles were created
for roleName := range DefaultRoles {
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
if err != nil {
t.Errorf("Role %s was not created: %v", roleName, err)
}
// Verify permissions for each role
perms, err := r.QueryPermissions().All(ctx)
if err != nil {
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
}
expectedPerms := DefaultRoles[roleName]
permCount := 0
for _, actions := range expectedPerms {
permCount += len(actions)
}
if len(perms) != permCount {
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
}
}
}
func TestAssignRoleToUser(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Initialize RBAC
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Create a test user
user, err := client.User.Create().
SetEmail("test@example.com").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx)
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
// Test assigning role to user
err = AssignRoleToUser(ctx, client, user.ID, "editor")
if err != nil {
t.Fatalf("Failed to assign role to user: %v", err)
}
// Verify role assignment
assignedRoles, err := user.QueryRoles().All(ctx)
if err != nil {
t.Fatalf("Failed to query user roles: %v", err)
}
if len(assignedRoles) != 1 {
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
}
if assignedRoles[0].Name != "editor" {
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
}
// Test assigning non-existent role
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
if err == nil {
t.Error("Expected error when assigning non-existent role, got nil")
}
}

View file

@ -0,0 +1,24 @@
package server
import (
"context"
"tss-rocks-be/ent"
"github.com/rs/zerolog/log"
)
func InitDatabase(ctx context.Context, driver, dsn string) (*ent.Client, error) {
client, err := ent.Open(driver, dsn)
if err != nil {
log.Error().Err(err).Msg("failed opening database connection")
return nil, err
}
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Error().Err(err).Msg("failed creating schema resources")
return nil, err
}
return client, nil
}

View file

@ -0,0 +1,64 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitDatabase(t *testing.T) {
tests := []struct {
name string
driver string
dsn string
wantErr bool
errContains string
}{
{
name: "success with sqlite3",
driver: "sqlite3",
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
},
{
name: "invalid driver",
driver: "invalid_driver",
dsn: "file:ent?mode=memory",
wantErr: true,
errContains: "unsupported driver",
},
{
name: "invalid dsn",
driver: "sqlite3",
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
wantErr: true,
errContains: "foreign_keys pragma is off",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, client)
} else {
require.NoError(t, err)
assert.NotNil(t, client)
// 测试数据库连接是否正常工作
err = client.Schema.Create(ctx)
assert.NoError(t, err)
// 清理
client.Close()
}
})
}
}

View file

@ -0,0 +1,31 @@
package server
import (
"context"
"entgo.io/ent/dialect/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
)
// NewEntClient creates a new ent client
func NewEntClient(cfg *config.Config) *ent.Client {
// TODO: Implement database connection based on config
// For now, we'll use SQLite for development
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
if err != nil {
log.Fatal().Err(err).Msg("Failed to connect to database")
}
// Create ent client
client := ent.NewClient(ent.Driver(db))
// Run the auto migration tool
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
return client
}

View file

@ -0,0 +1,40 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"tss-rocks-be/internal/config"
)
func TestNewEntClient(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
}{
{
name: "default sqlite3 config",
cfg: &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewEntClient(tt.cfg)
assert.NotNil(t, client)
// 验证客户端是否可以正常工作
err := client.Schema.Create(context.Background())
assert.NoError(t, err)
// 清理
client.Close()
})
}
}

View file

@ -0,0 +1,90 @@
package server
import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/handler"
"tss-rocks-be/internal/middleware"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/storage"
)
type Server struct {
config *config.Config
router *gin.Engine
handler *handler.Handler
server *http.Server
}
func New(cfg *config.Config, client *ent.Client) (*Server, error) {
// Initialize storage
store, err := storage.NewStorage(context.Background(), &cfg.Storage)
if err != nil {
return nil, fmt.Errorf("failed to initialize storage: %w", err)
}
// Initialize service
svc := service.NewService(client, store)
// Initialize RBAC
if err := svc.InitializeRBAC(context.Background()); err != nil {
return nil, fmt.Errorf("failed to initialize RBAC: %w", err)
}
// Initialize handler
h := handler.NewHandler(cfg, svc)
// Initialize router
router := gin.Default()
// Add CORS middleware if needed
router.Use(middleware.CORS())
// 添加全局中间件
router.Use(gin.Logger())
router.Use(gin.Recovery())
router.Use(middleware.RateLimit(&cfg.RateLimit))
// 添加访问日志中间件
accessLog, err := middleware.AccessLog(&cfg.AccessLog)
if err != nil {
return nil, fmt.Errorf("failed to initialize access log: %w", err)
}
router.Use(accessLog)
// 为上传路由添加文件验证中间件
router.POST("/api/v1/media/upload", middleware.ValidateUpload(&cfg.Storage.Upload))
// Register routes
h.RegisterRoutes(router)
return &Server{
config: cfg,
router: router,
handler: h,
}, nil
}
func (s *Server) Start() error {
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
s.server = &http.Server{
Addr: addr,
Handler: s.router,
}
log.Info().Msgf("Starting server on %s", addr)
return s.server.ListenAndServe()
}
func (s *Server) Shutdown(ctx context.Context) error {
if s.server != nil {
return s.server.Shutdown(ctx)
}
return nil
}

View file

@ -0,0 +1,220 @@
package server
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/types"
"tss-rocks-be/ent/enttest"
)
func TestNew(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8080,
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
Upload: types.UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/v1/upload": {Rate: 10, Burst: 20},
},
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "testdata/access.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 3,
Compress: true,
LocalTime: true,
},
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 测试服务器初始化
s, err := New(cfg, client)
require.NoError(t, err)
assert.NotNil(t, s)
assert.NotNil(t, s.router)
assert.NotNil(t, s.handler)
assert.Equal(t, cfg, s.config)
}
func TestNew_StorageError(t *testing.T) {
// 创建一个无效的存储配置
cfg := &config.Config{
Storage: config.StorageConfig{
Type: "invalid_type", // 使用无效的存储类型
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
s, err := New(cfg, client)
assert.Error(t, err)
assert.Nil(t, s)
assert.Contains(t, err.Error(), "failed to initialize storage")
}
func TestServer_StartAndShutdown(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 0, // 使用随机端口
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
Format: "json",
Level: "info",
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 初始化服务器
s, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 在 goroutine 中启动服务器
go func() {
err := s.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 测试关闭服务器
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = s.Shutdown(ctx)
assert.NoError(t, err)
// 检查服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
}
func TestServer_StartError(t *testing.T) {
// 创建一个配置,使用已经被占用的端口来触发错误
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8899, // 使用固定端口以便测试
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 创建第一个服务器实例
s1, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 启动第一个服务器
go func() {
err := s1.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 尝试在同一端口启动第二个服务器,应该会失败
s2, err := New(cfg, client)
require.NoError(t, err)
err = s2.Start()
assert.Error(t, err)
// 清理
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 关闭第一个服务器
err = s1.Shutdown(ctx)
assert.NoError(t, err)
// 检查第一个服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
// 关闭第二个服务器
err = s2.Shutdown(ctx)
assert.NoError(t, err)
}
func TestServer_ShutdownWithNilServer(t *testing.T) {
s := &Server{}
err := s.Shutdown(context.Background())
assert.NoError(t, err)
}

View file

@ -0,0 +1,892 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"mime/multipart"
"sort"
"strconv"
"strings"
"regexp"
"tss-rocks-be/ent"
"tss-rocks-be/ent/category"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/ent/contributor"
"tss-rocks-be/ent/contributorsociallink"
"tss-rocks-be/ent/daily"
"tss-rocks-be/ent/dailycontent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/post"
"tss-rocks-be/ent/postcontent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
// Error definitions
var (
ErrUnauthorized = errors.New("unauthorized")
)
// openFile is a variable that holds the Open method of multipart.FileHeader
// This allows us to mock it in tests
var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) {
return fh.Open()
}
type serviceImpl struct {
client *ent.Client
storage storage.Storage
}
// NewService creates a new Service instance
func NewService(client *ent.Client, storage storage.Storage) Service {
return &serviceImpl{
client: client,
storage: storage,
}
}
// User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
// Hash the password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Add the user role by default
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
AddRoles(userRole)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query().
Where(user.EmailEQ(email)).
Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email)
}
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// Category operations
func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) {
return s.client.Category.Create().Save(ctx)
}
func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.CategoryContent.Create().
SetCategoryID(categoryID).
SetLanguageCode(languageCode).
SetName(name).
SetDescription(description).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) {
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.Category.Query().
Where(
category.HasContentsWith(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugEQ(slug),
),
),
).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(categorycontent.LanguageCodeEQ(languageCode))
}).
Only(ctx)
}
func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
// 转换语言代码
var languageCode categorycontent.LanguageCode
switch langCode {
case "en":
languageCode = categorycontent.LanguageCodeEN
case "zh-Hans":
languageCode = categorycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = categorycontent.LanguageCodeZH_HANT
default:
// 不支持的语言代码返回空列表而不是错误
return []*ent.Category{}, nil
}
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
contents, err := s.client.CategoryContent.Query().
Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
).
WithCategory().
All(ctx)
if err != nil {
return nil, err
}
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
categoryMap := make(map[int]*ent.Category)
for _, content := range contents {
if content.Edges.Category != nil {
categoryMap[content.Edges.Category.ID] = content.Edges.Category
}
}
// 将 map 转换为有序的切片
var categories []*ent.Category
for _, cat := range categoryMap {
// 重新查询分类以获取完整的关联数据
c, err := s.client.Category.Query().
Where(category.ID(cat.ID)).
WithContents(func(q *ent.CategoryContentQuery) {
q.Where(
categorycontent.And(
categorycontent.LanguageCodeEQ(languageCode),
categorycontent.SlugHasPrefix("category-list-"),
),
)
}).
Only(ctx)
if err != nil {
return nil, err
}
categories = append(categories, c)
}
// 按 ID 排序以保持结果稳定
sort.Slice(categories, func(i, j int) bool {
return categories[i].ID < categories[j].ID
})
return categories, nil
}
// Daily operations
func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) {
_, err := s.client.Daily.Create().
SetID(id).
SetCategoryID(categoryID).
SetImageURL(imageURL).
Save(ctx)
if err != nil {
return nil, err
}
// 加载 Category Edge
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
Only(ctx)
}
func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
return s.client.DailyContent.Create().
SetDailyID(dailyID).
SetLanguageCode(languageCode).
SetQuote(quote).
Save(ctx)
}
func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) {
return s.client.Daily.Query().
Where(daily.IDEQ(id)).
WithCategory().
WithContents().
Only(ctx)
}
func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) {
var languageCode dailycontent.LanguageCode
switch langCode {
case "en":
languageCode = dailycontent.LanguageCodeEN
case "zh-Hans":
languageCode = dailycontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = dailycontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
query := s.client.Daily.Query().
WithContents(func(q *ent.DailyContentQuery) {
if langCode != "" {
q.Where(dailycontent.LanguageCodeEQ(languageCode))
}
}).
WithCategory()
if categoryID != nil {
query.Where(daily.HasCategoryWith(category.ID(*categoryID)))
}
query.Order(ent.Desc(daily.FieldCreatedAt))
if limit > 0 {
query.Limit(limit)
}
if offset > 0 {
query.Offset(offset)
}
return query.All(ctx)
}
// Media operations
func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
return s.client.Media.Query().
Order(ent.Desc("created_at")).
Limit(limit).
Offset(offset).
All(ctx)
}
func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
// Open the uploaded file
src, err := openFile(file)
if err != nil {
return nil, err
}
defer src.Close()
// Save the file to storage
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
if err != nil {
return nil, err
}
// Create media record
return s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(file.Filename).
SetMimeType(fileInfo.ContentType).
SetSize(fileInfo.Size).
SetURL(fileInfo.URL).
SetCreatedBy(strconv.Itoa(userID)).
Save(ctx)
}
func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) {
return s.client.Media.Get(ctx, id)
}
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.GetMedia(ctx, id)
if err != nil {
return nil, nil, err
}
return s.storage.Get(ctx, media.StorageID)
}
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
media, err := s.GetMedia(ctx, id)
if err != nil {
return err
}
// Check ownership
if media.CreatedBy != strconv.Itoa(userID) {
return ErrUnauthorized
}
// Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
return err
}
// Delete from database
return s.client.Media.DeleteOne(media).Exec(ctx)
}
// Post operations
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
var postStatus post.Status
switch status {
case "draft":
postStatus = post.StatusDraft
case "published":
postStatus = post.StatusPublished
case "archived":
postStatus = post.StatusArchived
default:
return nil, fmt.Errorf("invalid status: %s", status)
}
// Generate a random slug
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
return s.client.Post.Create().
SetStatus(postStatus).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Get the post first to check if it exists
post, err := s.client.Post.Get(ctx, postID)
if err != nil {
return nil, fmt.Errorf("failed to get post: %w", err)
}
// Generate slug from title
var slug string
if langCode == "en" {
// For English titles, convert to lowercase and replace spaces with dashes
slug = strings.ToLower(strings.ReplaceAll(title, " ", "-"))
// Remove all non-alphanumeric characters except dashes
slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "")
// Ensure slug is not empty and has minimum length
if slug == "" || len(slug) < 4 {
slug = fmt.Sprintf("post-%s", uuid.NewString()[:8])
}
} else {
// For Chinese titles, use the title as is
slug = title
}
return s.client.PostContent.Create().
SetPost(post).
SetLanguageCode(languageCode).
SetTitle(title).
SetContentMarkdown(content).
SetSummary(summary).
SetMetaKeywords(metaKeywords).
SetMetaDescription(metaDescription).
SetSlug(slug).
Save(ctx)
}
func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// Find posts that have content with the given slug and language code
posts, err := s.client.Post.Query().
Where(
post.And(
post.StatusEQ(post.StatusPublished),
post.HasContentsWith(
postcontent.And(
postcontent.LanguageCodeEQ(languageCode),
postcontent.SlugEQ(slug),
),
),
),
).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
if len(posts) == 0 {
return nil, fmt.Errorf("post not found")
}
if len(posts) > 1 {
return nil, fmt.Errorf("multiple posts found with the same slug")
}
return posts[0], nil
}
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
var languageCode postcontent.LanguageCode
switch langCode {
case "en":
languageCode = postcontent.LanguageCodeEN
case "zh-Hans":
languageCode = postcontent.LanguageCodeZH_HANS
case "zh-Hant":
languageCode = postcontent.LanguageCodeZH_HANT
default:
return nil, fmt.Errorf("unsupported language code: %s", langCode)
}
// First find all post IDs that have content in the given language
query := s.client.PostContent.Query().
Where(postcontent.LanguageCodeEQ(languageCode)).
QueryPost().
Where(post.StatusEQ(post.StatusPublished))
// Add category filter if provided
if categoryID != nil {
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
}
// Get unique post IDs
postIDs, err := query.
Order(ent.Desc(post.FieldCreatedAt)).
IDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get post IDs: %w", err)
}
// Remove duplicates while preserving order
seen := make(map[int]bool)
uniqueIDs := make([]int, 0, len(postIDs))
for _, id := range postIDs {
if !seen[id] {
seen[id] = true
uniqueIDs = append(uniqueIDs, id)
}
}
postIDs = uniqueIDs
if len(postIDs) == 0 {
return []*ent.Post{}, nil
}
// If no category filter is applied, only take the latest 5 posts
if categoryID == nil && len(postIDs) > 5 {
postIDs = postIDs[:5]
}
// Apply pagination
if offset >= len(postIDs) {
return []*ent.Post{}, nil
}
// If limit is 0, set it to the length of postIDs
if limit == 0 {
limit = len(postIDs)
}
// Adjust limit if it would exceed total
if offset+limit > len(postIDs) {
limit = len(postIDs) - offset
}
// Get the paginated post IDs
paginatedIDs := postIDs[offset : offset+limit]
// Get the posts with their contents
posts, err := s.client.Post.Query().
Where(post.IDIn(paginatedIDs...)).
WithContents(func(q *ent.PostContentQuery) {
q.Where(postcontent.LanguageCodeEQ(languageCode))
}).
WithCategory().
Order(ent.Desc(post.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get posts: %w", err)
}
// Sort posts by ID to match the order of postIDs
sort.Slice(posts, func(i, j int) bool {
// Find index of each post ID in postIDs
var iIndex, jIndex int
for idx, id := range paginatedIDs {
if posts[i].ID == id {
iIndex = idx
}
if posts[j].ID == id {
jIndex = idx
}
}
return iIndex < jIndex
})
return posts, nil
}
// Contributor operations
func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) {
builder := s.client.Contributor.Create().
SetName(name)
if avatarURL != nil {
builder.SetAvatarURL(*avatarURL)
}
if bio != nil {
builder.SetBio(*bio)
}
contributor, err := builder.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) {
// 验证贡献者是否存在
contributor, err := s.client.Contributor.Get(ctx, contributorID)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
// 创建社交链接
link, err := s.client.ContributorSocialLink.Create().
SetContributor(contributor).
SetType(contributorsociallink.Type(linkType)).
SetName(name).
SetValue(value).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create social link: %w", err)
}
return link, nil
}
func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) {
contributor, err := s.client.Contributor.Query().
Where(contributor.ID(id)).
WithSocialLinks().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get contributor: %w", err)
}
return contributor, nil
}
func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) {
contributors, err := s.client.Contributor.Query().
WithSocialLinks().
Order(ent.Asc(contributor.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list contributors: %w", err)
}
return contributors, nil
}
// RBAC operations
func (s *serviceImpl) InitializeRBAC(ctx context.Context) error {
// Create roles if they don't exist
adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx)
if ent.IsConstraintError(err) {
adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create admin role: %w", err)
}
editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx)
if ent.IsConstraintError(err) {
editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create editor role: %w", err)
}
userRole, err := s.client.Role.Create().SetName("user").Save(ctx)
if ent.IsConstraintError(err) {
userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create user role: %w", err)
}
// Define permissions
permissions := []struct {
role *ent.Role
resource string
actions []string
}{
// Admin permissions (full access)
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
{adminRole, "media", []string{"create", "read", "update", "delete"}},
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
// Editor permissions (can create and manage content)
{editorRole, "media", []string{"create", "read", "update", "delete"}},
{editorRole, "posts", []string{"create", "read", "update", "delete"}},
{editorRole, "categories", []string{"read"}},
{editorRole, "contributors", []string{"read"}},
{editorRole, "dailies", []string{"create", "read", "update", "delete"}},
// User permissions (read-only access)
{userRole, "media", []string{"read"}},
{userRole, "posts", []string{"read"}},
{userRole, "categories", []string{"read"}},
{userRole, "contributors", []string{"read"}},
{userRole, "dailies", []string{"read"}},
}
// Create permissions for each role
for _, p := range permissions {
for _, action := range p.actions {
perm, err := s.client.Permission.Create().
SetResource(p.resource).
SetAction(action).
Save(ctx)
if ent.IsConstraintError(err) {
perm, err = s.client.Permission.Query().
Where(
permission.ResourceEQ(p.resource),
permission.ActionEQ(action),
).
Only(ctx)
}
if err != nil {
return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err)
}
// Add permission to role
err = s.client.Role.UpdateOne(p.role).
AddPermissions(perm).
Exec(ctx)
if err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err)
}
}
}
return nil
}
func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error {
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx)
}
func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error {
// Don't allow removing the user role
if roleName == "user" {
return errors.New("cannot remove user role")
}
user, err := s.client.User.Get(ctx, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
if err != nil {
return fmt.Errorf("failed to get role: %w", err)
}
return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx)
}
func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles().
Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user.Edges.Roles, nil
}
func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) {
user, err := s.client.User.Query().
Where(user.ID(userID)).
WithRoles(func(q *ent.RoleQuery) {
q.WithPermissions()
}).
Only(ctx)
if err != nil {
return false, fmt.Errorf("failed to get user: %w", err)
}
parts := strings.Split(permission, ":")
if len(parts) != 2 {
return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission)
}
resource, action := parts[0], parts[1]
for _, r := range user.Edges.Roles {
for _, p := range r.Edges.Permissions {
if p.Resource == resource && p.Action == action {
return true, nil
}
}
}
return false, nil
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,179 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"strings"
"path/filepath"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/pkg/imageutil"
)
type MediaService interface {
// Upload uploads a new file and creates a media record
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
// Get retrieves a media file by ID
Get(ctx context.Context, id int) (*ent.Media, error)
// Delete deletes a media file
Delete(ctx context.Context, id int, userID int) error
// List lists media files with pagination
List(ctx context.Context, limit, offset int) ([]*ent.Media, error)
// GetFile gets the file content and info
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
}
type mediaService struct {
client *ent.Client
storage storage.Storage
}
func NewMediaService(client *ent.Client, storage storage.Storage) MediaService {
return &mediaService{
client: client,
storage: storage,
}
}
// isValidFilename checks if a filename is valid
func isValidFilename(filename string) bool {
// Check for illegal characters
if strings.Contains(filename, "../") ||
strings.Contains(filename, "./") ||
strings.Contains(filename, "\\") {
return false
}
return true
}
func (s *mediaService) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
// Validate filename
if !isValidFilename(file.Filename) {
return nil, fmt.Errorf("invalid filename: %s", file.Filename)
}
// Open the file
src, err := file.Open()
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer src.Close()
// Read file content for processing
fileBytes, err := io.ReadAll(src)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
contentType := file.Header.Get("Content-Type")
filename := file.Filename
var processedBytes []byte
// Process image if it's an image file
if imageutil.IsImageFormat(contentType) {
opts := imageutil.DefaultOptions()
processedBytes, err = imageutil.ProcessImage(bytes.NewReader(fileBytes), opts)
if err != nil {
return nil, fmt.Errorf("failed to process image: %w", err)
}
// Update content type and filename for WebP
contentType = "image/webp"
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + ".webp"
} else {
processedBytes = fileBytes
}
// Save the processed file
fileInfo, err := s.storage.Save(ctx, filename, contentType, bytes.NewReader(processedBytes))
if err != nil {
return nil, fmt.Errorf("failed to save file: %w", err)
}
// Create media record in database
media, err := s.client.Media.Create().
SetStorageID(fileInfo.ID).
SetOriginalName(filename).
SetMimeType(contentType).
SetSize(int64(len(processedBytes))).
SetURL(fmt.Sprintf("/api/media/%s", fileInfo.ID)).
SetCreatedBy(fmt.Sprint(userID)).
Save(ctx)
if err != nil {
// Try to cleanup the stored file if database operation fails
_ = s.storage.Delete(ctx, fileInfo.ID)
return nil, fmt.Errorf("failed to create media record: %w", err)
}
return media, nil
}
func (s *mediaService) Get(ctx context.Context, id int) (*ent.Media, error) {
media, err := s.client.Media.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("media not found: %d", id)
}
return nil, fmt.Errorf("failed to get media: %w", err)
}
return media, nil
}
func (s *mediaService) Delete(ctx context.Context, id int, userID int) error {
media, err := s.Get(ctx, id)
if err != nil {
return err
}
// Check ownership
if media.CreatedBy != fmt.Sprintf("%d", userID) {
return fmt.Errorf("unauthorized to delete media")
}
// Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
return fmt.Errorf("failed to delete file from storage: %w", err)
}
// Delete from database
if err := s.client.Media.DeleteOne(media).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete media record: %w", err)
}
return nil
}
func (s *mediaService) List(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
media, err := s.client.Media.Query().
Order(ent.Desc("created_at")).
Limit(limit).
Offset(offset).
All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list media: %w", err)
}
return media, nil
}
func (s *mediaService) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
media, err := s.Get(ctx, id)
if err != nil {
return nil, nil, err
}
reader, info, err := s.storage.Get(ctx, media.StorageID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get file from storage: %w", err)
}
return reader, info, nil
}

View file

@ -0,0 +1,332 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/textproto"
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/storage/mock"
"tss-rocks-be/internal/testutil"
"bou.ke/monkey"
)
type MediaServiceTestSuite struct {
suite.Suite
ctx context.Context
client *ent.Client
storage *mock.MockStorage
ctrl *gomock.Controller
svc MediaService
}
func (s *MediaServiceTestSuite) SetupTest() {
s.ctx = context.Background()
s.client = testutil.NewTestClient()
require.NotNil(s.T(), s.client)
s.ctrl = gomock.NewController(s.T())
s.storage = mock.NewMockStorage(s.ctrl)
s.svc = NewMediaService(s.client, s.storage)
// 清理数据库
_, err := s.client.Media.Delete().Exec(s.ctx)
require.NoError(s.T(), err)
}
func (s *MediaServiceTestSuite) TearDownTest() {
s.ctrl.Finish()
s.client.Close()
}
func TestMediaServiceSuite(t *testing.T) {
suite.Run(t, new(MediaServiceTestSuite))
}
type mockFileHeader struct {
filename string
contentType string
size int64
content []byte
}
func (h *mockFileHeader) Open() (multipart.File, error) {
return newMockMultipartFile(h.content), nil
}
func (h *mockFileHeader) Filename() string {
return h.filename
}
func (h *mockFileHeader) Size() int64 {
return h.size
}
func (h *mockFileHeader) Header() textproto.MIMEHeader {
header := make(textproto.MIMEHeader)
header.Set("Content-Type", h.contentType)
return header
}
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
header := &multipart.FileHeader{
Filename: filename,
Header: make(textproto.MIMEHeader),
Size: int64(len(content)),
}
header.Header.Set("Content-Type", contentType)
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
return newMockMultipartFile(content), nil
})
return header
}
func (s *MediaServiceTestSuite) TestUpload() {
testCases := []struct {
name string
filename string
contentType string
content []byte
setupMock func()
wantErr bool
errMsg string
}{
{
name: "Upload text file",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
content, err := io.ReadAll(reader)
s.Require().NoError(err)
s.Equal([]byte("test content"), content)
return &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: int64(len(content)),
}, nil
})
},
wantErr: false,
},
{
name: "Invalid filename",
filename: "../test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {},
wantErr: true,
errMsg: "invalid filename",
},
{
name: "Storage error",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
Return(nil, fmt.Errorf("storage error"))
},
wantErr: true,
errMsg: "storage error",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create test file
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
// Add debug output
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
// Test upload
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
// Add debug output
if err != nil {
s.T().Logf("Upload error: %v", err)
}
if tc.wantErr {
s.Require().Error(err)
s.Contains(err.Error(), tc.errMsg)
return
}
s.Require().NoError(err)
s.NotNil(media)
s.Equal(tc.filename, media.OriginalName)
s.Equal(tc.contentType, media.MimeType)
s.Equal(int64(len(tc.content)), media.Size)
s.Equal("1", media.CreatedBy)
})
}
}
func (s *MediaServiceTestSuite) TestGet() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test get existing media
result, err := s.svc.Get(s.ctx, media.ID)
s.Require().NoError(err)
s.Equal(media.ID, result.ID)
s.Equal(media.OriginalName, result.OriginalName)
// Test get non-existing media
_, err = s.svc.Get(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "media not found")
}
func (s *MediaServiceTestSuite) TestDelete() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test delete by unauthorized user
err = s.svc.Delete(s.ctx, media.ID, 2)
s.Require().Error(err)
s.Contains(err.Error(), "unauthorized")
// Test delete by owner
s.storage.EXPECT().
Delete(gomock.Any(), "test-id").
Return(nil)
err = s.svc.Delete(s.ctx, media.ID, 1)
s.Require().NoError(err)
// Verify media is deleted
_, err = s.svc.Get(s.ctx, media.ID)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestList() {
// Create test media
for i := 0; i < 5; i++ {
_, err := s.client.Media.Create().
SetStorageID(fmt.Sprintf("test-id-%d", i)).
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
SetMimeType("text/plain").
SetSize(12).
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
}
// Test list with limit and offset
media, err := s.svc.List(s.ctx, 3, 1)
s.Require().NoError(err)
s.Len(media, 3)
}
func (s *MediaServiceTestSuite) TestGetFile() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Mock storage.Get
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
mockFileInfo := &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: 12,
}
s.storage.EXPECT().
Get(gomock.Any(), "test-id").
Return(mockReader, mockFileInfo, nil)
// Test get file
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
s.Require().NoError(err)
s.NotNil(reader)
s.Equal(mockFileInfo, info)
// Test get non-existing file
_, _, err = s.svc.GetFile(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestIsValidFilename() {
testCases := []struct {
name string
filename string
want bool
}{
{
name: "Valid filename",
filename: "test.txt",
want: true,
},
{
name: "Invalid filename with ../",
filename: "../test.txt",
want: false,
},
{
name: "Invalid filename with ./",
filename: "./test.txt",
want: false,
},
{
name: "Invalid filename with backslash",
filename: "test\\file.txt",
want: false,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
got := isValidFilename(tc.filename)
s.Equal(tc.want, got)
})
}
}

View file

@ -0,0 +1,3 @@
package mock
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock

View file

@ -0,0 +1,105 @@
package service
import (
"context"
"fmt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/role"
)
type RBACService struct {
client *ent.Client
}
func NewRBACService(client *ent.Client) *RBACService {
return &RBACService{
client: client,
}
}
// InitializeRBAC sets up the initial RBAC configuration
func (s *RBACService) InitializeRBAC(ctx context.Context) error {
// Create admin role if it doesn't exist
adminRole, err := s.client.Role.Query().
Where(role.Name("admin")).
Only(ctx)
if ent.IsNotFound(err) {
adminRole, err = s.client.Role.Create().
SetName("admin").
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create admin role: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to query admin role: %w", err)
}
// Create editor role if it doesn't exist
editorRole, err := s.client.Role.Query().
Where(role.Name("editor")).
Only(ctx)
if ent.IsNotFound(err) {
editorRole, err = s.client.Role.Create().
SetName("editor").
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create editor role: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to query editor role: %w", err)
}
// Define permissions
permissions := []struct {
role *ent.Role
resource string
actions []string
}{
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
{adminRole, "media", []string{"create", "read", "update", "delete"}},
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
{editorRole, "media", []string{"create", "read", "update"}},
{editorRole, "posts", []string{"create", "read", "update"}},
{editorRole, "categories", []string{"read"}},
{editorRole, "contributors", []string{"read"}},
{editorRole, "dailies", []string{"create", "read", "update"}},
}
// Create permissions for each role
for _, p := range permissions {
for _, action := range p.actions {
// Check if permission already exists
exists, err := s.client.Permission.Query().
Where(
permission.Resource(p.resource),
permission.Action(action),
permission.HasRolesWith(role.ID(p.role.ID)),
).
Exist(ctx)
if err != nil {
return fmt.Errorf("failed to query permission: %w", err)
}
if !exists {
// Create permission and associate it with the role
_, err = s.client.Permission.Create().
SetResource(p.resource).
SetAction(action).
AddRoles(p.role).
Save(ctx)
if err != nil {
return fmt.Errorf("failed to create permission: %w", err)
}
}
}
}
return nil
}

View file

@ -0,0 +1,59 @@
package service
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
import (
"context"
"io"
"mime/multipart"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
)
// Service interface defines all business logic operations
type Service interface {
// User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
// Category operations
CreateCategory(ctx context.Context) (*ent.Category, error)
AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error)
GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error)
ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
// Post operations
CreatePost(ctx context.Context, status string) (*ent.Post, error)
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
// Contributor operations
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error)
ListContributors(ctx context.Context) ([]*ent.Contributor, error)
// Daily operations
CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error)
AddDailyContent(ctx context.Context, dailyID string, langCode, quote string) (*ent.DailyContent, error)
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations
AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
}

View file

@ -0,0 +1,66 @@
package storage
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"tss-rocks-be/internal/config"
)
// NewStorage creates a new storage instance based on the configuration
func NewStorage(ctx context.Context, cfg *config.StorageConfig) (Storage, error) {
switch cfg.Type {
case "local":
return NewLocalStorage(cfg.Local.RootDir)
case "s3":
// Load AWS configuration
var s3Client *s3.Client
if cfg.S3.Endpoint != "" {
// Custom endpoint (e.g., MinIO)
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{
URL: cfg.S3.Endpoint,
}, nil
})
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(cfg.S3.Region),
awsconfig.WithEndpointResolverWithOptions(customResolver),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.S3.AccessKeyID,
cfg.S3.SecretAccessKey,
"",
)),
)
if err != nil {
return nil, fmt.Errorf("unable to load AWS SDK config: %w", err)
}
s3Client = s3.NewFromConfig(awsCfg)
} else {
// Standard AWS S3
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(cfg.S3.Region),
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.S3.AccessKeyID,
cfg.S3.SecretAccessKey,
"",
)),
)
if err != nil {
return nil, fmt.Errorf("unable to load AWS SDK config: %w", err)
}
s3Client = s3.NewFromConfig(awsCfg)
}
return NewS3Storage(s3Client, cfg.S3.Bucket, cfg.S3.CustomURL, cfg.S3.ProxyS3), nil
default:
return nil, fmt.Errorf("unsupported storage type: %s", cfg.Type)
}
}

View file

@ -0,0 +1,260 @@
package storage
import (
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
)
type LocalStorage struct {
rootDir string
metaDir string
}
func NewLocalStorage(rootDir string) (*LocalStorage, error) {
// Ensure the root directory exists
if err := os.MkdirAll(rootDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create root directory: %w", err)
}
// Create metadata directory
metaDir := filepath.Join(rootDir, ".meta")
if err := os.MkdirAll(metaDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create metadata directory: %w", err)
}
return &LocalStorage{
rootDir: rootDir,
metaDir: metaDir,
}, nil
}
func (s *LocalStorage) generateID() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error {
metaPath := filepath.Join(s.metaDir, id+".meta")
file, err := os.Create(metaPath)
if err != nil {
return fmt.Errorf("failed to create metadata file: %w", err)
}
defer file.Close()
data := fmt.Sprintf("%s\n%s", info.Name, info.ContentType)
if _, err := file.WriteString(data); err != nil {
return fmt.Errorf("failed to write metadata: %w", err)
}
return nil
}
func (s *LocalStorage) loadMetadata(id string) (string, string, error) {
metaPath := filepath.Join(s.metaDir, id+".meta")
data, err := os.ReadFile(metaPath)
if err != nil {
if os.IsNotExist(err) {
return id, "", nil // Return ID as name if metadata doesn't exist
}
return "", "", fmt.Errorf("failed to read metadata: %w", err)
}
parts := bytes.Split(data, []byte("\n"))
name := string(parts[0])
contentType := ""
if len(parts) > 1 {
contentType = string(parts[1])
}
return name, contentType, nil
}
func (s *LocalStorage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
if reader == nil {
return nil, fmt.Errorf("reader cannot be nil")
}
// Generate a unique ID for the file
id, err := s.generateID()
if err != nil {
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
// Create the file path
filePath := filepath.Join(s.rootDir, id)
// Create the file
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
// Copy the content
size, err := io.Copy(file, reader)
if err != nil {
// Clean up the file if there's an error
os.Remove(filePath)
return nil, fmt.Errorf("failed to write file content: %w", err)
}
now := time.Now()
info := &FileInfo{
ID: id,
Name: name,
Size: size,
ContentType: contentType,
CreatedAt: now,
UpdatedAt: now,
URL: fmt.Sprintf("/api/media/file/%s", id),
}
// Save metadata
if err := s.saveMetadata(id, info); err != nil {
os.Remove(filePath)
return nil, err
}
return info, nil
}
func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
filePath := filepath.Join(s.rootDir, id)
// Open the file
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil, fmt.Errorf("file not found: %s", id)
}
return nil, nil, fmt.Errorf("failed to open file: %w", err)
}
// Get file info
stat, err := file.Stat()
if err != nil {
file.Close()
return nil, nil, fmt.Errorf("failed to get file info: %w", err)
}
// Load metadata
name, contentType, err := s.loadMetadata(id)
if err != nil {
file.Close()
return nil, nil, err
}
info := &FileInfo{
ID: id,
Name: name,
Size: stat.Size(),
ContentType: contentType,
CreatedAt: stat.ModTime(),
UpdatedAt: stat.ModTime(),
URL: fmt.Sprintf("/api/media/file/%s", id),
}
return file, info, nil
}
func (s *LocalStorage) Delete(ctx context.Context, id string) error {
filePath := filepath.Join(s.rootDir, id)
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("file not found: %s", id)
}
return fmt.Errorf("failed to delete file: %w", err)
}
// Remove metadata
metaPath := filepath.Join(s.metaDir, id+".meta")
if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove metadata: %w", err)
}
return nil
}
func (s *LocalStorage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
var files []*FileInfo
var count int
err := filepath.Walk(s.rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories and metadata directory
if info.IsDir() || path == s.metaDir {
if path == s.metaDir {
return filepath.SkipDir
}
return nil
}
// Get the file ID (basename of the path)
id := filepath.Base(path)
// Load metadata to get the original name
name, contentType, err := s.loadMetadata(id)
if err != nil {
return err
}
// Skip files that don't match the prefix
if prefix != "" && !strings.HasPrefix(name, prefix) {
return nil
}
// Skip files before offset
if count < offset {
count++
return nil
}
// Stop if we've reached the limit
if limit > 0 && len(files) >= limit {
return filepath.SkipDir
}
files = append(files, &FileInfo{
ID: id,
Name: name,
Size: info.Size(),
ContentType: contentType,
CreatedAt: info.ModTime(),
UpdatedAt: info.ModTime(),
URL: fmt.Sprintf("/api/media/file/%s", id),
})
count++
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to list files: %w", err)
}
return files, nil
}
func (s *LocalStorage) Exists(ctx context.Context, id string) (bool, error) {
filePath := filepath.Join(s.rootDir, id)
_, err := os.Stat(filePath)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check file existence: %w", err)
}

View file

@ -0,0 +1,154 @@
package storage
import (
"bytes"
"context"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLocalStorage(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "storage_test_*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a new LocalStorage instance
storage, err := NewLocalStorage(tempDir)
require.NoError(t, err)
ctx := context.Background()
t.Run("Save and Get", func(t *testing.T) {
content := []byte("test content")
reader := bytes.NewReader(content)
// Save the file
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, int64(len(content)), fileInfo.Size)
assert.Equal(t, "text/plain", fileInfo.ContentType)
assert.False(t, fileInfo.CreatedAt.IsZero())
// Get the file
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, fileInfo.ID, info.ID)
assert.Equal(t, fileInfo.Name, info.Name)
assert.Equal(t, fileInfo.Size, info.Size)
})
t.Run("List", func(t *testing.T) {
// Clear the directory first
dirEntries, err := os.ReadDir(tempDir)
require.NoError(t, err)
for _, entry := range dirEntries {
if entry.Name() != ".meta" {
os.Remove(filepath.Join(tempDir, entry.Name()))
}
}
// Save multiple files
testFiles := []struct {
name string
content string
}{
{"test1.txt", "content1"},
{"test2.txt", "content2"},
{"other.txt", "content3"},
}
for _, f := range testFiles {
reader := bytes.NewReader([]byte(f.content))
_, err := storage.Save(ctx, f.name, "text/plain", reader)
require.NoError(t, err)
}
// List all files
allFiles, err := storage.List(ctx, "", 10, 0)
require.NoError(t, err)
assert.Len(t, allFiles, 3)
// List files with prefix
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, filesWithPrefix, 2)
for _, f := range filesWithPrefix {
assert.True(t, strings.HasPrefix(f.Name, "test"))
}
// Test pagination
pagedFiles, err := storage.List(ctx, "", 2, 1)
require.NoError(t, err)
assert.Len(t, pagedFiles, 2)
})
t.Run("Exists", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
require.NoError(t, err)
// Check if file exists
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.True(t, exists)
// Check non-existent file
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
require.NoError(t, err)
// Delete the file
err = storage.Delete(ctx, fileInfo.ID)
require.NoError(t, err)
// Verify file is deleted
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.False(t, exists)
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
})
t.Run("Invalid operations", func(t *testing.T) {
// Try to get non-existent file
_, _, err := storage.Get(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
// Try to save file with nil reader
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "reader cannot be nil")
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
})
}

View file

@ -0,0 +1,232 @@
package storage
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
type S3Storage struct {
client s3Client
bucket string
customURL string
proxyS3 bool
}
// s3Client is the interface that wraps the basic S3 client operations we need
type s3Client interface {
PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error)
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error)
ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
}
func NewS3Storage(client s3Client, bucket string, customURL string, proxyS3 bool) *S3Storage {
return &S3Storage{
client: client,
bucket: bucket,
customURL: customURL,
proxyS3: proxyS3,
}
}
func (s *S3Storage) generateID() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func (s *S3Storage) getObjectURL(id string) string {
if s.customURL != "" {
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id)
}
if s.proxyS3 {
return fmt.Sprintf("/api/media/file/%s", id)
}
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id)
}
func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
// Generate a unique ID for the file
id, err := s.generateID()
if err != nil {
return nil, fmt.Errorf("failed to generate file ID: %w", err)
}
// Check if the file exists
_, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err == nil {
return nil, fmt.Errorf("file already exists with ID: %s", id)
}
var noSuchKey *types.NoSuchKey
if !errors.As(err, &noSuchKey) {
return nil, fmt.Errorf("failed to check if file exists: %w", err)
}
// Upload the file
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
Body: reader,
ContentType: aws.String(contentType),
Metadata: map[string]string{
"x-amz-meta-original-name": name,
},
})
if err != nil {
return nil, fmt.Errorf("failed to upload file: %w", err)
}
now := time.Now()
info := &FileInfo{
ID: id,
Name: name,
Size: 0, // Size is not available until after upload
ContentType: contentType,
CreatedAt: now,
UpdatedAt: now,
URL: s.getObjectURL(id),
}
return info, nil
}
func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
// Get the object from S3
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err != nil {
return nil, nil, fmt.Errorf("failed to get file from S3: %w", err)
}
info := &FileInfo{
ID: id,
Name: result.Metadata["x-amz-meta-original-name"],
Size: aws.ToInt64(result.ContentLength),
ContentType: aws.ToString(result.ContentType),
CreatedAt: aws.ToTime(result.LastModified),
UpdatedAt: aws.ToTime(result.LastModified),
URL: s.getObjectURL(id),
}
return result.Body, info, nil
}
func (s *S3Storage) Delete(ctx context.Context, id string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err != nil {
return fmt.Errorf("failed to delete file from S3: %w", err)
}
return nil
}
func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
var files []*FileInfo
var continuationToken *string
// Skip objects for offset
for i := 0; i < offset/1000; i++ {
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
MaxKeys: aws.Int32(1000),
})
if err != nil {
return nil, fmt.Errorf("failed to list files from S3: %w", err)
}
if !aws.ToBool(output.IsTruncated) {
return files, nil
}
continuationToken = output.NextContinuationToken
}
// Get the actual objects
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(prefix),
ContinuationToken: continuationToken,
MaxKeys: aws.Int32(int32(limit)),
})
if err != nil {
return nil, fmt.Errorf("failed to list files from S3: %w", err)
}
for _, obj := range output.Contents {
// Get the object metadata
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: obj.Key,
})
var contentType string
var originalName string
if err != nil {
var noSuchKey *types.NoSuchKey
if errors.As(err, &noSuchKey) {
// If the object doesn't exist (which shouldn't happen normally),
// we'll still include it in the list but with empty metadata
contentType = ""
originalName = aws.ToString(obj.Key)
} else {
continue
}
} else {
contentType = aws.ToString(head.ContentType)
originalName = head.Metadata["x-amz-meta-original-name"]
if originalName == "" {
originalName = aws.ToString(obj.Key)
}
}
files = append(files, &FileInfo{
ID: aws.ToString(obj.Key),
Name: originalName,
Size: aws.ToInt64(obj.Size),
ContentType: contentType,
CreatedAt: aws.ToTime(obj.LastModified),
UpdatedAt: aws.ToTime(obj.LastModified),
URL: s.getObjectURL(aws.ToString(obj.Key)),
})
}
return files, nil
}
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(id),
})
if err != nil {
var nsk *types.NoSuchKey
if ok := errors.As(err, &nsk); ok {
return false, nil
}
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
}
return true, nil
}

View file

@ -0,0 +1,211 @@
package storage
import (
"bytes"
"context"
"io"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockS3Client is a mock implementation of the S3 client interface
type MockS3Client struct {
mock.Mock
}
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
}
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
}
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
}
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
}
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
}
func TestS3Storage(t *testing.T) {
ctx := context.Background()
mockClient := new(MockS3Client)
storage := NewS3Storage(mockClient, "test-bucket", "", false)
t.Run("Save", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
content := []byte("test content")
reader := bytes.NewReader(content)
// Mock HeadObject to return NotFound error
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
})
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.ContentType) == "text/plain"
})).Return(&s3.PutObjectOutput{}, nil)
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, "text/plain", fileInfo.ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Get", func(t *testing.T) {
content := []byte("test content")
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.GetObjectOutput{
Body: io.NopCloser(bytes.NewReader(content)),
ContentType: aws.String("text/plain"),
ContentLength: aws.Int64(int64(len(content))),
LastModified: aws.Time(time.Now()),
}, nil)
readCloser, info, err := storage.Get(ctx, "test-id")
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, "test-id", info.ID)
assert.Equal(t, int64(len(content)), info.Size)
mockClient.AssertExpectations(t)
})
t.Run("List", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Prefix) == "test" &&
aws.ToInt32(input.MaxKeys) == 10
})).Return(&s3.ListObjectsV2Output{
Contents: []types.Object{
{
Key: aws.String("test1"),
Size: aws.Int64(100),
LastModified: aws.Time(time.Now()),
},
{
Key: aws.String("test2"),
Size: aws.Int64(200),
LastModified: aws.Time(time.Now()),
},
},
}, nil)
// Mock HeadObject for both files
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test1"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test1.txt",
},
}, nil).Once()
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test2"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test2.txt",
},
}, nil).Once()
files, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, files, 2)
assert.Equal(t, "test1", files[0].ID)
assert.Equal(t, int64(100), files[0].Size)
assert.Equal(t, "test1.txt", files[0].Name)
assert.Equal(t, "text/plain", files[0].ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Delete", func(t *testing.T) {
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.DeleteObjectOutput{}, nil)
err := storage.Delete(ctx, "test-id")
require.NoError(t, err)
mockClient.AssertExpectations(t)
})
t.Run("Exists", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
// Mock HeadObject for existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.HeadObjectOutput{}, nil).Once()
exists, err := storage.Exists(ctx, "test-id")
require.NoError(t, err)
assert.True(t, exists)
// Mock HeadObject for non-existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "non-existent"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
}).Once()
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
mockClient.AssertExpectations(t)
})
t.Run("Custom URL", func(t *testing.T) {
customStorage := &S3Storage{
client: mockClient,
bucket: "test-bucket",
customURL: "https://custom.domain",
proxyS3: true,
}
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
})
}

View file

@ -0,0 +1,38 @@
package storage
//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock
import (
"context"
"io"
"time"
)
// FileInfo represents metadata about a stored file
type FileInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Size int64 `json:"size"`
ContentType string `json:"content_type"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
URL string `json:"url"`
}
// Storage defines the interface for file storage operations
type Storage interface {
// Save stores a file and returns its FileInfo
Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error)
// Get retrieves a file by its ID
Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error)
// Delete removes a file by its ID
Delete(ctx context.Context, id string) error
// List returns a list of files with optional prefix
List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error)
// Exists checks if a file exists
Exists(ctx context.Context, id string) (bool, error)
}

View file

@ -0,0 +1,57 @@
package testutil
import (
"context"
"os"
"path/filepath"
"testing"
"entgo.io/ent/dialect"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
"tss-rocks-be/ent"
)
// SetupTestDB creates a new test database and returns a client
func SetupTestDB(t *testing.T) *ent.Client {
// Create a temporary SQLite database for testing
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
client, err := ent.Open(dialect.SQLite, "file:"+dbPath+"?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
// Run the auto migration tool
err = client.Schema.Create(context.Background())
require.NoError(t, err)
// Clean up the database after the test
t.Cleanup(func() {
client.Close()
os.Remove(dbPath)
})
return client
}
// CleanupTestDB removes all data from the test database
func CleanupTestDB(t *testing.T, client *ent.Client) {
ctx := context.Background()
// Delete all data in reverse order of dependencies
_, err := client.Permission.Delete().Exec(ctx)
require.NoError(t, err)
_, err = client.Role.Delete().Exec(ctx)
require.NoError(t, err)
_, err = client.User.Delete().Exec(ctx)
require.NoError(t, err)
}
// IsSQLiteConstraintError checks if the error is a SQLite constraint error
func IsSQLiteConstraintError(err error) bool {
sqliteErr, ok := err.(sqlite3.Error)
return ok && sqliteErr.Code == sqlite3.ErrConstraint
}

View file

@ -0,0 +1,32 @@
package testutil
import (
"io"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// MockReadCloser is a mock implementation of io.ReadCloser
type MockReadCloser struct {
io.Reader
CloseFunc func() error
}
func (m MockReadCloser) Close() error {
if m.CloseFunc != nil {
return m.CloseFunc()
}
return nil
}
// NewMockReadCloser creates a new MockReadCloser with the given content
func NewMockReadCloser(content string) io.ReadCloser {
return MockReadCloser{Reader: strings.NewReader(content)}
}
// RequireMockEquals asserts that two mocks are equal
func RequireMockEquals(t *testing.T, expected, actual interface{}) {
require.Equal(t, expected, actual)
}

View file

@ -0,0 +1,70 @@
package testutil
import (
"bytes"
"encoding/json"
"io"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/ent"
"tss-rocks-be/ent/enttest"
)
// SetupTestRouter returns a new Gin engine for testing
func SetupTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
return gin.New()
}
// NewTestClient creates a new ent client for testing
func NewTestClient() *ent.Client {
client := enttest.Open(testing.TB(nil), "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
return client
}
// MakeTestRequest performs a test HTTP request and returns the response
func MakeTestRequest(t *testing.T, router *gin.Engine, method, path string, body interface{}) *httptest.ResponseRecorder {
var reqBody io.Reader
if body != nil {
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
reqBody = bytes.NewBuffer(jsonBytes)
}
req := httptest.NewRequest(method, path, reqBody)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
return w
}
// AssertResponse asserts the HTTP response status code and body
func AssertResponse(t *testing.T, w *httptest.ResponseRecorder, expectedStatus int, expectedBody interface{}) {
assert.Equal(t, expectedStatus, w.Code)
if expectedBody != nil {
var actualBody interface{}
err := json.Unmarshal(w.Body.Bytes(), &actualBody)
require.NoError(t, err)
assert.Equal(t, expectedBody, actualBody)
}
}
// AssertErrorResponse asserts an error response with a specific message
func AssertErrorResponse(t *testing.T, w *httptest.ResponseRecorder, expectedStatus int, expectedMessage string) {
assert.Equal(t, expectedStatus, w.Code)
var response struct {
Error string `json:"error"`
}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, expectedMessage, response.Error)
}

View file

@ -0,0 +1,41 @@
package types
// RateLimitConfig 限流配置
type RateLimitConfig struct {
IPRate int `yaml:"ip_rate"` // IP限流速率
IPBurst int `yaml:"ip_burst"` // IP突发请求数
RouteRates map[string]struct {
Rate int `yaml:"rate"` // 路由限流速率
Burst int `yaml:"burst"` // 路由突发请求数
} `yaml:"route_rates"`
}
// AccessLogConfig 访问日志配置
type AccessLogConfig struct {
// 是否启用控制台输出
EnableConsole bool `yaml:"enable_console"`
// 是否启用文件日志
EnableFile bool `yaml:"enable_file"`
// 日志文件路径
FilePath string `yaml:"file_path"`
// 日志格式 (json 或 text)
Format string `yaml:"format"`
// 日志级别
Level string `yaml:"level"`
// 日志轮转配置
Rotation struct {
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小MB
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
} `yaml:"rotation"`
}
// UploadConfig 文件上传配置
type UploadConfig struct {
MaxSize int `yaml:"max_size"` // 最大文件大小MB
AllowedTypes []string `yaml:"allowed_types"` // 允许的MIME类型
AllowedExtensions []string `yaml:"allowed_extensions"` // 允许的文件扩展名
}

View file

@ -0,0 +1,116 @@
package types
import (
"testing"
)
func TestRateLimitConfig(t *testing.T) {
config := RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/test": {
Rate: 50,
Burst: 100,
},
},
}
if config.IPRate != 100 {
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
}
if config.IPBurst != 200 {
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
}
route := config.RouteRates["/api/test"]
if route.Rate != 50 {
t.Errorf("Expected route rate 50, got %d", route.Rate)
}
if route.Burst != 100 {
t.Errorf("Expected route burst 100, got %d", route.Burst)
}
}
func TestAccessLogConfig(t *testing.T) {
config := AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "/var/log/app.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 5,
Compress: true,
LocalTime: true,
},
}
if !config.EnableConsole {
t.Error("Expected EnableConsole to be true")
}
if !config.EnableFile {
t.Error("Expected EnableFile to be true")
}
if config.FilePath != "/var/log/app.log" {
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
}
if config.Format != "json" {
t.Errorf("Expected Format 'json', got '%s'", config.Format)
}
if config.Level != "info" {
t.Errorf("Expected Level 'info', got '%s'", config.Level)
}
rotation := config.Rotation
if rotation.MaxSize != 100 {
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
}
if rotation.MaxAge != 7 {
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
}
if rotation.MaxBackups != 5 {
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
}
if !rotation.Compress {
t.Error("Expected Compress to be true")
}
if !rotation.LocalTime {
t.Error("Expected LocalTime to be true")
}
}
func TestUploadConfig(t *testing.T) {
config := UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
}
if config.MaxSize != 10 {
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
}
if len(config.AllowedTypes) != 2 {
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
}
if config.AllowedTypes[0] != "image/jpeg" {
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
}
if len(config.AllowedExtensions) != 2 {
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
}
if config.AllowedExtensions[0] != ".jpg" {
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
}
}

View file

@ -0,0 +1,8 @@
package types
// FileInfo represents metadata about a file
type FileInfo struct {
Size int64
Name string
ContentType string
}

View file

@ -0,0 +1,21 @@
package types
import "testing"
func TestFileInfo(t *testing.T) {
fileInfo := FileInfo{
Size: 1024,
Name: "test.jpg",
ContentType: "image/jpeg",
}
if fileInfo.Size != 1024 {
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
}
if fileInfo.Name != "test.jpg" {
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
}
if fileInfo.ContentType != "image/jpeg" {
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
}
}

View file

@ -0,0 +1,43 @@
package types
import "errors"
// Common errors
var (
ErrNotFound = errors.New("not found")
)
// Category represents a category in the system
type Category struct {
ID int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Description *string `json:"description,omitempty"`
}
// Post represents a blog post
type Post struct {
ID int `json:"id"`
Title string `json:"title"`
Slug string `json:"slug"`
ContentMarkdown string `json:"content_markdown"`
Summary string `json:"summary"`
MetaKeywords *string `json:"meta_keywords,omitempty"`
MetaDescription *string `json:"meta_description,omitempty"`
}
// Contributor represents a contributor to the blog
type Contributor struct {
ID int `json:"id"`
Name string `json:"name"`
AvatarURL *string `json:"avatar_url,omitempty"`
Bio *string `json:"bio,omitempty"`
}
// Daily represents a daily quote or message
type Daily struct {
ID string `json:"id"`
CategoryID int `json:"category_id"`
ImageURL string `json:"image_url"`
Quote string `json:"quote"`
}

View file

@ -0,0 +1,77 @@
package types
import (
"testing"
)
func TestCategory(t *testing.T) {
description := "Test Description"
category := Category{
ID: 1,
Name: "Test Category",
Slug: "test-category",
Description: &description,
}
if category.ID != 1 {
t.Errorf("Expected ID 1, got %d", category.ID)
}
if category.Name != "Test Category" {
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
}
if category.Slug != "test-category" {
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
}
if *category.Description != description {
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
}
}
func TestPost(t *testing.T) {
metaKeywords := "test,blog"
metaDesc := "Test Description"
post := Post{
ID: 1,
Title: "Test Post",
Slug: "test-post",
ContentMarkdown: "# Test Content",
Summary: "Test Summary",
MetaKeywords: &metaKeywords,
MetaDescription: &metaDesc,
}
if post.ID != 1 {
t.Errorf("Expected ID 1, got %d", post.ID)
}
if post.Title != "Test Post" {
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
}
if post.Slug != "test-post" {
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
}
if *post.MetaKeywords != metaKeywords {
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
}
}
func TestDaily(t *testing.T) {
daily := Daily{
ID: "2025-02-12",
CategoryID: 1,
ImageURL: "https://example.com/image.jpg",
Quote: "Test Quote",
}
if daily.ID != "2025-02-12" {
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
}
if daily.CategoryID != 1 {
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
}
if daily.ImageURL != "https://example.com/image.jpg" {
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
}
if daily.Quote != "Test Quote" {
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
}
}