Gorm源码学习-创建行记录

发布时间 2023-03-22 19:06:49作者: Amos01

1. 前言

Gorm源码学习系列

此文是Gorm源码学习系列的第二篇,主要梳理下通过Gorm创建表的流程。

 

2. 创建行记录代码示例

gorm提供了以下几个接口来创建行记录

  • 一次创建一行 func (db *DB) Create(value interface{}) (tx *DB)
  • 批量创建 func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
  • 数据库不存在主键时创建,存在时更新 func (db *DB) Save(value interface{}) (tx *DB)

详细请看教程及源码finisher_api.go,这里使用func (db *DB) Create(value interface{}) (tx *DB)来说明创建行记录等大致流程。

 

2.1 声明模型

type Stu struct {
	ID     int64 `gorm:"column:id; primary_key" json:"id"`
	Age    int64 `gorm:"column:age;"`
	Height int64 `gorm:"column:height;"`
	Weight int64 `gorm:"column:weight;"`
}

// 设置表名
func (Stu) TableName() string {
	return "t_student"
}

模型代码的主要用途如下,

  • 申明的表中有哪些列及每列的名称、特性等,如gorm标签指定每个字断对于的表的列名
  • 通过实现Tabler接口指定了固定的表名,接口定义如下
type Tabler interface {
	TableName() string
}

关于模型定义中更多的约定和约束等,请看教程

出于分表等业务场景,我们并不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)等方法

来动态指定表名,详情请看教程

 

2.2 创建行

func main() {
	// 数据库连接, 具体查看https://www.cnblogs.com/amos01/p/16890747.html 连接数据库代码示例
	db, _ := dbOpen()
	// 打开调试模式、会打印DML
	db = db.Debug()
	stu := &Stu{
		Age:    18,
		Height: 185,
		Weight: 70,
	}
	db = db.Create(stu)
	fmt.Printf("Error:%v ID:%v RowsAffected:%v\n", db.Error, stu.ID, db.RowsAffected)
}

代码输出如下

$ go run main.go
2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33
[1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
Error:<nil> ID:1027 RowsAffected:1

从代码输出可以看,行记录的ID为1027,连接数据库查询,结果如下。

mysql> select * from t_student where id = 1027\G
*************************** 1. row ***************************
    id: 1027
   age: 18
height: 185
weight: 70
1 row in set (0.01 sec)

因此,我们带着以下问题来梳理下Gorm创建行记录的流程

  • 如何从model到DML语句的
  • 如何将ID写入到model的 

 

3. 从Model到DML

func (db *DB) Create(value interface{}) (tx *DB)的实现如下

// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
	if db.CreateBatchSize > 0 {
		return db.CreateInBatches(value, db.CreateBatchSize)
	}

	tx = db.getInstance()
	tx.Statement.Dest = value
	return tx.callbacks.Create().Execute(tx)
}

func (p *processor) Execute(db *DB) *DB 的实现比较长,具体代码见github

总结下来,做了两件主要的事情,

  • 解析model获取表名、每列的定义等
  • 执行钩子函数以及创建行函数

X

3.1 数据结构理解

  • gorm.Statement
查看gorm.Statement代码
// Statement statement
type Statement struct {
	*DB
	TableExpr            *clause.Expr
	Table                string      // 表名
	Model                interface{} // model定义
	Unscoped             bool
	Dest                 interface{} // model的另外一种表达,如map
	ReflectValue         reflect.Value
	Clauses              map[string]clause.Clause
	BuildClauses         []string
	Distinct             bool
	Selects              []string // selected columns
	Omits                []string // omit columns
	Joins                []join
	Preloads             map[string][]interface{}
	Settings             sync.Map
	ConnPool             ConnPool       // 数据库连接
	Schema               *schema.Schema // 表结构化信息
	Context              context.Context
	RaiseErrorOnNotFound bool
	SkipHooks            bool
	SQL                  strings.Builder // 最终的DML语句
	Vars                 []interface{}   // DML语句的参数值
	CurDestIndex         int             // 批量创建/更新时,gorm当前操作的数组/slice的下标
	attrs                []interface{}
	assigns              []interface{}
	scopes               []func(*DB) *DB
}
  • schema.Schem
查看schema.Schema代码
type Schema struct {
	Name                      string
	ModelType                 reflect.Type
	Table                     string // 表名
	PrioritizedPrimaryField   *Field
	DBNames                   []string // 表每列的名字
	PrimaryFields             []*Field
	PrimaryFieldDBNames       []string // 表的主键列明
	Fields                    []*Field // gorm自定义的model每个字短
	FieldsByName              map[string]*Field
	FieldsByDBName            map[string]*Field
	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
	Relationships             Relationships
	CreateClauses             []clause.Interface // 创建行的子句
	QueryClauses              []clause.Interface
	UpdateClauses             []clause.Interface
	DeleteClauses             []clause.Interface
	BeforeCreate, AfterCreate bool
	BeforeUpdate, AfterUpdate bool
	BeforeDelete, AfterDelete bool
	BeforeSave, AfterSave     bool
	AfterFind                 bool
	err                       error
	initialized               chan struct{}
	namer                     Namer
	cacheStore                *sync.Map
}
  • schema.Field
