[feature] migrate to monorepo
This commit is contained in:
commit
05ddc1f783
267 changed files with 75165 additions and 0 deletions
6
backend/internal/auth/auth.go
Normal file
6
backend/internal/auth/auth.go
Normal file
|
@ -0,0 +1,6 @@
|
|||
package auth
|
||||
|
||||
// Constants for auth-related context keys
|
||||
const (
|
||||
UserIDKey = "user_id"
|
||||
)
|
27
backend/internal/auth/auth_test.go
Normal file
27
backend/internal/auth/auth_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserIDKey(t *testing.T) {
|
||||
// Test that the UserIDKey constant is defined correctly
|
||||
if UserIDKey != "user_id" {
|
||||
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
|
||||
}
|
||||
|
||||
// Test context with user ID
|
||||
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
|
||||
value := ctx.Value(UserIDKey)
|
||||
if value != "test-user-123" {
|
||||
t.Errorf("Context value = %v, want %v", value, "test-user-123")
|
||||
}
|
||||
|
||||
// Test context without user ID
|
||||
emptyCtx := context.Background()
|
||||
emptyValue := emptyCtx.Value(UserIDKey)
|
||||
if emptyValue != nil {
|
||||
t.Errorf("Empty context value = %v, want nil", emptyValue)
|
||||
}
|
||||
}
|
74
backend/internal/config/config.go
Normal file
74
backend/internal/config/config.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
JWT JWTConfig `yaml:"jwt"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
RateLimit types.RateLimitConfig `yaml:"rate_limit"`
|
||||
AccessLog types.AccessLogConfig `yaml:"access_log"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Driver string `yaml:"driver"`
|
||||
DSN string `yaml:"dsn"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
Host string `yaml:"host"`
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `yaml:"secret"`
|
||||
Expiration string `yaml:"expiration"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Local LocalStorage `yaml:"local"`
|
||||
S3 S3Storage `yaml:"s3"`
|
||||
Upload types.UploadConfig `yaml:"upload"`
|
||||
}
|
||||
|
||||
type LocalStorage struct {
|
||||
RootDir string `yaml:"root_dir"`
|
||||
}
|
||||
|
||||
type S3Storage struct {
|
||||
Region string `yaml:"region"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
AccessKeyID string `yaml:"access_key_id"`
|
||||
SecretAccessKey string `yaml:"secret_access_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
CustomURL string `yaml:"custom_url"`
|
||||
ProxyS3 bool `yaml:"proxy_s3"`
|
||||
}
|
||||
|
||||
// Load loads configuration from a YAML file
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
85
backend/internal/config/config_test.go
Normal file
85
backend/internal/config/config_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary test config file
|
||||
content := []byte(`
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost:5432/dbname
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
jwt:
|
||||
secret: test-secret
|
||||
expiration: 24h
|
||||
storage:
|
||||
type: local
|
||||
local:
|
||||
root_dir: /tmp/storage
|
||||
upload:
|
||||
max_size: 10485760
|
||||
logging:
|
||||
level: info
|
||||
format: json
|
||||
`)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Test loading config
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
errorMsg string
|
||||
}{
|
||||
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
|
||||
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
|
||||
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
|
||||
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
|
||||
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadError(t *testing.T) {
|
||||
// Test loading non-existent file
|
||||
_, err := Load("non-existent-file.yaml")
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for non-existent file")
|
||||
}
|
||||
|
||||
// Test loading invalid YAML
|
||||
tmpDir := t.TempDir()
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
|
||||
t.Fatalf("Failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(invalidPath)
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for invalid YAML")
|
||||
}
|
||||
}
|
119
backend/internal/handler/auth.go
Normal file
119
backend/internal/handler/auth.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Role string `json:"role" binding:"required,oneof=admin editor contributor"`
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type AuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func (h *Handler) Register(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user roles
|
||||
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user roles")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract role names for JWT
|
||||
roleNames := make([]string, len(roles))
|
||||
for i, r := range roles {
|
||||
roleNames[i] = r.Name
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"roles": roleNames,
|
||||
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate token")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, AuthResponse{Token: tokenString})
|
||||
}
|
||||
|
||||
func (h *Handler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user roles
|
||||
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user roles")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract role names for JWT
|
||||
roleNames := make([]string, len(roles))
|
||||
for i, r := range roles {
|
||||
roleNames[i] = r.Name
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"roles": roleNames,
|
||||
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate token")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, AuthResponse{Token: tokenString})
|
||||
}
|
276
backend/internal/handler/auth_handler_test.go
Normal file
276
backend/internal/handler/auth_handler_test.go
Normal file
|
@ -0,0 +1,276 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type AuthHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestAuthHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestRegister() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
request RegisterRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功注册",
|
||||
request: RegisterRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
request: RegisterRequest{
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
},
|
||||
{
|
||||
name: "密码太短",
|
||||
request: RegisterRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "short",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
|
||||
},
|
||||
{
|
||||
name: "无效的角色",
|
||||
request: RegisterRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "invalid-role",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Register(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response["error"], tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestLogin() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
request LoginRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功登录",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "password123").
|
||||
Return(true)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
request: LoginRequest{
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
request: LoginRequest{
|
||||
Email: "nonexistent@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "nonexistent@example.com").
|
||||
Return(nil, errors.New("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid credentials",
|
||||
},
|
||||
{
|
||||
name: "密码错误",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "wrong-password",
|
||||
},
|
||||
setupMock: func() {
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "wrong-password").
|
||||
Return(false)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid credentials",
|
||||
},
|
||||
{
|
||||
name: "获取用户角色失败",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "password123").
|
||||
Return(true)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return(nil, errors.New("failed to get roles"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get user roles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Login(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response["error"], tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
468
backend/internal/handler/category_handler_test.go
Normal file
468
backend/internal/handler/category_handler_test.go
Normal file
|
@ -0,0 +1,468 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Custom assertion function for comparing categories
|
||||
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
|
||||
if expected == nil && actual == nil {
|
||||
return true
|
||||
}
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Fail(t, "One category is nil while the other is not")
|
||||
}
|
||||
|
||||
// Compare only relevant fields, ignoring time fields
|
||||
return assert.Equal(t, expected.ID, actual.ID) &&
|
||||
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
|
||||
}
|
||||
|
||||
// Custom assertion function for comparing category slices
|
||||
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
|
||||
if len(expected) != len(actual) {
|
||||
return assert.Fail(t, "Category slice lengths do not match")
|
||||
}
|
||||
|
||||
for i := range expected {
|
||||
if !assertCategoryEqual(t, expected[i], actual[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type CategoryHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestCategoryHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(CategoryHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListCategories
|
||||
func (s *CategoryHandlerTestSuite) TestListCategories() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("en")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("zh")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetCategory
|
||||
func (s *CategoryHandlerTestSuite) TestGetCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
langCode: "en",
|
||||
slug: "test-category",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
|
||||
Return(&ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Not Found",
|
||||
langCode: "en",
|
||||
slug: "non-existent",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
|
||||
Return(nil, types.ErrNotFound)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddCategoryContent
|
||||
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
|
||||
var description = "Test Description"
|
||||
testCases := []struct {
|
||||
name string
|
||||
categoryID string
|
||||
requestBody interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(&ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
categoryID: "1",
|
||||
requestBody: "invalid json",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Invalid Category ID",
|
||||
categoryID: "invalid",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
var body []byte
|
||||
var err error
|
||||
if str, ok := tc.requestBody.(string); ok {
|
||||
body = []byte(str)
|
||||
} else {
|
||||
body, err = json.Marshal(tc.requestBody)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.CategoryContent
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreateCategory
|
||||
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功创建分类",
|
||||
setupMock: func() {
|
||||
category := &ent.Category{
|
||||
ID: 1,
|
||||
}
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(category, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建分类失败",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(nil, errors.New("failed to create category"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to create category",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.CreateCategory(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
s.Equal(1, response.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
443
backend/internal/handler/contributor_handler_test.go
Normal file
443
backend/internal/handler/contributor_handler_test.go
Normal file
|
@ -0,0 +1,443 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type ContributorHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestContributorHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(ContributorHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestListContributors() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return([]*ent.Contributor{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Jane Smith",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []gin.H{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list contributors"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestGetContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid ID",
|
||||
id: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: name,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"name": "", // Empty name is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(&ent.ContributorSocialLink{
|
||||
Type: "github",
|
||||
Name: name,
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"type": "github",
|
||||
"name": "johndoe",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid contributor ID",
|
||||
id: "invalid",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
id: "1",
|
||||
body: map[string]interface{}{
|
||||
"type": "", // Empty type is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add contributor social link"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
519
backend/internal/handler/daily_handler_test.go
Normal file
519
backend/internal/handler/daily_handler_test.go
Normal file
|
@ -0,0 +1,519 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DailyHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestDailyHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(DailyHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestListDailies() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
categoryID := 1
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list dailies"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/dailies"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestGetDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestCreateDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"id": "daily1",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dailyID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(&ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
dailyID: "daily1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add daily content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
513
backend/internal/handler/handler.go
Normal file
513
backend/internal/handler/handler.go
Normal file
|
@ -0,0 +1,513 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
service service.Service
|
||||
}
|
||||
|
||||
func NewHandler(cfg *config.Config, service service.Service) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all the routes
|
||||
func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
||||
api := r.Group("/api/v1")
|
||||
{
|
||||
// Auth routes
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Register)
|
||||
auth.POST("/login", h.Login)
|
||||
}
|
||||
|
||||
// Category routes
|
||||
categories := api.Group("/categories")
|
||||
{
|
||||
categories.GET("", h.ListCategories)
|
||||
categories.GET("/:slug", h.GetCategory)
|
||||
categories.POST("", h.CreateCategory)
|
||||
categories.POST("/:id/contents", h.AddCategoryContent)
|
||||
}
|
||||
|
||||
// Post routes
|
||||
posts := api.Group("/posts")
|
||||
{
|
||||
posts.GET("", h.ListPosts)
|
||||
posts.GET("/:slug", h.GetPost)
|
||||
posts.POST("", h.CreatePost)
|
||||
posts.POST("/:id/contents", h.AddPostContent)
|
||||
}
|
||||
|
||||
// Contributor routes
|
||||
contributors := api.Group("/contributors")
|
||||
{
|
||||
contributors.GET("", h.ListContributors)
|
||||
contributors.GET("/:id", h.GetContributor)
|
||||
contributors.POST("", h.CreateContributor)
|
||||
contributors.POST("/:id/social-links", h.AddContributorSocialLink)
|
||||
}
|
||||
|
||||
// Daily routes
|
||||
dailies := api.Group("/dailies")
|
||||
{
|
||||
dailies.GET("", h.ListDailies)
|
||||
dailies.GET("/:id", h.GetDaily)
|
||||
dailies.POST("", h.CreateDaily)
|
||||
dailies.POST("/:id/contents", h.AddDailyContent)
|
||||
}
|
||||
|
||||
// Media routes
|
||||
media := api.Group("/media")
|
||||
{
|
||||
media.GET("", h.ListMedia)
|
||||
media.POST("", h.UploadMedia)
|
||||
media.GET("/:id", h.GetMedia)
|
||||
media.DELETE("/:id", h.DeleteMedia)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Category handlers
|
||||
func (h *Handler) ListCategories(c *gin.Context) {
|
||||
langCode := c.Query("lang")
|
||||
if langCode == "" {
|
||||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
categories, err := h.service.ListCategories(c.Request.Context(), langCode)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list categories")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list categories"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, categories)
|
||||
}
|
||||
|
||||
func (h *Handler) GetCategory(c *gin.Context) {
|
||||
langCode := c.Query("lang")
|
||||
if langCode == "" {
|
||||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
slug := c.Param("slug")
|
||||
category, err := h.service.GetCategoryBySlug(c.Request.Context(), langCode, slug)
|
||||
if err != nil {
|
||||
if err == types.ErrNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Category not found"})
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("Failed to get category")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get category"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, category)
|
||||
}
|
||||
|
||||
func (h *Handler) CreateCategory(c *gin.Context) {
|
||||
category, err := h.service.CreateCategory(c.Request.Context())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create category")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create category"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, category)
|
||||
}
|
||||
|
||||
type AddCategoryContentRequest struct {
|
||||
LanguageCode string `json:"language_code" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description *string `json:"description"`
|
||||
Slug string `json:"slug" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *Handler) AddCategoryContent(c *gin.Context) {
|
||||
var req AddCategoryContentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
categoryID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid category ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var description string
|
||||
if req.Description != nil {
|
||||
description = *req.Description
|
||||
}
|
||||
|
||||
content, err := h.service.AddCategoryContent(c.Request.Context(), categoryID, req.LanguageCode, req.Name, description, req.Slug)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to add category content")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add category content"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, content)
|
||||
}
|
||||
|
||||
// Post handlers
|
||||
func (h *Handler) ListPosts(c *gin.Context) {
|
||||
langCode := c.Query("lang")
|
||||
if langCode == "" {
|
||||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
var categoryID *int
|
||||
if catIDStr := c.Query("category_id"); catIDStr != "" {
|
||||
if id, err := strconv.Atoi(catIDStr); err == nil {
|
||||
categoryID = &id
|
||||
}
|
||||
}
|
||||
|
||||
limit := 10 // Default limit
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
offset := 0 // Default offset
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
offset = o
|
||||
}
|
||||
}
|
||||
|
||||
posts, err := h.service.ListPosts(c.Request.Context(), langCode, categoryID, limit, offset)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list posts")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list posts"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, posts)
|
||||
}
|
||||
|
||||
func (h *Handler) GetPost(c *gin.Context) {
|
||||
langCode := c.Query("lang")
|
||||
if langCode == "" {
|
||||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
slug := c.Param("slug")
|
||||
post, err := h.service.GetPostBySlug(c.Request.Context(), langCode, slug)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get post")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get post"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to a map to control the fields
|
||||
response := gin.H{
|
||||
"id": post.ID,
|
||||
"status": post.Status,
|
||||
"slug": post.Slug,
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{},
|
||||
},
|
||||
}
|
||||
|
||||
contents := make([]gin.H, 0, len(post.Edges.Contents))
|
||||
for _, content := range post.Edges.Contents {
|
||||
contents = append(contents, gin.H{
|
||||
"language_code": content.LanguageCode,
|
||||
"title": content.Title,
|
||||
"content_markdown": content.ContentMarkdown,
|
||||
"summary": content.Summary,
|
||||
})
|
||||
}
|
||||
response["edges"].(gin.H)["contents"] = contents
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) CreatePost(c *gin.Context) {
|
||||
post, err := h.service.CreatePost(c.Request.Context(), "draft") // Default to draft status
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create post")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create post"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to a map to control the fields
|
||||
response := gin.H{
|
||||
"id": post.ID,
|
||||
"status": post.Status,
|
||||
"edges": gin.H{
|
||||
"contents": []interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
type AddPostContentRequest struct {
|
||||
LanguageCode string `json:"language_code" binding:"required"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
ContentMarkdown string `json:"content_markdown" binding:"required"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
MetaKeywords string `json:"meta_keywords"`
|
||||
MetaDescription string `json:"meta_description"`
|
||||
}
|
||||
|
||||
func (h *Handler) AddPostContent(c *gin.Context) {
|
||||
var req AddPostContentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
postID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid post ID"})
|
||||
return
|
||||
}
|
||||
|
||||
content, err := h.service.AddPostContent(c.Request.Context(), postID, req.LanguageCode, req.Title, req.ContentMarkdown, req.Summary, req.MetaKeywords, req.MetaDescription)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to add post content")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add post content"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"title": content.Title,
|
||||
"content_markdown": content.ContentMarkdown,
|
||||
"language_code": content.LanguageCode,
|
||||
"summary": content.Summary,
|
||||
"meta_keywords": content.MetaKeywords,
|
||||
"meta_description": content.MetaDescription,
|
||||
"edges": gin.H{},
|
||||
})
|
||||
}
|
||||
|
||||
// Contributor handlers
|
||||
func (h *Handler) ListContributors(c *gin.Context) {
|
||||
contributors, err := h.service.ListContributors(c.Request.Context())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list contributors")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list contributors"})
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]gin.H, len(contributors))
|
||||
for i, contributor := range contributors {
|
||||
socialLinks := make([]gin.H, len(contributor.Edges.SocialLinks))
|
||||
for j, link := range contributor.Edges.SocialLinks {
|
||||
socialLinks[j] = gin.H{
|
||||
"type": link.Type,
|
||||
"value": link.Value,
|
||||
"edges": gin.H{},
|
||||
}
|
||||
}
|
||||
|
||||
response[i] = gin.H{
|
||||
"id": contributor.ID,
|
||||
"name": contributor.Name,
|
||||
"created_at": contributor.CreatedAt,
|
||||
"updated_at": contributor.UpdatedAt,
|
||||
"edges": gin.H{
|
||||
"social_links": socialLinks,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *Handler) GetContributor(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contributor ID"})
|
||||
return
|
||||
}
|
||||
|
||||
contributor, err := h.service.GetContributorByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get contributor")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get contributor"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, contributor)
|
||||
}
|
||||
|
||||
type CreateContributorRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Bio *string `json:"bio"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateContributor(c *gin.Context) {
|
||||
var req CreateContributorRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
contributor, err := h.service.CreateContributor(c.Request.Context(), req.Name, req.AvatarURL, req.Bio)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create contributor")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create contributor"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, contributor)
|
||||
}
|
||||
|
||||
type AddContributorSocialLinkRequest struct {
|
||||
Type string `json:"type" binding:"required"`
|
||||
Name *string `json:"name"`
|
||||
Value string `json:"value" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *Handler) AddContributorSocialLink(c *gin.Context) {
|
||||
var req AddContributorSocialLinkRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
contributorID, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contributor ID"})
|
||||
return
|
||||
}
|
||||
|
||||
name := ""
|
||||
if req.Name != nil {
|
||||
name = *req.Name
|
||||
}
|
||||
link, err := h.service.AddContributorSocialLink(c.Request.Context(), contributorID, req.Type, name, req.Value)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to add contributor social link")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add contributor social link"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, link)
|
||||
}
|
||||
|
||||
// Daily handlers
|
||||
func (h *Handler) ListDailies(c *gin.Context) {
|
||||
langCode := c.Query("lang")
|
||||
if langCode == "" {
|
||||
langCode = "en" // Default to English
|
||||
}
|
||||
|
||||
var categoryID *int
|
||||
if catIDStr := c.Query("category_id"); catIDStr != "" {
|
||||
if id, err := strconv.Atoi(catIDStr); err == nil {
|
||||
categoryID = &id
|
||||
}
|
||||
}
|
||||
|
||||
limit := 10 // Default limit
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
offset := 0 // Default offset
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
offset = o
|
||||
}
|
||||
}
|
||||
|
||||
dailies, err := h.service.ListDailies(c.Request.Context(), langCode, categoryID, limit, offset)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list dailies")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list dailies"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dailies)
|
||||
}
|
||||
|
||||
func (h *Handler) GetDaily(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
daily, err := h.service.GetDailyByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get daily")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get daily"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, daily)
|
||||
}
|
||||
|
||||
type CreateDailyRequest struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
CategoryID int `json:"category_id" binding:"required"`
|
||||
ImageURL string `json:"image_url" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateDaily(c *gin.Context) {
|
||||
var req CreateDailyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
daily, err := h.service.CreateDaily(c.Request.Context(), req.ID, req.CategoryID, req.ImageURL)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create daily")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create daily"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, daily)
|
||||
}
|
||||
|
||||
type AddDailyContentRequest struct {
|
||||
LanguageCode string `json:"language_code" binding:"required"`
|
||||
Quote string `json:"quote" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *Handler) AddDailyContent(c *gin.Context) {
|
||||
var req AddDailyContentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
dailyID := c.Param("id")
|
||||
content, err := h.service.AddDailyContent(c.Request.Context(), dailyID, req.LanguageCode, req.Quote)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to add daily content")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add daily content"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, content)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func stringPtr(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
43
backend/internal/handler/handler_test.go
Normal file
43
backend/internal/handler/handler_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStringPtr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: strPtr(""),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "non-empty string",
|
||||
input: strPtr("test"),
|
||||
expected: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := stringPtr(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create string pointer
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
173
backend/internal/handler/media.go
Normal file
173
backend/internal/handler/media.go
Normal file
|
@ -0,0 +1,173 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Media handlers
|
||||
func (h *Handler) ListMedia(c *gin.Context) {
|
||||
limit := 10 // Default limit
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
offset := 0 // Default offset
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
offset = o
|
||||
}
|
||||
}
|
||||
|
||||
media, err := h.service.ListMedia(c.Request.Context(), limit, offset)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list media"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, media)
|
||||
}
|
||||
|
||||
func (h *Handler) UploadMedia(c *gin.Context) {
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get file from form
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
|
||||
return
|
||||
}
|
||||
|
||||
// 文件大小限制
|
||||
if file.Size > 10*1024*1024 { // 10MB
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit (10MB)"})
|
||||
return
|
||||
}
|
||||
|
||||
// 文件类型限制
|
||||
allowedTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"video/mp4": true,
|
||||
"video/webm": true,
|
||||
"audio/mpeg": true,
|
||||
"audio/ogg": true,
|
||||
"application/pdf": true,
|
||||
}
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
if _, ok := allowedTypes[contentType]; !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type"})
|
||||
return
|
||||
}
|
||||
|
||||
// Upload file
|
||||
media, err := h.service.Upload(c.Request.Context(), file, userID.(int))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to upload media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload media"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, media)
|
||||
}
|
||||
|
||||
func (h *Handler) GetMedia(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get media metadata
|
||||
media, err := h.service.GetMedia(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get file content
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get media file")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Set response headers
|
||||
c.Header("Content-Type", media.MimeType)
|
||||
c.Header("Content-Length", fmt.Sprintf("%d", info.Size))
|
||||
c.Header("Content-Disposition", fmt.Sprintf("inline; filename=%s", media.OriginalName))
|
||||
|
||||
// Stream the file
|
||||
if _, err := io.Copy(c.Writer, reader); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to stream media file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) GetMediaFile(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get file content
|
||||
reader, info, err := h.service.GetFile(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get media file")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get media file"})
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Set response headers
|
||||
c.Header("Content-Type", info.ContentType)
|
||||
c.Header("Content-Length", fmt.Sprintf("%d", info.Size))
|
||||
c.Header("Content-Disposition", fmt.Sprintf("inline; filename=%s", info.Name))
|
||||
|
||||
// Stream the file
|
||||
if _, err := io.Copy(c.Writer, reader); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to stream media file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteMedia(c *gin.Context) {
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid media ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteMedia(c.Request.Context(), id, userID.(int)); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete media")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete media"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, nil)
|
||||
}
|
524
backend/internal/handler/media_handler_test.go
Normal file
524
backend/internal/handler/media_handler_test.go
Normal file
|
@ -0,0 +1,524 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"net/textproto"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type MediaHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestMediaHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestListMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
query string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功列出媒体",
|
||||
query: "?limit=10&offset=0",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "使用默认限制和偏移",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "列出媒体失败",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return(nil, errors.New("failed to list media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to list media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.ListMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response []*ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestUploadMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功上传媒体",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
expectedFile := &multipart.FileHeader{
|
||||
Filename: "test.jpg",
|
||||
Size: int64(len("test content")),
|
||||
Header: textproto.MIMEHeader{
|
||||
"Content-Type": []string{"image/jpeg"},
|
||||
},
|
||||
}
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
|
||||
s.Equal(expectedFile.Filename, f.Filename)
|
||||
s.Equal(expectedFile.Size, f.Size)
|
||||
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
|
||||
return &ent.Media{ID: 1}, nil
|
||||
})
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", nil)
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "上传失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to upload"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to upload media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, err := tc.setupRequest()
|
||||
s.NoError(err)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.UploadMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
|
||||
Size: 11,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "获取媒体元数据失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to get media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media",
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.GetMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal("11", w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
s.Equal("test content", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody []byte
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体文件",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
fileContent := "test file content"
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
|
||||
Name: "test.jpg",
|
||||
Size: int64(len(fileContent)),
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte("test file content"),
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup
|
||||
req, err := tc.setupRequest()
|
||||
s.Require().NoError(err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Test
|
||||
s.handler.GetMediaFile(c)
|
||||
|
||||
// Verify
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
s.Equal(tc.expectedBody, w.Body.Bytes())
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功删除媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(nil)
|
||||
},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
mediaID: "1",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "删除媒体失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(errors.New("failed to delete"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to delete media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.DeleteMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
611
backend/internal/handler/post_handler_test.go
Normal file
611
backend/internal/handler/post_handler_test.go
Normal file
|
@ -0,0 +1,611 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PostHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestPostHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(PostHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListPosts
|
||||
func (s *PostHandlerTestSuite) TestListPosts() {
|
||||
categoryID := 1
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
langCode: "en",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
langCode: "en",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
langCode: "en",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/posts"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Post
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetPost
|
||||
func (s *PostHandlerTestSuite) TestGetPost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "zh", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "zh",
|
||||
"title": "测试帖子",
|
||||
"content_markdown": "测试内容",
|
||||
"summary": "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/posts/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreatePost
|
||||
func (s *PostHandlerTestSuite) TestCreatePost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "draft",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "draft",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddPostContent
|
||||
func (s *PostHandlerTestSuite) TestAddPostContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
postID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(&ent.PostContent{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
Edges: ent.PostContentEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
"meta_keywords": "test,keywords",
|
||||
"meta_description": "Test meta description",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid post ID",
|
||||
postID: "invalid",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid post ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
postID: "1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add post content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
192
backend/internal/middleware/accesslog.go
Normal file
192
backend/internal/middleware/accesslog.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// AccessLogConfig 访问日志配置
|
||||
type AccessLogConfig struct {
|
||||
// 是否启用控制台输出
|
||||
EnableConsole bool `yaml:"enable_console"`
|
||||
// 是否启用文件日志
|
||||
EnableFile bool `yaml:"enable_file"`
|
||||
// 日志文件路径
|
||||
FilePath string `yaml:"file_path"`
|
||||
// 日志格式 (json 或 text)
|
||||
Format string `yaml:"format"`
|
||||
// 日志级别
|
||||
Level string `yaml:"level"`
|
||||
// 日志轮转配置
|
||||
Rotation struct {
|
||||
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小(MB)
|
||||
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
|
||||
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
|
||||
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
|
||||
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
|
||||
} `yaml:"rotation"`
|
||||
}
|
||||
|
||||
// accessLogger 访问日志记录器
|
||||
type accessLogger struct {
|
||||
consoleLogger *zerolog.Logger
|
||||
fileLogger *zerolog.Logger
|
||||
logWriter *lumberjack.Logger
|
||||
config *types.AccessLogConfig
|
||||
}
|
||||
|
||||
// Close 关闭日志文件
|
||||
func (l *accessLogger) Close() error {
|
||||
if l.logWriter != nil {
|
||||
return l.logWriter.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newAccessLogger 创建新的访问日志记录器
|
||||
func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
|
||||
var consoleLogger, fileLogger *zerolog.Logger
|
||||
var logWriter *lumberjack.Logger
|
||||
|
||||
// 设置日志级别
|
||||
level, err := zerolog.ParseLevel(config.Level)
|
||||
if err != nil {
|
||||
level = zerolog.InfoLevel
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
|
||||
// 配置控制台日志
|
||||
if config.EnableConsole {
|
||||
logger := zerolog.New(os.Stdout).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
if config.Format == "text" {
|
||||
logger = logger.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339})
|
||||
}
|
||||
consoleLogger = &logger
|
||||
}
|
||||
|
||||
// 配置文件日志
|
||||
if config.EnableFile {
|
||||
// 确保日志目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
// 配置日志轮转
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: config.FilePath,
|
||||
MaxSize: config.Rotation.MaxSize, // MB
|
||||
MaxAge: config.Rotation.MaxAge, // days
|
||||
MaxBackups: config.Rotation.MaxBackups, // files
|
||||
Compress: config.Rotation.Compress, // 是否压缩
|
||||
LocalTime: config.Rotation.LocalTime, // 使用本地时间
|
||||
}
|
||||
|
||||
logger := zerolog.New(logWriter).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
fileLogger = &logger
|
||||
}
|
||||
|
||||
return &accessLogger{
|
||||
consoleLogger: consoleLogger,
|
||||
fileLogger: fileLogger,
|
||||
logWriter: logWriter,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// logEvent 记录日志事件
|
||||
func (l *accessLogger) logEvent(fields map[string]interface{}, msg string) {
|
||||
if l.consoleLogger != nil {
|
||||
event := l.consoleLogger.Info()
|
||||
for k, v := range fields {
|
||||
event = event.Interface(k, v)
|
||||
}
|
||||
event.Msg(msg)
|
||||
}
|
||||
if l.fileLogger != nil {
|
||||
event := l.fileLogger.Info()
|
||||
for k, v := range fields {
|
||||
event = event.Interface(k, v)
|
||||
}
|
||||
event.Msg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// AccessLog 创建访问日志中间件
|
||||
func AccessLog(config *types.AccessLogConfig) (gin.HandlerFunc, error) {
|
||||
logger, err := newAccessLogger(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 用于测试时关闭日志文件
|
||||
if c == nil {
|
||||
if err := logger.Close(); err != nil {
|
||||
fmt.Printf("Error closing log file: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
requestID := uuid.New().String()
|
||||
path := c.Request.URL.Path
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
// 设置请求ID到上下文
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 计算处理时间
|
||||
latency := time.Since(start)
|
||||
|
||||
// 获取用户ID(如果已认证)
|
||||
var userID interface{}
|
||||
if id, exists := c.Get("user_id"); exists {
|
||||
userID = id
|
||||
}
|
||||
|
||||
// 准备日志字段
|
||||
fields := map[string]interface{}{
|
||||
"request_id": requestID,
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
"query": query,
|
||||
"ip": c.ClientIP(),
|
||||
"user_agent": c.Request.UserAgent(),
|
||||
"status": c.Writer.Status(),
|
||||
"size": c.Writer.Size(),
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"component": "access_log",
|
||||
}
|
||||
|
||||
if userID != nil {
|
||||
fields["user_id"] = userID
|
||||
}
|
||||
|
||||
// 如果有错误,添加到日志中
|
||||
if len(c.Errors) > 0 {
|
||||
fields["error"] = c.Errors.String()
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
logger.logEvent(fields, fmt.Sprintf("%s %s", c.Request.Method, path))
|
||||
}, nil
|
||||
}
|
238
backend/internal/middleware/accesslog_test.go
Normal file
238
backend/internal/middleware/accesslog_test.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
// 设置测试临时目录
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
setupRequest func(*http.Request)
|
||||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||||
}{
|
||||
{
|
||||
name: "Console logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "File logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: false,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both console and file logging",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With authenticated user",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
assert.Contains(t, logOutput, "test-user")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 捕获标准输出
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建访问日志中间件
|
||||
middleware, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 添加测试路由
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 如果是测试认证用户的情况,设置用户ID
|
||||
if tc.name == "With authenticated user" {
|
||||
c.Set("user_id", "test-user")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tc.setupRequest != nil {
|
||||
tc.setupRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 恢复标准输出并获取输出内容
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
os.Stdout = oldStdout
|
||||
|
||||
// 验证输出
|
||||
if tc.validateOutput != nil {
|
||||
tc.validateOutput(t, rec, buf.String())
|
||||
}
|
||||
|
||||
// 关闭日志文件
|
||||
if tc.config.EnableFile {
|
||||
// 调用中间件函数来关闭日志文件
|
||||
middleware(nil)
|
||||
// 等待一小段时间确保文件完全关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid log level",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Level: "invalid_level",
|
||||
},
|
||||
expectedError: false, // 应该使用默认的 info 级别
|
||||
},
|
||||
{
|
||||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
82
backend/internal/middleware/auth.go
Normal file
82
backend/internal/middleware/auth.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// AuthMiddleware creates a middleware for JWT authentication
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to parse token")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
c.Set("user_id", claims["sub"])
|
||||
c.Set("user_role", claims["role"])
|
||||
c.Next()
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RoleMiddleware creates a middleware for role-based authorization
|
||||
func RoleMiddleware(roles ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userRole, exists := c.Get("user_role")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := userRole.(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
if role == roleStr {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
|
||||
c.Abort()
|
||||
}
|
||||
}
|
217
backend/internal/middleware/auth_test.go
Normal file
217
backend/internal/middleware/auth_test.go
Normal file
|
@ -0,0 +1,217 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTestToken(secret string, claims jwt.MapClaims) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, _ := token.SignedString([]byte(secret))
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
jwtSecret := "test-secret"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(*http.Request)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
checkUserData bool
|
||||
expectedUserID string
|
||||
expectedRole string
|
||||
}{
|
||||
{
|
||||
name: "No Authorization header",
|
||||
setupAuth: func(req *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header is required"},
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization format",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "InvalidFormat")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
{
|
||||
name: "Valid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user123",
|
||||
"role": "user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
checkUserData: true,
|
||||
expectedUserID: "user123",
|
||||
expectedRole: "user",
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user123",
|
||||
"role": "user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(AuthMiddleware(jwtSecret))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if tc.checkUserData {
|
||||
userID, exists := c.Get("user_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tc.expectedUserID, userID)
|
||||
|
||||
role, exists := c.Get("user_role")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tc.expectedRole, role)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tc.setupAuth(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
allowedRoles []string
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "No user role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置用户角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "User role not found"},
|
||||
},
|
||||
{
|
||||
name: "Invalid role type",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", 123) // 设置错误类型的角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: map[string]string{"error": "Invalid user role type"},
|
||||
},
|
||||
{
|
||||
name: "Insufficient permissions",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "user")
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
||||
},
|
||||
{
|
||||
name: "Allowed role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "admin")
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "One of multiple allowed roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_role", "editor")
|
||||
},
|
||||
allowedRoles: []string{"admin", "editor", "moderator"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加角色中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
tc.setupContext(c)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(RoleMiddleware(tc.allowedRoles...))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
22
backend/internal/middleware/cors.go
Normal file
22
backend/internal/middleware/cors.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS middleware
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
76
backend/internal/middleware/cors_test.go
Normal file
76
backend/internal/middleware/cors_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
checkHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "Normal GET request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加 CORS 中间件
|
||||
router.Use(CORS())
|
||||
|
||||
// 添加测试路由
|
||||
router.Any("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(tc.method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证状态码
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.checkHeaders {
|
||||
// 验证 CORS 头部
|
||||
headers := rec.Header()
|
||||
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
107
backend/internal/middleware/ratelimit.go
Normal file
107
backend/internal/middleware/ratelimit.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// ipLimiter IP限流器
|
||||
type ipLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// rateLimiter 限流器管理器
|
||||
type rateLimiter struct {
|
||||
ips map[string]*ipLimiter
|
||||
mu sync.RWMutex
|
||||
config *types.RateLimitConfig
|
||||
routes map[string]*rate.Limiter
|
||||
}
|
||||
|
||||
// newRateLimiter 创建新的限流器
|
||||
func newRateLimiter(config *types.RateLimitConfig) *rateLimiter {
|
||||
// 初始化路由限流器
|
||||
routes := make(map[string]*rate.Limiter)
|
||||
for path, cfg := range config.RouteRates {
|
||||
routes[path] = rate.NewLimiter(rate.Limit(cfg.Rate), cfg.Burst)
|
||||
}
|
||||
|
||||
rl := &rateLimiter{
|
||||
ips: make(map[string]*ipLimiter),
|
||||
config: config,
|
||||
routes: routes,
|
||||
}
|
||||
|
||||
// 启动清理过期IP限流器的goroutine
|
||||
go rl.cleanupIPLimiters()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanupIPLimiters 清理过期的IP限流器
|
||||
func (rl *rateLimiter) cleanupIPLimiters() {
|
||||
for {
|
||||
time.Sleep(time.Hour) // 每小时清理一次
|
||||
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter 获取IP限流器
|
||||
func (rl *rateLimiter) getLimiter(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, exists := rl.ips[ip]
|
||||
if !exists {
|
||||
limiter := rate.NewLimiter(rate.Limit(rl.config.IPRate), rl.config.IPBurst)
|
||||
rl.ips[ip] = &ipLimiter{limiter: limiter, lastSeen: time.Now()}
|
||||
return limiter
|
||||
}
|
||||
|
||||
v.lastSeen = time.Now()
|
||||
return v.limiter
|
||||
}
|
||||
|
||||
// RateLimit 创建限流中间件
|
||||
func RateLimit(config *types.RateLimitConfig) gin.HandlerFunc {
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 检查路由限流
|
||||
path := c.Request.URL.Path
|
||||
if limiter, ok := rl.routes[path]; ok {
|
||||
if !limiter.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "too many requests for this route",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IP限流
|
||||
limiter := rl.getLimiter(c.ClientIP())
|
||||
if !limiter.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "too many requests from this IP",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
207
backend/internal/middleware/ratelimit_test.go
Normal file
207
backend/internal/middleware/ratelimit_test.go
Normal file
|
@ -0,0 +1,207 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.RateLimitConfig
|
||||
setupTest func(*gin.Engine)
|
||||
runTest func(*testing.T, *gin.Engine)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IP rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1, // 每秒1个请求
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 第一个请求应该成功
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 第二个请求应该被限制
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests from this IP", response["error"])
|
||||
|
||||
// 等待限流器重置
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// 第三个请求应该成功
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Route rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
||||
IPBurst: 10,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/limited": {
|
||||
Rate: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/limited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
router.GET("/unlimited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 测试限流路由
|
||||
req := httptest.NewRequest("GET", "/limited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests for this route", response["error"])
|
||||
|
||||
// 测试未限流路由
|
||||
req = httptest.NewRequest("GET", "/unlimited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// IP1 的请求
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.3:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 的请求应该不受 IP1 的限制影响
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.4:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req2)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加限流中间件
|
||||
router.Use(RateLimit(tc.config))
|
||||
|
||||
// 设置测试路由
|
||||
tc.setupTest(router)
|
||||
|
||||
// 运行测试
|
||||
tc.runTest(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
config := &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
}
|
||||
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
// 添加一些IP限流器
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
rl.getLimiter(ip)
|
||||
}
|
||||
|
||||
// 验证IP限流器已创建
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, len(ips), len(rl.ips))
|
||||
rl.mu.RUnlock()
|
||||
|
||||
// 修改一些IP的最后访问时间为1小时前
|
||||
rl.mu.Lock()
|
||||
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 手动触发清理
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 验证过期的IP限流器已被删除
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, 1, len(rl.ips))
|
||||
_, exists := rl.ips["192.168.1.3"]
|
||||
assert.True(t, exists)
|
||||
rl.mu.RUnlock()
|
||||
}
|
110
backend/internal/middleware/rbac.go
Normal file
110
backend/internal/middleware/rbac.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/auth"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequirePermission creates a middleware that checks if the user has the required permission
|
||||
func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user from context
|
||||
userID, exists := c.Get(auth.UserIDKey)
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user with roles
|
||||
user, err := client.User.Query().
|
||||
Where(user.ID(userID.(int))).
|
||||
WithRoles(func(q *ent.RoleQuery) {
|
||||
q.WithPermissions()
|
||||
}).
|
||||
Only(context.Background())
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "User not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has the required permission through any of their roles
|
||||
hasPermission := false
|
||||
for _, r := range user.Edges.Roles {
|
||||
for _, p := range r.Edges.Permissions {
|
||||
if p.Resource == resource && p.Action == action {
|
||||
hasPermission = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasPermission {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasPermission {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("Missing required permission: %s:%s", resource, action),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole creates a middleware that checks if the user has the required role
|
||||
func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user from context
|
||||
userID, exists := c.Get(auth.UserIDKey)
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user with roles
|
||||
user, err := client.User.Query().
|
||||
Where(user.ID(userID.(int))).
|
||||
WithRoles().
|
||||
Only(context.Background())
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "User not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has the required role
|
||||
hasRole := false
|
||||
for _, r := range user.Edges.Roles {
|
||||
if r.Name == roleName {
|
||||
hasRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRole {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": fmt.Sprintf("Required role: %s", roleName),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
159
backend/internal/middleware/upload.go
Normal file
159
backend/internal/middleware/upload.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
|
||||
)
|
||||
|
||||
// ValidateUpload 创建文件上传验证中间件
|
||||
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否是multipart/form-data请求
|
||||
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Content-Type must be multipart/form-data",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析multipart表单
|
||||
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Failed to parse form: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil || form.File == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "No file uploaded",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有上传的文件
|
||||
for _, files := range form.File {
|
||||
for _, file := range files {
|
||||
// 检查文件大小
|
||||
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if !contains(cfg.AllowedExtensions, ext) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File extension %s is not allowed", ext),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to open file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 读取文件头部用于MIME类型检测
|
||||
header := make([]byte, maxHeaderBytes)
|
||||
n, err := src.Read(header)
|
||||
if err != nil && err != io.EOF {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// 检测MIME类型
|
||||
contentType := http.DetectContentType(header)
|
||||
if !contains(cfg.AllowedTypes, contentType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File type %s is not allowed", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件指针重置到开始位置
|
||||
_, err = src.Seek(0, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件内容读入缓冲区
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = io.Copy(buf, src)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将验证过的文件内容和类型保存到上下文中
|
||||
c.Set("validated_file_"+file.Filename, buf)
|
||||
c.Set("validated_content_type_"+file.Filename, contentType)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含指定的字符串
|
||||
func contains(slice []string, str string) bool {
|
||||
for _, s := range slice {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetValidatedFile 从上下文中获取验证过的文件内容
|
||||
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
|
||||
file, exists := c.Get("validated_file_" + filename)
|
||||
if !exists {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
contentType, exists := c.Get("validated_content_type_" + filename)
|
||||
if !exists {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
return file.(*bytes.Buffer), contentType.(string), true
|
||||
}
|
262
backend/internal/middleware/upload_test.go
Normal file
262
backend/internal/middleware/upload_test.go
Normal file
|
@ -0,0 +1,262 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, bytes.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/upload", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.UploadConfig
|
||||
filename string
|
||||
content []byte
|
||||
setupRequest func(*testing.T) *http.Request
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Valid image upload",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5, // 5MB
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.jpg",
|
||||
content: []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
|
||||
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invalid file extension",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.txt",
|
||||
content: []byte("test content"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File extension .txt is not allowed",
|
||||
},
|
||||
{
|
||||
name: "File too large",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 1, // 1MB
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "large.jpg",
|
||||
content: make([]byte, 2<<20), // 2MB
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File large.jpg exceeds maximum size of 1 MB",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "fake.jpg",
|
||||
content: []byte("not a real image"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File type text/plain; charset=utf-8 is not allowed",
|
||||
},
|
||||
{
|
||||
name: "Missing file",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Failed to parse form",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type header",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest("POST", "/upload", nil)
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Content-Type must be multipart/form-data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if tt.setupRequest != nil {
|
||||
req = tt.setupRequest(t)
|
||||
} else {
|
||||
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
middleware := ValidateUpload(tt.config)
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
if tt.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(w.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response["error"], tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidatedFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
filename string
|
||||
expectedFound bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Get existing file",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 创建测试文件内容
|
||||
content := []byte("test content")
|
||||
buf := bytes.NewBuffer(content)
|
||||
|
||||
// 设置验证过的文件和内容类型
|
||||
c.Set("validated_file_test.txt", buf)
|
||||
c.Set("validated_content_type_test.txt", "text/plain")
|
||||
},
|
||||
filename: "test.txt",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "File not found",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置任何文件
|
||||
},
|
||||
filename: "nonexistent.txt",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setupContext != nil {
|
||||
tt.setupContext(c)
|
||||
}
|
||||
|
||||
buffer, contentType, found := GetValidatedFile(c, tt.filename)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
if tt.expectedFound {
|
||||
assert.NotNil(t, buffer)
|
||||
assert.NotEmpty(t, contentType)
|
||||
} else {
|
||||
assert.Nil(t, buffer)
|
||||
assert.Empty(t, contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "String found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "String not found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
str: "a",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.str)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
90
backend/internal/rbac/init.go
Normal file
90
backend/internal/rbac/init.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/role"
|
||||
)
|
||||
|
||||
// DefaultPermissions defines the default permissions for each resource
|
||||
var DefaultPermissions = map[string][]string{
|
||||
"media": {"create", "read", "update", "delete", "list"},
|
||||
"post": {"create", "read", "update", "delete", "list"},
|
||||
"daily": {"create", "read", "update", "delete", "list"},
|
||||
"user": {"create", "read", "update", "delete", "list"},
|
||||
}
|
||||
|
||||
// DefaultRoles defines the default roles and their permissions
|
||||
var DefaultRoles = map[string]map[string][]string{
|
||||
"admin": DefaultPermissions,
|
||||
"editor": {
|
||||
"media": {"create", "read", "update", "list"},
|
||||
"post": {"create", "read", "update", "list"},
|
||||
"daily": {"create", "read", "update", "list"},
|
||||
"user": {"read"},
|
||||
},
|
||||
"contributor": {
|
||||
"media": {"read", "list"},
|
||||
"post": {"read", "list"},
|
||||
"daily": {"read", "list"},
|
||||
},
|
||||
}
|
||||
|
||||
// InitializeRBAC initializes the RBAC system with default roles and permissions
|
||||
func InitializeRBAC(ctx context.Context, client *ent.Client) error {
|
||||
// Create permissions
|
||||
permissionMap := make(map[string]*ent.Permission)
|
||||
for resource, actions := range DefaultPermissions {
|
||||
for _, action := range actions {
|
||||
permission, err := client.Permission.Create().
|
||||
SetResource(resource).
|
||||
SetAction(action).
|
||||
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating permission: %w", err)
|
||||
}
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
permissionMap[key] = permission
|
||||
}
|
||||
}
|
||||
|
||||
// Create roles with permissions
|
||||
for roleName, permissions := range DefaultRoles {
|
||||
roleCreate := client.Role.Create().
|
||||
SetName(roleName).
|
||||
SetDescription(fmt.Sprintf("Role for %s users", roleName))
|
||||
|
||||
// Add permissions to role
|
||||
for resource, actions := range permissions {
|
||||
for _, action := range actions {
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
if permission, exists := permissionMap[key]; exists {
|
||||
roleCreate.AddPermissions(permission)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := roleCreate.Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed creating role %s: %w", roleName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignRoleToUser assigns a role to a user
|
||||
func AssignRoleToUser(ctx context.Context, client *ent.Client, userID int, roleName string) error {
|
||||
role, err := client.Role.Query().
|
||||
Where(role.Name(roleName)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed querying role: %w", err)
|
||||
}
|
||||
|
||||
return client.User.UpdateOneID(userID).
|
||||
AddRoles(role).
|
||||
Exec(ctx)
|
||||
}
|
98
backend/internal/rbac/init_test.go
Normal file
98
backend/internal/rbac/init_test.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent/enttest"
|
||||
"tss-rocks-be/ent/role"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestInitializeRBAC(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test initialization
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Verify roles were created
|
||||
for roleName := range DefaultRoles {
|
||||
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Role %s was not created: %v", roleName, err)
|
||||
}
|
||||
|
||||
// Verify permissions for each role
|
||||
perms, err := r.QueryPermissions().All(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
|
||||
}
|
||||
|
||||
expectedPerms := DefaultRoles[roleName]
|
||||
permCount := 0
|
||||
for _, actions := range expectedPerms {
|
||||
permCount += len(actions)
|
||||
}
|
||||
|
||||
if len(perms) != permCount {
|
||||
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignRoleToUser(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize RBAC
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user
|
||||
user, err := client.User.Create().
|
||||
SetEmail("test@example.com").
|
||||
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Test assigning role to user
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "editor")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign role to user: %v", err)
|
||||
}
|
||||
|
||||
// Verify role assignment
|
||||
assignedRoles, err := user.QueryRoles().All(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query user roles: %v", err)
|
||||
}
|
||||
|
||||
if len(assignedRoles) != 1 {
|
||||
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
|
||||
}
|
||||
|
||||
if assignedRoles[0].Name != "editor" {
|
||||
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
|
||||
}
|
||||
|
||||
// Test assigning non-existent role
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when assigning non-existent role, got nil")
|
||||
}
|
||||
}
|
24
backend/internal/server/database.go
Normal file
24
backend/internal/server/database.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tss-rocks-be/ent"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func InitDatabase(ctx context.Context, driver, dsn string) (*ent.Client, error) {
|
||||
client, err := ent.Open(driver, dsn)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed opening database connection")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("failed creating schema resources")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
64
backend/internal/server/database_test.go
Normal file
64
backend/internal/server/database_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitDatabase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
driver string
|
||||
dsn string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "success with sqlite3",
|
||||
driver: "sqlite3",
|
||||
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
{
|
||||
name: "invalid driver",
|
||||
driver: "invalid_driver",
|
||||
dsn: "file:ent?mode=memory",
|
||||
wantErr: true,
|
||||
errContains: "unsupported driver",
|
||||
},
|
||||
{
|
||||
name: "invalid dsn",
|
||||
driver: "sqlite3",
|
||||
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
|
||||
wantErr: true,
|
||||
errContains: "foreign_keys pragma is off",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, client)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 测试数据库连接是否正常工作
|
||||
err = client.Schema.Create(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
31
backend/internal/server/ent.go
Normal file
31
backend/internal/server/ent.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
)
|
||||
|
||||
// NewEntClient creates a new ent client
|
||||
func NewEntClient(cfg *config.Config) *ent.Client {
|
||||
// TODO: Implement database connection based on config
|
||||
// For now, we'll use SQLite for development
|
||||
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect to database")
|
||||
}
|
||||
|
||||
// Create ent client
|
||||
client := ent.NewClient(ent.Driver(db))
|
||||
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(context.Background()); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create schema resources")
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
40
backend/internal/server/ent_test.go
Normal file
40
backend/internal/server/ent_test.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tss-rocks-be/internal/config"
|
||||
)
|
||||
|
||||
func TestNewEntClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
}{
|
||||
{
|
||||
name: "default sqlite3 config",
|
||||
cfg: &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Driver: "sqlite3",
|
||||
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := NewEntClient(tt.cfg)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 验证客户端是否可以正常工作
|
||||
err := client.Schema.Create(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
})
|
||||
}
|
||||
}
|
90
backend/internal/server/server.go
Normal file
90
backend/internal/server/server.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/handler"
|
||||
"tss-rocks-be/internal/middleware"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/storage"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
router *gin.Engine
|
||||
handler *handler.Handler
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, client *ent.Client) (*Server, error) {
|
||||
// Initialize storage
|
||||
store, err := storage.NewStorage(context.Background(), &cfg.Storage)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize storage: %w", err)
|
||||
}
|
||||
|
||||
// Initialize service
|
||||
svc := service.NewService(client, store)
|
||||
|
||||
// Initialize RBAC
|
||||
if err := svc.InitializeRBAC(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize RBAC: %w", err)
|
||||
}
|
||||
|
||||
// Initialize handler
|
||||
h := handler.NewHandler(cfg, svc)
|
||||
|
||||
// Initialize router
|
||||
router := gin.Default()
|
||||
|
||||
// Add CORS middleware if needed
|
||||
router.Use(middleware.CORS())
|
||||
|
||||
// 添加全局中间件
|
||||
router.Use(gin.Logger())
|
||||
router.Use(gin.Recovery())
|
||||
router.Use(middleware.RateLimit(&cfg.RateLimit))
|
||||
|
||||
// 添加访问日志中间件
|
||||
accessLog, err := middleware.AccessLog(&cfg.AccessLog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize access log: %w", err)
|
||||
}
|
||||
router.Use(accessLog)
|
||||
|
||||
// 为上传路由添加文件验证中间件
|
||||
router.POST("/api/v1/media/upload", middleware.ValidateUpload(&cfg.Storage.Upload))
|
||||
|
||||
// Register routes
|
||||
h.RegisterRoutes(router)
|
||||
|
||||
return &Server{
|
||||
config: cfg,
|
||||
router: router,
|
||||
handler: h,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
|
||||
s.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.router,
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting server on %s", addr)
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.server != nil {
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
220
backend/internal/server/server_test.go
Normal file
220
backend/internal/server/server_test.go
Normal file
|
@ -0,0 +1,220 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/types"
|
||||
"tss-rocks-be/ent/enttest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
Upload: types.UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/v1/upload": {Rate: 10, Burst: 20},
|
||||
},
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "testdata/access.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 3,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 测试服务器初始化
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
assert.NotNil(t, s.router)
|
||||
assert.NotNil(t, s.handler)
|
||||
assert.Equal(t, cfg, s.config)
|
||||
}
|
||||
|
||||
func TestNew_StorageError(t *testing.T) {
|
||||
// 创建一个无效的存储配置
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{
|
||||
Type: "invalid_type", // 使用无效的存储类型
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
s, err := New(cfg, client)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
assert.Contains(t, err.Error(), "failed to initialize storage")
|
||||
}
|
||||
|
||||
func TestServer_StartAndShutdown(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 0, // 使用随机端口
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 初始化服务器
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 在 goroutine 中启动服务器
|
||||
go func() {
|
||||
err := s.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试关闭服务器
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = s.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_StartError(t *testing.T) {
|
||||
// 创建一个配置,使用已经被占用的端口来触发错误
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8899, // 使用固定端口以便测试
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 创建第一个服务器实例
|
||||
s1, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 启动第一个服务器
|
||||
go func() {
|
||||
err := s1.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 尝试在同一端口启动第二个服务器,应该会失败
|
||||
s2, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s2.Start()
|
||||
assert.Error(t, err)
|
||||
|
||||
// 清理
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 关闭第一个服务器
|
||||
err = s1.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查第一个服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 关闭第二个服务器
|
||||
err = s2.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_ShutdownWithNilServer(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.Shutdown(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
892
backend/internal/service/impl.go
Normal file
892
backend/internal/service/impl.go
Normal file
|
@ -0,0 +1,892 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"regexp"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/category"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/ent/contributor"
|
||||
"tss-rocks-be/ent/contributorsociallink"
|
||||
"tss-rocks-be/ent/daily"
|
||||
"tss-rocks-be/ent/dailycontent"
|
||||
"tss-rocks-be/ent/permission"
|
||||
"tss-rocks-be/ent/post"
|
||||
"tss-rocks-be/ent/postcontent"
|
||||
"tss-rocks-be/ent/role"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
)
|
||||
|
||||
// openFile is a variable that holds the Open method of multipart.FileHeader
|
||||
// This allows us to mock it in tests
|
||||
var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *multipart.FileHeader) (multipart.File, error) {
|
||||
return fh.Open()
|
||||
}
|
||||
|
||||
type serviceImpl struct {
|
||||
client *ent.Client
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
// NewService creates a new Service instance
|
||||
func NewService(client *ent.Client, storage storage.Storage) Service {
|
||||
return &serviceImpl{
|
||||
client: client,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// User operations
|
||||
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
|
||||
// Hash the password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Add the user role by default
|
||||
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user role: %w", err)
|
||||
}
|
||||
|
||||
// If a specific role is requested and it's not "user", get that role too
|
||||
var additionalRole *ent.Role
|
||||
if roleStr != "" && roleStr != "user" {
|
||||
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create user with password and user role
|
||||
userCreate := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash(string(hashedPassword)).
|
||||
AddRoles(userRole)
|
||||
|
||||
// Add the additional role if specified
|
||||
if additionalRole != nil {
|
||||
userCreate.AddRoles(additionalRole)
|
||||
}
|
||||
|
||||
// Save the user
|
||||
user, err := userCreate.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.EmailEQ(email)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %s", email)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Category operations
|
||||
func (s *serviceImpl) CreateCategory(ctx context.Context) (*ent.Category, error) {
|
||||
return s.client.Category.Create().Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error) {
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.CategoryContent.Create().
|
||||
SetCategoryID(categoryID).
|
||||
SetLanguageCode(languageCode).
|
||||
SetName(name).
|
||||
SetDescription(description).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error) {
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.Category.Query().
|
||||
Where(
|
||||
category.HasContentsWith(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugEQ(slug),
|
||||
),
|
||||
),
|
||||
).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(categorycontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
|
||||
// 转换语言代码
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
// 不支持的语言代码返回空列表而不是错误
|
||||
return []*ent.Category{}, nil
|
||||
}
|
||||
|
||||
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
|
||||
contents, err := s.client.CategoryContent.Query().
|
||||
Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
|
||||
categoryMap := make(map[int]*ent.Category)
|
||||
for _, content := range contents {
|
||||
if content.Edges.Category != nil {
|
||||
categoryMap[content.Edges.Category.ID] = content.Edges.Category
|
||||
}
|
||||
}
|
||||
|
||||
// 将 map 转换为有序的切片
|
||||
var categories []*ent.Category
|
||||
for _, cat := range categoryMap {
|
||||
// 重新查询分类以获取完整的关联数据
|
||||
c, err := s.client.Category.Query().
|
||||
Where(category.ID(cat.ID)).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
categories = append(categories, c)
|
||||
}
|
||||
|
||||
// 按 ID 排序以保持结果稳定
|
||||
sort.Slice(categories, func(i, j int) bool {
|
||||
return categories[i].ID < categories[j].ID
|
||||
})
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error) {
|
||||
// 转换语言代码
|
||||
var languageCode categorycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = categorycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = categorycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
// 不支持的语言代码返回空列表而不是错误
|
||||
return []*ent.Category{}, nil
|
||||
}
|
||||
|
||||
// 从 CategoryContent 表开始查询,确保只返回有指定语言内容的分类
|
||||
contents, err := s.client.CategoryContent.Query().
|
||||
Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用 map 去重,因为一个分类可能有多个相同语言的内容
|
||||
categoryMap := make(map[int]*ent.Category)
|
||||
for _, content := range contents {
|
||||
if content.Edges.Category != nil {
|
||||
categoryMap[content.Edges.Category.ID] = content.Edges.Category
|
||||
}
|
||||
}
|
||||
|
||||
// 将 map 转换为有序的切片
|
||||
var categories []*ent.Category
|
||||
for _, cat := range categoryMap {
|
||||
// 重新查询分类以获取完整的关联数据
|
||||
c, err := s.client.Category.Query().
|
||||
Where(category.ID(cat.ID)).
|
||||
WithContents(func(q *ent.CategoryContentQuery) {
|
||||
q.Where(
|
||||
categorycontent.And(
|
||||
categorycontent.LanguageCodeEQ(languageCode),
|
||||
categorycontent.SlugHasPrefix("category-list-"),
|
||||
),
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
categories = append(categories, c)
|
||||
}
|
||||
|
||||
// 按 ID 排序以保持结果稳定
|
||||
sort.Slice(categories, func(i, j int) bool {
|
||||
return categories[i].ID < categories[j].ID
|
||||
})
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// Daily operations
|
||||
func (s *serviceImpl) CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error) {
|
||||
_, err := s.client.Daily.Create().
|
||||
SetID(id).
|
||||
SetCategoryID(categoryID).
|
||||
SetImageURL(imageURL).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 加载 Category Edge
|
||||
return s.client.Daily.Query().
|
||||
Where(daily.IDEQ(id)).
|
||||
WithCategory().
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddDailyContent(ctx context.Context, dailyID string, langCode string, quote string) (*ent.DailyContent, error) {
|
||||
var languageCode dailycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = dailycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
return s.client.DailyContent.Create().
|
||||
SetDailyID(dailyID).
|
||||
SetLanguageCode(languageCode).
|
||||
SetQuote(quote).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) {
|
||||
return s.client.Daily.Query().
|
||||
Where(daily.IDEQ(id)).
|
||||
WithCategory().
|
||||
WithContents().
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListDailies(ctx context.Context, langCode string, categoryID *int, limit int, offset int) ([]*ent.Daily, error) {
|
||||
var languageCode dailycontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = dailycontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = dailycontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
query := s.client.Daily.Query().
|
||||
WithContents(func(q *ent.DailyContentQuery) {
|
||||
if langCode != "" {
|
||||
q.Where(dailycontent.LanguageCodeEQ(languageCode))
|
||||
}
|
||||
}).
|
||||
WithCategory()
|
||||
|
||||
if categoryID != nil {
|
||||
query.Where(daily.HasCategoryWith(category.ID(*categoryID)))
|
||||
}
|
||||
|
||||
query.Order(ent.Desc(daily.FieldCreatedAt))
|
||||
|
||||
if limit > 0 {
|
||||
query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query.Offset(offset)
|
||||
}
|
||||
|
||||
return query.All(ctx)
|
||||
}
|
||||
|
||||
// Media operations
|
||||
func (s *serviceImpl) ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
|
||||
return s.client.Media.Query().
|
||||
Order(ent.Desc("created_at")).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
All(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
|
||||
// Open the uploaded file
|
||||
src, err := openFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Save the file to storage
|
||||
fileInfo, err := s.storage.Save(ctx, file.Filename, file.Header.Get("Content-Type"), src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create media record
|
||||
return s.client.Media.Create().
|
||||
SetStorageID(fileInfo.ID).
|
||||
SetOriginalName(file.Filename).
|
||||
SetMimeType(fileInfo.ContentType).
|
||||
SetSize(fileInfo.Size).
|
||||
SetURL(fileInfo.URL).
|
||||
SetCreatedBy(strconv.Itoa(userID)).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetMedia(ctx context.Context, id int) (*ent.Media, error) {
|
||||
return s.client.Media.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
media, err := s.GetMedia(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return s.storage.Get(ctx, media.StorageID)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error {
|
||||
media, err := s.GetMedia(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
if media.CreatedBy != strconv.Itoa(userID) {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
return s.client.Media.DeleteOne(media).Exec(ctx)
|
||||
}
|
||||
|
||||
// Post operations
|
||||
func (s *serviceImpl) CreatePost(ctx context.Context, status string) (*ent.Post, error) {
|
||||
var postStatus post.Status
|
||||
switch status {
|
||||
case "draft":
|
||||
postStatus = post.StatusDraft
|
||||
case "published":
|
||||
postStatus = post.StatusPublished
|
||||
case "archived":
|
||||
postStatus = post.StatusArchived
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid status: %s", status)
|
||||
}
|
||||
|
||||
// Generate a random slug
|
||||
slug := fmt.Sprintf("post-%s", uuid.New().String()[:8])
|
||||
|
||||
return s.client.Post.Create().
|
||||
SetStatus(postStatus).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// Get the post first to check if it exists
|
||||
post, err := s.client.Post.Get(ctx, postID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
// Generate slug from title
|
||||
var slug string
|
||||
if langCode == "en" {
|
||||
// For English titles, convert to lowercase and replace spaces with dashes
|
||||
slug = strings.ToLower(strings.ReplaceAll(title, " ", "-"))
|
||||
// Remove all non-alphanumeric characters except dashes
|
||||
slug = regexp.MustCompile(`[^a-z0-9-]+`).ReplaceAllString(slug, "")
|
||||
// Ensure slug is not empty and has minimum length
|
||||
if slug == "" || len(slug) < 4 {
|
||||
slug = fmt.Sprintf("post-%s", uuid.NewString()[:8])
|
||||
}
|
||||
} else {
|
||||
// For Chinese titles, use the title as is
|
||||
slug = title
|
||||
}
|
||||
|
||||
return s.client.PostContent.Create().
|
||||
SetPost(post).
|
||||
SetLanguageCode(languageCode).
|
||||
SetTitle(title).
|
||||
SetContentMarkdown(content).
|
||||
SetSummary(summary).
|
||||
SetMetaKeywords(metaKeywords).
|
||||
SetMetaDescription(metaDescription).
|
||||
SetSlug(slug).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// Find posts that have content with the given slug and language code
|
||||
posts, err := s.client.Post.Query().
|
||||
Where(
|
||||
post.And(
|
||||
post.StatusEQ(post.StatusPublished),
|
||||
post.HasContentsWith(
|
||||
postcontent.And(
|
||||
postcontent.LanguageCodeEQ(languageCode),
|
||||
postcontent.SlugEQ(slug),
|
||||
),
|
||||
),
|
||||
),
|
||||
).
|
||||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get posts: %w", err)
|
||||
}
|
||||
|
||||
if len(posts) == 0 {
|
||||
return nil, fmt.Errorf("post not found")
|
||||
}
|
||||
if len(posts) > 1 {
|
||||
return nil, fmt.Errorf("multiple posts found with the same slug")
|
||||
}
|
||||
|
||||
return posts[0], nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) {
|
||||
var languageCode postcontent.LanguageCode
|
||||
switch langCode {
|
||||
case "en":
|
||||
languageCode = postcontent.LanguageCodeEN
|
||||
case "zh-Hans":
|
||||
languageCode = postcontent.LanguageCodeZH_HANS
|
||||
case "zh-Hant":
|
||||
languageCode = postcontent.LanguageCodeZH_HANT
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported language code: %s", langCode)
|
||||
}
|
||||
|
||||
// First find all post IDs that have content in the given language
|
||||
query := s.client.PostContent.Query().
|
||||
Where(postcontent.LanguageCodeEQ(languageCode)).
|
||||
QueryPost().
|
||||
Where(post.StatusEQ(post.StatusPublished))
|
||||
|
||||
// Add category filter if provided
|
||||
if categoryID != nil {
|
||||
query = query.Where(post.HasCategoryWith(category.ID(*categoryID)))
|
||||
}
|
||||
|
||||
// Get unique post IDs
|
||||
postIDs, err := query.
|
||||
Order(ent.Desc(post.FieldCreatedAt)).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get post IDs: %w", err)
|
||||
}
|
||||
|
||||
// Remove duplicates while preserving order
|
||||
seen := make(map[int]bool)
|
||||
uniqueIDs := make([]int, 0, len(postIDs))
|
||||
for _, id := range postIDs {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
}
|
||||
postIDs = uniqueIDs
|
||||
|
||||
if len(postIDs) == 0 {
|
||||
return []*ent.Post{}, nil
|
||||
}
|
||||
|
||||
// If no category filter is applied, only take the latest 5 posts
|
||||
if categoryID == nil && len(postIDs) > 5 {
|
||||
postIDs = postIDs[:5]
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if offset >= len(postIDs) {
|
||||
return []*ent.Post{}, nil
|
||||
}
|
||||
|
||||
// If limit is 0, set it to the length of postIDs
|
||||
if limit == 0 {
|
||||
limit = len(postIDs)
|
||||
}
|
||||
|
||||
// Adjust limit if it would exceed total
|
||||
if offset+limit > len(postIDs) {
|
||||
limit = len(postIDs) - offset
|
||||
}
|
||||
|
||||
// Get the paginated post IDs
|
||||
paginatedIDs := postIDs[offset : offset+limit]
|
||||
|
||||
// Get the posts with their contents
|
||||
posts, err := s.client.Post.Query().
|
||||
Where(post.IDIn(paginatedIDs...)).
|
||||
WithContents(func(q *ent.PostContentQuery) {
|
||||
q.Where(postcontent.LanguageCodeEQ(languageCode))
|
||||
}).
|
||||
WithCategory().
|
||||
Order(ent.Desc(post.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get posts: %w", err)
|
||||
}
|
||||
|
||||
// Sort posts by ID to match the order of postIDs
|
||||
sort.Slice(posts, func(i, j int) bool {
|
||||
// Find index of each post ID in postIDs
|
||||
var iIndex, jIndex int
|
||||
for idx, id := range paginatedIDs {
|
||||
if posts[i].ID == id {
|
||||
iIndex = idx
|
||||
}
|
||||
if posts[j].ID == id {
|
||||
jIndex = idx
|
||||
}
|
||||
}
|
||||
return iIndex < jIndex
|
||||
})
|
||||
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
// Contributor operations
|
||||
func (s *serviceImpl) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) {
|
||||
builder := s.client.Contributor.Create().
|
||||
SetName(name)
|
||||
|
||||
if avatarURL != nil {
|
||||
builder.SetAvatarURL(*avatarURL)
|
||||
}
|
||||
if bio != nil {
|
||||
builder.SetBio(*bio)
|
||||
}
|
||||
|
||||
contributor, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create contributor: %w", err)
|
||||
}
|
||||
|
||||
return contributor, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) {
|
||||
// 验证贡献者是否存在
|
||||
contributor, err := s.client.Contributor.Get(ctx, contributorID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get contributor: %w", err)
|
||||
}
|
||||
|
||||
// 创建社交链接
|
||||
link, err := s.client.ContributorSocialLink.Create().
|
||||
SetContributor(contributor).
|
||||
SetType(contributorsociallink.Type(linkType)).
|
||||
SetName(name).
|
||||
SetValue(value).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create social link: %w", err)
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error) {
|
||||
contributor, err := s.client.Contributor.Query().
|
||||
Where(contributor.ID(id)).
|
||||
WithSocialLinks().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get contributor: %w", err)
|
||||
}
|
||||
return contributor, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ListContributors(ctx context.Context) ([]*ent.Contributor, error) {
|
||||
contributors, err := s.client.Contributor.Query().
|
||||
WithSocialLinks().
|
||||
Order(ent.Asc(contributor.FieldName)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list contributors: %w", err)
|
||||
}
|
||||
return contributors, nil
|
||||
}
|
||||
|
||||
// RBAC operations
|
||||
func (s *serviceImpl) InitializeRBAC(ctx context.Context) error {
|
||||
// Create roles if they don't exist
|
||||
adminRole, err := s.client.Role.Create().SetName("admin").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
adminRole, err = s.client.Role.Query().Where(role.NameEQ("admin")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create admin role: %w", err)
|
||||
}
|
||||
|
||||
editorRole, err := s.client.Role.Create().SetName("editor").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
editorRole, err = s.client.Role.Query().Where(role.NameEQ("editor")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create editor role: %w", err)
|
||||
}
|
||||
|
||||
userRole, err := s.client.Role.Create().SetName("user").Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
userRole, err = s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user role: %w", err)
|
||||
}
|
||||
|
||||
// Define permissions
|
||||
permissions := []struct {
|
||||
role *ent.Role
|
||||
resource string
|
||||
actions []string
|
||||
}{
|
||||
// Admin permissions (full access)
|
||||
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
|
||||
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
// Editor permissions (can create and manage content)
|
||||
{editorRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{editorRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{editorRole, "categories", []string{"read"}},
|
||||
{editorRole, "contributors", []string{"read"}},
|
||||
{editorRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
// User permissions (read-only access)
|
||||
{userRole, "media", []string{"read"}},
|
||||
{userRole, "posts", []string{"read"}},
|
||||
{userRole, "categories", []string{"read"}},
|
||||
{userRole, "contributors", []string{"read"}},
|
||||
{userRole, "dailies", []string{"read"}},
|
||||
}
|
||||
|
||||
// Create permissions for each role
|
||||
for _, p := range permissions {
|
||||
for _, action := range p.actions {
|
||||
perm, err := s.client.Permission.Create().
|
||||
SetResource(p.resource).
|
||||
SetAction(action).
|
||||
Save(ctx)
|
||||
if ent.IsConstraintError(err) {
|
||||
perm, err = s.client.Permission.Query().
|
||||
Where(
|
||||
permission.ResourceEQ(p.resource),
|
||||
permission.ActionEQ(action),
|
||||
).
|
||||
Only(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create permission %s:%s: %w", p.resource, action, err)
|
||||
}
|
||||
|
||||
// Add permission to role
|
||||
err = s.client.Role.UpdateOne(p.role).
|
||||
AddPermissions(perm).
|
||||
Exec(ctx)
|
||||
if err != nil && !ent.IsConstraintError(err) {
|
||||
return fmt.Errorf("failed to add permission %s:%s to role %s: %w", p.resource, action, p.role.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) AssignRole(ctx context.Context, userID int, roleName string) error {
|
||||
user, err := s.client.User.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
return s.client.User.UpdateOne(user).AddRoles(role).Exec(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) RemoveRole(ctx context.Context, userID int, roleName string) error {
|
||||
// Don't allow removing the user role
|
||||
if roleName == "user" {
|
||||
return errors.New("cannot remove user role")
|
||||
}
|
||||
|
||||
user, err := s.client.User.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
role, err := s.client.Role.Query().Where(role.NameEQ(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
return s.client.User.UpdateOne(user).RemoveRoles(role).Exec(ctx)
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.ID(userID)).
|
||||
WithRoles().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return user.Edges.Roles, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission string) (bool, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.ID(userID)).
|
||||
WithRoles(func(q *ent.RoleQuery) {
|
||||
q.WithPermissions()
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(permission, ":")
|
||||
if len(parts) != 2 {
|
||||
return false, fmt.Errorf("invalid permission format: %s, expected format: resource:action", permission)
|
||||
}
|
||||
resource, action := parts[0], parts[1]
|
||||
|
||||
for _, r := range user.Edges.Roles {
|
||||
for _, p := range r.Edges.Permissions {
|
||||
if p.Resource == resource && p.Action == action {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
1092
backend/internal/service/impl_test.go
Normal file
1092
backend/internal/service/impl_test.go
Normal file
File diff suppressed because it is too large
Load diff
179
backend/internal/service/media.go
Normal file
179
backend/internal/service/media.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"path/filepath"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/pkg/imageutil"
|
||||
)
|
||||
|
||||
type MediaService interface {
|
||||
// Upload uploads a new file and creates a media record
|
||||
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
|
||||
|
||||
// Get retrieves a media file by ID
|
||||
Get(ctx context.Context, id int) (*ent.Media, error)
|
||||
|
||||
// Delete deletes a media file
|
||||
Delete(ctx context.Context, id int, userID int) error
|
||||
|
||||
// List lists media files with pagination
|
||||
List(ctx context.Context, limit, offset int) ([]*ent.Media, error)
|
||||
|
||||
// GetFile gets the file content and info
|
||||
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
|
||||
}
|
||||
|
||||
type mediaService struct {
|
||||
client *ent.Client
|
||||
storage storage.Storage
|
||||
}
|
||||
|
||||
func NewMediaService(client *ent.Client, storage storage.Storage) MediaService {
|
||||
return &mediaService{
|
||||
client: client,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// isValidFilename checks if a filename is valid
|
||||
func isValidFilename(filename string) bool {
|
||||
// Check for illegal characters
|
||||
if strings.Contains(filename, "../") ||
|
||||
strings.Contains(filename, "./") ||
|
||||
strings.Contains(filename, "\\") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *mediaService) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) {
|
||||
// Validate filename
|
||||
if !isValidFilename(file.Filename) {
|
||||
return nil, fmt.Errorf("invalid filename: %s", file.Filename)
|
||||
}
|
||||
|
||||
// Open the file
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Read file content for processing
|
||||
fileBytes, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
filename := file.Filename
|
||||
var processedBytes []byte
|
||||
|
||||
// Process image if it's an image file
|
||||
if imageutil.IsImageFormat(contentType) {
|
||||
opts := imageutil.DefaultOptions()
|
||||
processedBytes, err = imageutil.ProcessImage(bytes.NewReader(fileBytes), opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process image: %w", err)
|
||||
}
|
||||
|
||||
// Update content type and filename for WebP
|
||||
contentType = "image/webp"
|
||||
filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + ".webp"
|
||||
} else {
|
||||
processedBytes = fileBytes
|
||||
}
|
||||
|
||||
// Save the processed file
|
||||
fileInfo, err := s.storage.Save(ctx, filename, contentType, bytes.NewReader(processedBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
|
||||
// Create media record in database
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID(fileInfo.ID).
|
||||
SetOriginalName(filename).
|
||||
SetMimeType(contentType).
|
||||
SetSize(int64(len(processedBytes))).
|
||||
SetURL(fmt.Sprintf("/api/media/%s", fileInfo.ID)).
|
||||
SetCreatedBy(fmt.Sprint(userID)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
// Try to cleanup the stored file if database operation fails
|
||||
_ = s.storage.Delete(ctx, fileInfo.ID)
|
||||
return nil, fmt.Errorf("failed to create media record: %w", err)
|
||||
}
|
||||
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) Get(ctx context.Context, id int) (*ent.Media, error) {
|
||||
media, err := s.client.Media.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("media not found: %d", id)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get media: %w", err)
|
||||
}
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) Delete(ctx context.Context, id int, userID int) error {
|
||||
media, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
if media.CreatedBy != fmt.Sprintf("%d", userID) {
|
||||
return fmt.Errorf("unauthorized to delete media")
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
if err := s.storage.Delete(ctx, media.StorageID); err != nil {
|
||||
return fmt.Errorf("failed to delete file from storage: %w", err)
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
if err := s.client.Media.DeleteOne(media).Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete media record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mediaService) List(ctx context.Context, limit, offset int) ([]*ent.Media, error) {
|
||||
media, err := s.client.Media.Query().
|
||||
Order(ent.Desc("created_at")).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list media: %w", err)
|
||||
}
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (s *mediaService) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) {
|
||||
media, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
reader, info, err := s.storage.Get(ctx, media.StorageID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get file from storage: %w", err)
|
||||
}
|
||||
|
||||
return reader, info, nil
|
||||
}
|
332
backend/internal/service/media_test.go
Normal file
332
backend/internal/service/media_test.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/internal/storage/mock"
|
||||
"tss-rocks-be/internal/testutil"
|
||||
|
||||
"bou.ke/monkey"
|
||||
)
|
||||
|
||||
type MediaServiceTestSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *ent.Client
|
||||
storage *mock.MockStorage
|
||||
ctrl *gomock.Controller
|
||||
svc MediaService
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testutil.NewTestClient()
|
||||
require.NotNil(s.T(), s.client)
|
||||
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storage = mock.NewMockStorage(s.ctrl)
|
||||
s.svc = NewMediaService(s.client, s.storage)
|
||||
|
||||
// 清理数据库
|
||||
_, err := s.client.Media.Delete().Exec(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
s.client.Close()
|
||||
}
|
||||
|
||||
func TestMediaServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaServiceTestSuite))
|
||||
}
|
||||
|
||||
type mockFileHeader struct {
|
||||
filename string
|
||||
contentType string
|
||||
size int64
|
||||
content []byte
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Open() (multipart.File, error) {
|
||||
return newMockMultipartFile(h.content), nil
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Filename() string {
|
||||
return h.filename
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Size() int64 {
|
||||
return h.size
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Header() textproto.MIMEHeader {
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Type", h.contentType)
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
|
||||
header := &multipart.FileHeader{
|
||||
Filename: filename,
|
||||
Header: make(textproto.MIMEHeader),
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
header.Header.Set("Content-Type", contentType)
|
||||
|
||||
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
|
||||
return newMockMultipartFile(content), nil
|
||||
})
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestUpload() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
contentType string
|
||||
content []byte
|
||||
setupMock func()
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Upload text file",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||||
content, err := io.ReadAll(reader)
|
||||
s.Require().NoError(err)
|
||||
s.Equal([]byte("test content"), content)
|
||||
return &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: int64(len(content)),
|
||||
}, nil
|
||||
})
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename",
|
||||
filename: "../test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {},
|
||||
wantErr: true,
|
||||
errMsg: "invalid filename",
|
||||
},
|
||||
{
|
||||
name: "Storage error",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
Return(nil, fmt.Errorf("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "storage error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create test file
|
||||
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
|
||||
|
||||
// Add debug output
|
||||
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
|
||||
|
||||
// Test upload
|
||||
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
|
||||
|
||||
// Add debug output
|
||||
if err != nil {
|
||||
s.T().Logf("Upload error: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), tc.errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(media)
|
||||
s.Equal(tc.filename, media.OriginalName)
|
||||
s.Equal(tc.contentType, media.MimeType)
|
||||
s.Equal(int64(len(tc.content)), media.Size)
|
||||
s.Equal("1", media.CreatedBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGet() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test get existing media
|
||||
result, err := s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(media.ID, result.ID)
|
||||
s.Equal(media.OriginalName, result.OriginalName)
|
||||
|
||||
// Test get non-existing media
|
||||
_, err = s.svc.Get(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "media not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestDelete() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test delete by unauthorized user
|
||||
err = s.svc.Delete(s.ctx, media.ID, 2)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "unauthorized")
|
||||
|
||||
// Test delete by owner
|
||||
s.storage.EXPECT().
|
||||
Delete(gomock.Any(), "test-id").
|
||||
Return(nil)
|
||||
err = s.svc.Delete(s.ctx, media.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify media is deleted
|
||||
_, err = s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestList() {
|
||||
// Create test media
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := s.client.Media.Create().
|
||||
SetStorageID(fmt.Sprintf("test-id-%d", i)).
|
||||
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test list with limit and offset
|
||||
media, err := s.svc.List(s.ctx, 3, 1)
|
||||
s.Require().NoError(err)
|
||||
s.Len(media, 3)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGetFile() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Mock storage.Get
|
||||
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
|
||||
mockFileInfo := &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: 12,
|
||||
}
|
||||
s.storage.EXPECT().
|
||||
Get(gomock.Any(), "test-id").
|
||||
Return(mockReader, mockFileInfo, nil)
|
||||
|
||||
// Test get file
|
||||
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(reader)
|
||||
s.Equal(mockFileInfo, info)
|
||||
|
||||
// Test get non-existing file
|
||||
_, _, err = s.svc.GetFile(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestIsValidFilename() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Valid filename",
|
||||
filename: "test.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ../",
|
||||
filename: "../test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ./",
|
||||
filename: "./test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with backslash",
|
||||
filename: "test\\file.txt",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
got := isValidFilename(tc.filename)
|
||||
s.Equal(tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
3
backend/internal/service/mock/mock.go
Normal file
3
backend/internal/service/mock/mock.go
Normal file
|
@ -0,0 +1,3 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -source=../service.go -destination=mock_service.go -package=mock
|
105
backend/internal/service/rbac_service.go
Normal file
105
backend/internal/service/rbac_service.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/permission"
|
||||
"tss-rocks-be/ent/role"
|
||||
)
|
||||
|
||||
type RBACService struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewRBACService(client *ent.Client) *RBACService {
|
||||
return &RBACService{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeRBAC sets up the initial RBAC configuration
|
||||
func (s *RBACService) InitializeRBAC(ctx context.Context) error {
|
||||
// Create admin role if it doesn't exist
|
||||
adminRole, err := s.client.Role.Query().
|
||||
Where(role.Name("admin")).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
adminRole, err = s.client.Role.Create().
|
||||
SetName("admin").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create admin role: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query admin role: %w", err)
|
||||
}
|
||||
|
||||
// Create editor role if it doesn't exist
|
||||
editorRole, err := s.client.Role.Query().
|
||||
Where(role.Name("editor")).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
editorRole, err = s.client.Role.Create().
|
||||
SetName("editor").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create editor role: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query editor role: %w", err)
|
||||
}
|
||||
|
||||
// Define permissions
|
||||
permissions := []struct {
|
||||
role *ent.Role
|
||||
resource string
|
||||
actions []string
|
||||
}{
|
||||
{adminRole, "users", []string{"create", "read", "update", "delete", "assign_role"}},
|
||||
{adminRole, "roles", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "media", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "posts", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "categories", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "contributors", []string{"create", "read", "update", "delete"}},
|
||||
{adminRole, "dailies", []string{"create", "read", "update", "delete"}},
|
||||
|
||||
{editorRole, "media", []string{"create", "read", "update"}},
|
||||
{editorRole, "posts", []string{"create", "read", "update"}},
|
||||
{editorRole, "categories", []string{"read"}},
|
||||
{editorRole, "contributors", []string{"read"}},
|
||||
{editorRole, "dailies", []string{"create", "read", "update"}},
|
||||
}
|
||||
|
||||
// Create permissions for each role
|
||||
for _, p := range permissions {
|
||||
for _, action := range p.actions {
|
||||
// Check if permission already exists
|
||||
exists, err := s.client.Permission.Query().
|
||||
Where(
|
||||
permission.Resource(p.resource),
|
||||
permission.Action(action),
|
||||
permission.HasRolesWith(role.ID(p.role.ID)),
|
||||
).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query permission: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Create permission and associate it with the role
|
||||
_, err = s.client.Permission.Create().
|
||||
SetResource(p.resource).
|
||||
SetAction(action).
|
||||
AddRoles(p.role).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create permission: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
59
backend/internal/service/service.go
Normal file
59
backend/internal/service/service.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package service
|
||||
|
||||
//go:generate mockgen -source=service.go -destination=mock/mock_service.go -package=mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
)
|
||||
|
||||
// Service interface defines all business logic operations
|
||||
type Service interface {
|
||||
// User operations
|
||||
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
|
||||
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
|
||||
|
||||
// Category operations
|
||||
CreateCategory(ctx context.Context) (*ent.Category, error)
|
||||
AddCategoryContent(ctx context.Context, categoryID int, langCode, name, description, slug string) (*ent.CategoryContent, error)
|
||||
GetCategoryBySlug(ctx context.Context, langCode, slug string) (*ent.Category, error)
|
||||
ListCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
|
||||
GetCategories(ctx context.Context, langCode string) ([]*ent.Category, error)
|
||||
|
||||
// Post operations
|
||||
CreatePost(ctx context.Context, status string) (*ent.Post, error)
|
||||
AddPostContent(ctx context.Context, postID int, langCode, title, content, summary string, metaKeywords, metaDescription string) (*ent.PostContent, error)
|
||||
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
|
||||
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
|
||||
|
||||
// Contributor operations
|
||||
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
|
||||
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
|
||||
GetContributorByID(ctx context.Context, id int) (*ent.Contributor, error)
|
||||
ListContributors(ctx context.Context) ([]*ent.Contributor, error)
|
||||
|
||||
// Daily operations
|
||||
CreateDaily(ctx context.Context, id string, categoryID int, imageURL string) (*ent.Daily, error)
|
||||
AddDailyContent(ctx context.Context, dailyID string, langCode, quote string) (*ent.DailyContent, error)
|
||||
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
|
||||
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
|
||||
|
||||
// Media operations
|
||||
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
|
||||
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
|
||||
GetMedia(ctx context.Context, id int) (*ent.Media, error)
|
||||
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
|
||||
DeleteMedia(ctx context.Context, id int, userID int) error
|
||||
|
||||
// RBAC operations
|
||||
AssignRole(ctx context.Context, userID int, role string) error
|
||||
RemoveRole(ctx context.Context, userID int, role string) error
|
||||
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
|
||||
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
|
||||
InitializeRBAC(ctx context.Context) error
|
||||
}
|
66
backend/internal/storage/factory.go
Normal file
66
backend/internal/storage/factory.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"tss-rocks-be/internal/config"
|
||||
)
|
||||
|
||||
// NewStorage creates a new storage instance based on the configuration
|
||||
func NewStorage(ctx context.Context, cfg *config.StorageConfig) (Storage, error) {
|
||||
switch cfg.Type {
|
||||
case "local":
|
||||
return NewLocalStorage(cfg.Local.RootDir)
|
||||
case "s3":
|
||||
// Load AWS configuration
|
||||
var s3Client *s3.Client
|
||||
|
||||
if cfg.S3.Endpoint != "" {
|
||||
// Custom endpoint (e.g., MinIO)
|
||||
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
|
||||
return aws.Endpoint{
|
||||
URL: cfg.S3.Endpoint,
|
||||
}, nil
|
||||
})
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(cfg.S3.Region),
|
||||
awsconfig.WithEndpointResolverWithOptions(customResolver),
|
||||
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
cfg.S3.AccessKeyID,
|
||||
cfg.S3.SecretAccessKey,
|
||||
"",
|
||||
)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load AWS SDK config: %w", err)
|
||||
}
|
||||
|
||||
s3Client = s3.NewFromConfig(awsCfg)
|
||||
} else {
|
||||
// Standard AWS S3
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(cfg.S3.Region),
|
||||
awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
cfg.S3.AccessKeyID,
|
||||
cfg.S3.SecretAccessKey,
|
||||
"",
|
||||
)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load AWS SDK config: %w", err)
|
||||
}
|
||||
|
||||
s3Client = s3.NewFromConfig(awsCfg)
|
||||
}
|
||||
|
||||
return NewS3Storage(s3Client, cfg.S3.Bucket, cfg.S3.CustomURL, cfg.S3.ProxyS3), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", cfg.Type)
|
||||
}
|
||||
}
|
260
backend/internal/storage/local.go
Normal file
260
backend/internal/storage/local.go
Normal file
|
@ -0,0 +1,260 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
rootDir string
|
||||
metaDir string
|
||||
}
|
||||
|
||||
func NewLocalStorage(rootDir string) (*LocalStorage, error) {
|
||||
// Ensure the root directory exists
|
||||
if err := os.MkdirAll(rootDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create root directory: %w", err)
|
||||
}
|
||||
|
||||
// Create metadata directory
|
||||
metaDir := filepath.Join(rootDir, ".meta")
|
||||
if err := os.MkdirAll(metaDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create metadata directory: %w", err)
|
||||
}
|
||||
|
||||
return &LocalStorage{
|
||||
rootDir: rootDir,
|
||||
metaDir: metaDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) generateID() (string, error) {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) saveMetadata(id string, info *FileInfo) error {
|
||||
metaPath := filepath.Join(s.metaDir, id+".meta")
|
||||
file, err := os.Create(metaPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create metadata file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data := fmt.Sprintf("%s\n%s", info.Name, info.ContentType)
|
||||
if _, err := file.WriteString(data); err != nil {
|
||||
return fmt.Errorf("failed to write metadata: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) loadMetadata(id string) (string, string, error) {
|
||||
metaPath := filepath.Join(s.metaDir, id+".meta")
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return id, "", nil // Return ID as name if metadata doesn't exist
|
||||
}
|
||||
return "", "", fmt.Errorf("failed to read metadata: %w", err)
|
||||
}
|
||||
|
||||
parts := bytes.Split(data, []byte("\n"))
|
||||
name := string(parts[0])
|
||||
contentType := ""
|
||||
if len(parts) > 1 {
|
||||
contentType = string(parts[1])
|
||||
}
|
||||
return name, contentType, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
|
||||
if reader == nil {
|
||||
return nil, fmt.Errorf("reader cannot be nil")
|
||||
}
|
||||
|
||||
// Generate a unique ID for the file
|
||||
id, err := s.generateID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate file ID: %w", err)
|
||||
}
|
||||
|
||||
// Create the file path
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
|
||||
// Create the file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Copy the content
|
||||
size, err := io.Copy(file, reader)
|
||||
if err != nil {
|
||||
// Clean up the file if there's an error
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("failed to write file content: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: size,
|
||||
ContentType: contentType,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
URL: fmt.Sprintf("/api/media/file/%s", id),
|
||||
}
|
||||
|
||||
// Save metadata
|
||||
if err := s.saveMetadata(id, info); err != nil {
|
||||
os.Remove(filePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
|
||||
// Open the file
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil, fmt.Errorf("file not found: %s", id)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
// Get file info
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, nil, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
// Load metadata
|
||||
name, contentType, err := s.loadMetadata(id)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: stat.Size(),
|
||||
ContentType: contentType,
|
||||
CreatedAt: stat.ModTime(),
|
||||
UpdatedAt: stat.ModTime(),
|
||||
URL: fmt.Sprintf("/api/media/file/%s", id),
|
||||
}
|
||||
|
||||
return file, info, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Delete(ctx context.Context, id string) error {
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("file not found: %s", id)
|
||||
}
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
|
||||
// Remove metadata
|
||||
metaPath := filepath.Join(s.metaDir, id+".meta")
|
||||
if err := os.Remove(metaPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove metadata: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
|
||||
var files []*FileInfo
|
||||
var count int
|
||||
|
||||
err := filepath.Walk(s.rootDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories and metadata directory
|
||||
if info.IsDir() || path == s.metaDir {
|
||||
if path == s.metaDir {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the file ID (basename of the path)
|
||||
id := filepath.Base(path)
|
||||
|
||||
// Load metadata to get the original name
|
||||
name, contentType, err := s.loadMetadata(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip files that don't match the prefix
|
||||
if prefix != "" && !strings.HasPrefix(name, prefix) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip files before offset
|
||||
if count < offset {
|
||||
count++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop if we've reached the limit
|
||||
if limit > 0 && len(files) >= limit {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
files = append(files, &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: info.Size(),
|
||||
ContentType: contentType,
|
||||
CreatedAt: info.ModTime(),
|
||||
UpdatedAt: info.ModTime(),
|
||||
URL: fmt.Sprintf("/api/media/file/%s", id),
|
||||
})
|
||||
|
||||
count++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files: %w", err)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Exists(ctx context.Context, id string) (bool, error) {
|
||||
filePath := filepath.Join(s.rootDir, id)
|
||||
_, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check file existence: %w", err)
|
||||
}
|
154
backend/internal/storage/local_test.go
Normal file
154
backend/internal/storage/local_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLocalStorage(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "storage_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a new LocalStorage instance
|
||||
storage, err := NewLocalStorage(tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Save and Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Save the file
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, int64(len(content)), fileInfo.Size)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
assert.False(t, fileInfo.CreatedAt.IsZero())
|
||||
|
||||
// Get the file
|
||||
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, fileInfo.ID, info.ID)
|
||||
assert.Equal(t, fileInfo.Name, info.Name)
|
||||
assert.Equal(t, fileInfo.Size, info.Size)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
// Clear the directory first
|
||||
dirEntries, err := os.ReadDir(tempDir)
|
||||
require.NoError(t, err)
|
||||
for _, entry := range dirEntries {
|
||||
if entry.Name() != ".meta" {
|
||||
os.Remove(filepath.Join(tempDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
// Save multiple files
|
||||
testFiles := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{"test1.txt", "content1"},
|
||||
{"test2.txt", "content2"},
|
||||
{"other.txt", "content3"},
|
||||
}
|
||||
|
||||
for _, f := range testFiles {
|
||||
reader := bytes.NewReader([]byte(f.content))
|
||||
_, err := storage.Save(ctx, f.name, "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// List all files
|
||||
allFiles, err := storage.List(ctx, "", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, allFiles, 3)
|
||||
|
||||
// List files with prefix
|
||||
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, filesWithPrefix, 2)
|
||||
for _, f := range filesWithPrefix {
|
||||
assert.True(t, strings.HasPrefix(f.Name, "test"))
|
||||
}
|
||||
|
||||
// Test pagination
|
||||
pagedFiles, err := storage.List(ctx, "", 2, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pagedFiles, 2)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check if file exists
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Check non-existent file
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the file
|
||||
err = storage.Delete(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file is deleted
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid operations", func(t *testing.T) {
|
||||
// Try to get non-existent file
|
||||
_, _, err := storage.Get(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
|
||||
// Try to save file with nil reader
|
||||
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "reader cannot be nil")
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
})
|
||||
}
|
232
backend/internal/storage/s3.go
Normal file
232
backend/internal/storage/s3.go
Normal file
|
@ -0,0 +1,232 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
)
|
||||
|
||||
type S3Storage struct {
|
||||
client s3Client
|
||||
bucket string
|
||||
customURL string
|
||||
proxyS3 bool
|
||||
}
|
||||
|
||||
// s3Client is the interface that wraps the basic S3 client operations we need
|
||||
type s3Client interface {
|
||||
PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error)
|
||||
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
|
||||
DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error)
|
||||
ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
|
||||
HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
|
||||
}
|
||||
|
||||
func NewS3Storage(client s3Client, bucket string, customURL string, proxyS3 bool) *S3Storage {
|
||||
return &S3Storage{
|
||||
client: client,
|
||||
bucket: bucket,
|
||||
customURL: customURL,
|
||||
proxyS3: proxyS3,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3Storage) generateID() (string, error) {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) getObjectURL(id string) string {
|
||||
if s.customURL != "" {
|
||||
return fmt.Sprintf("%s/%s", strings.TrimRight(s.customURL, "/"), id)
|
||||
}
|
||||
if s.proxyS3 {
|
||||
return fmt.Sprintf("/api/media/file/%s", id)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.s3.amazonaws.com/%s", s.bucket, id)
|
||||
}
|
||||
|
||||
func (s *S3Storage) Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error) {
|
||||
// Generate a unique ID for the file
|
||||
id, err := s.generateID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate file ID: %w", err)
|
||||
}
|
||||
|
||||
// Check if the file exists
|
||||
_, err = s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("file already exists with ID: %s", id)
|
||||
}
|
||||
|
||||
var noSuchKey *types.NoSuchKey
|
||||
if !errors.As(err, &noSuchKey) {
|
||||
return nil, fmt.Errorf("failed to check if file exists: %w", err)
|
||||
}
|
||||
|
||||
// Upload the file
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
Body: reader,
|
||||
ContentType: aws.String(contentType),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Size: 0, // Size is not available until after upload
|
||||
ContentType: contentType,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
URL: s.getObjectURL(id),
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error) {
|
||||
// Get the object from S3
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get file from S3: %w", err)
|
||||
}
|
||||
|
||||
info := &FileInfo{
|
||||
ID: id,
|
||||
Name: result.Metadata["x-amz-meta-original-name"],
|
||||
Size: aws.ToInt64(result.ContentLength),
|
||||
ContentType: aws.ToString(result.ContentType),
|
||||
CreatedAt: aws.ToTime(result.LastModified),
|
||||
UpdatedAt: aws.ToTime(result.LastModified),
|
||||
URL: s.getObjectURL(id),
|
||||
}
|
||||
|
||||
return result.Body, info, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from S3: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error) {
|
||||
var files []*FileInfo
|
||||
var continuationToken *string
|
||||
|
||||
// Skip objects for offset
|
||||
for i := 0; i < offset/1000; i++ {
|
||||
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(prefix),
|
||||
ContinuationToken: continuationToken,
|
||||
MaxKeys: aws.Int32(1000),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files from S3: %w", err)
|
||||
}
|
||||
if !aws.ToBool(output.IsTruncated) {
|
||||
return files, nil
|
||||
}
|
||||
continuationToken = output.NextContinuationToken
|
||||
}
|
||||
|
||||
// Get the actual objects
|
||||
output, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(prefix),
|
||||
ContinuationToken: continuationToken,
|
||||
MaxKeys: aws.Int32(int32(limit)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files from S3: %w", err)
|
||||
}
|
||||
|
||||
for _, obj := range output.Contents {
|
||||
// Get the object metadata
|
||||
head, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: obj.Key,
|
||||
})
|
||||
|
||||
var contentType string
|
||||
var originalName string
|
||||
|
||||
if err != nil {
|
||||
var noSuchKey *types.NoSuchKey
|
||||
if errors.As(err, &noSuchKey) {
|
||||
// If the object doesn't exist (which shouldn't happen normally),
|
||||
// we'll still include it in the list but with empty metadata
|
||||
contentType = ""
|
||||
originalName = aws.ToString(obj.Key)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
contentType = aws.ToString(head.ContentType)
|
||||
originalName = head.Metadata["x-amz-meta-original-name"]
|
||||
if originalName == "" {
|
||||
originalName = aws.ToString(obj.Key)
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, &FileInfo{
|
||||
ID: aws.ToString(obj.Key),
|
||||
Name: originalName,
|
||||
Size: aws.ToInt64(obj.Size),
|
||||
ContentType: contentType,
|
||||
CreatedAt: aws.ToTime(obj.LastModified),
|
||||
UpdatedAt: aws.ToTime(obj.LastModified),
|
||||
URL: s.getObjectURL(aws.ToString(obj.Key)),
|
||||
})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Exists(ctx context.Context, id string) (bool, error) {
|
||||
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(id),
|
||||
})
|
||||
if err != nil {
|
||||
var nsk *types.NoSuchKey
|
||||
if ok := errors.As(err, &nsk); ok {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check file existence in S3: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
211
backend/internal/storage/s3_test.go
Normal file
211
backend/internal/storage/s3_test.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockS3Client is a mock implementation of the S3 client interface
|
||||
type MockS3Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func TestS3Storage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockClient := new(MockS3Client)
|
||||
storage := NewS3Storage(mockClient, "test-bucket", "", false)
|
||||
|
||||
t.Run("Save", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Mock HeadObject to return NotFound error
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
})
|
||||
|
||||
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.ContentType) == "text/plain"
|
||||
})).Return(&s3.PutObjectOutput{}, nil)
|
||||
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.GetObjectOutput{
|
||||
Body: io.NopCloser(bytes.NewReader(content)),
|
||||
ContentType: aws.String("text/plain"),
|
||||
ContentLength: aws.Int64(int64(len(content))),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
}, nil)
|
||||
|
||||
readCloser, info, err := storage.Get(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, "test-id", info.ID)
|
||||
assert.Equal(t, int64(len(content)), info.Size)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Prefix) == "test" &&
|
||||
aws.ToInt32(input.MaxKeys) == 10
|
||||
})).Return(&s3.ListObjectsV2Output{
|
||||
Contents: []types.Object{
|
||||
{
|
||||
Key: aws.String("test1"),
|
||||
Size: aws.Int64(100),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
{
|
||||
Key: aws.String("test2"),
|
||||
Size: aws.Int64(200),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
// Mock HeadObject for both files
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test1"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test1.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test2"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test2.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
files, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, files, 2)
|
||||
assert.Equal(t, "test1", files[0].ID)
|
||||
assert.Equal(t, int64(100), files[0].Size)
|
||||
assert.Equal(t, "test1.txt", files[0].Name)
|
||||
assert.Equal(t, "text/plain", files[0].ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.DeleteObjectOutput{}, nil)
|
||||
|
||||
err := storage.Delete(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
// Mock HeadObject for existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.HeadObjectOutput{}, nil).Once()
|
||||
|
||||
exists, err := storage.Exists(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Mock HeadObject for non-existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "non-existent"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
}).Once()
|
||||
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom URL", func(t *testing.T) {
|
||||
customStorage := &S3Storage{
|
||||
client: mockClient,
|
||||
bucket: "test-bucket",
|
||||
customURL: "https://custom.domain",
|
||||
proxyS3: true,
|
||||
}
|
||||
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
|
||||
})
|
||||
}
|
38
backend/internal/storage/storage.go
Normal file
38
backend/internal/storage/storage.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package storage
|
||||
|
||||
//go:generate mockgen -source=storage.go -destination=mock/mock_storage.go -package=mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileInfo represents metadata about a stored file
|
||||
type FileInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ContentType string `json:"content_type"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// Storage defines the interface for file storage operations
|
||||
type Storage interface {
|
||||
// Save stores a file and returns its FileInfo
|
||||
Save(ctx context.Context, name string, contentType string, reader io.Reader) (*FileInfo, error)
|
||||
|
||||
// Get retrieves a file by its ID
|
||||
Get(ctx context.Context, id string) (io.ReadCloser, *FileInfo, error)
|
||||
|
||||
// Delete removes a file by its ID
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// List returns a list of files with optional prefix
|
||||
List(ctx context.Context, prefix string, limit int, offset int) ([]*FileInfo, error)
|
||||
|
||||
// Exists checks if a file exists
|
||||
Exists(ctx context.Context, id string) (bool, error)
|
||||
}
|
57
backend/internal/testutil/db.go
Normal file
57
backend/internal/testutil/db.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
)
|
||||
|
||||
// SetupTestDB creates a new test database and returns a client
|
||||
func SetupTestDB(t *testing.T) *ent.Client {
|
||||
// Create a temporary SQLite database for testing
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "test.db")
|
||||
|
||||
client, err := ent.Open(dialect.SQLite, "file:"+dbPath+"?mode=memory&cache=shared&_fk=1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the auto migration tool
|
||||
err = client.Schema.Create(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Clean up the database after the test
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
os.Remove(dbPath)
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// CleanupTestDB removes all data from the test database
|
||||
func CleanupTestDB(t *testing.T, client *ent.Client) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Delete all data in reverse order of dependencies
|
||||
_, err := client.Permission.Delete().Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Role.Delete().Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.User.Delete().Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// IsSQLiteConstraintError checks if the error is a SQLite constraint error
|
||||
func IsSQLiteConstraintError(err error) bool {
|
||||
sqliteErr, ok := err.(sqlite3.Error)
|
||||
return ok && sqliteErr.Code == sqlite3.ErrConstraint
|
||||
}
|
32
backend/internal/testutil/mock.go
Normal file
32
backend/internal/testutil/mock.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockReadCloser is a mock implementation of io.ReadCloser
|
||||
type MockReadCloser struct {
|
||||
io.Reader
|
||||
CloseFunc func() error
|
||||
}
|
||||
|
||||
func (m MockReadCloser) Close() error {
|
||||
if m.CloseFunc != nil {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewMockReadCloser creates a new MockReadCloser with the given content
|
||||
func NewMockReadCloser(content string) io.ReadCloser {
|
||||
return MockReadCloser{Reader: strings.NewReader(content)}
|
||||
}
|
||||
|
||||
// RequireMockEquals asserts that two mocks are equal
|
||||
func RequireMockEquals(t *testing.T, expected, actual interface{}) {
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
70
backend/internal/testutil/testutil.go
Normal file
70
backend/internal/testutil/testutil.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/enttest"
|
||||
)
|
||||
|
||||
// SetupTestRouter returns a new Gin engine for testing
|
||||
func SetupTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
return gin.New()
|
||||
}
|
||||
|
||||
// NewTestClient creates a new ent client for testing
|
||||
func NewTestClient() *ent.Client {
|
||||
client := enttest.Open(testing.TB(nil), "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
return client
|
||||
}
|
||||
|
||||
// MakeTestRequest performs a test HTTP request and returns the response
|
||||
func MakeTestRequest(t *testing.T, router *gin.Engine, method, path string, body interface{}) *httptest.ResponseRecorder {
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
reqBody = bytes.NewBuffer(jsonBytes)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, path, reqBody)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
// AssertResponse asserts the HTTP response status code and body
|
||||
func AssertResponse(t *testing.T, w *httptest.ResponseRecorder, expectedStatus int, expectedBody interface{}) {
|
||||
assert.Equal(t, expectedStatus, w.Code)
|
||||
|
||||
if expectedBody != nil {
|
||||
var actualBody interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &actualBody)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedBody, actualBody)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertErrorResponse asserts an error response with a specific message
|
||||
func AssertErrorResponse(t *testing.T, w *httptest.ResponseRecorder, expectedStatus int, expectedMessage string) {
|
||||
assert.Equal(t, expectedStatus, w.Code)
|
||||
|
||||
var response struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedMessage, response.Error)
|
||||
}
|
41
backend/internal/types/config.go
Normal file
41
backend/internal/types/config.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package types
|
||||
|
||||
// RateLimitConfig 限流配置
|
||||
type RateLimitConfig struct {
|
||||
IPRate int `yaml:"ip_rate"` // IP限流速率
|
||||
IPBurst int `yaml:"ip_burst"` // IP突发请求数
|
||||
|
||||
RouteRates map[string]struct {
|
||||
Rate int `yaml:"rate"` // 路由限流速率
|
||||
Burst int `yaml:"burst"` // 路由突发请求数
|
||||
} `yaml:"route_rates"`
|
||||
}
|
||||
|
||||
// AccessLogConfig 访问日志配置
|
||||
type AccessLogConfig struct {
|
||||
// 是否启用控制台输出
|
||||
EnableConsole bool `yaml:"enable_console"`
|
||||
// 是否启用文件日志
|
||||
EnableFile bool `yaml:"enable_file"`
|
||||
// 日志文件路径
|
||||
FilePath string `yaml:"file_path"`
|
||||
// 日志格式 (json 或 text)
|
||||
Format string `yaml:"format"`
|
||||
// 日志级别
|
||||
Level string `yaml:"level"`
|
||||
// 日志轮转配置
|
||||
Rotation struct {
|
||||
MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小(MB)
|
||||
MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数
|
||||
MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量
|
||||
Compress bool `yaml:"compress"` // 是否压缩旧日志文件
|
||||
LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间
|
||||
} `yaml:"rotation"`
|
||||
}
|
||||
|
||||
// UploadConfig 文件上传配置
|
||||
type UploadConfig struct {
|
||||
MaxSize int `yaml:"max_size"` // 最大文件大小(MB)
|
||||
AllowedTypes []string `yaml:"allowed_types"` // 允许的MIME类型
|
||||
AllowedExtensions []string `yaml:"allowed_extensions"` // 允许的文件扩展名
|
||||
}
|
116
backend/internal/types/config_test.go
Normal file
116
backend/internal/types/config_test.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRateLimitConfig(t *testing.T) {
|
||||
config := RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/test": {
|
||||
Rate: 50,
|
||||
Burst: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if config.IPRate != 100 {
|
||||
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
|
||||
}
|
||||
if config.IPBurst != 200 {
|
||||
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
|
||||
}
|
||||
|
||||
route := config.RouteRates["/api/test"]
|
||||
if route.Rate != 50 {
|
||||
t.Errorf("Expected route rate 50, got %d", route.Rate)
|
||||
}
|
||||
if route.Burst != 100 {
|
||||
t.Errorf("Expected route burst 100, got %d", route.Burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogConfig(t *testing.T) {
|
||||
config := AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "/var/log/app.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 5,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
}
|
||||
|
||||
if !config.EnableConsole {
|
||||
t.Error("Expected EnableConsole to be true")
|
||||
}
|
||||
if !config.EnableFile {
|
||||
t.Error("Expected EnableFile to be true")
|
||||
}
|
||||
if config.FilePath != "/var/log/app.log" {
|
||||
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
|
||||
}
|
||||
if config.Format != "json" {
|
||||
t.Errorf("Expected Format 'json', got '%s'", config.Format)
|
||||
}
|
||||
if config.Level != "info" {
|
||||
t.Errorf("Expected Level 'info', got '%s'", config.Level)
|
||||
}
|
||||
|
||||
rotation := config.Rotation
|
||||
if rotation.MaxSize != 100 {
|
||||
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
|
||||
}
|
||||
if rotation.MaxAge != 7 {
|
||||
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
|
||||
}
|
||||
if rotation.MaxBackups != 5 {
|
||||
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
|
||||
}
|
||||
if !rotation.Compress {
|
||||
t.Error("Expected Compress to be true")
|
||||
}
|
||||
if !rotation.LocalTime {
|
||||
t.Error("Expected LocalTime to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadConfig(t *testing.T) {
|
||||
config := UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
}
|
||||
|
||||
if config.MaxSize != 10 {
|
||||
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
|
||||
}
|
||||
if len(config.AllowedTypes) != 2 {
|
||||
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
|
||||
}
|
||||
if config.AllowedTypes[0] != "image/jpeg" {
|
||||
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
|
||||
}
|
||||
if len(config.AllowedExtensions) != 2 {
|
||||
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
|
||||
}
|
||||
if config.AllowedExtensions[0] != ".jpg" {
|
||||
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
|
||||
}
|
||||
}
|
8
backend/internal/types/file.go
Normal file
8
backend/internal/types/file.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package types
|
||||
|
||||
// FileInfo represents metadata about a file
|
||||
type FileInfo struct {
|
||||
Size int64
|
||||
Name string
|
||||
ContentType string
|
||||
}
|
21
backend/internal/types/file_test.go
Normal file
21
backend/internal/types/file_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFileInfo(t *testing.T) {
|
||||
fileInfo := FileInfo{
|
||||
Size: 1024,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}
|
||||
|
||||
if fileInfo.Size != 1024 {
|
||||
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
|
||||
}
|
||||
if fileInfo.Name != "test.jpg" {
|
||||
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
|
||||
}
|
||||
if fileInfo.ContentType != "image/jpeg" {
|
||||
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
|
||||
}
|
||||
}
|
43
backend/internal/types/types.go
Normal file
43
backend/internal/types/types.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package types
|
||||
|
||||
import "errors"
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
// Category represents a category in the system
|
||||
type Category struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// Post represents a blog post
|
||||
type Post struct {
|
||||
ID int `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Slug string `json:"slug"`
|
||||
ContentMarkdown string `json:"content_markdown"`
|
||||
Summary string `json:"summary"`
|
||||
MetaKeywords *string `json:"meta_keywords,omitempty"`
|
||||
MetaDescription *string `json:"meta_description,omitempty"`
|
||||
}
|
||||
|
||||
// Contributor represents a contributor to the blog
|
||||
type Contributor struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL *string `json:"avatar_url,omitempty"`
|
||||
Bio *string `json:"bio,omitempty"`
|
||||
}
|
||||
|
||||
// Daily represents a daily quote or message
|
||||
type Daily struct {
|
||||
ID string `json:"id"`
|
||||
CategoryID int `json:"category_id"`
|
||||
ImageURL string `json:"image_url"`
|
||||
Quote string `json:"quote"`
|
||||
}
|
77
backend/internal/types/types_test.go
Normal file
77
backend/internal/types/types_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCategory(t *testing.T) {
|
||||
description := "Test Description"
|
||||
category := Category{
|
||||
ID: 1,
|
||||
Name: "Test Category",
|
||||
Slug: "test-category",
|
||||
Description: &description,
|
||||
}
|
||||
|
||||
if category.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", category.ID)
|
||||
}
|
||||
if category.Name != "Test Category" {
|
||||
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
|
||||
}
|
||||
if category.Slug != "test-category" {
|
||||
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
|
||||
}
|
||||
if *category.Description != description {
|
||||
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPost(t *testing.T) {
|
||||
metaKeywords := "test,blog"
|
||||
metaDesc := "Test Description"
|
||||
post := Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
Slug: "test-post",
|
||||
ContentMarkdown: "# Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: &metaKeywords,
|
||||
MetaDescription: &metaDesc,
|
||||
}
|
||||
|
||||
if post.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", post.ID)
|
||||
}
|
||||
if post.Title != "Test Post" {
|
||||
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
|
||||
}
|
||||
if post.Slug != "test-post" {
|
||||
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
|
||||
}
|
||||
if *post.MetaKeywords != metaKeywords {
|
||||
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaily(t *testing.T) {
|
||||
daily := Daily{
|
||||
ID: "2025-02-12",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image.jpg",
|
||||
Quote: "Test Quote",
|
||||
}
|
||||
|
||||
if daily.ID != "2025-02-12" {
|
||||
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
|
||||
}
|
||||
if daily.CategoryID != 1 {
|
||||
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
|
||||
}
|
||||
if daily.ImageURL != "https://example.com/image.jpg" {
|
||||
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
|
||||
}
|
||||
if daily.Quote != "Test Quote" {
|
||||
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue