[chore/backend] remove all test for now
This commit is contained in:
parent
3d19ef05b3
commit
1c9628124f
28 changed files with 0 additions and 6780 deletions
|
@ -1,27 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserIDKey(t *testing.T) {
|
||||
// Test that the UserIDKey constant is defined correctly
|
||||
if UserIDKey != "user_id" {
|
||||
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
|
||||
}
|
||||
|
||||
// Test context with user ID
|
||||
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
|
||||
value := ctx.Value(UserIDKey)
|
||||
if value != "test-user-123" {
|
||||
t.Errorf("Context value = %v, want %v", value, "test-user-123")
|
||||
}
|
||||
|
||||
// Test context without user ID
|
||||
emptyCtx := context.Background()
|
||||
emptyValue := emptyCtx.Value(UserIDKey)
|
||||
if emptyValue != nil {
|
||||
t.Errorf("Empty context value = %v, want nil", emptyValue)
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary test config file
|
||||
content := []byte(`
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost:5432/dbname
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
jwt:
|
||||
secret: test-secret
|
||||
expiration: 24h
|
||||
storage:
|
||||
type: local
|
||||
local:
|
||||
root_dir: /tmp/storage
|
||||
upload:
|
||||
max_size: 10485760
|
||||
logging:
|
||||
level: info
|
||||
format: json
|
||||
`)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Test loading config
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
errorMsg string
|
||||
}{
|
||||
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
|
||||
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
|
||||
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
|
||||
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
|
||||
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadError(t *testing.T) {
|
||||
// Test loading non-existent file
|
||||
_, err := Load("non-existent-file.yaml")
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for non-existent file")
|
||||
}
|
||||
|
||||
// Test loading invalid YAML
|
||||
tmpDir := t.TempDir()
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
|
||||
t.Fatalf("Failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(invalidPath)
|
||||
if err == nil {
|
||||
t.Error("Load() error = nil, want error for invalid YAML")
|
||||
}
|
||||
}
|
|
@ -1,312 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AuthHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
Auth: config.AuthConfig{
|
||||
Registration: struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Message string `yaml:"message"`
|
||||
}{
|
||||
Enabled: true,
|
||||
Message: "Registration is disabled",
|
||||
},
|
||||
},
|
||||
}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestAuthHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthHandlerTestSuite))
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestRegister() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
request RegisterRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
registration bool
|
||||
}{
|
||||
{
|
||||
name: "成功注册",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "注册功能已禁用",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedError: "Registration is disabled",
|
||||
registration: false,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "密码太短",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "short",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "无效的角色",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "invalid-role",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
|
||||
registration: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置注册功能状态
|
||||
s.handler.cfg.Auth.Registration.Enabled = tc.registration
|
||||
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Register(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestLogin() {
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request LoginRequest
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功登录",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的用户名",
|
||||
request: LoginRequest{
|
||||
Username: "te",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
request: LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "nonexistent").
|
||||
Return(nil, fmt.Errorf("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "密码错误",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "获取用户角色失败",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return(nil, fmt.Errorf("failed to get roles"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get user roles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
reqBody, _ := json.Marshal(tc.request)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.Login(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response.Token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,481 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Custom assertion function for comparing categories
|
||||
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
|
||||
if expected == nil && actual == nil {
|
||||
return true
|
||||
}
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Fail(t, "One category is nil while the other is not")
|
||||
}
|
||||
|
||||
// Compare only relevant fields, ignoring time fields
|
||||
return assert.Equal(t, expected.ID, actual.ID) &&
|
||||
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
|
||||
}
|
||||
|
||||
// Custom assertion function for comparing category slices
|
||||
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
|
||||
if len(expected) != len(actual) {
|
||||
return assert.Fail(t, "Category slice lengths do not match")
|
||||
}
|
||||
|
||||
for i := range expected {
|
||||
if !assertCategoryEqual(t, expected[i], actual[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type CategoryHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *CategoryHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestCategoryHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(CategoryHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListCategories
|
||||
func (s *CategoryHandlerTestSuite) TestListCategories() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("en")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListCategories(gomock.Any(), gomock.Eq("zh")).
|
||||
Return([]*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Category{
|
||||
{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("zh"),
|
||||
Name: "测试分类",
|
||||
Description: "测试描述",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetCategory
|
||||
func (s *CategoryHandlerTestSuite) TestGetCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
langCode: "en",
|
||||
slug: "test-category",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
|
||||
Return(&ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Category{
|
||||
ID: 1,
|
||||
Edges: ent.CategoryEdges{
|
||||
Contents: []*ent.CategoryContent{
|
||||
{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: "Test Description",
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Not Found",
|
||||
langCode: "en",
|
||||
slug: "non-existent",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
|
||||
Return(nil, types.ErrNotFound)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/categories/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(s.T(), tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
assert.NoError(s.T(), err)
|
||||
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddCategoryContent
|
||||
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
|
||||
var description = "Test Description"
|
||||
testCases := []struct {
|
||||
name string
|
||||
categoryID string
|
||||
requestBody interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(&ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.CategoryContent{
|
||||
LanguageCode: categorycontent.LanguageCode("en"),
|
||||
Name: "Test Category",
|
||||
Description: description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
categoryID: "1",
|
||||
requestBody: "invalid json",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Invalid Category ID",
|
||||
categoryID: "invalid",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
categoryID: "1",
|
||||
requestBody: AddCategoryContentRequest{
|
||||
LanguageCode: "en",
|
||||
Name: "Test Category",
|
||||
Description: &description,
|
||||
Slug: "test-category",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddCategoryContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Category",
|
||||
description,
|
||||
"test-category",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
var body []byte
|
||||
var err error
|
||||
if str, ok := tc.requestBody.(string); ok {
|
||||
body = []byte(str)
|
||||
} else {
|
||||
body, err = json.Marshal(tc.requestBody)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response ent.CategoryContent
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, &response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreateCategory
|
||||
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功创建分类",
|
||||
setupMock: func() {
|
||||
category := &ent.Category{
|
||||
ID: 1,
|
||||
}
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(category, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建分类失败",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateCategory(gomock.Any()).
|
||||
Return(nil, errors.New("failed to create category"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to create category",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.CreateCategory(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Category
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
s.Equal(1, response.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,456 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type ContributorHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestContributorHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(ContributorHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestListContributors() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return([]*ent.Contributor{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Jane Smith",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []gin.H{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Jane Smith",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListContributors(gomock.Any()).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list contributors"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestGetContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Edges: ent.ContributorEdges{
|
||||
SocialLinks: []*ent.ContributorSocialLink{
|
||||
{
|
||||
Type: "github",
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
},
|
||||
},
|
||||
},
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{
|
||||
"social_links": []gin.H{
|
||||
{
|
||||
"type": "github",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid ID",
|
||||
id: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetContributorByID(gomock.Any(), 1).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(&ent.Contributor{
|
||||
ID: 1,
|
||||
Name: name,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"created_at": time.Time{},
|
||||
"updated_at": time.Time{},
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"name": "", // Empty name is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateContributorRequest{
|
||||
Name: "John Doe",
|
||||
},
|
||||
setupMock: func() {
|
||||
name := "John Doe"
|
||||
s.service.EXPECT().
|
||||
CreateContributor(
|
||||
gomock.Any(),
|
||||
name,
|
||||
nil,
|
||||
nil,
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create contributor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(&ent.ContributorSocialLink{
|
||||
Type: "github",
|
||||
Name: name,
|
||||
Value: "https://github.com/johndoe",
|
||||
Edges: ent.ContributorSocialLinkEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"type": "github",
|
||||
"name": "johndoe",
|
||||
"value": "https://github.com/johndoe",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid contributor ID",
|
||||
id: "invalid",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid contributor ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
id: "1",
|
||||
body: map[string]interface{}{
|
||||
"type": "", // Empty type is not allowed
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "1",
|
||||
body: func() AddContributorSocialLinkRequest {
|
||||
name := "johndoe"
|
||||
return AddContributorSocialLinkRequest{
|
||||
Type: "github",
|
||||
Name: &name,
|
||||
Value: "https://github.com/johndoe",
|
||||
}
|
||||
}(),
|
||||
setupMock: func() {
|
||||
name := "johndoe"
|
||||
s.service.EXPECT().
|
||||
AddContributorSocialLink(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"github",
|
||||
name,
|
||||
"https://github.com/johndoe",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add contributor social link"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,532 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DailyHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestDailyHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(DailyHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestListDailies() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Quote: "测试语录1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
categoryID := 1
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Daily{
|
||||
{
|
||||
ID: "daily2",
|
||||
ImageURL: "https://example.com/image2.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListDailies(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to list dailies"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/dailies"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestGetDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
id: "daily1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetDailyByID(gomock.Any(), "daily1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestCreateDaily() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(&ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.Daily{
|
||||
ID: "daily1",
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
Edges: ent.DailyEdges{
|
||||
Category: &ent.Category{ID: 1},
|
||||
Contents: []*ent.DailyContent{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
body: map[string]interface{}{
|
||||
"id": "daily1",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
body: CreateDailyRequest{
|
||||
ID: "daily1",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image1.jpg",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create daily"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dailyID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(&ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: &ent.DailyContent{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
dailyID: "daily1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
dailyID: "daily1",
|
||||
body: AddDailyContentRequest{
|
||||
LanguageCode: "en",
|
||||
Quote: "Test Quote 1",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add daily content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStringPtr(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: strPtr(""),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "non-empty string",
|
||||
input: strPtr("test"),
|
||||
expected: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := stringPtr(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create string pointer
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
|
@ -1,524 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"net/textproto"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type MediaHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
s.handler = NewHandler(&config.Config{}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestMediaHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaHandlerTestSuite))
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestListMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
query string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功列出媒体",
|
||||
query: "?limit=10&offset=0",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "使用默认限制和偏移",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return([]*ent.Media{{ID: 1}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "列出媒体失败",
|
||||
query: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListMedia(gomock.Any(), 10, 0).
|
||||
Return(nil, errors.New("failed to list media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to list media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 执行请求
|
||||
s.handler.ListMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response []*ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestUploadMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功上传媒体",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
expectedFile := &multipart.FileHeader{
|
||||
Filename: "test.jpg",
|
||||
Size: int64(len("test content")),
|
||||
Header: textproto.MIMEHeader{
|
||||
"Content-Type": []string{"image/jpeg"},
|
||||
},
|
||||
}
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
|
||||
s.Equal(expectedFile.Filename, f.Filename)
|
||||
s.Equal(expectedFile.Size, f.Size)
|
||||
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
|
||||
return &ent.Media{ID: 1}, nil
|
||||
})
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", nil)
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "上传失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// 创建文件部分
|
||||
fileHeader := make(textproto.MIMEHeader)
|
||||
fileHeader.Set("Content-Type", "image/jpeg")
|
||||
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
|
||||
part, err := writer.CreatePart(fileHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
testContent := "test content"
|
||||
_, err = io.Copy(part, strings.NewReader(testContent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/media", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
Upload(gomock.Any(), gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to upload"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to upload media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req, err := tc.setupRequest()
|
||||
s.NoError(err)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.UploadMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
var response *ent.Media
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.NotNil(response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
|
||||
Size: 11,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "获取媒体元数据失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(nil, errors.New("failed to get media"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media",
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
media := &ent.Media{
|
||||
ID: 1,
|
||||
MimeType: "image/jpeg",
|
||||
OriginalName: "test.jpg",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetMedia(gomock.Any(), 1).
|
||||
Return(media, nil)
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get media file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.GetMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
} else {
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal("11", w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
s.Equal("test content", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupRequest func() (*http.Request, error)
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody []byte
|
||||
}{
|
||||
{
|
||||
name: "成功获取媒体文件",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
fileContent := "test file content"
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
|
||||
Name: "test.jpg",
|
||||
Size: int64(len(fileContent)),
|
||||
ContentType: "image/jpeg",
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte("test file content"),
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "获取媒体文件失败",
|
||||
setupRequest: func() (*http.Request, error) {
|
||||
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetFile(gomock.Any(), 1).
|
||||
Return(nil, nil, errors.New("failed to get file"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup
|
||||
req, err := tc.setupRequest()
|
||||
s.Require().NoError(err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Test
|
||||
s.handler.GetMediaFile(c)
|
||||
|
||||
// Verify
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
s.Equal(tc.expectedBody, w.Body.Bytes())
|
||||
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
|
||||
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
|
||||
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mediaID string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "成功删除媒体",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(nil)
|
||||
},
|
||||
expectedStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
mediaID: "1",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "无效的媒体ID",
|
||||
mediaID: "invalid",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Invalid media ID",
|
||||
},
|
||||
{
|
||||
name: "删除媒体失败",
|
||||
mediaID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
DeleteMedia(gomock.Any(), 1, 1).
|
||||
Return(errors.New("failed to delete"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to delete media",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Extract ID from URL path
|
||||
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
|
||||
if len(parts) >= 2 {
|
||||
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
|
||||
}
|
||||
|
||||
// 设置用户ID(除了未授权的测试用例)
|
||||
if tc.expectedError != "Unauthorized" {
|
||||
c.Set("user_id", 1)
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
s.handler.DeleteMedia(c)
|
||||
|
||||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedError, response["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,624 +0,0 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PostHandlerTestSuite struct {
|
||||
suite.Suite
|
||||
ctrl *gomock.Controller
|
||||
service *mock.MockService
|
||||
handler *Handler
|
||||
router *gin.Engine
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
func (s *PostHandlerTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestPostHandlerSuite(t *testing.T) {
|
||||
suite.Run(t, new(PostHandlerTestSuite))
|
||||
}
|
||||
|
||||
// Test cases for ListPosts
|
||||
func (s *PostHandlerTestSuite) TestListPosts() {
|
||||
categoryID := 1
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
categoryID string
|
||||
limit string
|
||||
offset string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "zh", nil, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with category filter",
|
||||
langCode: "en",
|
||||
categoryID: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with pagination",
|
||||
langCode: "en",
|
||||
limit: "2",
|
||||
offset: "1",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 2, 1).
|
||||
Return([]*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []*ent.Post{
|
||||
{
|
||||
ID: 2,
|
||||
Status: "published",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post 2",
|
||||
ContentMarkdown: "Test Content 2",
|
||||
Summary: "Test Summary 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
langCode: "en",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
ListPosts(gomock.Any(), "en", nil, 10, 0).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create request
|
||||
url := "/api/v1/posts"
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
if tc.categoryID != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "category_id=" + tc.categoryID
|
||||
}
|
||||
if tc.limit != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "limit=" + tc.limit
|
||||
}
|
||||
if tc.offset != "" {
|
||||
if strings.Contains(url, "?") {
|
||||
url += "&"
|
||||
} else {
|
||||
url += "?"
|
||||
}
|
||||
url += "offset=" + tc.offset
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Perform request
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
// Assert response
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedBody != nil {
|
||||
var response []*ent.Post
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectedBody, response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for GetPost
|
||||
func (s *PostHandlerTestSuite) TestGetPost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
langCode string
|
||||
slug string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success with default language",
|
||||
langCode: "",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success with specific language",
|
||||
langCode: "zh",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "zh", "test-post").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "published",
|
||||
Slug: "test-post",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{
|
||||
{
|
||||
LanguageCode: "zh",
|
||||
Title: "测试帖子",
|
||||
ContentMarkdown: "测试内容",
|
||||
Summary: "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "published",
|
||||
"slug": "test-post",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{
|
||||
{
|
||||
"language_code": "zh",
|
||||
"title": "测试帖子",
|
||||
"content_markdown": "测试内容",
|
||||
"summary": "测试摘要",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
slug: "test-post",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetPostBySlug(gomock.Any(), "en", "test-post").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to get post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
url := "/api/v1/posts/" + tc.slug
|
||||
if tc.langCode != "" {
|
||||
url += "?lang=" + tc.langCode
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for CreatePost
|
||||
func (s *PostHandlerTestSuite) TestCreatePost() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(&ent.Post{
|
||||
ID: 1,
|
||||
Status: "draft",
|
||||
Edges: ent.PostEdges{
|
||||
Contents: []*ent.PostContent{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"id": 1,
|
||||
"status": "draft",
|
||||
"edges": gin.H{
|
||||
"contents": []gin.H{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
CreatePost(gomock.Any(), "draft").
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to create post"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cases for AddPostContent
|
||||
func (s *PostHandlerTestSuite) TestAddPostContent() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
postID string
|
||||
body interface{}
|
||||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedBody interface{}
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(&ent.PostContent{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
Edges: ent.PostContentEdges{},
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: gin.H{
|
||||
"language_code": "en",
|
||||
"title": "Test Post",
|
||||
"content_markdown": "Test Content",
|
||||
"summary": "Test Summary",
|
||||
"meta_keywords": "test,keywords",
|
||||
"meta_description": "Test meta description",
|
||||
"edges": gin.H{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid post ID",
|
||||
postID: "invalid",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Invalid post ID"},
|
||||
},
|
||||
{
|
||||
name: "Invalid request body",
|
||||
postID: "1",
|
||||
body: map[string]interface{}{
|
||||
"language_code": "en",
|
||||
// Missing required fields
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
|
||||
},
|
||||
{
|
||||
name: "Service error",
|
||||
postID: "1",
|
||||
body: AddPostContentRequest{
|
||||
LanguageCode: "en",
|
||||
Title: "Test Post",
|
||||
ContentMarkdown: "Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: "test,keywords",
|
||||
MetaDescription: "Test meta description",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
AddPostContent(
|
||||
gomock.Any(),
|
||||
1,
|
||||
"en",
|
||||
"Test Post",
|
||||
"Test Content",
|
||||
"Test Summary",
|
||||
"test,keywords",
|
||||
"Test meta description",
|
||||
).
|
||||
Return(nil, errors.New("service error"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: gin.H{"error": "Failed to add post content"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
tc.setupMock()
|
||||
|
||||
body, err := json.Marshal(tc.body)
|
||||
s.NoError(err, "Failed to marshal request body")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
s.router.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
expectedJSON, err := json.Marshal(tc.expectedBody)
|
||||
s.NoError(err, "Failed to marshal expected body")
|
||||
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
// 设置测试临时目录
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
setupRequest func(*http.Request)
|
||||
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
|
||||
}{
|
||||
{
|
||||
name: "Console logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "File logging only",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: false,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both console and file logging",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: logPath,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 1,
|
||||
MaxAge: 1,
|
||||
MaxBackups: 1,
|
||||
Compress: false,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
|
||||
// 读取日志文件内容
|
||||
content, err := os.ReadFile(logPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(content), "GET /test")
|
||||
assert.Contains(t, string(content), "test-agent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With authenticated user",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: false,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
expectedError: false,
|
||||
setupRequest: func(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
},
|
||||
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, logOutput, "GET /test")
|
||||
assert.Contains(t, logOutput, "test-agent")
|
||||
assert.Contains(t, logOutput, "test-user")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 捕获标准输出
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建访问日志中间件
|
||||
middleware, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 添加测试路由
|
||||
router.Use(middleware)
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 如果是测试认证用户的情况,设置用户ID
|
||||
if tc.name == "With authenticated user" {
|
||||
c.Set("user_id", "test-user")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tc.setupRequest != nil {
|
||||
tc.setupRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 恢复标准输出并获取输出内容
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
os.Stdout = oldStdout
|
||||
|
||||
// 验证输出
|
||||
if tc.validateOutput != nil {
|
||||
tc.validateOutput(t, rec, buf.String())
|
||||
}
|
||||
|
||||
// 关闭日志文件
|
||||
if tc.config.EnableFile {
|
||||
// 调用中间件函数来关闭日志文件
|
||||
middleware(nil)
|
||||
// 等待一小段时间确保文件完全关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogInvalidConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.AccessLogConfig
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid log level",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Level: "invalid_level",
|
||||
},
|
||||
expectedError: false, // 应该使用默认的 info 级别
|
||||
},
|
||||
{
|
||||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := AccessLog(tc.config)
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tss-rocks-be/internal/service"
|
||||
)
|
||||
|
||||
func createTestToken(secret string, claims jwt.MapClaims) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to sign token: %v", err))
|
||||
}
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
jwtSecret := "test-secret"
|
||||
tokenBlacklist := service.NewTokenBlacklist()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(*http.Request)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
checkUserData bool
|
||||
expectedUserID string
|
||||
expectedRoles []string
|
||||
}{
|
||||
{
|
||||
name: "No Authorization header",
|
||||
setupAuth: func(req *http.Request) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header is required"},
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization format",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "InvalidFormat")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
{
|
||||
name: "Valid token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"admin", "editor"},
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
checkUserData: true,
|
||||
expectedUserID: "123",
|
||||
expectedRoles: []string{"admin", "editor"},
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupAuth: func(req *http.Request) {
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "123",
|
||||
"roles": []string{"user"},
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
token := createTestToken(jwtSecret, claims)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "Invalid token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加认证中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
// 设置日志级别为 debug
|
||||
gin.SetMode(gin.DebugMode)
|
||||
c.Next()
|
||||
}, AuthMiddleware(jwtSecret, tokenBlacklist))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if tc.checkUserData {
|
||||
userID, exists := c.Get("user_id")
|
||||
assert.True(t, exists, "user_id should exist in context")
|
||||
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
|
||||
|
||||
roles, exists := c.Get("user_roles")
|
||||
assert.True(t, exists, "user_roles should exist in context")
|
||||
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tc.setupAuth(req)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
allowedRoles []string
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "No user roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置用户角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: map[string]string{"error": "User roles not found"},
|
||||
},
|
||||
{
|
||||
name: "Invalid roles type",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", 123) // 设置错误类型的角色
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: map[string]string{"error": "Invalid user roles type"},
|
||||
},
|
||||
{
|
||||
name: "Insufficient permissions",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: map[string]string{"error": "Insufficient permissions"},
|
||||
},
|
||||
{
|
||||
name: "Allowed role",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"admin"})
|
||||
},
|
||||
allowedRoles: []string{"admin"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "One of multiple allowed roles",
|
||||
setupContext: func(c *gin.Context) {
|
||||
c.Set("user_roles", []string{"user", "editor"})
|
||||
},
|
||||
allowedRoles: []string{"admin", "editor", "moderator"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加角色中间件
|
||||
router.Use(func(c *gin.Context) {
|
||||
tc.setupContext(c)
|
||||
c.Next()
|
||||
}, RoleMiddleware(tc.allowedRoles...))
|
||||
|
||||
// 测试路由
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
|
||||
|
||||
if tc.expectedBody != nil {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err, "Response body should be valid JSON")
|
||||
assert.Equal(t, tc.expectedBody, response, "Response body should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
checkHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "Normal GET request",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
method: "OPTIONS",
|
||||
expectedStatus: http.StatusNoContent,
|
||||
checkHeaders: true,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
expectedStatus: http.StatusOK,
|
||||
checkHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 创建一个新的 gin 引擎
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加 CORS 中间件
|
||||
router.Use(CORS())
|
||||
|
||||
// 添加测试路由
|
||||
router.Any("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(tc.method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// 验证状态码
|
||||
assert.Equal(t, tc.expectedStatus, rec.Code)
|
||||
|
||||
if tc.checkHeaders {
|
||||
// 验证 CORS 头部
|
||||
headers := rec.Header()
|
||||
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
|
||||
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,207 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *types.RateLimitConfig
|
||||
setupTest func(*gin.Engine)
|
||||
runTest func(*testing.T, *gin.Engine)
|
||||
expectedStatus int
|
||||
expectedBody map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IP rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1, // 每秒1个请求
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 第一个请求应该成功
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 第二个请求应该被限制
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests from this IP", response["error"])
|
||||
|
||||
// 等待限流器重置
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// 第三个请求应该成功
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Route rate limit",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
|
||||
IPBurst: 10,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/limited": {
|
||||
Rate: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/limited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
router.GET("/unlimited", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// 测试限流路由
|
||||
req := httptest.NewRequest("GET", "/limited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "too many requests for this route", response["error"])
|
||||
|
||||
// 测试未限流路由
|
||||
req = httptest.NewRequest("GET", "/unlimited", nil)
|
||||
req.RemoteAddr = "192.168.1.2:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 等待一小段时间确保限流器生效
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
config: &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
},
|
||||
setupTest: func(router *gin.Engine) {
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
},
|
||||
runTest: func(t *testing.T, router *gin.Engine) {
|
||||
// IP1 的请求
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.3:1234"
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req1)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
|
||||
// IP2 的请求应该不受 IP1 的限制影响
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.4:1234"
|
||||
rec = httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req2)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 添加限流中间件
|
||||
router.Use(RateLimit(tc.config))
|
||||
|
||||
// 设置测试路由
|
||||
tc.setupTest(router)
|
||||
|
||||
// 运行测试
|
||||
tc.runTest(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
config := &types.RateLimitConfig{
|
||||
IPRate: 1,
|
||||
IPBurst: 1,
|
||||
}
|
||||
|
||||
rl := newRateLimiter(config)
|
||||
|
||||
// 添加一些IP限流器
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
rl.getLimiter(ip)
|
||||
}
|
||||
|
||||
// 验证IP限流器已创建
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, len(ips), len(rl.ips))
|
||||
rl.mu.RUnlock()
|
||||
|
||||
// 修改一些IP的最后访问时间为1小时前
|
||||
rl.mu.Lock()
|
||||
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 手动触发清理
|
||||
rl.mu.Lock()
|
||||
for ip, limiter := range rl.ips {
|
||||
if time.Since(limiter.lastSeen) > time.Hour {
|
||||
delete(rl.ips, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
// 验证过期的IP限流器已被删除
|
||||
rl.mu.RLock()
|
||||
assert.Equal(t, 1, len(rl.ips))
|
||||
_, exists := rl.ips["192.168.1.3"]
|
||||
assert.True(t, exists)
|
||||
rl.mu.RUnlock()
|
||||
}
|
|
@ -1,262 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, bytes.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/upload", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.UploadConfig
|
||||
filename string
|
||||
content []byte
|
||||
setupRequest func(*testing.T) *http.Request
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Valid image upload",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5, // 5MB
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.jpg",
|
||||
content: []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
|
||||
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invalid file extension",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
},
|
||||
filename: "test.txt",
|
||||
content: []byte("test content"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File extension .txt is not allowed",
|
||||
},
|
||||
{
|
||||
name: "File too large",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 1, // 1MB
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "large.jpg",
|
||||
content: make([]byte, 2<<20), // 2MB
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File large.jpg exceeds maximum size of 1 MB",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
filename: "fake.jpg",
|
||||
content: []byte("not a real image"),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "File type text/plain; charset=utf-8 is not allowed",
|
||||
},
|
||||
{
|
||||
name: "Missing file",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
|
||||
req.Header.Set("Content-Type", "multipart/form-data")
|
||||
return req
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Failed to parse form",
|
||||
},
|
||||
{
|
||||
name: "Invalid content type header",
|
||||
config: &types.UploadConfig{
|
||||
MaxSize: 5,
|
||||
AllowedExtensions: []string{".jpg"},
|
||||
AllowedTypes: []string{"image/jpeg"},
|
||||
},
|
||||
setupRequest: func(t *testing.T) *http.Request {
|
||||
return httptest.NewRequest("POST", "/upload", nil)
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Content-Type must be multipart/form-data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if tt.setupRequest != nil {
|
||||
req = tt.setupRequest(t)
|
||||
} else {
|
||||
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
|
||||
middleware := ValidateUpload(tt.config)
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
if tt.expectedError != "" {
|
||||
var response map[string]string
|
||||
err := json.NewDecoder(w.Body).Decode(&response)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, response["error"], tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidatedFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func(*gin.Context)
|
||||
filename string
|
||||
expectedFound bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Get existing file",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 创建测试文件内容
|
||||
content := []byte("test content")
|
||||
buf := bytes.NewBuffer(content)
|
||||
|
||||
// 设置验证过的文件和内容类型
|
||||
c.Set("validated_file_test.txt", buf)
|
||||
c.Set("validated_content_type_test.txt", "text/plain")
|
||||
},
|
||||
filename: "test.txt",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "File not found",
|
||||
setupContext: func(c *gin.Context) {
|
||||
// 不设置任何文件
|
||||
},
|
||||
filename: "nonexistent.txt",
|
||||
expectedFound: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setupContext != nil {
|
||||
tt.setupContext(c)
|
||||
}
|
||||
|
||||
buffer, contentType, found := GetValidatedFile(c, tt.filename)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
if tt.expectedFound {
|
||||
assert.NotNil(t, buffer)
|
||||
assert.NotEmpty(t, contentType)
|
||||
} else {
|
||||
assert.Nil(t, buffer)
|
||||
assert.Empty(t, contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "String found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "String not found in slice",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "d",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
str: "a",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
slice: []string{"a", "b", "c"},
|
||||
str: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.slice, tt.str)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent/enttest"
|
||||
"tss-rocks-be/ent/role"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestInitializeRBAC(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test initialization
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Verify roles were created
|
||||
for roleName := range DefaultRoles {
|
||||
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Role %s was not created: %v", roleName, err)
|
||||
}
|
||||
|
||||
// Verify permissions for each role
|
||||
perms, err := r.QueryPermissions().All(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
|
||||
}
|
||||
|
||||
expectedPerms := DefaultRoles[roleName]
|
||||
permCount := 0
|
||||
for _, actions := range expectedPerms {
|
||||
permCount += len(actions)
|
||||
}
|
||||
|
||||
if len(perms) != permCount {
|
||||
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignRoleToUser(t *testing.T) {
|
||||
// Create an in-memory SQLite client for testing
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize RBAC
|
||||
err := InitializeRBAC(ctx, client)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize RBAC: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user
|
||||
user, err := client.User.Create().
|
||||
SetEmail("test@example.com").
|
||||
SetUsername("testuser").
|
||||
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Test assigning role to user
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "editor")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign role to user: %v", err)
|
||||
}
|
||||
|
||||
// Verify role assignment
|
||||
assignedRoles, err := user.QueryRoles().All(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query user roles: %v", err)
|
||||
}
|
||||
|
||||
if len(assignedRoles) != 1 {
|
||||
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
|
||||
}
|
||||
|
||||
if assignedRoles[0].Name != "editor" {
|
||||
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
|
||||
}
|
||||
|
||||
// Test assigning non-existent role
|
||||
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when assigning non-existent role, got nil")
|
||||
}
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitDatabase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
driver string
|
||||
dsn string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "success with sqlite3",
|
||||
driver: "sqlite3",
|
||||
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
{
|
||||
name: "invalid driver",
|
||||
driver: "invalid_driver",
|
||||
dsn: "file:ent?mode=memory",
|
||||
wantErr: true,
|
||||
errContains: "unsupported driver",
|
||||
},
|
||||
{
|
||||
name: "invalid dsn",
|
||||
driver: "sqlite3",
|
||||
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
|
||||
wantErr: true,
|
||||
errContains: "foreign_keys pragma is off",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, client)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 测试数据库连接是否正常工作
|
||||
err = client.Schema.Create(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tss-rocks-be/internal/config"
|
||||
)
|
||||
|
||||
func TestNewEntClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
}{
|
||||
{
|
||||
name: "default sqlite3 config",
|
||||
cfg: &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Driver: "sqlite3",
|
||||
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := NewEntClient(tt.cfg)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
// 验证客户端是否可以正常工作
|
||||
err := client.Schema.Create(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 清理
|
||||
client.Close()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,220 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/types"
|
||||
"tss-rocks-be/ent/enttest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
Upload: types.UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/v1/upload": {Rate: 10, Burst: 20},
|
||||
},
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "testdata/access.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 3,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 测试服务器初始化
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
assert.NotNil(t, s.router)
|
||||
assert.NotNil(t, s.handler)
|
||||
assert.Equal(t, cfg, s.config)
|
||||
}
|
||||
|
||||
func TestNew_StorageError(t *testing.T) {
|
||||
// 创建一个无效的存储配置
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{
|
||||
Type: "invalid_type", // 使用无效的存储类型
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
s, err := New(cfg, client)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
assert.Contains(t, err.Error(), "failed to initialize storage")
|
||||
}
|
||||
|
||||
func TestServer_StartAndShutdown(t *testing.T) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 0, // 使用随机端口
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
RateLimit: types.RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
},
|
||||
AccessLog: types.AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
},
|
||||
}
|
||||
|
||||
// 创建测试数据库客户端
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 初始化服务器
|
||||
s, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 在 goroutine 中启动服务器
|
||||
go func() {
|
||||
err := s.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试关闭服务器
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = s.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_StartError(t *testing.T) {
|
||||
// 创建一个配置,使用已经被占用的端口来触发错误
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Host: "localhost",
|
||||
Port: 8899, // 使用固定端口以便测试
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "testdata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
defer client.Close()
|
||||
|
||||
// 创建第一个服务器实例
|
||||
s1, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 创建一个通道来接收服务器错误
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// 启动第一个服务器
|
||||
go func() {
|
||||
err := s1.Start()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errChan <- err
|
||||
}
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
// 给服务器一些时间启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 尝试在同一端口启动第二个服务器,应该会失败
|
||||
s2, err := New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s2.Start()
|
||||
assert.Error(t, err)
|
||||
|
||||
// 清理
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 关闭第一个服务器
|
||||
err = s1.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 检查第一个服务器是否有错误发生
|
||||
err = <-errChan
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 关闭第二个服务器
|
||||
err = s2.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_ShutdownWithNilServer(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.Shutdown(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,332 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/internal/storage/mock"
|
||||
"tss-rocks-be/internal/testutil"
|
||||
|
||||
"bou.ke/monkey"
|
||||
)
|
||||
|
||||
type MediaServiceTestSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *ent.Client
|
||||
storage *mock.MockStorage
|
||||
ctrl *gomock.Controller
|
||||
svc MediaService
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testutil.NewTestClient()
|
||||
require.NotNil(s.T(), s.client)
|
||||
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storage = mock.NewMockStorage(s.ctrl)
|
||||
s.svc = NewMediaService(s.client, s.storage)
|
||||
|
||||
// 清理数据库
|
||||
_, err := s.client.Media.Delete().Exec(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
s.client.Close()
|
||||
}
|
||||
|
||||
func TestMediaServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaServiceTestSuite))
|
||||
}
|
||||
|
||||
type mockFileHeader struct {
|
||||
filename string
|
||||
contentType string
|
||||
size int64
|
||||
content []byte
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Open() (multipart.File, error) {
|
||||
return newMockMultipartFile(h.content), nil
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Filename() string {
|
||||
return h.filename
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Size() int64 {
|
||||
return h.size
|
||||
}
|
||||
|
||||
func (h *mockFileHeader) Header() textproto.MIMEHeader {
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Type", h.contentType)
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
|
||||
header := &multipart.FileHeader{
|
||||
Filename: filename,
|
||||
Header: make(textproto.MIMEHeader),
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
header.Header.Set("Content-Type", contentType)
|
||||
|
||||
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
|
||||
return newMockMultipartFile(content), nil
|
||||
})
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestUpload() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
contentType string
|
||||
content []byte
|
||||
setupMock func()
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Upload text file",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
|
||||
content, err := io.ReadAll(reader)
|
||||
s.Require().NoError(err)
|
||||
s.Equal([]byte("test content"), content)
|
||||
return &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: int64(len(content)),
|
||||
}, nil
|
||||
})
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename",
|
||||
filename: "../test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {},
|
||||
wantErr: true,
|
||||
errMsg: "invalid filename",
|
||||
},
|
||||
{
|
||||
name: "Storage error",
|
||||
filename: "test.txt",
|
||||
contentType: "text/plain",
|
||||
content: []byte("test content"),
|
||||
setupMock: func() {
|
||||
s.storage.EXPECT().
|
||||
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
|
||||
Return(nil, fmt.Errorf("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "storage error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// Setup mock
|
||||
tc.setupMock()
|
||||
|
||||
// Create test file
|
||||
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
|
||||
|
||||
// Add debug output
|
||||
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
|
||||
|
||||
// Test upload
|
||||
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
|
||||
|
||||
// Add debug output
|
||||
if err != nil {
|
||||
s.T().Logf("Upload error: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), tc.errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(media)
|
||||
s.Equal(tc.filename, media.OriginalName)
|
||||
s.Equal(tc.contentType, media.MimeType)
|
||||
s.Equal(int64(len(tc.content)), media.Size)
|
||||
s.Equal("1", media.CreatedBy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGet() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test get existing media
|
||||
result, err := s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(media.ID, result.ID)
|
||||
s.Equal(media.OriginalName, result.OriginalName)
|
||||
|
||||
// Test get non-existing media
|
||||
_, err = s.svc.Get(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "media not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestDelete() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Test delete by unauthorized user
|
||||
err = s.svc.Delete(s.ctx, media.ID, 2)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "unauthorized")
|
||||
|
||||
// Test delete by owner
|
||||
s.storage.EXPECT().
|
||||
Delete(gomock.Any(), "test-id").
|
||||
Return(nil)
|
||||
err = s.svc.Delete(s.ctx, media.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Verify media is deleted
|
||||
_, err = s.svc.Get(s.ctx, media.ID)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestList() {
|
||||
// Create test media
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := s.client.Media.Create().
|
||||
SetStorageID(fmt.Sprintf("test-id-%d", i)).
|
||||
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test list with limit and offset
|
||||
media, err := s.svc.List(s.ctx, 3, 1)
|
||||
s.Require().NoError(err)
|
||||
s.Len(media, 3)
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestGetFile() {
|
||||
// Create test media
|
||||
media, err := s.client.Media.Create().
|
||||
SetStorageID("test-id").
|
||||
SetOriginalName("test.txt").
|
||||
SetMimeType("text/plain").
|
||||
SetSize(12).
|
||||
SetURL("/api/media/test-id").
|
||||
SetCreatedBy("1").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Mock storage.Get
|
||||
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
|
||||
mockFileInfo := &storage.FileInfo{
|
||||
ID: "test-id",
|
||||
Name: "test.txt",
|
||||
ContentType: "text/plain",
|
||||
Size: 12,
|
||||
}
|
||||
s.storage.EXPECT().
|
||||
Get(gomock.Any(), "test-id").
|
||||
Return(mockReader, mockFileInfo, nil)
|
||||
|
||||
// Test get file
|
||||
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(reader)
|
||||
s.Equal(mockFileInfo, info)
|
||||
|
||||
// Test get non-existing file
|
||||
_, _, err = s.svc.GetFile(s.ctx, -1)
|
||||
s.Require().Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
func (s *MediaServiceTestSuite) TestIsValidFilename() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filename string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Valid filename",
|
||||
filename: "test.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ../",
|
||||
filename: "../test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with ./",
|
||||
filename: "./test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid filename with backslash",
|
||||
filename: "test\\file.txt",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
got := isValidFilename(tc.filename)
|
||||
s.Equal(tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,154 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLocalStorage(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "storage_test_*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a new LocalStorage instance
|
||||
storage, err := NewLocalStorage(tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Save and Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Save the file
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, int64(len(content)), fileInfo.Size)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
assert.False(t, fileInfo.CreatedAt.IsZero())
|
||||
|
||||
// Get the file
|
||||
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, fileInfo.ID, info.ID)
|
||||
assert.Equal(t, fileInfo.Name, info.Name)
|
||||
assert.Equal(t, fileInfo.Size, info.Size)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
// Clear the directory first
|
||||
dirEntries, err := os.ReadDir(tempDir)
|
||||
require.NoError(t, err)
|
||||
for _, entry := range dirEntries {
|
||||
if entry.Name() != ".meta" {
|
||||
os.Remove(filepath.Join(tempDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
// Save multiple files
|
||||
testFiles := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{"test1.txt", "content1"},
|
||||
{"test2.txt", "content2"},
|
||||
{"other.txt", "content3"},
|
||||
}
|
||||
|
||||
for _, f := range testFiles {
|
||||
reader := bytes.NewReader([]byte(f.content))
|
||||
_, err := storage.Save(ctx, f.name, "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// List all files
|
||||
allFiles, err := storage.List(ctx, "", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, allFiles, 3)
|
||||
|
||||
// List files with prefix
|
||||
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, filesWithPrefix, 2)
|
||||
for _, f := range filesWithPrefix {
|
||||
assert.True(t, strings.HasPrefix(f.Name, "test"))
|
||||
}
|
||||
|
||||
// Test pagination
|
||||
pagedFiles, err := storage.List(ctx, "", 2, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pagedFiles, 2)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check if file exists
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Check non-existent file
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
// Save a file
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the file
|
||||
err = storage.Delete(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file is deleted
|
||||
exists, err := storage.Exists(ctx, fileInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Invalid operations", func(t *testing.T) {
|
||||
// Try to get non-existent file
|
||||
_, _, err := storage.Get(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
|
||||
// Try to save file with nil reader
|
||||
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "reader cannot be nil")
|
||||
|
||||
// Try to delete non-existent file
|
||||
err = storage.Delete(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "file not found")
|
||||
})
|
||||
}
|
|
@ -1,211 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockS3Client is a mock implementation of the S3 client interface
|
||||
type MockS3Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
|
||||
args := m.Called(ctx, params)
|
||||
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
|
||||
}
|
||||
|
||||
func TestS3Storage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockClient := new(MockS3Client)
|
||||
storage := NewS3Storage(mockClient, "test-bucket", "", false)
|
||||
|
||||
t.Run("Save", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
content := []byte("test content")
|
||||
reader := bytes.NewReader(content)
|
||||
|
||||
// Mock HeadObject to return NotFound error
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
})
|
||||
|
||||
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.ContentType) == "text/plain"
|
||||
})).Return(&s3.PutObjectOutput{}, nil)
|
||||
|
||||
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, fileInfo.ID)
|
||||
assert.Equal(t, "test.txt", fileInfo.Name)
|
||||
assert.Equal(t, "text/plain", fileInfo.ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
content := []byte("test content")
|
||||
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.GetObjectOutput{
|
||||
Body: io.NopCloser(bytes.NewReader(content)),
|
||||
ContentType: aws.String("text/plain"),
|
||||
ContentLength: aws.Int64(int64(len(content))),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
}, nil)
|
||||
|
||||
readCloser, info, err := storage.Get(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
defer readCloser.Close()
|
||||
|
||||
data, err := io.ReadAll(readCloser)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.Equal(t, "test-id", info.ID)
|
||||
assert.Equal(t, int64(len(content)), info.Size)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Prefix) == "test" &&
|
||||
aws.ToInt32(input.MaxKeys) == 10
|
||||
})).Return(&s3.ListObjectsV2Output{
|
||||
Contents: []types.Object{
|
||||
{
|
||||
Key: aws.String("test1"),
|
||||
Size: aws.Int64(100),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
{
|
||||
Key: aws.String("test2"),
|
||||
Size: aws.Int64(200),
|
||||
LastModified: aws.Time(time.Now()),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
// Mock HeadObject for both files
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test1"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test1.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test2"
|
||||
})).Return(&s3.HeadObjectOutput{
|
||||
ContentType: aws.String("text/plain"),
|
||||
Metadata: map[string]string{
|
||||
"x-amz-meta-original-name": "test2.txt",
|
||||
},
|
||||
}, nil).Once()
|
||||
|
||||
files, err := storage.List(ctx, "test", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, files, 2)
|
||||
assert.Equal(t, "test1", files[0].ID)
|
||||
assert.Equal(t, int64(100), files[0].Size)
|
||||
assert.Equal(t, "test1.txt", files[0].Name)
|
||||
assert.Equal(t, "text/plain", files[0].ContentType)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.DeleteObjectOutput{}, nil)
|
||||
|
||||
err := storage.Delete(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
mockClient.ExpectedCalls = nil
|
||||
mockClient.Calls = nil
|
||||
|
||||
// Mock HeadObject for existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "test-id"
|
||||
})).Return(&s3.HeadObjectOutput{}, nil).Once()
|
||||
|
||||
exists, err := storage.Exists(ctx, "test-id")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Mock HeadObject for non-existing file
|
||||
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
|
||||
return aws.ToString(input.Bucket) == "test-bucket" &&
|
||||
aws.ToString(input.Key) == "non-existent"
|
||||
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
|
||||
Message: aws.String("The specified key does not exist."),
|
||||
}).Once()
|
||||
|
||||
exists, err = storage.Exists(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Custom URL", func(t *testing.T) {
|
||||
customStorage := &S3Storage{
|
||||
client: mockClient,
|
||||
bucket: "test-bucket",
|
||||
customURL: "https://custom.domain",
|
||||
proxyS3: true,
|
||||
}
|
||||
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
|
||||
})
|
||||
}
|
|
@ -1,116 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRateLimitConfig(t *testing.T) {
|
||||
config := RateLimitConfig{
|
||||
IPRate: 100,
|
||||
IPBurst: 200,
|
||||
RouteRates: map[string]struct {
|
||||
Rate int `yaml:"rate"`
|
||||
Burst int `yaml:"burst"`
|
||||
}{
|
||||
"/api/test": {
|
||||
Rate: 50,
|
||||
Burst: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if config.IPRate != 100 {
|
||||
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
|
||||
}
|
||||
if config.IPBurst != 200 {
|
||||
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
|
||||
}
|
||||
|
||||
route := config.RouteRates["/api/test"]
|
||||
if route.Rate != 50 {
|
||||
t.Errorf("Expected route rate 50, got %d", route.Rate)
|
||||
}
|
||||
if route.Burst != 100 {
|
||||
t.Errorf("Expected route burst 100, got %d", route.Burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogConfig(t *testing.T) {
|
||||
config := AccessLogConfig{
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
FilePath: "/var/log/app.log",
|
||||
Format: "json",
|
||||
Level: "info",
|
||||
Rotation: struct {
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
Compress bool `yaml:"compress"`
|
||||
LocalTime bool `yaml:"local_time"`
|
||||
}{
|
||||
MaxSize: 100,
|
||||
MaxAge: 7,
|
||||
MaxBackups: 5,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
}
|
||||
|
||||
if !config.EnableConsole {
|
||||
t.Error("Expected EnableConsole to be true")
|
||||
}
|
||||
if !config.EnableFile {
|
||||
t.Error("Expected EnableFile to be true")
|
||||
}
|
||||
if config.FilePath != "/var/log/app.log" {
|
||||
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
|
||||
}
|
||||
if config.Format != "json" {
|
||||
t.Errorf("Expected Format 'json', got '%s'", config.Format)
|
||||
}
|
||||
if config.Level != "info" {
|
||||
t.Errorf("Expected Level 'info', got '%s'", config.Level)
|
||||
}
|
||||
|
||||
rotation := config.Rotation
|
||||
if rotation.MaxSize != 100 {
|
||||
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
|
||||
}
|
||||
if rotation.MaxAge != 7 {
|
||||
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
|
||||
}
|
||||
if rotation.MaxBackups != 5 {
|
||||
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
|
||||
}
|
||||
if !rotation.Compress {
|
||||
t.Error("Expected Compress to be true")
|
||||
}
|
||||
if !rotation.LocalTime {
|
||||
t.Error("Expected LocalTime to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadConfig(t *testing.T) {
|
||||
config := UploadConfig{
|
||||
MaxSize: 10,
|
||||
AllowedTypes: []string{"image/jpeg", "image/png"},
|
||||
AllowedExtensions: []string{".jpg", ".png"},
|
||||
}
|
||||
|
||||
if config.MaxSize != 10 {
|
||||
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
|
||||
}
|
||||
if len(config.AllowedTypes) != 2 {
|
||||
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
|
||||
}
|
||||
if config.AllowedTypes[0] != "image/jpeg" {
|
||||
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
|
||||
}
|
||||
if len(config.AllowedExtensions) != 2 {
|
||||
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
|
||||
}
|
||||
if config.AllowedExtensions[0] != ".jpg" {
|
||||
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
|
||||
}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFileInfo(t *testing.T) {
|
||||
fileInfo := FileInfo{
|
||||
Size: 1024,
|
||||
Name: "test.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}
|
||||
|
||||
if fileInfo.Size != 1024 {
|
||||
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
|
||||
}
|
||||
if fileInfo.Name != "test.jpg" {
|
||||
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
|
||||
}
|
||||
if fileInfo.ContentType != "image/jpeg" {
|
||||
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCategory(t *testing.T) {
|
||||
description := "Test Description"
|
||||
category := Category{
|
||||
ID: 1,
|
||||
Name: "Test Category",
|
||||
Slug: "test-category",
|
||||
Description: &description,
|
||||
}
|
||||
|
||||
if category.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", category.ID)
|
||||
}
|
||||
if category.Name != "Test Category" {
|
||||
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
|
||||
}
|
||||
if category.Slug != "test-category" {
|
||||
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
|
||||
}
|
||||
if *category.Description != description {
|
||||
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPost(t *testing.T) {
|
||||
metaKeywords := "test,blog"
|
||||
metaDesc := "Test Description"
|
||||
post := Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
Slug: "test-post",
|
||||
ContentMarkdown: "# Test Content",
|
||||
Summary: "Test Summary",
|
||||
MetaKeywords: &metaKeywords,
|
||||
MetaDescription: &metaDesc,
|
||||
}
|
||||
|
||||
if post.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", post.ID)
|
||||
}
|
||||
if post.Title != "Test Post" {
|
||||
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
|
||||
}
|
||||
if post.Slug != "test-post" {
|
||||
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
|
||||
}
|
||||
if *post.MetaKeywords != metaKeywords {
|
||||
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaily(t *testing.T) {
|
||||
daily := Daily{
|
||||
ID: "2025-02-12",
|
||||
CategoryID: 1,
|
||||
ImageURL: "https://example.com/image.jpg",
|
||||
Quote: "Test Quote",
|
||||
}
|
||||
|
||||
if daily.ID != "2025-02-12" {
|
||||
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
|
||||
}
|
||||
if daily.CategoryID != 1 {
|
||||
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
|
||||
}
|
||||
if daily.ImageURL != "https://example.com/image.jpg" {
|
||||
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
|
||||
}
|
||||
if daily.Quote != "Test Quote" {
|
||||
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary test config file
|
||||
testConfig := `
|
||||
database:
|
||||
driver: postgres
|
||||
dsn: postgres://user:pass@localhost:5432/db
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
jwt:
|
||||
secret: test-secret
|
||||
expiration: 24h
|
||||
logging:
|
||||
level: debug
|
||||
format: console
|
||||
`
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
// Test successful config loading
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"database.driver", cfg.Database.Driver, "postgres"},
|
||||
{"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"},
|
||||
{"server.port", cfg.Server.Port, 8080},
|
||||
{"server.host", cfg.Server.Host, "localhost"},
|
||||
{"jwt.secret", cfg.JWT.Secret, "test-secret"},
|
||||
{"jwt.expiration", cfg.JWT.Expiration, "24h"},
|
||||
{"logging.level", cfg.Logging.Level, "debug"},
|
||||
{"logging.format", cfg.Logging.Format, "console"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test loading non-existent file
|
||||
_, err = Load("non-existent.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading non-existent file, got nil")
|
||||
}
|
||||
|
||||
// Test loading invalid YAML
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||
if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create invalid config file: %v", err)
|
||||
}
|
||||
|
||||
_, err = Load(invalidPath)
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading invalid YAML, got nil")
|
||||
}
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package imageutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsImageFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{"JPEG", "image/jpeg", true},
|
||||
{"PNG", "image/png", true},
|
||||
{"GIF", "image/gif", true},
|
||||
{"WebP", "image/webp", true},
|
||||
{"Invalid", "image/invalid", false},
|
||||
{"Empty", "", false},
|
||||
{"Text", "text/plain", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsImageFormat(tt.contentType); got != tt.want {
|
||||
t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultOptions(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
|
||||
if !opts.Lossless {
|
||||
t.Error("DefaultOptions().Lossless = false, want true")
|
||||
}
|
||||
if opts.Quality != 90 {
|
||||
t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality)
|
||||
}
|
||||
if opts.Compression != 4 {
|
||||
t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessImage(t *testing.T) {
|
||||
// Create a test image
|
||||
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
for y := 0; y < 100; y++ {
|
||||
for x := 0; x < 100; x++ {
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("Failed to create test PNG: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts ProcessOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Default options",
|
||||
opts: DefaultOptions(),
|
||||
},
|
||||
{
|
||||
name: "Custom quality",
|
||||
opts: ProcessOptions{
|
||||
Lossless: false,
|
||||
Quality: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := bytes.NewReader(buf.Bytes())
|
||||
result, err := ProcessImage(reader, tt.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(result) == 0 {
|
||||
t.Error("ProcessImage() returned empty result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test with invalid input
|
||||
_, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions())
|
||||
if err == nil {
|
||||
t.Error("ProcessImage() with invalid input should return error")
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tss-rocks-be/internal/config"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *config.Config
|
||||
expectedLevel zerolog.Level
|
||||
}{
|
||||
{
|
||||
name: "Debug level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.DebugLevel,
|
||||
},
|
||||
{
|
||||
name: "Info level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "Error level",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "error",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.ErrorLevel,
|
||||
},
|
||||
{
|
||||
name: "Invalid level defaults to Info",
|
||||
config: &config.Config{
|
||||
Logging: struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}{
|
||||
Level: "invalid",
|
||||
Format: "json",
|
||||
},
|
||||
},
|
||||
expectedLevel: zerolog.InfoLevel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Setup(tt.config)
|
||||
if zerolog.GlobalLevel() != tt.expectedLevel {
|
||||
t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogger(t *testing.T) {
|
||||
logger := GetLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetLogger() returned nil")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue