@ -12,13 +12,17 @@ import (
"slices"
"strings"
"code.gitea.io/gitea/models/db"
"gopkg.in/yaml.v3"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
type fixtureItem struct {
tableName string
type FixtureItem struct {
fileFullPath string
tableName string
tableNameQuoted string
sqlInserts [ ] string
sqlInsertArgs [ ] [ ] any
@ -27,10 +31,11 @@ type fixtureItem struct {
}
type fixturesLoaderInternal struct {
xormEngine * xorm . Engine
xormTableNames map [ string ] bool
db * sql . DB
dbType schemas . DBType
files [ ] string
fixtures map [ string ] * fixtureItem
fixtures map [ string ] * FixtureItem
quoteObject func ( string ) string
paramPlaceholder func ( idx int ) string
}
@ -59,29 +64,27 @@ func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err
return nil
}
func ( f * fixturesLoaderInternal ) prepareFixtureItem ( file string ) ( _ * fixtureItem , err error ) {
fixture := & fixtureItem { }
fixture . tableName , _ , _ = strings . Cut ( filepath . Base ( file ) , "." )
func ( f * fixturesLoaderInternal ) prepareFixtureItem ( fixture * FixtureItem ) ( err error ) {
fixture . tableNameQuoted = f . quoteObject ( fixture . tableName )
if f . dbType == schemas . MSSQL {
fixture . mssqlHasIdentityColumn , err = f . mssqlTableHasIdentityColumn ( f . db , fixture . tableName )
if err != nil {
return nil , err
return err
}
}
data , err := os . ReadFile ( fi le)
data , err := os . ReadFile ( fi xture. fi leFullPath )
if err != nil {
return nil , fmt . Errorf ( "failed to read file %q: %w" , fi le, err )
return fmt . Errorf ( "failed to read file %q: %w" , fi xture. fi leFullPath , err )
}
var rows [ ] map [ string ] any
if err = yaml . Unmarshal ( data , & rows ) ; err != nil {
return nil , fmt . Errorf ( "failed to unmarshal yaml data from %q: %w" , fi le, err )
return fmt . Errorf ( "failed to unmarshal yaml data from %q: %w" , fi xture. fi leFullPath , err )
}
if err = f . preprocessFixtureRow ( rows ) ; err != nil {
return nil , fmt . Errorf ( "failed to preprocess fixture rows from %q: %w" , fi le, err )
return fmt . Errorf ( "failed to preprocess fixture rows from %q: %w" , fi xture. fi leFullPath , err )
}
var sqlBuf [ ] byte
@ -107,16 +110,14 @@ func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem
sqlBuf = sqlBuf [ : 0 ]
sqlArguments = sqlArguments [ : 0 ]
}
return fixture , nil
return nil
}
func ( f * fixturesLoaderInternal ) loadFixtures ( tx * sql . Tx , file string ) ( err error ) {
fixture := f . fixtures [ file ]
if fixture == nil {
if fixture , err = f . prepareFixtureItem ( file ) ; err != nil {
func ( f * fixturesLoaderInternal ) loadFixtures ( tx * sql . Tx , fixture * FixtureItem ) ( err error ) {
if fixture . tableNameQuoted == "" {
if err = f . prepareFixtureItem ( fixture ) ; err != nil {
return err
}
f . fixtures [ file ] = fixture
}
_ , err = tx . Exec ( fmt . Sprintf ( "DELETE FROM %s" , fixture . tableNameQuoted ) ) // sqlite3 doesn't support truncate
@ -147,15 +148,26 @@ func (f *fixturesLoaderInternal) Load() error {
}
defer func ( ) { _ = tx . Rollback ( ) } ( )
for _ , file := range f . files {
if err := f . loadFixtures ( tx , file ) ; err != nil {
return fmt . Errorf ( "failed to load fixtures from %s: %w" , file , err )
for _ , fixture := range f . fixtures {
if ! f . xormTableNames [ fixture . tableName ] {
continue
}
if err := f . loadFixtures ( tx , fixture ) ; err != nil {
return fmt . Errorf ( "failed to load fixtures from %s: %w" , fixture . fileFullPath , err )
}
}
return tx . Commit ( )
if err = tx . Commit ( ) ; err != nil {
return err
}
for xormTableName := range f . xormTableNames {
if f . fixtures [ xormTableName ] == nil {
_ , _ = f . xormEngine . Exec ( "DELETE FROM `" + xormTableName + "`" )
}
}
return nil
}
func FixturesFileFullPaths ( dir string , files [ ] string ) ( [ ] string , error ) {
func FixturesFileFullPaths ( dir string , files [ ] string ) ( map [ string ] * FixtureItem , error ) {
if files != nil && len ( files ) == 0 {
return nil , nil // load nothing
}
@ -169,20 +181,25 @@ func FixturesFileFullPaths(dir string, files []string) ([]string, error) {
files = append ( files , e . Name ( ) )
}
}
for i , file := range files {
if ! filepath . IsAbs ( file ) {
files [ i ] = filepath . Join ( dir , file )
fixtureItems := map [ string ] * FixtureItem { }
for _ , file := range files {
fileFillPath := file
if ! filepath . IsAbs ( fileFillPath ) {
fileFillPath = filepath . Join ( dir , file )
}
tableName , _ , _ := strings . Cut ( filepath . Base ( file ) , "." )
fixtureItems [ tableName ] = & FixtureItem { fileFullPath : fileFillPath , tableName : tableName }
}
return files , nil
return fi xtureItem s, nil
}
func NewFixturesLoader ( x * xorm . Engine , opts FixturesOptions ) ( FixturesLoader , error ) {
fi le s, err := FixturesFileFullPaths ( opts . Dir , opts . Files )
fi xtureItem s, err := FixturesFileFullPaths ( opts . Dir , opts . Files )
if err != nil {
return nil , fmt . Errorf ( "failed to get fixtures files: %w" , err )
}
f := & fixturesLoaderInternal { db : x . DB ( ) . DB , dbType : x . Dialect ( ) . URI ( ) . DBType , files : files , fixtures : map [ string ] * fixtureItem { } }
f := & fixturesLoaderInternal { xormEngine : x , db : x . DB ( ) . DB , dbType : x . Dialect ( ) . URI ( ) . DBType , fixtures : fixtureItems }
switch f . dbType {
case schemas . SQLITE :
f . quoteObject = func ( s string ) string { return fmt . Sprintf ( ` "%s" ` , s ) }
@ -197,5 +214,12 @@ func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, er
f . quoteObject = func ( s string ) string { return fmt . Sprintf ( "[%s]" , s ) }
f . paramPlaceholder = func ( idx int ) string { return "?" }
}
xormBeans , _ := db . NamesToBean ( )
f . xormTableNames = map [ string ] bool { }
for _ , bean := range xormBeans {
f . xormTableNames [ db . TableName ( bean ) ] = true
}
return f , nil
}