[chore/backend] remove all test for now

This commit is contained in:
CDN 2025-02-22 02:11:27 +08:00
parent 3d19ef05b3
commit 1c9628124f
Signed by: CDN
GPG key ID: 0C656827F9F80080
28 changed files with 0 additions and 6780 deletions

View file

@ -1,27 +0,0 @@
package auth
import (
"context"
"testing"
)
func TestUserIDKey(t *testing.T) {
// Test that the UserIDKey constant is defined correctly
if UserIDKey != "user_id" {
t.Errorf("UserIDKey = %v, want %v", UserIDKey, "user_id")
}
// Test context with user ID
ctx := context.WithValue(context.Background(), UserIDKey, "test-user-123")
value := ctx.Value(UserIDKey)
if value != "test-user-123" {
t.Errorf("Context value = %v, want %v", value, "test-user-123")
}
// Test context without user ID
emptyCtx := context.Background()
emptyValue := emptyCtx.Value(UserIDKey)
if emptyValue != nil {
t.Errorf("Empty context value = %v, want nil", emptyValue)
}
}

View file

@ -1,85 +0,0 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
// Create a temporary test config file
content := []byte(`
database:
driver: postgres
dsn: postgres://user:pass@localhost:5432/dbname
server:
port: 8080
host: localhost
jwt:
secret: test-secret
expiration: 24h
storage:
type: local
local:
root_dir: /tmp/storage
upload:
max_size: 10485760
logging:
level: info
format: json
`)
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, content, 0644); err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Test loading config
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
want interface{}
errorMsg string
}{
{"Database Driver", cfg.Database.Driver, "postgres", "incorrect database driver"},
{"Server Port", cfg.Server.Port, 8080, "incorrect server port"},
{"JWT Secret", cfg.JWT.Secret, "test-secret", "incorrect JWT secret"},
{"Storage Type", cfg.Storage.Type, "local", "incorrect storage type"},
{"Logging Level", cfg.Logging.Level, "info", "incorrect logging level"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.want {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
}
})
}
}
func TestLoadError(t *testing.T) {
// Test loading non-existent file
_, err := Load("non-existent-file.yaml")
if err == nil {
t.Error("Load() error = nil, want error for non-existent file")
}
// Test loading invalid YAML
tmpDir := t.TempDir()
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(invalidPath, []byte("invalid: }{yaml"), 0644); err != nil {
t.Fatalf("Failed to write invalid config: %v", err)
}
_, err = Load(invalidPath)
if err == nil {
t.Error("Load() error = nil, want error for invalid YAML")
}
}

View file

