diff --git a/internal/analysis/analysis.pb.go b/internal/analysis/analysis.pb.go index e039cd6162..eb63e06f1e 100644 --- a/internal/analysis/analysis.pb.go +++ b/internal/analysis/analysis.pb.go @@ -104,7 +104,8 @@ type Column struct { TableAlias string `protobuf:"bytes,14,opt,name=table_alias,json=tableAlias,proto3" json:"table_alias,omitempty"` Type *Identifier `protobuf:"bytes,15,opt,name=type,proto3" json:"type,omitempty"` EmbedTable *Identifier `protobuf:"bytes,16,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` - IsSqlcSlice bool `protobuf:"varint,17,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` + IsSqlcSlice bool `protobuf:"varint,17,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` + IsNullableEmbed bool `protobuf:"varint,18,opt,name=is_nullable_embed,json=isNullableEmbed,proto3" json:"is_nullable_embed,omitempty"` } func (x *Column) Reset() { @@ -258,6 +259,13 @@ func (x *Column) GetIsSqlcSlice() bool { return false } +func (x *Column) GetIsNullableEmbed() bool { + if x != nil { + return x.IsNullableEmbed + } + return false +} + type Parameter struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 654500429a..462926b1b6 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -213,6 +213,8 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column { } } + out.IsNullableEmbed = c.NullableEmbed + return out } diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 2a63b6d342..784ecf94da 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -10,6 +10,16 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) +// NullableEmbedFieldInfo stores metadata for scanning nullable embed fields +type NullableEmbedFieldInfo struct { + TempVarName string // ex: "nembedPostID" + ScanType string // ex: "sql.NullInt32" for stdlib, "*int32" for pgx + ValidExpr string // ex: "nembedPostID.Valid" or "nembedPostID != nil" + AssignExpr string // ex: "nembedPostID.Int32" or "*nembedPostID" + StructField string // ex: "ID" + OriginalType string // original type in the model struct (ex: "int32") +} + type Field struct { Name string // CamelCased name for Go DBName string // Name as used in the DB @@ -19,6 +29,10 @@ type Field struct { Column *plugin.Column // EmbedFields contains the embedded fields that require scanning. EmbedFields []Field + // IsNullableEmbed indicates this field is a nullable embed (*Struct) + IsNullableEmbed bool + // NullableEmbedInfo stores scan metadata for each embedded field + NullableEmbedInfo []NullableEmbedFieldInfo } func (gf Field) Tag() string { diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7df56a0a41..f6ee700929 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -383,7 +383,9 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu keepTypes[query.Ret.Type()] = struct{}{} if query.Ret.IsStruct() { for _, field := range query.Ret.Struct.Fields { - keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{} + trimmedType := strings.TrimPrefix(field.Type, "[]") + trimmedType = strings.TrimPrefix(trimmedType, "*") + keepTypes[trimmedType] = struct{}{} for _, embedField := range field.EmbedFields { keepTypes[embedField.Type] = struct{}{} } diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index ccca4f603c..3d38e7e069 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -319,6 +319,12 @@ func (i *importer) queryImports(filename string) fileImports { if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { return true } + // Check nullable embed scan types + for _, info := range f.NullableEmbedInfo { + if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) { + return true + } + } } } if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) { @@ -459,6 +465,11 @@ func (i *importer) batchImports() fileImports { if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { return true } + for _, info := range f.NullableEmbedInfo { + if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) { + return true + } + } } } if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..697e213c29 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -201,6 +201,14 @@ func (v QueryValue) Scan() string { } else { for _, f := range v.Struct.Fields { + // nullable embed: scan into temporary variables + if f.IsNullableEmbed && len(f.NullableEmbedInfo) > 0 { + for _, info := range f.NullableEmbedInfo { + out = append(out, "&"+info.TempVarName) + } + continue + } + // append any embedded fields if len(f.EmbedFields) > 0 { for _, embed := range f.EmbedFields { @@ -227,6 +235,77 @@ func (v QueryValue) Scan() string { return "\n" + strings.Join(out, ",\n") } +// HasNullableEmbeds returns true if the query value has nullable embed fields +func (v QueryValue) HasNullableEmbeds() bool { + if v.Struct == nil { + return false + } + for _, f := range v.Struct.Fields { + if f.IsNullableEmbed { + return true + } + } + return false +} + +// NullableEmbedDecls generates declarations for nullable embed temporary variables +func (v QueryValue) NullableEmbedDecls() string { + if v.Struct == nil { + return "" + } + var lines []string + for _, f := range v.Struct.Fields { + if !f.IsNullableEmbed { + continue + } + for _, info := range f.NullableEmbedInfo { + lines = append(lines, fmt.Sprintf("var %s %s", info.TempVarName, info.ScanType)) + } + } + if len(lines) == 0 { + return "" + } + return "\n" + strings.Join(lines, "\n") +} + +// NullableEmbedAssigns generates post-scan code to construct nullable embed structs +func (v QueryValue) NullableEmbedAssigns() string { + if v.Struct == nil { + return "" + } + var blocks []string + for _, f := range v.Struct.Fields { + if !f.IsNullableEmbed || len(f.NullableEmbedInfo) == 0 { + continue + } + + // Build the validity check: any field non-nil means the row exists + var validChecks []string + for _, info := range f.NullableEmbedInfo { + validChecks = append(validChecks, info.ValidExpr) + } + + // Build the struct assignment + modelType := strings.TrimPrefix(f.Type, "*") + var assignments []string + for _, info := range f.NullableEmbedInfo { + assignments = append(assignments, fmt.Sprintf("%s: %s,", info.StructField, info.AssignExpr)) + } + + block := fmt.Sprintf("if %s {\n%s.%s = &%s{\n%s\n}\n}", + strings.Join(validChecks, " || "), + v.Name, f.Name, + modelType, + strings.Join(assignments, "\n"), + ) + blocks = append(blocks, block) + } + if len(blocks) == 0 { + return "" + } + return "\n" + strings.Join(blocks, "\n") +} + // Deprecated: This method does not respect the Emit field set on the // QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should // not be used other places. diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 0820488f9d..8d3f3f3d5a 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -120,11 +120,12 @@ type goEmbed struct { modelType string modelName string fields []Field + nullable bool } // look through all the structs and attempt to find a matching one to embed // We need the name of the struct and its field names. -func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *goEmbed { +func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string, nullable bool) *goEmbed { if embed == nil { return nil } @@ -147,6 +148,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string modelType: s.Name, modelName: s.Name, fields: fields, + nullable: nullable, } } @@ -304,7 +306,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] columns = append(columns, goColumn{ id: i, Column: c, - embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), + embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema, c.IsNullableEmbed), }) } var err error @@ -396,6 +398,11 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st } if c.embed == nil { f.Type = goType(req, options, c.Column) + } else if c.embed.nullable { + f.Type = "*" + c.embed.modelType + f.EmbedFields = c.embed.fields + f.IsNullableEmbed = true + f.NullableEmbedInfo = computeNullableEmbedInfo(c.embed) } else { f.Type = c.embed.modelType f.EmbedFields = c.embed.fields @@ -435,6 +442,40 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st return &gs, nil } +// computeNullableEmbedInfo computes scan metadata for nullable embed fields. +// For each field in the embed, we scan into a pointer-typed temporary variable, +// then check if any temp var is non-nil to construct the struct. +func computeNullableEmbedInfo(embed *goEmbed) []NullableEmbedFieldInfo { + var infos []NullableEmbedFieldInfo + for _, f := range embed.fields { + varName := "nembed" + embed.modelName + f.Name + originalType := f.Type + var scanType, validExpr, assignExpr string + + if strings.HasPrefix(originalType, "*") { + // Already a pointer type (e.g., pgx nullable), scan directly + scanType = originalType + validExpr = varName + " != nil" + assignExpr = varName + } else { + // Wrap in pointer for nullable scan + scanType = "*" + originalType + validExpr = varName + " != nil" + assignExpr = "*" + varName + } + + infos = append(infos, NullableEmbedFieldInfo{ + TempVarName: varName, + ScanType: scanType, + ValidExpr: validExpr, + AssignExpr: assignExpr, + StructField: f.Name, + OriginalType: originalType, + }) + } + return infos +} + func checkIncompatibleFieldTypes(fields []Field) error { fieldTypes := map[string]string{} for _, field := range fields { diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl index 35bd701bd3..12d76d2580 100644 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ b/internal/codegen/golang/templates/pgx/batchCode.tmpl @@ -91,9 +91,11 @@ func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, e defer rows.Close() for rows.Next() { var {{.Ret.Name}} {{.Ret.Type}} + {{- .Ret.NullableEmbedDecls}} if err := rows.Scan({{.Ret.Scan}}); err != nil { return err } + {{- .Ret.NullableEmbedAssigns}} items = append(items, {{.Ret.ReturnName}}) } return rows.Err() @@ -110,6 +112,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, defer b.br.Close() for t := 0; t < b.tot; t++ { var {{.Ret.Name}} {{.Ret.Type}} + {{- .Ret.NullableEmbedDecls}} if b.closed { if f != nil { f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed) @@ -118,6 +121,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, } row := b.br.QueryRow() err := row.Scan({{.Ret.Scan}}) + {{- .Ret.NullableEmbedAssigns}} if f != nil { f(t, {{.Ret.ReturnName}}, err) } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 59a88c880a..b118c68f5b 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -36,7 +36,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} + {{- .Ret.NullableEmbedDecls}} err := row.Scan({{.Ret.Scan}}) + {{- .Ret.NullableEmbedAssigns}} {{- if $.WrapErrors}} if err != nil { err = fmt.Errorf("query {{.MethodName}}: %w", err) @@ -67,9 +69,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret. {{end -}} for rows.Next() { var {{.Ret.Name}} {{.Ret.Type}} + {{- .Ret.NullableEmbedDecls}} if err := rows.Scan({{.Ret.Scan}}); err != nil { return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} } + {{- .Ret.NullableEmbedAssigns}} items = append(items, {{.Ret.ReturnName}}) } if err := rows.Err(); err != nil { diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 1e7f4e22a4..116666c530 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -27,7 +27,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} + {{- .Ret.NullableEmbedDecls}} err := row.Scan({{.Ret.Scan}}) + {{- .Ret.NullableEmbedAssigns}} {{- if $.WrapErrors}} if err != nil { err = fmt.Errorf("query {{.MethodName}}: %w", err) @@ -53,9 +55,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{end -}} for rows.Next() { var {{.Ret.Name}} {{.Ret.Type}} + {{- .Ret.NullableEmbedDecls}} if err := rows.Scan({{.Ret.Scan}}); err != nil { return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} } + {{- .Ret.NullableEmbedAssigns}} items = append(items, {{.Ret.ReturnName}}) } if err := rows.Close(); err != nil { diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..8a2ad83a27 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -45,23 +45,24 @@ func convertTypeName(id *analyzer.Identifier) *ast.TypeName { func convertColumn(c *analyzer.Column) *Column { length := int(c.Length) return &Column{ - Name: c.Name, - OriginalName: c.OriginalName, - DataType: c.DataType, - NotNull: c.NotNull, - Unsigned: c.Unsigned, - IsArray: c.IsArray, - ArrayDims: int(c.ArrayDims), - Comment: c.Comment, - Length: &length, - IsNamedParam: c.IsNamedParam, - IsFuncCall: c.IsFuncCall, - Scope: c.Scope, - Table: convertTableName(c.Table), - TableAlias: c.TableAlias, - Type: convertTypeName(c.Type), - EmbedTable: convertTableName(c.EmbedTable), - IsSqlcSlice: c.IsSqlcSlice, + Name: c.Name, + OriginalName: c.OriginalName, + DataType: c.DataType, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int(c.ArrayDims), + Comment: c.Comment, + Length: &length, + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + Scope: c.Scope, + Table: convertTableName(c.Table), + TableAlias: c.TableAlias, + Type: convertTypeName(c.Type), + EmbedTable: convertTableName(c.EmbedTable), + IsSqlcSlice: c.IsSqlcSlice, + NullableEmbed: c.IsNullableEmbed, } } diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index dbd486359a..e1987b04fa 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -260,8 +260,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er // add a column with a reference to an embedded table if embed, ok := qc.embeds.Find(n); ok { cols = append(cols, &Column{ - Name: embed.Table.Name, - EmbedTable: embed.Table, + Name: embed.Table.Name, + EmbedTable: embed.Table, + NullableEmbed: embed.Nullable, }) continue } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index b3cf9d6154..abbff6b044 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -35,7 +35,8 @@ type Column struct { Table *ast.TableName TableAlias string Type *ast.TypeName - EmbedTable *ast.TableName + EmbedTable *ast.TableName + NullableEmbed bool IsSqlcSlice bool // is this sqlc.slice() diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..3895084dc3 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..2fe9ddd532 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/models.go @@ -0,0 +1,20 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Post struct { + ID int32 `db:"id" json:"id"` + UserID int32 `db:"user_id" json:"user_id"` + Body sql.NullString `db:"body" json:"body"` +} + +type User struct { + ID int32 `db:"id" json:"id"` + Name string `db:"name" json:"name"` +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..dc9fc2768e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/go/query.sql.go @@ -0,0 +1,160 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const leftJoin = `-- name: LeftJoin :many +SELECT users.id, users.name, posts.id, posts.user_id, posts.body FROM users +LEFT JOIN posts ON users.id = posts.user_id +` + +type LeftJoinRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) LeftJoin(ctx context.Context) ([]LeftJoinRow, error) { + rows, err := q.db.Query(ctx, leftJoin) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LeftJoinRow + for rows.Next() { + var i LeftJoinRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + if err := rows.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ); err != nil { + return nil, err + } + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const leftJoinOne = `-- name: LeftJoinOne :one +SELECT users.id, users.name, posts.id, posts.user_id, posts.body FROM users +LEFT JOIN posts ON users.id = posts.user_id +LIMIT 1 +` + +type LeftJoinOneRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) LeftJoinOne(ctx context.Context) (LeftJoinOneRow, error) { + row := q.db.QueryRow(ctx, leftJoinOne) + var i LeftJoinOneRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + err := row.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ) + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + return i, err +} + +const nembedOnly = `-- name: NembedOnly :one +SELECT users.id, users.name FROM users WHERE id = $1 +` + +type NembedOnlyRow struct { + User *User `db:"user" json:"user"` +} + +func (q *Queries) NembedOnly(ctx context.Context, id int32) (NembedOnlyRow, error) { + row := q.db.QueryRow(ctx, nembedOnly, id) + var i NembedOnlyRow + var nembedUserID *int32 + var nembedUserName *string + err := row.Scan(&nembedUserID, &nembedUserName) + if nembedUserID != nil || nembedUserName != nil { + i.User = &User{ + ID: *nembedUserID, + Name: *nembedUserName, + } + } + return i, err +} + +const withAlias = `-- name: WithAlias :many +SELECT u.id, u.name, p.id, p.user_id, p.body FROM users u +LEFT JOIN posts p ON u.id = p.user_id +` + +type WithAliasRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) WithAlias(ctx context.Context) ([]WithAliasRow, error) { + rows, err := q.db.Query(ctx, withAlias) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WithAliasRow + for rows.Next() { + var i WithAliasRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + if err := rows.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ); err != nil { + return nil, err + } + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/query.sql new file mode 100644 index 0000000000..498f5bb3e7 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/query.sql @@ -0,0 +1,15 @@ +-- name: LeftJoin :many +SELECT sqlc.embed(users), sqlc.nembed(posts) FROM users +LEFT JOIN posts ON users.id = posts.user_id; + +-- name: LeftJoinOne :one +SELECT sqlc.embed(users), sqlc.nembed(posts) FROM users +LEFT JOIN posts ON users.id = posts.user_id +LIMIT 1; + +-- name: NembedOnly :one +SELECT sqlc.nembed(users) FROM users WHERE id = $1; + +-- name: WithAlias :many +SELECT sqlc.embed(u), sqlc.nembed(p) FROM users u +LEFT JOIN posts p ON u.id = p.user_id; diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/schema.sql b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/schema.sql new file mode 100644 index 0000000000..dcdabe0e5d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/schema.sql @@ -0,0 +1,10 @@ +CREATE TABLE users ( + id integer NOT NULL PRIMARY KEY, + name text NOT NULL +); + +CREATE TABLE posts ( + id integer NOT NULL PRIMARY KEY, + user_id integer NOT NULL, + body text +); diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..c74e245180 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/pgx/sqlc.json @@ -0,0 +1,15 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_json_tags": true, + "emit_db_tags": true + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..2fe9ddd532 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/models.go @@ -0,0 +1,20 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "database/sql" +) + +type Post struct { + ID int32 `db:"id" json:"id"` + UserID int32 `db:"user_id" json:"user_id"` + Body sql.NullString `db:"body" json:"body"` +} + +type User struct { + ID int32 `db:"id" json:"id"` + Name string `db:"name" json:"name"` +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..382dfebdb5 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,166 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const leftJoin = `-- name: LeftJoin :many +SELECT users.id, users.name, posts.id, posts.user_id, posts.body FROM users +LEFT JOIN posts ON users.id = posts.user_id +` + +type LeftJoinRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) LeftJoin(ctx context.Context) ([]LeftJoinRow, error) { + rows, err := q.db.QueryContext(ctx, leftJoin) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LeftJoinRow + for rows.Next() { + var i LeftJoinRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + if err := rows.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ); err != nil { + return nil, err + } + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const leftJoinOne = `-- name: LeftJoinOne :one +SELECT users.id, users.name, posts.id, posts.user_id, posts.body FROM users +LEFT JOIN posts ON users.id = posts.user_id +LIMIT 1 +` + +type LeftJoinOneRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) LeftJoinOne(ctx context.Context) (LeftJoinOneRow, error) { + row := q.db.QueryRowContext(ctx, leftJoinOne) + var i LeftJoinOneRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + err := row.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ) + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + return i, err +} + +const nembedOnly = `-- name: NembedOnly :one +SELECT users.id, users.name FROM users WHERE id = $1 +` + +type NembedOnlyRow struct { + User *User `db:"user" json:"user"` +} + +func (q *Queries) NembedOnly(ctx context.Context, id int32) (NembedOnlyRow, error) { + row := q.db.QueryRowContext(ctx, nembedOnly, id) + var i NembedOnlyRow + var nembedUserID *int32 + var nembedUserName *string + err := row.Scan(&nembedUserID, &nembedUserName) + if nembedUserID != nil || nembedUserName != nil { + i.User = &User{ + ID: *nembedUserID, + Name: *nembedUserName, + } + } + return i, err +} + +const withAlias = `-- name: WithAlias :many +SELECT u.id, u.name, p.id, p.user_id, p.body FROM users u +LEFT JOIN posts p ON u.id = p.user_id +` + +type WithAliasRow struct { + User User `db:"user" json:"user"` + Post *Post `db:"post" json:"post"` +} + +func (q *Queries) WithAlias(ctx context.Context) ([]WithAliasRow, error) { + rows, err := q.db.QueryContext(ctx, withAlias) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WithAliasRow + for rows.Next() { + var i WithAliasRow + var nembedPostID *int32 + var nembedPostUserID *int32 + var nembedPostBody *sql.NullString + if err := rows.Scan( + &i.User.ID, + &i.User.Name, + &nembedPostID, + &nembedPostUserID, + &nembedPostBody, + ); err != nil { + return nil, err + } + if nembedPostID != nil || nembedPostUserID != nil || nembedPostBody != nil { + i.Post = &Post{ + ID: *nembedPostID, + UserID: *nembedPostUserID, + Body: *nembedPostBody, + } + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..498f5bb3e7 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/query.sql @@ -0,0 +1,15 @@ +-- name: LeftJoin :many +SELECT sqlc.embed(users), sqlc.nembed(posts) FROM users +LEFT JOIN posts ON users.id = posts.user_id; + +-- name: LeftJoinOne :one +SELECT sqlc.embed(users), sqlc.nembed(posts) FROM users +LEFT JOIN posts ON users.id = posts.user_id +LIMIT 1; + +-- name: NembedOnly :one +SELECT sqlc.nembed(users) FROM users WHERE id = $1; + +-- name: WithAlias :many +SELECT sqlc.embed(u), sqlc.nembed(p) FROM users u +LEFT JOIN posts p ON u.id = p.user_id; diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..dcdabe0e5d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/schema.sql @@ -0,0 +1,10 @@ +CREATE TABLE users ( + id integer NOT NULL PRIMARY KEY, + name text NOT NULL +); + +CREATE TABLE posts ( + id integer NOT NULL PRIMARY KEY, + user_id integer NOT NULL, + body text +); diff --git a/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..45dbdc427c --- /dev/null +++ b/internal/endtoend/testdata/sqlc_nembed/postgresql/stdlib/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_json_tags": true, + "emit_db_tags": true + } + ] +} diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index 525ffc72ef..6d1a147916 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -655,8 +655,9 @@ type Column struct { IsSqlcSlice bool `protobuf:"varint,13,opt,name=is_sqlc_slice,json=isSqlcSlice,proto3" json:"is_sqlc_slice,omitempty"` EmbedTable *Identifier `protobuf:"bytes,14,opt,name=embed_table,json=embedTable,proto3" json:"embed_table,omitempty"` OriginalName string `protobuf:"bytes,15,opt,name=original_name,json=originalName,proto3" json:"original_name,omitempty"` - Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` - ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` + Unsigned bool `protobuf:"varint,16,opt,name=unsigned,proto3" json:"unsigned,omitempty"` + ArrayDims int32 `protobuf:"varint,17,opt,name=array_dims,json=arrayDims,proto3" json:"array_dims,omitempty"` + IsNullableEmbed bool `protobuf:"varint,18,opt,name=is_nullable_embed,json=isNullableEmbed,proto3" json:"is_nullable_embed,omitempty"` } func (x *Column) Reset() { @@ -803,6 +804,13 @@ func (x *Column) GetArrayDims() int32 { return 0 } +func (x *Column) GetIsNullableEmbed() bool { + if x != nil { + return x.IsNullableEmbed + } + return false +} + type Query struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/internal/sql/rewrite/embeds.go b/internal/sql/rewrite/embeds.go index 596c03be89..19524f6d99 100644 --- a/internal/sql/rewrite/embeds.go +++ b/internal/sql/rewrite/embeds.go @@ -7,15 +7,19 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/astutils" ) -// Embed is an instance of `sqlc.embed(param)` +// Embed is an instance of `sqlc.embed(param)` or `sqlc.nembed(param)` type Embed struct { - Table *ast.TableName - param string - Node *ast.ColumnRef + Table *ast.TableName + param string + Node *ast.ColumnRef + Nullable bool } // Orig string to replace func (e Embed) Orig() string { + if e.Nullable { + return fmt.Sprintf("sqlc.nembed(%s)", e.param) + } return fmt.Sprintf("sqlc.embed(%s)", e.param) } @@ -41,7 +45,11 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() + nullable := false switch { + case isNullableEmbed(node): + nullable = true + fallthrough case isEmbed(node): fun := node.(*ast.FuncCall) @@ -61,9 +69,10 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { } embeds = append(embeds, &Embed{ - Table: &ast.TableName{Name: param}, - param: param, - Node: node, + Table: &ast.TableName{Name: param}, + param: param, + Node: node, + Nullable: nullable, }) cr.Replace(node) @@ -89,3 +98,16 @@ func isEmbed(node ast.Node) bool { isValid := call.Func.Schema == "sqlc" && call.Func.Name == "embed" return isValid } + +func isNullableEmbed(node ast.Node) bool { + call, ok := node.(*ast.FuncCall) + if !ok { + return false + } + + if call.Func == nil { + return false + } + + return call.Func.Schema == "sqlc" && call.Func.Name == "nembed" +} diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 1182051d20..27881afd1d 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -29,7 +29,7 @@ func (v *sqlcFuncVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg, sqlc.narg and sqlc.slice // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed") { + if !(fn.Name == "arg" || fn.Name == "narg" || fn.Name == "slice" || fn.Name == "embed" || fn.Name == "nembed") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } diff --git a/protos/analysis/analysis.proto b/protos/analysis/analysis.proto index 60e53b56f1..7f15ab5408 100644 --- a/protos/analysis/analysis.proto +++ b/protos/analysis/analysis.proto @@ -26,6 +26,7 @@ message Column { Identifier type = 15; Identifier embed_table = 16; bool is_sqlc_slice = 17; + bool is_nullable_embed = 18; } message Parameter { diff --git a/protos/plugin/codegen.proto b/protos/plugin/codegen.proto index e6faf19bad..5f6d1285e4 100644 --- a/protos/plugin/codegen.proto +++ b/protos/plugin/codegen.proto @@ -100,6 +100,7 @@ message Column { string original_name = 15; bool unsigned = 16; int32 array_dims = 17; + bool is_nullable_embed = 18; } message Query {