查看schema.Field代码
// Field is the representation of model schema's field
type Field struct {
	Name                   string // model的字段名
	DBName                 string // 对应表的列名
	BindNames              []string
	DataType               DataType
	GORMDataType           DataType
	PrimaryKey             bool
	AutoIncrement          bool
	AutoIncrementIncrement int64
	Creatable              bool
	Updatable              bool
	Readable               bool
	AutoCreateTime         TimeType
	AutoUpdateTime         TimeType
	HasDefaultValue        bool
	DefaultValue           string
	DefaultValueInterface  interface{}
	NotNull                bool
	Unique                 bool
	Comment                string
	Size                   int
	Precision              int
	Scale                  int
	IgnoreMigration        bool
	FieldType              reflect.Type // 反射类型
	IndirectFieldType      reflect.Type // 反射类型
	StructField            reflect.StructField // model字段信息
	Tag                    reflect.StructTag // tag
	TagSettings            map[string]string
	Schema                 *Schema
	EmbeddedSchema         *Schema
	OwnerSchema            *Schema
	ReflectValueOf         func(context.Context, reflect.Value) reflect.Value                  // 通过反射获取该字段的反射对象
	ValueOf                func(context.Context, reflect.Value) (value interface{}, zero bool) // 通过反射获取该字段的值 get方法
	Set                    func(context.Context, reflect.Value, interface{}) error             // 通过反射设置该字段的值 set方法
	Serializer             SerializerInterface
	NewValuePool           FieldNewValuePool
}
  • clause.Interfaceclause.Clause

gorm定义了多种clause,包括

查看clause.Interface代码
// Interface clause interface
type Interface interface {
	Name() string
	Build(Builder)
	MergeClause(*Clause)
}
查看clause.Clause代码
// Clause
type Clause struct {
	Name                string // WHERE
	BeforeExpression    Expression
	AfterNameExpression Expression
	AfterExpression     Expression
	Expression          Expression
	Builder             ClauseBuilder
}

 

3.2 解析Model

通过调用stmt.Parse(stmt.Model)进行model解析

stmt.Parse(stmt.Model)会调用到函数func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) 进行解析。

详细代码见schema.go,下面列举重要的几个点。

  • 通过反射判断dest interface{}是否为reflect.Struct
  • 通过接口获取表名,其中stu实现了Tabler接口
	// 获取表名
	modelValue := reflect.New(modelType)
	tableName := namer.TableName(modelType.Name())
	if tabler, ok := modelValue.Interface().(Tabler); ok {
		tableName = tabler.TableName()
	}
	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
		tableName = tabler.TableName(namer)
	}
	if en, ok := namer.(embeddedNamer); ok {
		tableName = en.Table
	}
	if specialTableName != "" && specialTableName != tableName {
		tableName = specialTableName
	}
  • 解析model每个字段
// 通过反射获取每个字段
for i := 0; i < modelType.NumField(); i++ {
    if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
        // 解析每个字段
        if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
            schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
        } else {
            schema.Fields = append(schema.Fields, field)
        }
    }
}
  • 放到map方便查找,并且通过func (field *Field) setupValuerAndSetter()初始化每个Field的ReflectValueOfValueOfSet方法。
    for _, field := range schema.Fields {
        if field.DBName == "" && field.DataType != "" {
            field.DBName = namer.ColumnName(schema.Table, field.Name)
        }
        if field.DBName != "" {
            // nonexistence or shortest path or first appear prioritized if has permission
            if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
                if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
                    schema.DBNames = append(schema.DBNames, field.DBName)
                }
                // gorm tag字段到field的映射
                schema.FieldsByDBName[field.DBName] = field
                // model 字段到field的映射
                schema.FieldsByName[field.Name] = field
                if v != nil && v.PrimaryKey {
                    for idx, f := range schema.PrimaryFields {
                        if f == v {
                            schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
                        }
                    }
                }
                // 主键
                if field.PrimaryKey {
                    schema.PrimaryFields = append(schema.PrimaryFields, field)
                }
            }
        }
        if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
            schema.FieldsByName[field.Name] = field
        }
        // 挂载字段的set方法和get方法
        field.setupValuerAndSetter()
    }

 

值得一提的是,每个model解析后的结果是一致,可以将结果解析的结构缓存下来,并且通过chan来解决并发的问题。

解析model之后,通过process获取到钩子函数及创建行的函数,具体代码见Github

	for _, f := range p.fns {
		f(db)
	}

 

3.3 执行钩子函数及创建行的函数

创建行的函数及对应的钩子函数位于create.go

  • 创建行记录
if db.Statement.SQL.Len() == 0 {
    db.Statement.SQL.Grow(180)
    db.Statement.AddClauseIfNotExists(clause.Insert{})
    db.Statement.AddClause(ConvertToCreateValues(db.Statement))

    db.Statement.Build(db.Statement.BuildClauses...)
}

 这里插入两个clause.Clause,分别为clause.Insert以及clause.Values,然后调用这两种clause.Clausebuild方法生成SQL语句。