@ -1,312 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
)
type AuthHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *AuthHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
Auth: config.AuthConfig{
Registration: struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
}{
Enabled: true,
Message: "Registration is disabled",
},
},
}, s.service)
s.router = gin.New()
}
func (s *AuthHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite))
}
type ErrorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct {
name string
request RegisterRequest
setupMock func()
expectedStatus int
expectedError string
registration bool
}{
{
name: "成功注册",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {
s.service.EXPECT().
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(&ent.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusCreated,
registration: true,
},
{
name: "注册功能已禁用",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusForbidden,
expectedError: "Registration is disabled",
registration: false,
},
{
name: "无效的邮箱格式",
request: RegisterRequest{
Username: "testuser",
Email: "invalid-email",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
registration: true,
},
{
name: "密码太短",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "short",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
registration: true,
},
{
name: "无效的角色",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "invalid-role",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
registration: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置注册功能状态
s.handler.cfg.Auth.Registration.Enabled = tc.registration
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Register(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}
func (s *AuthHandlerTestSuite) TestLogin() {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
testCases := []struct {
name string
request LoginRequest
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功登录",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的用户名",
request: LoginRequest{
Username: "te",
Password: "password123",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
},
{
name: "用户不存在",
request: LoginRequest{
Username: "nonexistent",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "密码错误",
request: LoginRequest{
Username: "testuser",
Password: "wrongpassword",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "获取用户角色失败",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return(nil, fmt.Errorf("failed to get roles"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Login(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}

View file

@ -1,481 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/ent/categorycontent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
// Custom assertion function for comparing categories
func assertCategoryEqual(t assert.TestingT, expected, actual *ent.Category) bool {
if expected == nil && actual == nil {
return true
}
if expected == nil || actual == nil {
return assert.Fail(t, "One category is nil while the other is not")
}
// Compare only relevant fields, ignoring time fields
return assert.Equal(t, expected.ID, actual.ID) &&
assert.Equal(t, expected.Edges.Contents, actual.Edges.Contents)
}
// Custom assertion function for comparing category slices
func assertCategorySliceEqual(t assert.TestingT, expected, actual []*ent.Category) bool {
if len(expected) != len(actual) {
return assert.Fail(t, "Category slice lengths do not match")
}
for i := range expected {
if !assertCategoryEqual(t, expected[i], actual[i]) {
return false
}
}
return true
}
type CategoryHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *CategoryHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *CategoryHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestCategoryHandlerSuite(t *testing.T) {
suite.Run(t, new(CategoryHandlerTestSuite))
}
// Test cases for ListCategories
func (s *CategoryHandlerTestSuite) TestListCategories() {
testCases := []struct {
name string
langCode string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("en")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListCategories(gomock.Any(), gomock.Eq("zh")).
Return([]*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Category{
{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("zh"),
Name: "测试分类",
Description: "测试描述",
Slug: "test-category",
},
},
},
},
},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategorySliceEqual(s.T(), tc.expectedBody.([]*ent.Category), response)
}
})
}
}
// Test cases for GetCategory
func (s *CategoryHandlerTestSuite) TestGetCategory() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
langCode: "en",
slug: "test-category",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("test-category")).
Return(&ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Category{
ID: 1,
Edges: ent.CategoryEdges{
Contents: []*ent.CategoryContent{
{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: "Test Description",
Slug: "test-category",
},
},
},
},
},
{
name: "Not Found",
langCode: "en",
slug: "non-existent",
setupMock: func() {
s.service.EXPECT().
GetCategoryBySlug(gomock.Any(), gomock.Eq("en"), gomock.Eq("non-existent")).
Return(nil, types.ErrNotFound)
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/categories/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
assert.Equal(s.T(), tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
assert.NoError(s.T(), err)
assertCategoryEqual(s.T(), tc.expectedBody.(*ent.Category), &response)
}
})
}
}
// Test cases for AddCategoryContent
func (s *CategoryHandlerTestSuite) TestAddCategoryContent() {
var description = "Test Description"
testCases := []struct {
name string
categoryID string
requestBody interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(&ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.CategoryContent{
LanguageCode: categorycontent.LanguageCode("en"),
Name: "Test Category",
Description: description,
Slug: "test-category",
},
},
{
name: "Invalid JSON",
categoryID: "1",
requestBody: "invalid json",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid Category ID",
categoryID: "invalid",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "Service Error",
categoryID: "1",
requestBody: AddCategoryContentRequest{
LanguageCode: "en",
Name: "Test Category",
Description: &description,
Slug: "test-category",
},
setupMock: func() {
s.service.EXPECT().
AddCategoryContent(
gomock.Any(),
1,
"en",
"Test Category",
description,
"test-category",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
var body []byte
var err error
if str, ok := tc.requestBody.(string); ok {
body = []byte(str)
} else {
body, err = json.Marshal(tc.requestBody)
s.NoError(err)
}
req := httptest.NewRequest(http.MethodPost, "/api/v1/categories/"+tc.categoryID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response ent.CategoryContent
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, &response)
}
})
}
}
// Test cases for CreateCategory
func (s *CategoryHandlerTestSuite) TestCreateCategory() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功创建分类",
setupMock: func() {
category := &ent.Category{
ID: 1,
}
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(category, nil)
},
expectedStatus: http.StatusCreated,
},
{
name: "创建分类失败",
setupMock: func() {
s.service.EXPECT().
CreateCategory(gomock.Any()).
Return(nil, errors.New("failed to create category"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to create category",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodPost, "/categories", nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.CreateCategory(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Category
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
s.Equal(1, response.ID)
}
})
}
}

View file

@ -1,456 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
)
type ContributorHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *ContributorHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *ContributorHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestContributorHandlerSuite(t *testing.T) {
suite.Run(t, new(ContributorHandlerTestSuite))
}
func (s *ContributorHandlerTestSuite) TestListContributors() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return([]*ent.Contributor{
{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
{
ID: 2,
Name: "Jane Smith",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{}, // Ensure empty SocialLinks array is present
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []gin.H{
{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
{
"id": 2,
"name": "Jane Smith",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{}, // Ensure empty SocialLinks array is present
},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
ListContributors(gomock.Any()).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list contributors"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestGetContributor() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(&ent.Contributor{
ID: 1,
Name: "John Doe",
Edges: ent.ContributorEdges{
SocialLinks: []*ent.ContributorSocialLink{
{
Type: "github",
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
},
},
},
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{
"social_links": []gin.H{
{
"type": "github",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
},
},
},
{
name: "Invalid ID",
id: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Service error",
id: "1",
setupMock: func() {
s.service.EXPECT().
GetContributorByID(gomock.Any(), 1).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/contributors/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestCreateContributor() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(&ent.Contributor{
ID: 1,
Name: name,
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"name": "John Doe",
"created_at": time.Time{},
"updated_at": time.Time{},
"edges": gin.H{},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"name": "", // Empty name is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateContributorRequest.Name' Error:Field validation for 'Name' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateContributorRequest{
Name: "John Doe",
},
setupMock: func() {
name := "John Doe"
s.service.EXPECT().
CreateContributor(
gomock.Any(),
name,
nil,
nil,
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create contributor"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *ContributorHandlerTestSuite) TestAddContributorSocialLink() {
testCases := []struct {
name string
id string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(&ent.ContributorSocialLink{
Type: "github",
Name: name,
Value: "https://github.com/johndoe",
Edges: ent.ContributorSocialLinkEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"type": "github",
"name": "johndoe",
"value": "https://github.com/johndoe",
"edges": gin.H{},
},
},
{
name: "Invalid contributor ID",
id: "invalid",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid contributor ID"},
},
{
name: "Invalid request body",
id: "1",
body: map[string]interface{}{
"type": "", // Empty type is not allowed
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddContributorSocialLinkRequest.Type' Error:Field validation for 'Type' failed on the 'required' tag\nKey: 'AddContributorSocialLinkRequest.Value' Error:Field validation for 'Value' failed on the 'required' tag"},
},
{
name: "Service error",
id: "1",
body: func() AddContributorSocialLinkRequest {
name := "johndoe"
return AddContributorSocialLinkRequest{
Type: "github",
Name: &name,
Value: "https://github.com/johndoe",
}
}(),
setupMock: func() {
name := "johndoe"
s.service.EXPECT().
AddContributorSocialLink(
gomock.Any(),
1,
"github",
name,
"https://github.com/johndoe",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add contributor social link"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/contributors/"+tc.id+"/social-links", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -1,532 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type DailyHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *DailyHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *DailyHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestDailyHandlerSuite(t *testing.T) {
suite.Run(t, new(DailyHandlerTestSuite))
}
func (s *DailyHandlerTestSuite) TestListDailies() {
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "zh",
Quote: "测试语录1",
},
},
},
},
},
},
{
name: "Success with category filter",
categoryID: "1",
setupMock: func() {
categoryID := 1
s.service.EXPECT().
ListDailies(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
},
{
name: "Success with pagination",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Daily{
{
ID: "daily2",
ImageURL: "https://example.com/image2.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 2",
},
},
},
},
},
},
{
name: "Service Error",
setupMock: func() {
s.service.EXPECT().
ListDailies(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to list dailies"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/dailies"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestGetDaily() {
testCases := []struct {
name string
id string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{
{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
},
},
},
{
name: "Service error",
id: "daily1",
setupMock: func() {
s.service.EXPECT().
GetDailyByID(gomock.Any(), "daily1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodGet, "/api/v1/dailies/"+tc.id, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestCreateDaily() {
testCases := []struct {
name string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(&ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.Daily{
ID: "daily1",
ImageURL: "https://example.com/image1.jpg",
Edges: ent.DailyEdges{
Category: &ent.Category{ID: 1},
Contents: []*ent.DailyContent{},
},
},
},
{
name: "Invalid request body",
body: map[string]interface{}{
"id": "daily1",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'CreateDailyRequest.CategoryID' Error:Field validation for 'CategoryID' failed on the 'required' tag\nKey: 'CreateDailyRequest.ImageURL' Error:Field validation for 'ImageURL' failed on the 'required' tag"},
},
{
name: "Service error",
body: CreateDailyRequest{
ID: "daily1",
CategoryID: 1,
ImageURL: "https://example.com/image1.jpg",
},
setupMock: func() {
s.service.EXPECT().
CreateDaily(gomock.Any(), "daily1", 1, "https://example.com/image1.jpg").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create daily"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
func (s *DailyHandlerTestSuite) TestAddDailyContent() {
testCases := []struct {
name string
dailyID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(&ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: &ent.DailyContent{
LanguageCode: "en",
Quote: "Test Quote 1",
},
},
{
name: "Invalid request body",
dailyID: "daily1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddDailyContentRequest.Quote' Error:Field validation for 'Quote' failed on the 'required' tag"},
},
{
name: "Service error",
dailyID: "daily1",
body: AddDailyContentRequest{
LanguageCode: "en",
Quote: "Test Quote 1",
},
setupMock: func() {
s.service.EXPECT().
AddDailyContent(gomock.Any(), "daily1", "en", "Test Quote 1").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add daily content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/dailies/"+tc.dailyID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -1,43 +0,0 @@
package handler
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStringPtr(t *testing.T) {
testCases := []struct {
name string
input *string
expected string
}{
{
name: "nil pointer",
input: nil,
expected: "",
},
{
name: "empty string",
input: strPtr(""),
expected: "",
},
{
name: "non-empty string",
input: strPtr("test"),
expected: "test",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := stringPtr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// Helper function to create string pointer
func strPtr(s string) *string {
return &s
}

View file

@ -1,524 +0,0 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/storage"
"net/textproto"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
)
type MediaHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *MediaHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{}, s.service)
s.router = gin.New()
}
func (s *MediaHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestMediaHandlerSuite(t *testing.T) {
suite.Run(t, new(MediaHandlerTestSuite))
}
func (s *MediaHandlerTestSuite) TestListMedia() {
testCases := []struct {
name string
query string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功列出媒体",
query: "?limit=10&offset=0",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "使用默认限制和偏移",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return([]*ent.Media{{ID: 1}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "列出媒体失败",
query: "",
setupMock: func() {
s.service.EXPECT().
ListMedia(gomock.Any(), 10, 0).
Return(nil, errors.New("failed to list media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to list media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.ListMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response []*ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestUploadMedia() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功上传媒体",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
expectedFile := &multipart.FileHeader{
Filename: "test.jpg",
Size: int64(len("test content")),
Header: textproto.MIMEHeader{
"Content-Type": []string{"image/jpeg"},
},
}
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) {
s.Equal(expectedFile.Filename, f.Filename)
s.Equal(expectedFile.Size, f.Size)
s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type"))
return &ent.Media{ID: 1}, nil
})
},
expectedStatus: http.StatusCreated,
},
{
name: "未授权",
setupRequest: func() (*http.Request, error) {
req := httptest.NewRequest(http.MethodPost, "/media", nil)
return req, nil
},
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "上传失败",
setupRequest: func() (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 创建文件部分
fileHeader := make(textproto.MIMEHeader)
fileHeader.Set("Content-Type", "image/jpeg")
fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`)
part, err := writer.CreatePart(fileHeader)
if err != nil {
return nil, err
}
testContent := "test content"
_, err = io.Copy(part, strings.NewReader(testContent))
if err != nil {
return nil, err
}
writer.Close()
req := httptest.NewRequest(http.MethodPost, "/media", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
},
setupMock: func() {
s.service.EXPECT().
Upload(gomock.Any(), gomock.Any(), 1).
Return(nil, errors.New("failed to upload"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to upload media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req, err := tc.setupRequest()
s.NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.UploadMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
var response *ent.Media
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotNil(response)
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功获取媒体",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{
Size: 11,
Name: "test.jpg",
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "获取媒体元数据失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(nil, errors.New("failed to get media"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media",
},
{
name: "获取媒体文件失败",
mediaID: "1",
setupMock: func() {
media := &ent.Media{
ID: 1,
MimeType: "image/jpeg",
OriginalName: "test.jpg",
}
s.service.EXPECT().
GetMedia(gomock.Any(), 1).
Return(media, nil)
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get media file",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 执行请求
s.handler.GetMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
} else {
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal("11", w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
s.Equal("test content", w.Body.String())
}
})
}
}
func (s *MediaHandlerTestSuite) TestGetMediaFile() {
testCases := []struct {
name string
setupRequest func() (*http.Request, error)
setupMock func()
expectedStatus int
expectedBody []byte
}{
{
name: "成功获取媒体文件",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
fileContent := "test file content"
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{
Name: "test.jpg",
Size: int64(len(fileContent)),
ContentType: "image/jpeg",
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []byte("test file content"),
},
{
name: "无效的媒体ID",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
},
{
name: "获取媒体文件失败",
setupRequest: func() (*http.Request, error) {
return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil
},
setupMock: func() {
s.service.EXPECT().
GetFile(gomock.Any(), 1).
Return(nil, nil, errors.New("failed to get file"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup
req, err := tc.setupRequest()
s.Require().NoError(err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// Setup mock
tc.setupMock()
// Test
s.handler.GetMediaFile(c)
// Verify
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
s.Equal(tc.expectedBody, w.Body.Bytes())
s.Equal("image/jpeg", w.Header().Get("Content-Type"))
s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length"))
s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition"))
}
})
}
}
func (s *MediaHandlerTestSuite) TestDeleteMedia() {
testCases := []struct {
name string
mediaID string
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功删除媒体",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(nil)
},
expectedStatus: http.StatusNoContent,
},
{
name: "未授权",
mediaID: "1",
setupMock: func() {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Unauthorized",
},
{
name: "无效的媒体ID",
mediaID: "invalid",
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid media ID",
},
{
name: "删除媒体失败",
mediaID: "1",
setupMock: func() {
s.service.EXPECT().
DeleteMedia(gomock.Any(), 1, 1).
Return(errors.New("failed to delete"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to delete media",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// Extract ID from URL path
parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/")
if len(parts) >= 2 {
c.Params = []gin.Param{{Key: "id", Value: parts[1]}}
}
// 设置用户ID除了未授权的测试用例
if tc.expectedError != "Unauthorized" {
c.Set("user_id", 1)
}
// 执行请求
s.handler.DeleteMedia(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response map[string]string
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedError, response["error"])
}
})
}
}

View file

@ -1,624 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"errors"
"strings"
)
type PostHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *PostHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service)
// Setup Gin router
gin.SetMode(gin.TestMode)
s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router)
}
func (s *PostHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestPostHandlerSuite(t *testing.T) {
suite.Run(t, new(PostHandlerTestSuite))
}
// Test cases for ListPosts
func (s *PostHandlerTestSuite) TestListPosts() {
categoryID := 1
testCases := []struct {
name string
langCode string
categoryID string
limit string
offset string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "zh", nil, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
},
},
},
{
name: "Success with category filter",
langCode: "en",
categoryID: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", &categoryID, 10, 0).
Return([]*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 1,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
},
},
},
{
name: "Success with pagination",
langCode: "en",
limit: "2",
offset: "1",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 2, 1).
Return([]*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: []*ent.Post{
{
ID: 2,
Status: "published",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post 2",
ContentMarkdown: "Test Content 2",
Summary: "Test Summary 2",
},
},
},
},
},
},
{
name: "Service Error",
langCode: "en",
setupMock: func() {
s.service.EXPECT().
ListPosts(gomock.Any(), "en", nil, 10, 0).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create request
url := "/api/v1/posts"
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
if tc.categoryID != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "category_id=" + tc.categoryID
}
if tc.limit != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "limit=" + tc.limit
}
if tc.offset != "" {
if strings.Contains(url, "?") {
url += "&"
} else {
url += "?"
}
url += "offset=" + tc.offset
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
// Perform request
s.router.ServeHTTP(w, req)
// Assert response
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedBody != nil {
var response []*ent.Post
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Equal(tc.expectedBody, response)
}
})
}
}
// Test cases for GetPost
func (s *PostHandlerTestSuite) TestGetPost() {
testCases := []struct {
name string
langCode string
slug string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success with default language",
langCode: "",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
},
},
},
},
},
{
name: "Success with specific language",
langCode: "zh",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "zh", "test-post").
Return(&ent.Post{
ID: 1,
Status: "published",
Slug: "test-post",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{
{
LanguageCode: "zh",
Title: "测试帖子",
ContentMarkdown: "测试内容",
Summary: "测试摘要",
},
},
},
}, nil)
},
expectedStatus: http.StatusOK,
expectedBody: gin.H{
"id": 1,
"status": "published",
"slug": "test-post",
"edges": gin.H{
"contents": []gin.H{
{
"language_code": "zh",
"title": "测试帖子",
"content_markdown": "测试内容",
"summary": "测试摘要",
},
},
},
},
},
{
name: "Service error",
slug: "test-post",
setupMock: func() {
s.service.EXPECT().
GetPostBySlug(gomock.Any(), "en", "test-post").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to get post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
url := "/api/v1/posts/" + tc.slug
if tc.langCode != "" {
url += "?lang=" + tc.langCode
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for CreatePost
func (s *PostHandlerTestSuite) TestCreatePost() {
testCases := []struct {
name string
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(&ent.Post{
ID: 1,
Status: "draft",
Edges: ent.PostEdges{
Contents: []*ent.PostContent{},
},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"id": 1,
"status": "draft",
"edges": gin.H{
"contents": []gin.H{},
},
},
},
{
name: "Service error",
setupMock: func() {
s.service.EXPECT().
CreatePost(gomock.Any(), "draft").
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to create post"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts", nil)
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}
// Test cases for AddPostContent
func (s *PostHandlerTestSuite) TestAddPostContent() {
testCases := []struct {
name string
postID string
body interface{}
setupMock func()
expectedStatus int
expectedBody interface{}
}{
{
name: "Success",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(&ent.PostContent{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
Edges: ent.PostContentEdges{},
}, nil)
},
expectedStatus: http.StatusCreated,
expectedBody: gin.H{
"language_code": "en",
"title": "Test Post",
"content_markdown": "Test Content",
"summary": "Test Summary",
"meta_keywords": "test,keywords",
"meta_description": "Test meta description",
"edges": gin.H{},
},
},
{
name: "Invalid post ID",
postID: "invalid",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Invalid post ID"},
},
{
name: "Invalid request body",
postID: "1",
body: map[string]interface{}{
"language_code": "en",
// Missing required fields
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedBody: gin.H{"error": "Key: 'AddPostContentRequest.Title' Error:Field validation for 'Title' failed on the 'required' tag\nKey: 'AddPostContentRequest.ContentMarkdown' Error:Field validation for 'ContentMarkdown' failed on the 'required' tag\nKey: 'AddPostContentRequest.Summary' Error:Field validation for 'Summary' failed on the 'required' tag"},
},
{
name: "Service error",
postID: "1",
body: AddPostContentRequest{
LanguageCode: "en",
Title: "Test Post",
ContentMarkdown: "Test Content",
Summary: "Test Summary",
MetaKeywords: "test,keywords",
MetaDescription: "Test meta description",
},
setupMock: func() {
s.service.EXPECT().
AddPostContent(
gomock.Any(),
1,
"en",
"Test Post",
"Test Content",
"Test Summary",
"test,keywords",
"Test meta description",
).
Return(nil, errors.New("service error"))
},
expectedStatus: http.StatusInternalServerError,
expectedBody: gin.H{"error": "Failed to add post content"},
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
tc.setupMock()
body, err := json.Marshal(tc.body)
s.NoError(err, "Failed to marshal request body")
req := httptest.NewRequest(http.MethodPost, "/api/v1/posts/"+tc.postID+"/contents", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
s.router.ServeHTTP(w, req)
s.Equal(tc.expectedStatus, w.Code, "HTTP status code mismatch")
if tc.expectedBody != nil {
expectedJSON, err := json.Marshal(tc.expectedBody)
s.NoError(err, "Failed to marshal expected body")
s.JSONEq(string(expectedJSON), w.Body.String(), "Response body mismatch")
}
})
}
}

View file

@ -1,238 +0,0 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestAccessLog(t *testing.T) {
// 设置测试临时目录
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "test.log")
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
setupRequest func(*http.Request)
validateOutput func(*testing.T, *httptest.ResponseRecorder, string)
}{
{
name: "Console logging only",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
},
},
{
name: "File logging only",
config: &types.AccessLogConfig{
EnableConsole: false,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "Both console and file logging",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: logPath,
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 1,
MaxAge: 1,
MaxBackups: 1,
Compress: false,
LocalTime: true,
},
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
// 读取日志文件内容
content, err := os.ReadFile(logPath)
assert.NoError(t, err)
assert.Contains(t, string(content), "GET /test")
assert.Contains(t, string(content), "test-agent")
},
},
{
name: "With authenticated user",
config: &types.AccessLogConfig{
EnableConsole: true,
EnableFile: false,
Format: "json",
Level: "info",
},
expectedError: false,
setupRequest: func(req *http.Request) {
req.Header.Set("User-Agent", "test-agent")
},
validateOutput: func(t *testing.T, w *httptest.ResponseRecorder, logOutput string) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, logOutput, "GET /test")
assert.Contains(t, logOutput, "test-agent")
assert.Contains(t, logOutput, "test-user")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 捕获标准输出
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建访问日志中间件
middleware, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// 添加测试路由
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
// 如果是测试认证用户的情况设置用户ID
if tc.name == "With authenticated user" {
c.Set("user_id", "test-user")
}
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest("GET", "/test", nil)
if tc.setupRequest != nil {
tc.setupRequest(req)
}
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 恢复标准输出并获取输出内容
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
os.Stdout = oldStdout
// 验证输出
if tc.validateOutput != nil {
tc.validateOutput(t, rec, buf.String())
}
// 关闭日志文件
if tc.config.EnableFile {
// 调用中间件函数来关闭日志文件
middleware(nil)
// 等待一小段时间确保文件完全关闭
time.Sleep(100 * time.Millisecond)
}
})
}
}
func TestAccessLogInvalidConfig(t *testing.T) {
testCases := []struct {
name string
config *types.AccessLogConfig
expectedError bool
}{
{
name: "Invalid log level",
config: &types.AccessLogConfig{
EnableConsole: true,
Level: "invalid_level",
},
expectedError: false, // 应该使用默认的 info 级别
},
{
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := AccessLog(tc.config)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View file

@ -1,227 +0,0 @@
package middleware
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/service"
)
func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
panic(fmt.Sprintf("Failed to sign token: %v", err))
}
return signedToken
}
func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret"
tokenBlacklist := service.NewTokenBlacklist()
testCases := []struct {
name string
setupAuth func(*http.Request)
expectedStatus int
expectedBody map[string]string
checkUserData bool
expectedUserID string
expectedRoles []string
}{
{
name: "No Authorization header",
setupAuth: func(req *http.Request) {},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header is required"},
},
{
name: "Invalid Authorization format",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "InvalidFormat")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Authorization header format must be Bearer {token}"},
},
{
name: "Invalid token",
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer invalid.token.here")
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
{
name: "Valid token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"admin", "editor"},
"exp": time.Now().Add(time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusOK,
checkUserData: true,
expectedUserID: "123",
expectedRoles: []string{"admin", "editor"},
},
{
name: "Expired token",
setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{
"sub": "123",
"roles": []string{"user"},
"exp": time.Now().Add(-time.Hour).Unix(),
}
token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token)
},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "Invalid token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加认证中间件
router.Use(func(c *gin.Context) {
// 设置日志级别为 debug
gin.SetMode(gin.DebugMode)
c.Next()
}, AuthMiddleware(jwtSecret, tokenBlacklist))
// 测试路由
router.GET("/test", func(c *gin.Context) {
if tc.checkUserData {
userID, exists := c.Get("user_id")
assert.True(t, exists, "user_id should exist in context")
assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
roles, exists := c.Get("user_roles")
assert.True(t, exists, "user_roles should exist in context")
assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
}
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
tc.setupAuth(req)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}
func TestRoleMiddleware(t *testing.T) {
testCases := []struct {
name string
setupContext func(*gin.Context)
allowedRoles []string
expectedStatus int
expectedBody map[string]string
}{
{
name: "No user roles",
setupContext: func(c *gin.Context) {
// 不设置用户角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User roles not found"},
},
{
name: "Invalid roles type",
setupContext: func(c *gin.Context) {
c.Set("user_roles", 123) // 设置错误类型的角色
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user roles type"},
},
{
name: "Insufficient permissions",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden,
expectedBody: map[string]string{"error": "Insufficient permissions"},
},
{
name: "Allowed role",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"admin"})
},
allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK,
},
{
name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) {
c.Set("user_roles", []string{"user", "editor"})
},
allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加角色中间件
router.Use(func(c *gin.Context) {
tc.setupContext(c)
c.Next()
}, RoleMiddleware(tc.allowedRoles...))
// 测试路由
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建请求
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil {
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response, "Response body should match")
}
})
}
}

View file

@ -1,76 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCORS(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
checkHeaders bool
}{
{
name: "Normal GET request",
method: "GET",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
{
name: "OPTIONS request",
method: "OPTIONS",
expectedStatus: http.StatusNoContent,
checkHeaders: true,
},
{
name: "POST request",
method: "POST",
expectedStatus: http.StatusOK,
checkHeaders: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 创建一个新的 gin 引擎
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加 CORS 中间件
router.Use(CORS())
// 添加测试路由
router.Any("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// 创建测试请求
req := httptest.NewRequest(tc.method, "/test", nil)
rec := httptest.NewRecorder()
// 执行请求
router.ServeHTTP(rec, req)
// 验证状态码
assert.Equal(t, tc.expectedStatus, rec.Code)
if tc.checkHeaders {
// 验证 CORS 头部
headers := rec.Header()
assert.Equal(t, "*", headers.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", headers.Get("Access-Control-Allow-Credentials"))
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, headers.Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "PUT")
assert.Contains(t, headers.Get("Access-Control-Allow-Methods"), "DELETE")
}
})
}
}

View file

@ -1,207 +0,0 @@
package middleware
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
"tss-rocks-be/internal/types"
)
func TestRateLimit(t *testing.T) {
testCases := []struct {
name string
config *types.RateLimitConfig
setupTest func(*gin.Engine)
runTest func(*testing.T, *gin.Engine)
expectedStatus int
expectedBody map[string]string
}{
{
name: "IP rate limit",
config: &types.RateLimitConfig{
IPRate: 1, // 每秒1个请求
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 第一个请求应该成功
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 第二个请求应该被限制
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests from this IP", response["error"])
// 等待限流器重置
time.Sleep(time.Second)
// 第三个请求应该成功
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Route rate limit",
config: &types.RateLimitConfig{
IPRate: 100, // 设置较高的 IP 限流,以便测试路由限流
IPBurst: 10,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/limited": {
Rate: 1,
Burst: 1,
},
},
},
setupTest: func(router *gin.Engine) {
router.GET("/limited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
router.GET("/unlimited", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// 测试限流路由
req := httptest.NewRequest("GET", "/limited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err)
assert.Equal(t, "too many requests for this route", response["error"])
// 测试未限流路由
req = httptest.NewRequest("GET", "/unlimited", nil)
req.RemoteAddr = "192.168.1.2:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// 等待一小段时间确保限流器生效
time.Sleep(10 * time.Millisecond)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
{
name: "Multiple IPs",
config: &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
},
setupTest: func(router *gin.Engine) {
router.GET("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
},
runTest: func(t *testing.T, router *gin.Engine) {
// IP1 的请求
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.3:1234"
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req1)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
// IP2 的请求应该不受 IP1 的限制影响
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.4:1234"
rec = httptest.NewRecorder()
router.ServeHTTP(rec, req2)
assert.Equal(t, http.StatusOK, rec.Code)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 添加限流中间件
router.Use(RateLimit(tc.config))
// 设置测试路由
tc.setupTest(router)
// 运行测试
tc.runTest(t, router)
})
}
}
func TestRateLimiterCleanup(t *testing.T) {
config := &types.RateLimitConfig{
IPRate: 1,
IPBurst: 1,
}
rl := newRateLimiter(config)
// 添加一些IP限流器
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
rl.getLimiter(ip)
}
// 验证IP限流器已创建
rl.mu.RLock()
assert.Equal(t, len(ips), len(rl.ips))
rl.mu.RUnlock()
// 修改一些IP的最后访问时间为1小时前
rl.mu.Lock()
rl.ips["192.168.1.1"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.ips["192.168.1.2"].lastSeen = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
// 手动触发清理
rl.mu.Lock()
for ip, limiter := range rl.ips {
if time.Since(limiter.lastSeen) > time.Hour {
delete(rl.ips, ip)
}
}
rl.mu.Unlock()
// 验证过期的IP限流器已被删除
rl.mu.RLock()
assert.Equal(t, 1, len(rl.ips))
_, exists := rl.ips["192.168.1.3"]
assert.True(t, exists)
rl.mu.RUnlock()
}

View file

@ -1,262 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}

View file

@ -1,99 +0,0 @@
package rbac
import (
"context"
"testing"
"tss-rocks-be/ent/enttest"
"tss-rocks-be/ent/role"
_ "github.com/mattn/go-sqlite3"
)
func TestInitializeRBAC(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Test initialization
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Verify roles were created
for roleName := range DefaultRoles {
r, err := client.Role.Query().Where(role.Name(roleName)).Only(ctx)
if err != nil {
t.Errorf("Role %s was not created: %v", roleName, err)
}
// Verify permissions for each role
perms, err := r.QueryPermissions().All(ctx)
if err != nil {
t.Errorf("Failed to query permissions for role %s: %v", roleName, err)
}
expectedPerms := DefaultRoles[roleName]
permCount := 0
for _, actions := range expectedPerms {
permCount += len(actions)
}
if len(perms) != permCount {
t.Errorf("Role %s has %d permissions, expected %d", roleName, len(perms), permCount)
}
}
}
func TestAssignRoleToUser(t *testing.T) {
// Create an in-memory SQLite client for testing
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// Initialize RBAC
err := InitializeRBAC(ctx, client)
if err != nil {
t.Fatalf("Failed to initialize RBAC: %v", err)
}
// Create a test user
user, err := client.User.Create().
SetEmail("test@example.com").
SetUsername("testuser").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx)
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
// Test assigning role to user
err = AssignRoleToUser(ctx, client, user.ID, "editor")
if err != nil {
t.Fatalf("Failed to assign role to user: %v", err)
}
// Verify role assignment
assignedRoles, err := user.QueryRoles().All(ctx)
if err != nil {
t.Fatalf("Failed to query user roles: %v", err)
}
if len(assignedRoles) != 1 {
t.Errorf("Expected 1 role, got %d", len(assignedRoles))
}
if assignedRoles[0].Name != "editor" {
t.Errorf("Expected role name 'editor', got '%s'", assignedRoles[0].Name)
}
// Test assigning non-existent role
err = AssignRoleToUser(ctx, client, user.ID, "nonexistent")
if err == nil {
t.Error("Expected error when assigning non-existent role, got nil")
}
}

View file

@ -1,64 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitDatabase(t *testing.T) {
tests := []struct {
name string
driver string
dsn string
wantErr bool
errContains string
}{
{
name: "success with sqlite3",
driver: "sqlite3",
dsn: "file:ent?mode=memory&cache=shared&_fk=1",
},
{
name: "invalid driver",
driver: "invalid_driver",
dsn: "file:ent?mode=memory",
wantErr: true,
errContains: "unsupported driver",
},
{
name: "invalid dsn",
driver: "sqlite3",
dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项
wantErr: true,
errContains: "foreign_keys pragma is off",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client, err := InitDatabase(ctx, tt.driver, tt.dsn)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, client)
} else {
require.NoError(t, err)
assert.NotNil(t, client)
// 测试数据库连接是否正常工作
err = client.Schema.Create(ctx)
assert.NoError(t, err)
// 清理
client.Close()
}
})
}
}

View file

@ -1,40 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"tss-rocks-be/internal/config"
)
func TestNewEntClient(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
}{
{
name: "default sqlite3 config",
cfg: &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: "file:ent?mode=memory&cache=shared&_fk=1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewEntClient(tt.cfg)
assert.NotNil(t, client)
// 验证客户端是否可以正常工作
err := client.Schema.Create(context.Background())
assert.NoError(t, err)
// 清理
client.Close()
})
}
}

View file

@ -1,220 +0,0 @@
package server
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/types"
"tss-rocks-be/ent/enttest"
)
func TestNew(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8080,
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
Upload: types.UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/v1/upload": {Rate: 10, Burst: 20},
},
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "testdata/access.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 3,
Compress: true,
LocalTime: true,
},
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 测试服务器初始化
s, err := New(cfg, client)
require.NoError(t, err)
assert.NotNil(t, s)
assert.NotNil(t, s.router)
assert.NotNil(t, s.handler)
assert.Equal(t, cfg, s.config)
}
func TestNew_StorageError(t *testing.T) {
// 创建一个无效的存储配置
cfg := &config.Config{
Storage: config.StorageConfig{
Type: "invalid_type", // 使用无效的存储类型
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
s, err := New(cfg, client)
assert.Error(t, err)
assert.Nil(t, s)
assert.Contains(t, err.Error(), "failed to initialize storage")
}
func TestServer_StartAndShutdown(t *testing.T) {
// 创建测试配置
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 0, // 使用随机端口
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
RateLimit: types.RateLimitConfig{
IPRate: 100,
IPBurst: 200,
},
AccessLog: types.AccessLogConfig{
EnableConsole: true,
Format: "json",
Level: "info",
},
}
// 创建测试数据库客户端
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 初始化服务器
s, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 在 goroutine 中启动服务器
go func() {
err := s.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 测试关闭服务器
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = s.Shutdown(ctx)
assert.NoError(t, err)
// 检查服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
}
func TestServer_StartError(t *testing.T) {
// 创建一个配置,使用已经被占用的端口来触发错误
cfg := &config.Config{
Server: config.ServerConfig{
Host: "localhost",
Port: 8899, // 使用固定端口以便测试
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "testdata",
},
},
}
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
// 创建第一个服务器实例
s1, err := New(cfg, client)
require.NoError(t, err)
// 创建一个通道来接收服务器错误
errChan := make(chan error, 1)
// 启动第一个服务器
go func() {
err := s1.Start()
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
close(errChan)
}()
// 给服务器一些时间启动
time.Sleep(100 * time.Millisecond)
// 尝试在同一端口启动第二个服务器,应该会失败
s2, err := New(cfg, client)
require.NoError(t, err)
err = s2.Start()
assert.Error(t, err)
// 清理
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 关闭第一个服务器
err = s1.Shutdown(ctx)
assert.NoError(t, err)
// 检查第一个服务器是否有错误发生
err = <-errChan
assert.NoError(t, err)
// 关闭第二个服务器
err = s2.Shutdown(ctx)
assert.NoError(t, err)
}
func TestServer_ShutdownWithNilServer(t *testing.T) {
s := &Server{}
err := s.Shutdown(context.Background())
assert.NoError(t, err)
}

File diff suppressed because it is too large Load diff

View file

@ -1,332 +0,0 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/textproto"
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/storage/mock"
"tss-rocks-be/internal/testutil"
"bou.ke/monkey"
)
type MediaServiceTestSuite struct {
suite.Suite
ctx context.Context
client *ent.Client
storage *mock.MockStorage
ctrl *gomock.Controller
svc MediaService
}
func (s *MediaServiceTestSuite) SetupTest() {
s.ctx = context.Background()
s.client = testutil.NewTestClient()
require.NotNil(s.T(), s.client)
s.ctrl = gomock.NewController(s.T())
s.storage = mock.NewMockStorage(s.ctrl)
s.svc = NewMediaService(s.client, s.storage)
// 清理数据库
_, err := s.client.Media.Delete().Exec(s.ctx)
require.NoError(s.T(), err)
}
func (s *MediaServiceTestSuite) TearDownTest() {
s.ctrl.Finish()
s.client.Close()
}
func TestMediaServiceSuite(t *testing.T) {
suite.Run(t, new(MediaServiceTestSuite))
}
type mockFileHeader struct {
filename string
contentType string
size int64
content []byte
}
func (h *mockFileHeader) Open() (multipart.File, error) {
return newMockMultipartFile(h.content), nil
}
func (h *mockFileHeader) Filename() string {
return h.filename
}
func (h *mockFileHeader) Size() int64 {
return h.size
}
func (h *mockFileHeader) Header() textproto.MIMEHeader {
header := make(textproto.MIMEHeader)
header.Set("Content-Type", h.contentType)
return header
}
func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader {
header := &multipart.FileHeader{
Filename: filename,
Header: make(textproto.MIMEHeader),
Size: int64(len(content)),
}
header.Header.Set("Content-Type", contentType)
monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) {
return newMockMultipartFile(content), nil
})
return header
}
func (s *MediaServiceTestSuite) TestUpload() {
testCases := []struct {
name string
filename string
contentType string
content []byte
setupMock func()
wantErr bool
errMsg string
}{
{
name: "Upload text file",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) {
content, err := io.ReadAll(reader)
s.Require().NoError(err)
s.Equal([]byte("test content"), content)
return &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: int64(len(content)),
}, nil
})
},
wantErr: false,
},
{
name: "Invalid filename",
filename: "../test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {},
wantErr: true,
errMsg: "invalid filename",
},
{
name: "Storage error",
filename: "test.txt",
contentType: "text/plain",
content: []byte("test content"),
setupMock: func() {
s.storage.EXPECT().
Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()).
Return(nil, fmt.Errorf("storage error"))
},
wantErr: true,
errMsg: "storage error",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// Setup mock
tc.setupMock()
// Create test file
fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content)
// Add debug output
s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size)
// Test upload
media, err := s.svc.Upload(s.ctx, fileHeader, 1)
// Add debug output
if err != nil {
s.T().Logf("Upload error: %v", err)
}
if tc.wantErr {
s.Require().Error(err)
s.Contains(err.Error(), tc.errMsg)
return
}
s.Require().NoError(err)
s.NotNil(media)
s.Equal(tc.filename, media.OriginalName)
s.Equal(tc.contentType, media.MimeType)
s.Equal(int64(len(tc.content)), media.Size)
s.Equal("1", media.CreatedBy)
})
}
}
func (s *MediaServiceTestSuite) TestGet() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test get existing media
result, err := s.svc.Get(s.ctx, media.ID)
s.Require().NoError(err)
s.Equal(media.ID, result.ID)
s.Equal(media.OriginalName, result.OriginalName)
// Test get non-existing media
_, err = s.svc.Get(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "media not found")
}
func (s *MediaServiceTestSuite) TestDelete() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Test delete by unauthorized user
err = s.svc.Delete(s.ctx, media.ID, 2)
s.Require().Error(err)
s.Contains(err.Error(), "unauthorized")
// Test delete by owner
s.storage.EXPECT().
Delete(gomock.Any(), "test-id").
Return(nil)
err = s.svc.Delete(s.ctx, media.ID, 1)
s.Require().NoError(err)
// Verify media is deleted
_, err = s.svc.Get(s.ctx, media.ID)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestList() {
// Create test media
for i := 0; i < 5; i++ {
_, err := s.client.Media.Create().
SetStorageID(fmt.Sprintf("test-id-%d", i)).
SetOriginalName(fmt.Sprintf("test-%d.txt", i)).
SetMimeType("text/plain").
SetSize(12).
SetURL(fmt.Sprintf("/api/media/test-id-%d", i)).
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
}
// Test list with limit and offset
media, err := s.svc.List(s.ctx, 3, 1)
s.Require().NoError(err)
s.Len(media, 3)
}
func (s *MediaServiceTestSuite) TestGetFile() {
// Create test media
media, err := s.client.Media.Create().
SetStorageID("test-id").
SetOriginalName("test.txt").
SetMimeType("text/plain").
SetSize(12).
SetURL("/api/media/test-id").
SetCreatedBy("1").
Save(s.ctx)
s.Require().NoError(err)
// Mock storage.Get
mockReader := io.NopCloser(bytes.NewReader([]byte("test content")))
mockFileInfo := &storage.FileInfo{
ID: "test-id",
Name: "test.txt",
ContentType: "text/plain",
Size: 12,
}
s.storage.EXPECT().
Get(gomock.Any(), "test-id").
Return(mockReader, mockFileInfo, nil)
// Test get file
reader, info, err := s.svc.GetFile(s.ctx, media.ID)
s.Require().NoError(err)
s.NotNil(reader)
s.Equal(mockFileInfo, info)
// Test get non-existing file
_, _, err = s.svc.GetFile(s.ctx, -1)
s.Require().Error(err)
s.Contains(err.Error(), "not found")
}
func (s *MediaServiceTestSuite) TestIsValidFilename() {
testCases := []struct {
name string
filename string
want bool
}{
{
name: "Valid filename",
filename: "test.txt",
want: true,
},
{
name: "Invalid filename with ../",
filename: "../test.txt",
want: false,
},
{
name: "Invalid filename with ./",
filename: "./test.txt",
want: false,
},
{
name: "Invalid filename with backslash",
filename: "test\\file.txt",
want: false,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
got := isValidFilename(tc.filename)
s.Equal(tc.want, got)
})
}
}

View file

@ -1,154 +0,0 @@
package storage
import (
"bytes"
"context"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLocalStorage(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "storage_test_*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a new LocalStorage instance
storage, err := NewLocalStorage(tempDir)
require.NoError(t, err)
ctx := context.Background()
t.Run("Save and Get", func(t *testing.T) {
content := []byte("test content")
reader := bytes.NewReader(content)
// Save the file
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, int64(len(content)), fileInfo.Size)
assert.Equal(t, "text/plain", fileInfo.ContentType)
assert.False(t, fileInfo.CreatedAt.IsZero())
// Get the file
readCloser, info, err := storage.Get(ctx, fileInfo.ID)
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, fileInfo.ID, info.ID)
assert.Equal(t, fileInfo.Name, info.Name)
assert.Equal(t, fileInfo.Size, info.Size)
})
t.Run("List", func(t *testing.T) {
// Clear the directory first
dirEntries, err := os.ReadDir(tempDir)
require.NoError(t, err)
for _, entry := range dirEntries {
if entry.Name() != ".meta" {
os.Remove(filepath.Join(tempDir, entry.Name()))
}
}
// Save multiple files
testFiles := []struct {
name string
content string
}{
{"test1.txt", "content1"},
{"test2.txt", "content2"},
{"other.txt", "content3"},
}
for _, f := range testFiles {
reader := bytes.NewReader([]byte(f.content))
_, err := storage.Save(ctx, f.name, "text/plain", reader)
require.NoError(t, err)
}
// List all files
allFiles, err := storage.List(ctx, "", 10, 0)
require.NoError(t, err)
assert.Len(t, allFiles, 3)
// List files with prefix
filesWithPrefix, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, filesWithPrefix, 2)
for _, f := range filesWithPrefix {
assert.True(t, strings.HasPrefix(f.Name, "test"))
}
// Test pagination
pagedFiles, err := storage.List(ctx, "", 2, 1)
require.NoError(t, err)
assert.Len(t, pagedFiles, 2)
})
t.Run("Exists", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "exists.txt", "text/plain", reader)
require.NoError(t, err)
// Check if file exists
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.True(t, exists)
// Check non-existent file
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
// Save a file
content := []byte("test content")
reader := bytes.NewReader(content)
fileInfo, err := storage.Save(ctx, "delete.txt", "text/plain", reader)
require.NoError(t, err)
// Delete the file
err = storage.Delete(ctx, fileInfo.ID)
require.NoError(t, err)
// Verify file is deleted
exists, err := storage.Exists(ctx, fileInfo.ID)
require.NoError(t, err)
assert.False(t, exists)
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
})
t.Run("Invalid operations", func(t *testing.T) {
// Try to get non-existent file
_, _, err := storage.Get(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
// Try to save file with nil reader
_, err = storage.Save(ctx, "test.txt", "text/plain", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "reader cannot be nil")
// Try to delete non-existent file
err = storage.Delete(ctx, "non-existent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "file not found")
})
}

View file

@ -1,211 +0,0 @@
package storage
import (
"bytes"
"context"
"io"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockS3Client is a mock implementation of the S3 client interface
type MockS3Client struct {
mock.Mock
}
func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
}
func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.GetObjectOutput), args.Error(1)
}
func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.DeleteObjectOutput), args.Error(1)
}
func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.ListObjectsV2Output), args.Error(1)
}
func (m *MockS3Client) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
args := m.Called(ctx, params)
return args.Get(0).(*s3.HeadObjectOutput), args.Error(1)
}
func TestS3Storage(t *testing.T) {
ctx := context.Background()
mockClient := new(MockS3Client)
storage := NewS3Storage(mockClient, "test-bucket", "", false)
t.Run("Save", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
content := []byte("test content")
reader := bytes.NewReader(content)
// Mock HeadObject to return NotFound error
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
})
mockClient.On("PutObject", ctx, mock.MatchedBy(func(input *s3.PutObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.ContentType) == "text/plain"
})).Return(&s3.PutObjectOutput{}, nil)
fileInfo, err := storage.Save(ctx, "test.txt", "text/plain", reader)
require.NoError(t, err)
assert.NotEmpty(t, fileInfo.ID)
assert.Equal(t, "test.txt", fileInfo.Name)
assert.Equal(t, "text/plain", fileInfo.ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Get", func(t *testing.T) {
content := []byte("test content")
mockClient.On("GetObject", ctx, mock.MatchedBy(func(input *s3.GetObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.GetObjectOutput{
Body: io.NopCloser(bytes.NewReader(content)),
ContentType: aws.String("text/plain"),
ContentLength: aws.Int64(int64(len(content))),
LastModified: aws.Time(time.Now()),
}, nil)
readCloser, info, err := storage.Get(ctx, "test-id")
require.NoError(t, err)
defer readCloser.Close()
data, err := io.ReadAll(readCloser)
require.NoError(t, err)
assert.Equal(t, content, data)
assert.Equal(t, "test-id", info.ID)
assert.Equal(t, int64(len(content)), info.Size)
mockClient.AssertExpectations(t)
})
t.Run("List", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
mockClient.On("ListObjectsV2", ctx, mock.MatchedBy(func(input *s3.ListObjectsV2Input) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Prefix) == "test" &&
aws.ToInt32(input.MaxKeys) == 10
})).Return(&s3.ListObjectsV2Output{
Contents: []types.Object{
{
Key: aws.String("test1"),
Size: aws.Int64(100),
LastModified: aws.Time(time.Now()),
},
{
Key: aws.String("test2"),
Size: aws.Int64(200),
LastModified: aws.Time(time.Now()),
},
},
}, nil)
// Mock HeadObject for both files
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test1"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test1.txt",
},
}, nil).Once()
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test2"
})).Return(&s3.HeadObjectOutput{
ContentType: aws.String("text/plain"),
Metadata: map[string]string{
"x-amz-meta-original-name": "test2.txt",
},
}, nil).Once()
files, err := storage.List(ctx, "test", 10, 0)
require.NoError(t, err)
assert.Len(t, files, 2)
assert.Equal(t, "test1", files[0].ID)
assert.Equal(t, int64(100), files[0].Size)
assert.Equal(t, "test1.txt", files[0].Name)
assert.Equal(t, "text/plain", files[0].ContentType)
mockClient.AssertExpectations(t)
})
t.Run("Delete", func(t *testing.T) {
mockClient.On("DeleteObject", ctx, mock.MatchedBy(func(input *s3.DeleteObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.DeleteObjectOutput{}, nil)
err := storage.Delete(ctx, "test-id")
require.NoError(t, err)
mockClient.AssertExpectations(t)
})
t.Run("Exists", func(t *testing.T) {
mockClient.ExpectedCalls = nil
mockClient.Calls = nil
// Mock HeadObject for existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "test-id"
})).Return(&s3.HeadObjectOutput{}, nil).Once()
exists, err := storage.Exists(ctx, "test-id")
require.NoError(t, err)
assert.True(t, exists)
// Mock HeadObject for non-existing file
mockClient.On("HeadObject", ctx, mock.MatchedBy(func(input *s3.HeadObjectInput) bool {
return aws.ToString(input.Bucket) == "test-bucket" &&
aws.ToString(input.Key) == "non-existent"
})).Return(&s3.HeadObjectOutput{}, &types.NoSuchKey{
Message: aws.String("The specified key does not exist."),
}).Once()
exists, err = storage.Exists(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
mockClient.AssertExpectations(t)
})
t.Run("Custom URL", func(t *testing.T) {
customStorage := &S3Storage{
client: mockClient,
bucket: "test-bucket",
customURL: "https://custom.domain",
proxyS3: true,
}
assert.Contains(t, customStorage.getObjectURL("test-id"), "https://custom.domain")
})
}

View file

@ -1,116 +0,0 @@
package types
import (
"testing"
)
func TestRateLimitConfig(t *testing.T) {
config := RateLimitConfig{
IPRate: 100,
IPBurst: 200,
RouteRates: map[string]struct {
Rate int `yaml:"rate"`
Burst int `yaml:"burst"`
}{
"/api/test": {
Rate: 50,
Burst: 100,
},
},
}
if config.IPRate != 100 {
t.Errorf("Expected IPRate 100, got %d", config.IPRate)
}
if config.IPBurst != 200 {
t.Errorf("Expected IPBurst 200, got %d", config.IPBurst)
}
route := config.RouteRates["/api/test"]
if route.Rate != 50 {
t.Errorf("Expected route rate 50, got %d", route.Rate)
}
if route.Burst != 100 {
t.Errorf("Expected route burst 100, got %d", route.Burst)
}
}
func TestAccessLogConfig(t *testing.T) {
config := AccessLogConfig{
EnableConsole: true,
EnableFile: true,
FilePath: "/var/log/app.log",
Format: "json",
Level: "info",
Rotation: struct {
MaxSize int `yaml:"max_size"`
MaxAge int `yaml:"max_age"`
MaxBackups int `yaml:"max_backups"`
Compress bool `yaml:"compress"`
LocalTime bool `yaml:"local_time"`
}{
MaxSize: 100,
MaxAge: 7,
MaxBackups: 5,
Compress: true,
LocalTime: true,
},
}
if !config.EnableConsole {
t.Error("Expected EnableConsole to be true")
}
if !config.EnableFile {
t.Error("Expected EnableFile to be true")
}
if config.FilePath != "/var/log/app.log" {
t.Errorf("Expected FilePath '/var/log/app.log', got '%s'", config.FilePath)
}
if config.Format != "json" {
t.Errorf("Expected Format 'json', got '%s'", config.Format)
}
if config.Level != "info" {
t.Errorf("Expected Level 'info', got '%s'", config.Level)
}
rotation := config.Rotation
if rotation.MaxSize != 100 {
t.Errorf("Expected MaxSize 100, got %d", rotation.MaxSize)
}
if rotation.MaxAge != 7 {
t.Errorf("Expected MaxAge 7, got %d", rotation.MaxAge)
}
if rotation.MaxBackups != 5 {
t.Errorf("Expected MaxBackups 5, got %d", rotation.MaxBackups)
}
if !rotation.Compress {
t.Error("Expected Compress to be true")
}
if !rotation.LocalTime {
t.Error("Expected LocalTime to be true")
}
}
func TestUploadConfig(t *testing.T) {
config := UploadConfig{
MaxSize: 10,
AllowedTypes: []string{"image/jpeg", "image/png"},
AllowedExtensions: []string{".jpg", ".png"},
}
if config.MaxSize != 10 {
t.Errorf("Expected MaxSize 10, got %d", config.MaxSize)
}
if len(config.AllowedTypes) != 2 {
t.Errorf("Expected 2 AllowedTypes, got %d", len(config.AllowedTypes))
}
if config.AllowedTypes[0] != "image/jpeg" {
t.Errorf("Expected AllowedTypes[0] 'image/jpeg', got '%s'", config.AllowedTypes[0])
}
if len(config.AllowedExtensions) != 2 {
t.Errorf("Expected 2 AllowedExtensions, got %d", len(config.AllowedExtensions))
}
if config.AllowedExtensions[0] != ".jpg" {
t.Errorf("Expected AllowedExtensions[0] '.jpg', got '%s'", config.AllowedExtensions[0])
}
}

View file

@ -1,21 +0,0 @@
package types
import "testing"
func TestFileInfo(t *testing.T) {
fileInfo := FileInfo{
Size: 1024,
Name: "test.jpg",
ContentType: "image/jpeg",
}
if fileInfo.Size != 1024 {
t.Errorf("Expected Size 1024, got %d", fileInfo.Size)
}
if fileInfo.Name != "test.jpg" {
t.Errorf("Expected Name 'test.jpg', got '%s'", fileInfo.Name)
}
if fileInfo.ContentType != "image/jpeg" {
t.Errorf("Expected ContentType 'image/jpeg', got '%s'", fileInfo.ContentType)
}
}

View file

@ -1,77 +0,0 @@
package types
import (
"testing"
)
func TestCategory(t *testing.T) {
description := "Test Description"
category := Category{
ID: 1,
Name: "Test Category",
Slug: "test-category",
Description: &description,
}
if category.ID != 1 {
t.Errorf("Expected ID 1, got %d", category.ID)
}
if category.Name != "Test Category" {
t.Errorf("Expected name 'Test Category', got '%s'", category.Name)
}
if category.Slug != "test-category" {
t.Errorf("Expected slug 'test-category', got '%s'", category.Slug)
}
if *category.Description != description {
t.Errorf("Expected description '%s', got '%s'", description, *category.Description)
}
}
func TestPost(t *testing.T) {
metaKeywords := "test,blog"
metaDesc := "Test Description"
post := Post{
ID: 1,
Title: "Test Post",
Slug: "test-post",
ContentMarkdown: "# Test Content",
Summary: "Test Summary",
MetaKeywords: &metaKeywords,
MetaDescription: &metaDesc,
}
if post.ID != 1 {
t.Errorf("Expected ID 1, got %d", post.ID)
}
if post.Title != "Test Post" {
t.Errorf("Expected title 'Test Post', got '%s'", post.Title)
}
if post.Slug != "test-post" {
t.Errorf("Expected slug 'test-post', got '%s'", post.Slug)
}
if *post.MetaKeywords != metaKeywords {
t.Errorf("Expected meta keywords '%s', got '%s'", metaKeywords, *post.MetaKeywords)
}
}
func TestDaily(t *testing.T) {
daily := Daily{
ID: "2025-02-12",
CategoryID: 1,
ImageURL: "https://example.com/image.jpg",
Quote: "Test Quote",
}
if daily.ID != "2025-02-12" {
t.Errorf("Expected ID '2025-02-12', got '%s'", daily.ID)
}
if daily.CategoryID != 1 {
t.Errorf("Expected CategoryID 1, got %d", daily.CategoryID)
}
if daily.ImageURL != "https://example.com/image.jpg" {
t.Errorf("Expected ImageURL 'https://example.com/image.jpg', got '%s'", daily.ImageURL)
}
if daily.Quote != "Test Quote" {
t.Errorf("Expected Quote 'Test Quote', got '%s'", daily.Quote)
}
}

View file

@ -1,77 +0,0 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoad(t *testing.T) {
// Create a temporary test config file
testConfig := `
database:
driver: postgres
dsn: postgres://user:pass@localhost:5432/db
server:
port: 8080
host: localhost
jwt:
secret: test-secret
expiration: 24h
logging:
level: debug
format: console
`
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(testConfig), 0644); err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// Test successful config loading
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"database.driver", cfg.Database.Driver, "postgres"},
{"database.dsn", cfg.Database.DSN, "postgres://user:pass@localhost:5432/db"},
{"server.port", cfg.Server.Port, 8080},
{"server.host", cfg.Server.Host, "localhost"},
{"jwt.secret", cfg.JWT.Secret, "test-secret"},
{"jwt.expiration", cfg.JWT.Expiration, "24h"},
{"logging.level", cfg.Logging.Level, "debug"},
{"logging.format", cfg.Logging.Format, "console"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("Config %s = %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
// Test loading non-existent file
_, err = Load("non-existent.yaml")
if err == nil {
t.Error("Expected error when loading non-existent file, got nil")
}
// Test loading invalid YAML
invalidPath := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(invalidPath, []byte("invalid: yaml: content"), 0644); err != nil {
t.Fatalf("Failed to create invalid config file: %v", err)
}
_, err = Load(invalidPath)
if err == nil {
t.Error("Expected error when loading invalid YAML, got nil")
}
}

View file

@ -1,100 +0,0 @@
package imageutil
import (
"bytes"
"image"
"image/color"
"image/png"
"testing"
)
func TestIsImageFormat(t *testing.T) {
tests := []struct {
name string
contentType string
want bool
}{
{"JPEG", "image/jpeg", true},
{"PNG", "image/png", true},
{"GIF", "image/gif", true},
{"WebP", "image/webp", true},
{"Invalid", "image/invalid", false},
{"Empty", "", false},
{"Text", "text/plain", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsImageFormat(tt.contentType); got != tt.want {
t.Errorf("IsImageFormat(%q) = %v, want %v", tt.contentType, got, tt.want)
}
})
}
}
func TestDefaultOptions(t *testing.T) {
opts := DefaultOptions()
if !opts.Lossless {
t.Error("DefaultOptions().Lossless = false, want true")
}
if opts.Quality != 90 {
t.Errorf("DefaultOptions().Quality = %v, want 90", opts.Quality)
}
if opts.Compression != 4 {
t.Errorf("DefaultOptions().Compression = %v, want 4", opts.Compression)
}
}
func TestProcessImage(t *testing.T) {
// Create a test image
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
for y := 0; y < 100; y++ {
for x := 0; x < 100; x++ {
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
}
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
t.Fatalf("Failed to create test PNG: %v", err)
}
tests := []struct {
name string
opts ProcessOptions
wantErr bool
}{
{
name: "Default options",
opts: DefaultOptions(),
},
{
name: "Custom quality",
opts: ProcessOptions{
Lossless: false,
Quality: 75,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := bytes.NewReader(buf.Bytes())
result, err := ProcessImage(reader, tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("ProcessImage() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(result) == 0 {
t.Error("ProcessImage() returned empty result")
}
})
}
// Test with invalid input
_, err := ProcessImage(bytes.NewReader([]byte("invalid image data")), DefaultOptions())
if err == nil {
t.Error("ProcessImage() with invalid input should return error")
}
}

View file

@ -1,85 +0,0 @@
package logger
import (
"testing"
"tss-rocks-be/internal/config"
"github.com/rs/zerolog"
)
func TestSetup(t *testing.T) {
tests := []struct {
name string
config *config.Config
expectedLevel zerolog.Level
}{
{
name: "Debug level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "debug",
Format: "json",
},
},
expectedLevel: zerolog.DebugLevel,
},
{
name: "Info level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "info",
Format: "json",
},
},
expectedLevel: zerolog.InfoLevel,
},
{
name: "Error level",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "error",
Format: "json",
},
},
expectedLevel: zerolog.ErrorLevel,
},
{
name: "Invalid level defaults to Info",
config: &config.Config{
Logging: struct {
Level string `yaml:"level"`
Format string `yaml:"format"`
}{
Level: "invalid",
Format: "json",
},
},
expectedLevel: zerolog.InfoLevel,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Setup(tt.config)
if zerolog.GlobalLevel() != tt.expectedLevel {
t.Errorf("Setup() set level to %v, want %v", zerolog.GlobalLevel(), tt.expectedLevel)
}
})
}
}
func TestGetLogger(t *testing.T) {
logger := GetLogger()
if logger == nil {
t.Error("GetLogger() returned nil")
}
}