Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion internal/analysis/analysis.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
}
}

out.IsNullableEmbed = c.NullableEmbed

return out
}

Expand Down
14 changes: 14 additions & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

// NullableEmbedFieldInfo stores metadata for scanning nullable embed fields
type NullableEmbedFieldInfo struct {
TempVarName string // ex: "nembedPostID"
ScanType string // ex: "sql.NullInt32" for stdlib, "*int32" for pgx
ValidExpr string // ex: "nembedPostID.Valid" or "nembedPostID != nil"
AssignExpr string // ex: "nembedPostID.Int32" or "*nembedPostID"
StructField string // ex: "ID"
OriginalType string // original type in the model struct (ex: "int32")
}

type Field struct {
Name string // CamelCased name for Go
DBName string // Name as used in the DB
Expand All @@ -19,6 +29,10 @@ type Field struct {
Column *plugin.Column
// EmbedFields contains the embedded fields that require scanning.
EmbedFields []Field
// IsNullableEmbed indicates this field is a nullable embed (*Struct)
IsNullableEmbed bool
// NullableEmbedInfo stores scan metadata for each embedded field
NullableEmbedInfo []NullableEmbedFieldInfo
}

func (gf Field) Tag() string {
Expand Down
4 changes: 3 additions & 1 deletion internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
keepTypes[query.Ret.Type()] = struct{}{}
if query.Ret.IsStruct() {
for _, field := range query.Ret.Struct.Fields {
keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{}
trimmedType := strings.TrimPrefix(field.Type, "[]")
trimmedType = strings.TrimPrefix(trimmedType, "*")
keepTypes[trimmedType] = struct{}{}
for _, embedField := range field.EmbedFields {
keepTypes[embedField.Type] = struct{}{}
}
Expand Down
11 changes: 11 additions & 0 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ func (i *importer) queryImports(filename string) fileImports {
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
return true
}
// Check nullable embed scan types
for _, info := range f.NullableEmbedInfo {
if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) {
return true
}
}
}
}
if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) {
Expand Down Expand Up @@ -459,6 +465,11 @@ func (i *importer) batchImports() fileImports {
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
return true
}
for _, info := range f.NullableEmbedInfo {
if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) {
return true
}
}
}
}
if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) {
Expand Down
79 changes: 79 additions & 0 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ func (v QueryValue) Scan() string {
} else {
for _, f := range v.Struct.Fields {

// nullable embed: scan into temporary variables
if f.IsNullableEmbed && len(f.NullableEmbedInfo) > 0 {
for _, info := range f.NullableEmbedInfo {
out = append(out, "&"+info.TempVarName)
}
continue
}

// append any embedded fields
if len(f.EmbedFields) > 0 {
for _, embed := range f.EmbedFields {
Expand All @@ -227,6 +235,77 @@ func (v QueryValue) Scan() string {
return "\n" + strings.Join(out, ",\n")
}

// HasNullableEmbeds returns true if the query value has nullable embed fields
func (v QueryValue) HasNullableEmbeds() bool {
if v.Struct == nil {
return false
}
for _, f := range v.Struct.Fields {
if f.IsNullableEmbed {
return true
}
}
return false
}

// NullableEmbedDecls generates declarations for nullable embed temporary variables
func (v QueryValue) NullableEmbedDecls() string {
if v.Struct == nil {
return ""
}
var lines []string
for _, f := range v.Struct.Fields {
if !f.IsNullableEmbed {
continue
}
for _, info := range f.NullableEmbedInfo {
lines = append(lines, fmt.Sprintf("var %s %s", info.TempVarName, info.ScanType))
}
}
if len(lines) == 0 {
return ""
}
return "\n" + strings.Join(lines, "\n")
}

// NullableEmbedAssigns generates post-scan code to construct nullable embed structs
func (v QueryValue) NullableEmbedAssigns() string {
if v.Struct == nil {
return ""
}
var blocks []string
for _, f := range v.Struct.Fields {
if !f.IsNullableEmbed || len(f.NullableEmbedInfo) == 0 {
continue
}

// Build the validity check: any field non-nil means the row exists
var validChecks []string
for _, info := range f.NullableEmbedInfo {
validChecks = append(validChecks, info.ValidExpr)
}

// Build the struct assignment
modelType := strings.TrimPrefix(f.Type, "*")
var assignments []string
for _, info := range f.NullableEmbedInfo {
assignments = append(assignments, fmt.Sprintf("%s: %s,", info.StructField, info.AssignExpr))
}

block := fmt.Sprintf("if %s {\n%s.%s = &%s{\n%s\n}\n}",
strings.Join(validChecks, " || "),
v.Name, f.Name,
modelType,
strings.Join(assignments, "\n"),
)
blocks = append(blocks, block)
}
if len(blocks) == 0 {
return ""
}
return "\n" + strings.Join(blocks, "\n")
}

// Deprecated: This method does not respect the Emit field set on the
// QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should
// not be used other places.
Expand Down
45 changes: 43 additions & 2 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ type goEmbed struct {
modelType string
modelName string
fields []Field
nullable bool
}

// look through all the structs and attempt to find a matching one to embed
// We need the name of the struct and its field names.
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *goEmbed {
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string, nullable bool) *goEmbed {
if embed == nil {
return nil
}
Expand All @@ -147,6 +148,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string
modelType: s.Name,
modelName: s.Name,
fields: fields,
nullable: nullable,
}
}

Expand Down Expand Up @@ -304,7 +306,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
columns = append(columns, goColumn{
id: i,
Column: c,
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema, c.IsNullableEmbed),
})
}
var err error
Expand Down Expand Up @@ -396,6 +398,11 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st
}
if c.embed == nil {
f.Type = goType(req, options, c.Column)
} else if c.embed.nullable {
f.Type = "*" + c.embed.modelType
f.EmbedFields = c.embed.fields
f.IsNullableEmbed = true
f.NullableEmbedInfo = computeNullableEmbedInfo(c.embed)
} else {
f.Type = c.embed.modelType
f.EmbedFields = c.embed.fields
Expand Down Expand Up @@ -435,6 +442,40 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st
return &gs, nil
}

// computeNullableEmbedInfo computes scan metadata for nullable embed fields.
// For each field in the embed, we scan into a pointer-typed temporary variable,
// then check if any temp var is non-nil to construct the struct.
func computeNullableEmbedInfo(embed *goEmbed) []NullableEmbedFieldInfo {
var infos []NullableEmbedFieldInfo
for _, f := range embed.fields {
varName := "nembed" + embed.modelName + f.Name
originalType := f.Type
var scanType, validExpr, assignExpr string

if strings.HasPrefix(originalType, "*") {
// Already a pointer type (e.g., pgx nullable), scan directly
scanType = originalType
validExpr = varName + " != nil"
assignExpr = varName
} else {
// Wrap in pointer for nullable scan
scanType = "*" + originalType
validExpr = varName + " != nil"
assignExpr = "*" + varName
}

infos = append(infos, NullableEmbedFieldInfo{
TempVarName: varName,
ScanType: scanType,
ValidExpr: validExpr,
AssignExpr: assignExpr,
StructField: f.Name,
OriginalType: originalType,
})
}
return infos
}

func checkIncompatibleFieldTypes(fields []Field) error {
fieldTypes := map[string]string{}
for _, field := range fields {
Expand Down
4 changes: 4 additions & 0 deletions internal/codegen/golang/templates/pgx/batchCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, e
defer rows.Close()
for rows.Next() {
var {{.Ret.Name}} {{.Ret.Type}}
{{- .Ret.NullableEmbedDecls}}
if err := rows.Scan({{.Ret.Scan}}); err != nil {
return err
}
{{- .Ret.NullableEmbedAssigns}}
items = append(items, {{.Ret.ReturnName}})
}
return rows.Err()
Expand All @@ -110,6 +112,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}},
defer b.br.Close()
for t := 0; t < b.tot; t++ {
var {{.Ret.Name}} {{.Ret.Type}}
{{- .Ret.NullableEmbedDecls}}
if b.closed {
if f != nil {
f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed)
Expand All @@ -118,6 +121,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}},
}
row := b.br.QueryRow()
err := row.Scan({{.Ret.Scan}})
{{- .Ret.NullableEmbedAssigns}}
if f != nil {
f(t, {{.Ret.ReturnName}}, err)
}
Expand Down
4 changes: 4 additions & 0 deletions internal/codegen/golang/templates/pgx/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
var {{.Ret.Name}} {{.Ret.Type}}
{{- end}}
{{- .Ret.NullableEmbedDecls}}
err := row.Scan({{.Ret.Scan}})
{{- .Ret.NullableEmbedAssigns}}
{{- if $.WrapErrors}}
if err != nil {
err = fmt.Errorf("query {{.MethodName}}: %w", err)
Expand Down Expand Up @@ -67,9 +69,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
{{end -}}
for rows.Next() {
var {{.Ret.Name}} {{.Ret.Type}}
{{- .Ret.NullableEmbedDecls}}
if err := rows.Scan({{.Ret.Scan}}); err != nil {
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
}
{{- .Ret.NullableEmbedAssigns}}
items = append(items, {{.Ret.ReturnName}})
}
if err := rows.Err(); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/codegen/golang/templates/stdlib/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
var {{.Ret.Name}} {{.Ret.Type}}
{{- end}}
{{- .Ret.NullableEmbedDecls}}
err := row.Scan({{.Ret.Scan}})
{{- .Ret.NullableEmbedAssigns}}
{{- if $.WrapErrors}}
if err != nil {
err = fmt.Errorf("query {{.MethodName}}: %w", err)
Expand All @@ -53,9 +55,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}
{{end -}}
for rows.Next() {
var {{.Ret.Name}} {{.Ret.Type}}
{{- .Ret.NullableEmbedDecls}}
if err := rows.Scan({{.Ret.Scan}}); err != nil {
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
}
{{- .Ret.NullableEmbedAssigns}}
items = append(items, {{.Ret.ReturnName}})
}
if err := rows.Close(); err != nil {
Expand Down
35 changes: 18 additions & 17 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,24 @@ func convertTypeName(id *analyzer.Identifier) *ast.TypeName {
func convertColumn(c *analyzer.Column) *Column {
length := int(c.Length)
return &Column{
Name: c.Name,
OriginalName: c.OriginalName,
DataType: c.DataType,
NotNull: c.NotNull,
Unsigned: c.Unsigned,
IsArray: c.IsArray,
ArrayDims: int(c.ArrayDims),
Comment: c.Comment,
Length: &length,
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
Scope: c.Scope,
Table: convertTableName(c.Table),
TableAlias: c.TableAlias,
Type: convertTypeName(c.Type),
EmbedTable: convertTableName(c.EmbedTable),
IsSqlcSlice: c.IsSqlcSlice,
Name: c.Name,
OriginalName: c.OriginalName,
DataType: c.DataType,
NotNull: c.NotNull,
Unsigned: c.Unsigned,
IsArray: c.IsArray,
ArrayDims: int(c.ArrayDims),
Comment: c.Comment,
Length: &length,
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
Scope: c.Scope,
Table: convertTableName(c.Table),
TableAlias: c.TableAlias,
Type: convertTypeName(c.Type),
EmbedTable: convertTableName(c.EmbedTable),
IsSqlcSlice: c.IsSqlcSlice,
NullableEmbed: c.IsNullableEmbed,
}
}

Expand Down
5 changes: 3 additions & 2 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
// add a column with a reference to an embedded table
if embed, ok := qc.embeds.Find(n); ok {
cols = append(cols, &Column{
Name: embed.Table.Name,
EmbedTable: embed.Table,
Name: embed.Table.Name,
EmbedTable: embed.Table,
NullableEmbed: embed.Nullable,
})
continue
}
Expand Down
Loading
Loading