首先,看下ConvertToCreateValues的实现,这里只截取部分代码

values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
// 获取每一列的名字
for _, db := range stmt.Schema.DBNames {
    if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
	    if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
		    values.Columns = append(values.Columns, clause.Column{Name: db})
		}
	}
}

// 获取每一列对应的值
switch stmt.ReflectValue.Kind() {
    case reflect.Slice, reflect.Array:
    case reflect.Struct:
        values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
        for idx, column := range values.Columns {
            field := stmt.Schema.FieldsByDBName[column.Name]
            // func (field *Field) setupValuerAndSetter() 挂载的方法
            if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
                if field.DefaultValueInterface != nil {
                    values.Values[0][idx] = field.DefaultValueInterface
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
                } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                    values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
                }
            } else if field.AutoUpdateTime > 0 && updateTrackTime {
                stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
            }
        }

通过ConvertToCreateValues获取了每一列的名称及对应的值。

接下来,看clause.ClauseSQL语句的过程。

遍历加入clause,此时分别为clause.Insert以及clause.Values

// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
    var firstClauseWritten bool
    for _, name := range clauses {
        if c, ok := stmt.Clauses[name]; ok {
            // 代码有删减
            c.Build(stmt)
        }
    }
}

接着调用func (c Clause) Build(builder Builder)

// Build build clause 
func (c Clause) Build(builder Builder) {
    // 有删减
    // c为clause.Insert以及clause.Values
    if c.Name != "" {
        // builder写入 INSERT 或者 VALUES
        builder.WriteString(c.Name)
        builder.WriteByte(' ')
    }
    // 通过clause.Insert以及clause.Values的MergeClause函数,c.Expression为clause.Insert以及clause.Values
    // 因此,这里调用clause.Insert或者clause.Values的Build的方法
    c.Expression.Build(builder)
}

接下来分别看clause.Insert以及clause.Values

// Build build insert clause
func (insert Insert) Build(builder Builder) {
    // builder写入INTO,此时builder为INSERT INTO
    builder.WriteString("INTO ")
    // builder写入表名
    builder.WriteQuoted(currentTable)
}

从调用的链路可以得出,这里builderstmt *Statement,并且currentTable类型为clause.Table,因此

// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
    stmt.QuoteTo(&stmt.SQL, value)
}

// QuoteTo write quoted value to writer 有删减
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
    write := func(raw bool, str string) {
        // mysql驱动Dialector
        stmt.DB.Dialector.QuoteTo(writer, str)
    }
    switch v := field.(type) {
    case clause.Table:
        write(v.Raw, stmt.Table)
    }
}

至此,builder已经拼装出INSERT INTO `t_student` ,解析来再看clause.Valuesbuild方法

// Build build from clause
func (values Values) Build(builder Builder) {
    if len(values.Columns) > 0 {
        builder.WriteByte('(')
        for idx, column := range values.Columns {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteQuoted(column)
        }
        builder.WriteByte(')')
        builder.WriteString(" VALUES ")
        for idx, value := range values.Values {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteByte('(')
            builder.AddVar(builder, value...)
            builder.WriteByte(')')
        }
    } else {
        builder.WriteString("DEFAULT VALUES")
    }
}

func (values Values) Build(builder Builder)取出所有列名和列对应的值

最终builder拼装成例子的完整SQL语句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)

有了SQL语句,就可以执行了

result, err := db.Statement.ConnPool.ExecContext(
    db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)

通过前一面学习,db.Statement.ConnPool的值为sql.DB,实际执行的函数为func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)

至此,从Model到DML到流程已经完成。

 

4. 将ID写入到model的 

看返回参数sql.Result,因此通过LastInsertId() (int64, error)可以获取到插入行的ID值。

// A Result summarizes an executed SQL command.
type Result interface {
	// LastInsertId returns the integer generated by the database
	// in response to a command. Typically this will be from an
	// "auto increment" column when inserting a new row. Not all
	// databases support this feature, and the syntax of such
	// statements varies.
	LastInsertId() (int64, error)

	// RowsAffected returns the number of rows affected by an
	// update, insert, or delete. Not every database or database
	// driver may support this.
	RowsAffected() (int64, error)
}

获取到刚插入的行ID值,再通过反射写入model的ID字段即可。

db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
    db.Statement.Schema.PrioritizedPrimaryField != nil &&
    db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
    insertID, err := result.LastInsertId()
    switch db.Statement.ReflectValue.Kind() {
    case reflect.Struct:
        _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
        if isZero {
            // 通过反射更新ID
            db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
        }
    }
}

 

5. 总结

使用反射解析Model,获得每个成员对应的表的列名、值等信息。

定义SQL各个关键词如INSERTVALUESFROMDELETE的结构体,并实现clause.Interface接口

进而对SQL语句的构造进行抽象封装。