diff --git a/sqlx/column.go b/sqlx/column.go index 92b2dc844c2caacb89d4c8890ad0c42422c5b155..811e57b0a85eb78947f99cb98bc612ac2fa6633a 100644 --- a/sqlx/column.go +++ b/sqlx/column.go @@ -12,14 +12,14 @@ import ( type ColumnTypeBuilder struct { table *TableBuilder - column string + name string buffer *bytes.Buffer } -func newColumnType(table *TableBuilder, column string) *ColumnTypeBuilder { +func newColumnType(table *TableBuilder, name string) *ColumnTypeBuilder { ctb := &ColumnTypeBuilder{ table: table, - column: column, + name: name, buffer: new(bytes.Buffer), } return ctb @@ -163,20 +163,29 @@ func (slf *ColumnTypeBuilder) Time() *ColumnTypeBuilder { // func (slf *ColumnTypeBuilder) Unique() *ColumnTypeBuilder { - slf.table.addConstraint(fmt.Sprintf("UNIQUE(%s)", EscapeName(slf.column))) + return slf.UniqueNamed("") +} + +func (slf *ColumnTypeBuilder) UniqueNamed(name string) *ColumnTypeBuilder { + slf.table.addUnique(name, EscapeName(slf.name)) return slf } func (slf *ColumnTypeBuilder) PrimaryKey() *ColumnTypeBuilder { - slf.table.addConstraint(fmt.Sprintf("PRIMARY KEY(%s)", EscapeName(slf.column))) + slf.table.addConstraint(fmt.Sprintf("PRIMARY KEY(%s)", EscapeName(slf.name))) return slf } -func (slf *ColumnTypeBuilder) ForeignKey(refTable string, refColumn string) *ColumnTypeBuilder { - slf.table.addConstraint(fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", slf.column, EscapeName(refTable), EscapeName(refColumn))) +func (slf *ColumnTypeBuilder) ForeignKeyNamed(fkName string, refTableName string, refColumnName string) *ColumnTypeBuilder { + slf.table.addConstraint(fmt.Sprintf("%sFOREIGN KEY (%s) REFERENCES %s(%s)", + namedConstraint(fkName), EscapeName(slf.name), EscapeName(refTableName), EscapeName(refColumnName))) return slf } +func (slf *ColumnTypeBuilder) ForeignKey(refTableName string, refColumnName string) *ColumnTypeBuilder { + return slf.ForeignKeyNamed("", refTableName, refColumnName) +} + func (slf *ColumnTypeBuilder) NotNull() *ColumnTypeBuilder { slf.addKey("NOT NULL") return slf @@ -189,8 +198,8 @@ func (slf *ColumnTypeBuilder) AutoIncrement() *ColumnTypeBuilder { // -func (slf *ColumnTypeBuilder) DefaultGetDate() *ColumnTypeBuilder { - slf.addKey("DEFAULT GETDATE()") +func (slf *ColumnTypeBuilder) DefaultNow() *ColumnTypeBuilder { + slf.addKey("DEFAULT NOW()") return slf } @@ -216,9 +225,9 @@ func (slf *ColumnTypeBuilder) Default(value interface{}) *ColumnTypeBuilder { // -func (slf *ColumnTypeBuilder) Column(column string) *ColumnTypeBuilder { +func (slf *ColumnTypeBuilder) Column(name string) *ColumnTypeBuilder { slf.columnDefineComplete() - return newColumnType(slf.table, column) + return newColumnType(slf.table, name) } func (slf *ColumnTypeBuilder) GetSQL() string { @@ -233,7 +242,7 @@ func (slf *ColumnTypeBuilder) Execute(prepare SQLPrepare) *Executor { // func (slf *ColumnTypeBuilder) columnDefineComplete() { - slf.table.addColumn(slf.column, slf.buffer.String()) + slf.table.addColumn(slf.name, slf.buffer.String()) } func (slf *ColumnTypeBuilder) addKeyWithSize(key string, size int) { diff --git a/sqlx/create_index.go b/sqlx/create_index.go new file mode 100644 index 0000000000000000000000000000000000000000..b30ae4ec99063ca57de9f3c4d8a1d747d66c92c7 --- /dev/null +++ b/sqlx/create_index.go @@ -0,0 +1,79 @@ +package sqlx + +import ( + "bytes" + "strings" +) + +// +// Author: 陈永佳 chenyongjia@parkingwang.com, yoojiachen@gmail.com +// + +type CreateIndexBuilder struct { + table string + name string + columns []string + unique bool +} + +func CreateIndex(indexName string) *CreateIndexBuilder { + return &CreateIndexBuilder{ + name: indexName, + columns: make([]string, 0), + } +} + +func (slf *CreateIndexBuilder) Unique() *CreateIndexBuilder { + slf.unique = true + return slf +} + +func (slf *CreateIndexBuilder) OnTable(table string) *CreateIndexBuilder { + slf.table = table + return slf +} + +func (slf *CreateIndexBuilder) Column(name string, desc bool) *CreateIndexBuilder { + var column string + if desc { + column = EscapeName(name) + SQLSpace + "DESC" + } else { + column = EscapeName(column) + } + slf.columns = append(slf.columns, column) + return slf +} + +func (slf *CreateIndexBuilder) Columns(columns ...string) *CreateIndexBuilder { + slf.columns = append(slf.columns, Map(columns, EscapeName)...) + return slf +} + +func (slf *CreateIndexBuilder) build() *bytes.Buffer { + if "" == slf.table { + panic("table not found, you should call 'Table(table)' method to set it") + } + + buf := new(bytes.Buffer) + buf.WriteString("CREATE ") + if slf.unique { + buf.WriteString("UNIQUE ") + } + buf.WriteString("INDEX ") + buf.WriteString(EscapeName(slf.name)) + buf.WriteString(" ON ") + buf.WriteString(EscapeName(slf.table)) + buf.WriteByte('(') + // 在输入时已经转义 + buf.WriteString(strings.Join(slf.columns, SQLComma)) + buf.WriteByte(')') + return buf +} + +func (slf *CreateIndexBuilder) GetSQL() string { + return makeSQL(slf.build()) +} + +func (slf *CreateIndexBuilder) Execute(prepare SQLPrepare) *Executor { + return newExecute(slf.GetSQL(), prepare) +} diff --git a/sqlx/create_index_test.go b/sqlx/create_index_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8b286b7632b04fae6a597decd7984bfc3d4032ab --- /dev/null +++ b/sqlx/create_index_test.go @@ -0,0 +1,18 @@ +package sqlx + +import "testing" + +// +// Author: 陈永佳 chenyongjia@parkingwang.com, yoojiachen@gmail.com +// + +func TestCreateIndex(t *testing.T) { + sql := CreateIndex("PersonIndex"). + Unique(). + OnTable("t_users"). + Columns("LastName", "FirstName"). + Column("Age", true). + GetSQL() + + checkSQLMatches(sql, "CREATE UNIQUE INDEX `PersonIndex` ON `t_users`(`LastName`, `FirstName`, `Age` DESC);", t) +} diff --git a/sqlx/drop_index.go b/sqlx/drop_index.go new file mode 100644 index 0000000000000000000000000000000000000000..804c6211e17059391ea2b716bd2a24425d038d68 --- /dev/null +++ b/sqlx/drop_index.go @@ -0,0 +1,46 @@ +package sqlx + +import ( + "bytes" +) + +// +// Author: 陈永佳 chenyongjia@parkingwang.com, yoojiachen@gmail.com +// + +type DropIndexBuilder struct { + name string + table string +} + +func DropIndex(indexName string) *DropIndexBuilder { + return &DropIndexBuilder{ + name: indexName, + } +} + +func (slf *DropIndexBuilder) OnTable(table string) *DropIndexBuilder { + slf.table = table + return slf +} + +func (slf *DropIndexBuilder) build() *bytes.Buffer { + if "" == slf.table { + panic("table not found, you should call 'OnTable(table)' method to set it") + } + //ALTER TABLE table_name DROP INDEX index_name + buf := new(bytes.Buffer) + buf.WriteString("ALTER TABLE ") + buf.WriteString(EscapeName(slf.table)) + buf.WriteString(" DROP INDEX ") + buf.WriteString(EscapeName(slf.name)) + return buf +} + +func (slf *DropIndexBuilder) GetSQL() string { + return makeSQL(slf.build()) +} + +func (slf *DropIndexBuilder) Execute(prepare SQLPrepare) *Executor { + return newExecute(slf.GetSQL(), prepare) +} diff --git a/sqlx/drop_index_test.go b/sqlx/drop_index_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7db27428d100e0bd6668794acdd03a490ee59ac7 --- /dev/null +++ b/sqlx/drop_index_test.go @@ -0,0 +1,14 @@ +package sqlx + +import "testing" + +// +// Author: 陈永佳 chenyongjia@parkingwang.com, yoojiachen@gmail.com +// + +func TestDropIndex(t *testing.T) { + sql := DropIndex("idx_Uid"). + OnTable("t_username"). + GetSQL() + checkSQLMatches(sql, "ALTER TABLE `t_username` DROP INDEX `idx_Uid`;", t) +} diff --git a/sqlx/limit.go b/sqlx/limit.go index 9c55f04c4dd1370bb29c9d3965d0c3c1d596ea55..4a837a5ab431c106f9f8f1395938bc38452c6b37 100644 --- a/sqlx/limit.go +++ b/sqlx/limit.go @@ -2,7 +2,7 @@ package sqlx import ( "bytes" - "fmt" + "strconv" ) // @@ -19,13 +19,13 @@ func newLimit(preStatement SQLStatement, limit int) *LimitBuilder { } lb.buffer.WriteString(preStatement.Statement()) lb.buffer.WriteString(" LIMIT ") - lb.buffer.WriteString(fmt.Sprintf("%d", limit)) + lb.buffer.WriteString(strconv.Itoa(limit)) return lb } func (slf *LimitBuilder) Offset(offset int) *LimitBuilder { slf.buffer.WriteString(" OFFSET ") - slf.buffer.WriteString(fmt.Sprintf("%d", offset)) + slf.buffer.WriteString(strconv.Itoa(offset)) return slf } diff --git a/sqlx/table.go b/sqlx/table.go index b5b5643340190d2bdb8bffc77b0bf3afd60f8d2a..b690f7471d17ac1e27307e61855610dad1bf8a45 100644 --- a/sqlx/table.go +++ b/sqlx/table.go @@ -2,7 +2,7 @@ package sqlx import ( "bytes" - "fmt" + "strconv" "strings" ) @@ -10,22 +10,27 @@ import ( // Author: 陈永佳 chenyongjia@parkingwang.com, yoojiachen@gmail.com // -//// +type columnDefine struct { + name string + defines string +} type TableBuilder struct { table string - columns map[string]string // 类似:{username: VARCHAR(255) NOT NULL} - constraints []string // 约束列表 - charset string // 表字符编码 - autoIncrement int // 自增编号起步值 + columns []columnDefine + constraints []string // 通用的约束列表 + uniques map[string][]string // Unique约束列表,根据名称来合并其Column。默认合并在 “” 组 + charset string + autoIncrement int ifNotExists bool } func CreateTable(table string) *TableBuilder { return &TableBuilder{ table: table, - columns: make(map[string]string), + columns: make([]columnDefine, 0), constraints: make([]string, 0), + uniques: make(map[string][]string), charset: "utf8", autoIncrement: 0, ifNotExists: true, @@ -52,7 +57,23 @@ func (slf *TableBuilder) Column(name string) *ColumnTypeBuilder { } func (slf *TableBuilder) addColumn(name string, defines string) { - slf.columns[name] = defines + for _, d := range slf.columns { + if d.name == name { + panic("Duplicated column define, name: " + name) + } + } + slf.columns = append(slf.columns, columnDefine{ + name: name, + defines: defines, + }) +} + +func (slf *TableBuilder) addUnique(name string, column string) { + if exists, ok := slf.uniques[name]; ok { + slf.uniques[name] = append(exists, column) + } else { + slf.uniques[name] = append(make([]string, 0), column) + } } func (slf *TableBuilder) addConstraint(constraint string) { @@ -60,9 +81,19 @@ func (slf *TableBuilder) addConstraint(constraint string) { } func (slf *TableBuilder) build() *bytes.Buffer { + // 数据列 columns := make([]string, 0) - for name, defines := range slf.columns { - columns = append(columns, EscapeName(name)+defines) + for _, define := range slf.columns { + columns = append(columns, EscapeName(define.name)+define.defines) + } + + // 通用约束 + columns = append(columns, slf.constraints...) + + // Unique约束列 + for name, colNames := range slf.uniques { + constraint := namedConstraint(name) + "UNIQUE (" + strings.Join(colNames, SQLComma) + ")" + columns = append(columns, constraint) } buf := new(bytes.Buffer) @@ -72,12 +103,12 @@ func (slf *TableBuilder) build() *bytes.Buffer { } buf.WriteString(EscapeName(slf.table)) buf.WriteByte('(') - buf.WriteString(strings.Join(append(columns, slf.constraints...), SQLComma)) + buf.WriteString(strings.Join(columns, SQLComma)) buf.WriteByte(')') buf.WriteString(" DEFAULT CHARSET=") buf.WriteString(slf.charset) buf.WriteString(" AUTO_INCREMENT=") - buf.WriteString(fmt.Sprintf("%d", slf.autoIncrement)) + buf.WriteString(strconv.Itoa(slf.autoIncrement)) return buf } @@ -85,6 +116,10 @@ func (slf *TableBuilder) GetSQL() string { return makeSQL(slf.build()) } -func (slf *TableBuilder) Execute(prepare SQLPrepare) *Executor { - return newExecute(slf.GetSQL(), prepare) +func namedConstraint(name string) string { + if len(name) > 0 { + return "CONSTRAINT " + EscapeName(name) + SQLSpace + } else { + return "" + } } diff --git a/sqlx/table_test.go b/sqlx/table_test.go index 22328458543fa589da333878782085ab85f4c2b5..42e7ea45403b331b8c814dc920659408c11e12c9 100644 --- a/sqlx/table_test.go +++ b/sqlx/table_test.go @@ -1,7 +1,6 @@ package sqlx import ( - "fmt" "testing" ) @@ -15,8 +14,54 @@ func TestCreateTable(t *testing.T) { Column("username").VarChar(255).NotNull().Unique(). Column("password").VarChar(255).NotNull(). Column("age").Int(2).Default0(). - Column("register_time").Date().DefaultGetDate(). + Column("register_time").Date().DefaultNow(). GetSQL() - // CREATE TABLE `t_users`(`register_time` DATE DEFAULT GETDATE(), `id` INT(20) NOT NULL AUTO_INCREMENT, `username` VARCHAR(255) NOT NULL, `password` VARCHAR(255) NOT NULL, `age` INT(2) DEFAULT 0, PRIMARY KEY(`id`), UNIQUE(`username`)) DEFAULT CHARSET=utf8 AUTO_INCREMENT=0; - fmt.Println(sql) + checkSQLMatches(sql, "CREATE TABLE IF NOT EXISTS `t_users`("+ + "`id` INT(20) NOT NULL AUTO_INCREMENT, "+ + "`username` VARCHAR(255) NOT NULL, "+ + "`password` VARCHAR(255) NOT NULL, "+ + "`age` INT(2) DEFAULT 0, "+ + "`register_time` DATE DEFAULT NOW(), "+ + "PRIMARY KEY(`id`), "+ + "UNIQUE (`username`)"+ + ") DEFAULT CHARSET=utf8 AUTO_INCREMENT=0;", t) +} + +func TestTableBuilder_ForeignKey(t *testing.T) { + sql := CreateTable("t_user"). + Column("id").Int(12).NotNull().PrimaryKey().AutoIncrement(). + Column("pid").Int(12).ForeignKey("t_profile", "prof_id"). + GetSQL() + checkSQLMatches(sql, "CREATE TABLE IF NOT EXISTS `t_user`("+ + "`id` INT(12) NOT NULL AUTO_INCREMENT, "+ + "`pid` INT(12), PRIMARY KEY(`id`), "+ + "FOREIGN KEY (`pid`) REFERENCES `t_profile`(`prof_id`)"+ + ") DEFAULT CHARSET=utf8 AUTO_INCREMENT=0;", t) +} + +func TestTableBuilder_ForeignKeyNamed(t *testing.T) { + sql := CreateTable("t_user"). + Column("id").Int(12).NotNull().PrimaryKey().AutoIncrement(). + Column("pid").Int(12).ForeignKeyNamed("FK_PID", "t_profile", "prof_id"). + GetSQL() + checkSQLMatches(sql, "CREATE TABLE IF NOT EXISTS `t_user`("+ + "`id` INT(12) NOT NULL AUTO_INCREMENT, "+ + "`pid` INT(12), PRIMARY KEY(`id`), "+ + "CONSTRAINT `FK_PID` FOREIGN KEY (`pid`) REFERENCES `t_profile`(`prof_id`)"+ + ") DEFAULT CHARSET=utf8 AUTO_INCREMENT=0;", t) +} + +func TestTableBuilder_UniqueNamed(t *testing.T) { + sql := CreateTable("t_user"). + Column("id").Int(12).NotNull().PrimaryKey().AutoIncrement(). + Column("pid").Int(12).UniqueNamed("uc_Id_P"). + Column("pid_bak").Int(12).UniqueNamed("uc_Id_P"). + GetSQL() + checkSQLMatches(sql, "CREATE TABLE IF NOT EXISTS `t_user`("+ + "`id` INT(12) NOT NULL AUTO_INCREMENT, "+ + "`pid` INT(12), "+ + "`pid_bak` INT(12), "+ + "PRIMARY KEY(`id`), "+ + "CONSTRAINT `uc_Id_P` UNIQUE (`pid`, `pid_bak`)"+ + ") DEFAULT CHARSET=utf8 AUTO_INCREMENT=0;", t) }