package server import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInitDatabase(t *testing.T) { tests := []struct { name string driver string dsn string wantErr bool errContains string }{ { name: "success with sqlite3", driver: "sqlite3", dsn: "file:ent?mode=memory&cache=shared&_fk=1", }, { name: "invalid driver", driver: "invalid_driver", dsn: "file:ent?mode=memory", wantErr: true, errContains: "unsupported driver", }, { name: "invalid dsn", driver: "sqlite3", dsn: "file::memory:?not_exist_option=1", // 使用内存数据库但带有无效选项 wantErr: true, errContains: "foreign_keys pragma is off", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() client, err := InitDatabase(ctx, tt.driver, tt.dsn) if tt.wantErr { assert.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } assert.Nil(t, client) } else { require.NoError(t, err) assert.NotNil(t, client) // 测试数据库连接是否正常工作 err = client.Schema.Create(ctx) assert.NoError(t, err) // 清理 client.Close() } }) } }