@@ -22,28 +22,34 @@ const (
2222 LowerRestoreFlag = format .RestoreStringSingleQuotes | format .RestoreKeyWordLowercase | format .RestoreNameLowercase | format .RestoreNameBackQuotes
2323)
2424
25- // Column ...
26- type Column struct {
27- Node
25+ type SqlAttr struct {
2826 MysqlType * types.FieldType
2927 PgType * ptypes.T
3028 LiteType * sqlite.Type
3129 Options []* ast.ColumnOption
3230 Comment string
3331}
3432
33+ // Column ...
34+ type Column struct {
35+ Node
36+
37+ CurrentAttr SqlAttr
38+ PreviousAttr SqlAttr
39+ }
40+
3541// GetType ...
3642func (c Column ) GetType () byte {
37- if c .MysqlType != nil {
38- return c .MysqlType .Tp
43+ if c .CurrentAttr . MysqlType != nil {
44+ return c .CurrentAttr . MysqlType .Tp
3945 }
4046
4147 return 0
4248}
4349
4450// HasDefaultValue ...
4551func (c Column ) HasDefaultValue () bool {
46- for _ , opt := range c .Options {
52+ for _ , opt := range c .CurrentAttr . Options {
4753 if opt .Tp == ast .ColumnOptionDefaultValue {
4854 return true
4955 }
@@ -54,7 +60,7 @@ func (c Column) HasDefaultValue() bool {
5460
5561func (c Column ) hashValue () string {
5662 strHash := sql .EscapeSqlName (c .Name )
57- strHash += c .typeDefinition ()
63+ strHash += c .typeDefinition (false )
5864 hash := md5 .Sum ([]byte (strHash ))
5965 return hex .EncodeToString (hash [:])
6066}
@@ -71,7 +77,7 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
7177 strSql += strings .Repeat (" " , ident - len (c .Name ))
7278 }
7379
74- strSql += c .definition ()
80+ strSql += c .definition (false )
7581
7682 if ident < 0 {
7783 if after != "" {
@@ -90,10 +96,27 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
9096 return []string {fmt .Sprintf (sql .AlterTableDropColumnStm (), sql .EscapeSqlName (tbName ), sql .EscapeSqlName (c .Name ))}
9197
9298 case MigrateModifyAction :
93- def := strings .Replace (c .definition (), sql .PrimaryOption (), "" , 1 )
99+ def , isPk := c .pkDefinition (false )
100+ if isPk {
101+ if _ , isPrevPk := c .pkDefinition (true ); isPrevPk {
102+ // avoid repeat define primary key
103+ def = strings .Replace (def , " " + sql .PrimaryOption (), "" , 1 )
104+ }
105+ }
94106
95107 return []string {fmt .Sprintf (sql .AlterTableModifyColumnStm (), sql .EscapeSqlName (tbName ), sql .EscapeSqlName (c .Name )+ def )}
96108
109+ case MigrateRevertAction :
110+ prevDef , isPrevPk := c .pkDefinition (true )
111+ if isPrevPk {
112+ if _ , isPk := c .pkDefinition (false ); isPk {
113+ // avoid repeat define primary key
114+ prevDef = strings .Replace (prevDef , " " + sql .PrimaryOption (), "" , 1 )
115+ }
116+ }
117+
118+ return []string {fmt .Sprintf (sql .AlterTableModifyColumnStm (), sql .EscapeSqlName (tbName ), sql .EscapeSqlName (c .Name )+ prevDef )}
119+
97120 case MigrateRenameAction :
98121 return []string {fmt .Sprintf (sql .AlterTableRenameColumnStm (), sql .EscapeSqlName (tbName ), sql .EscapeSqlName (c .OldName ), sql .EscapeSqlName (c .Name ))}
99122
@@ -103,12 +126,12 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
103126}
104127
105128func (c Column ) migrationCommentUp (tbName string ) []string {
106- if c .Comment == "" || sql .GetDialect () != sql_templates .PostgresDialect {
129+ if c .CurrentAttr . Comment == "" || sql .GetDialect () != sql_templates .PostgresDialect {
107130 return nil
108131 }
109132
110133 // apply for postgres only
111- return []string {fmt .Sprintf (sql .ColumnComment (), tbName , c .Name , c .Comment )}
134+ return []string {fmt .Sprintf (sql .ColumnComment (), tbName , c .Name , c .CurrentAttr . Comment )}
112135}
113136
114137func (c Column ) migrationDown (tbName , after string ) []string {
@@ -123,7 +146,7 @@ func (c Column) migrationDown(tbName, after string) []string {
123146 c .Action = MigrateAddAction
124147
125148 case MigrateModifyAction :
126- return nil
149+ c . Action = MigrateRevertAction
127150
128151 case MigrateRenameAction :
129152 c .Name , c .OldName = c .OldName , c .Name
@@ -135,10 +158,19 @@ func (c Column) migrationDown(tbName, after string) []string {
135158 return c .migrationUp (tbName , after , - 1 )
136159}
137160
138- func (c Column ) definition () string {
139- strSql := c .typeDefinition ()
161+ func (c Column ) pkDefinition (isPrev bool ) (string , bool ) {
162+ attr := c .CurrentAttr
163+ if isPrev {
164+ attr = c .PreviousAttr
165+ }
166+ strSql := c .typeDefinition (isPrev )
167+
168+ isPrimaryKey := false
169+ for _ , opt := range attr .Options {
170+ if opt .Tp == ast .ColumnOptionPrimaryKey {
171+ isPrimaryKey = true
172+ }
140173
141- for _ , opt := range c .Options {
142174 b := bytes .NewBufferString ("" )
143175 var ctx * format.RestoreCtx
144176
@@ -157,17 +189,27 @@ func (c Column) definition() string {
157189 strSql += " " + b .String ()
158190 }
159191
160- return strSql
192+ return strSql , isPrimaryKey
193+ }
194+
195+ func (c Column ) definition (isPrev bool ) string {
196+ def , _ := c .pkDefinition (isPrev )
197+ return def
161198}
162199
163- func (c Column ) typeDefinition () string {
200+ func (c Column ) typeDefinition (isPrev bool ) string {
201+ attr := c .CurrentAttr
202+ if isPrev {
203+ attr = c .PreviousAttr
204+ }
205+
164206 switch {
165- case sql .IsPostgres () && c .PgType != nil :
166- return " " + c .PgType .SQLString ()
167- case sql .IsSqlite () && c .LiteType != nil :
168- return " " + c .LiteType .Name .Name
169- case c .MysqlType != nil :
170- return " " + c .MysqlType .String ()
207+ case sql .IsPostgres () && attr .PgType != nil :
208+ return " " + attr .PgType .SQLString ()
209+ case sql .IsSqlite () && attr .LiteType != nil :
210+ return " " + attr .LiteType .Name .Name
211+ case attr .MysqlType != nil :
212+ return " " + attr .MysqlType .String ()
171213 }
172214
173215 return "" // column type is empty
0 commit comments