From 38cf885c86b72556ec8070d16f6eda38e064574e Mon Sep 17 00:00:00 2001 From: cerredz <422michaelcerreto@gmail.com> Date: Fri, 15 May 2026 22:23:18 -0400 Subject: [PATCH] Add returning clause support --- QueryBuilder.Tests/InsertTests.cs | 20 ++++++ QueryBuilder.Tests/UpdateTests.cs | 22 ++++++ QueryBuilder/Compilers/Compiler.cs | 76 +++++++++++++++++++-- QueryBuilder/Compilers/PostgresCompiler.cs | 11 +++ QueryBuilder/Compilers/SqlServerCompiler.cs | 42 ++++++++++++ QueryBuilder/Query.Returning.cs | 48 +++++++++++++ 6 files changed, 213 insertions(+), 6 deletions(-) create mode 100644 QueryBuilder/Query.Returning.cs diff --git a/QueryBuilder.Tests/InsertTests.cs b/QueryBuilder.Tests/InsertTests.cs index 926e18b2..813b6ff0 100644 --- a/QueryBuilder.Tests/InsertTests.cs +++ b/QueryBuilder.Tests/InsertTests.cs @@ -75,6 +75,26 @@ public void InsertFromSubQueryWithCte() c[EngineCodes.PostgreSql]); } + [Fact] + public void InsertWithReturning() + { + var query = new Query("Books") + .AsInsert( + new[] { "Title", "Author" }, + new object[] { "SqlKata", "Kata" }) + .Returning("Id", "CreatedAt"); + + var postgres = Compilers.CompileFor(EngineCodes.PostgreSql, query); + Assert.Equal( + "INSERT INTO \"Books\" (\"Title\", \"Author\") VALUES ('SqlKata', 'Kata') RETURNING \"Id\", \"CreatedAt\"", + postgres.ToString()); + + var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, query); + Assert.Equal( + "INSERT INTO [Books] ([Title], [Author]) OUTPUT INSERTED.[Id], INSERTED.[CreatedAt] VALUES ('SqlKata', 'Kata')", + sqlServer.ToString()); + } + [Fact] public void InsertMultiRecords() { diff --git a/QueryBuilder.Tests/UpdateTests.cs b/QueryBuilder.Tests/UpdateTests.cs index bf3dd7d9..aeef6c8b 100644 --- a/QueryBuilder.Tests/UpdateTests.cs +++ b/QueryBuilder.Tests/UpdateTests.cs @@ -94,6 +94,28 @@ public void UpdateWithNullValues() c[EngineCodes.Firebird]); } + [Fact] + public void UpdateWithReturning() + { + var query = new Query("Books") + .Where("Id", 10) + .AsUpdate(new Dictionary + { + { "Title", "SqlKata" } + }) + .Returning("Id", "UpdatedAt"); + + var postgres = Compilers.CompileFor(EngineCodes.PostgreSql, query); + Assert.Equal( + "UPDATE \"Books\" SET \"Title\" = 'SqlKata' WHERE \"Id\" = 10 RETURNING \"Id\", \"UpdatedAt\"", + postgres.ToString()); + + var sqlServer = Compilers.CompileFor(EngineCodes.SqlServer, query); + Assert.Equal( + "UPDATE [Books] SET [Title] = 'SqlKata' OUTPUT INSERTED.[Id], INSERTED.[UpdatedAt] WHERE [Id] = 10", + sqlServer.ToString()); + } + [Fact] public void UpdateWithEmptyString() { diff --git a/QueryBuilder/Compilers/Compiler.cs b/QueryBuilder/Compilers/Compiler.cs index 5ac6c080..1ee82214 100644 --- a/QueryBuilder/Compilers/Compiler.cs +++ b/QueryBuilder/Compilers/Compiler.cs @@ -353,6 +353,8 @@ protected virtual SqlResult CompileUpdateQuery(Query query) var value = Parameter(ctx, Math.Abs(increment.Value)); var sign = increment.Value >= 0 ? "+" : "-"; + var returningBeforeWhere = CompileUpdateReturningBeforeWhere(ctx); + wheres = CompileWheres(ctx); if (!string.IsNullOrEmpty(wheres)) @@ -360,7 +362,9 @@ protected virtual SqlResult CompileUpdateQuery(Query query) wheres = " " + wheres; } - ctx.RawSql = $"UPDATE {table} SET {column} = {column} {sign} {value}{wheres}"; + var returningAfterWhere = CompileUpdateReturningAfterWhere(ctx); + + ctx.RawSql = $"UPDATE {table} SET {column} = {column} {sign} {value}{returningBeforeWhere}{wheres}{returningAfterWhere}"; return ctx; } @@ -376,6 +380,8 @@ protected virtual SqlResult CompileUpdateQuery(Query query) var sets = string.Join(", ", parts); + var returningBefore = CompileUpdateReturningBeforeWhere(ctx); + wheres = CompileWheres(ctx); if (!string.IsNullOrEmpty(wheres)) @@ -383,7 +389,9 @@ protected virtual SqlResult CompileUpdateQuery(Query query) wheres = " " + wheres; } - ctx.RawSql = $"UPDATE {table} SET {sets}{wheres}"; + var returningAfter = CompileUpdateReturningAfterWhere(ctx); + + ctx.RawSql = $"UPDATE {table} SET {sets}{returningBefore}{wheres}{returningAfter}"; return ctx; } @@ -426,10 +434,12 @@ protected virtual SqlResult CompileInsertQueryClause( { string columns = GetInsertColumnsList(clause.Columns); + var returningBefore = CompileInsertReturningBeforeSource(ctx); var subCtx = CompileSelectQuery(clause.Query); ctx.Bindings.AddRange(subCtx.Bindings); + var returningAfter = CompileInsertReturningAfterSource(ctx); - ctx.RawSql = $"{SingleInsertStartClause} {table}{columns} {subCtx.RawSql}"; + ctx.RawSql = $"{SingleInsertStartClause} {table}{columns}{returningBefore} {subCtx.RawSql}{returningAfter}"; return ctx; } @@ -443,15 +453,24 @@ protected virtual SqlResult CompileValueInsertClauses( var firstInsert = insertClauses.First(); string columns = GetInsertColumnsList(firstInsert.Columns); + + var returningBefore = CompileInsertReturningBeforeSource(ctx); var values = string.Join(", ", Parameterize(ctx, firstInsert.Values)); + var returningAfter = CompileInsertReturningAfterSource(ctx); - ctx.RawSql = $"{insertInto} {table}{columns} VALUES ({values})"; + ctx.RawSql = $"{insertInto} {table}{columns}{returningBefore} VALUES ({values})"; if (isMultiValueInsert) - return CompileRemainingInsertClauses(ctx, table, insertClauses); + { + ctx = CompileRemainingInsertClauses(ctx, table, insertClauses); + ctx.RawSql += returningAfter; + return ctx; + } - if (firstInsert.ReturnId && !string.IsNullOrEmpty(LastId)) + if (firstInsert.ReturnId && !HasReturning(ctx) && !string.IsNullOrEmpty(LastId)) ctx.RawSql += ";" + LastId; + else + ctx.RawSql += returningAfter; return ctx; } @@ -475,6 +494,51 @@ protected string GetInsertColumnsList(List columnList) return columns; } + protected bool HasReturning(SqlResult ctx) + { + return ctx.Query.HasComponent("returning", EngineCode); + } + + protected string CompileStandardReturning(SqlResult ctx) + { + var columns = ctx.Query + .GetComponents("returning", EngineCode) + .Select(x => CompileColumn(ctx, x)) + .ToList(); + + return columns.Any() ? $"RETURNING {string.Join(", ", columns)}" : null; + } + + protected virtual string CompileInsertReturningBeforeSource(SqlResult ctx) + { + return ""; + } + + protected virtual string CompileInsertReturningAfterSource(SqlResult ctx) + { + ThrowIfReturningUnsupported(ctx); + return ""; + } + + protected virtual string CompileUpdateReturningBeforeWhere(SqlResult ctx) + { + return ""; + } + + protected virtual string CompileUpdateReturningAfterWhere(SqlResult ctx) + { + ThrowIfReturningUnsupported(ctx); + return ""; + } + + protected void ThrowIfReturningUnsupported(SqlResult ctx) + { + if (HasReturning(ctx)) + { + throw new InvalidOperationException($"{EngineCode} compiler does not support returning clauses"); + } + } + protected virtual SqlResult CompileCteQuery(SqlResult ctx, Query query) { var cteFinder = new CteFinder(query, EngineCode); diff --git a/QueryBuilder/Compilers/PostgresCompiler.cs b/QueryBuilder/Compilers/PostgresCompiler.cs index 3b45d0e6..0e758d93 100644 --- a/QueryBuilder/Compilers/PostgresCompiler.cs +++ b/QueryBuilder/Compilers/PostgresCompiler.cs @@ -13,6 +13,17 @@ public PostgresCompiler() public override string EngineCode { get; } = EngineCodes.PostgreSql; public override bool SupportsFilterClause { get; set; } = true; + protected override string CompileInsertReturningAfterSource(SqlResult ctx) + { + var returning = CompileStandardReturning(ctx); + return string.IsNullOrEmpty(returning) ? "" : $" {returning}"; + } + + protected override string CompileUpdateReturningAfterWhere(SqlResult ctx) + { + var returning = CompileStandardReturning(ctx); + return string.IsNullOrEmpty(returning) ? "" : $" {returning}"; + } protected override string CompileBasicStringCondition(SqlResult ctx, BasicStringCondition x) { diff --git a/QueryBuilder/Compilers/SqlServerCompiler.cs b/QueryBuilder/Compilers/SqlServerCompiler.cs index 98b488bf..c5090010 100644 --- a/QueryBuilder/Compilers/SqlServerCompiler.cs +++ b/QueryBuilder/Compilers/SqlServerCompiler.cs @@ -14,6 +14,48 @@ public SqlServerCompiler() public override string EngineCode { get; } = EngineCodes.SqlServer; public bool UseLegacyPagination { get; set; } = false; + protected override string CompileInsertReturningBeforeSource(SqlResult ctx) + { + var output = CompileOutputReturning(ctx); + return string.IsNullOrEmpty(output) ? "" : $" {output}"; + } + + protected override string CompileUpdateReturningBeforeWhere(SqlResult ctx) + { + var output = CompileOutputReturning(ctx); + return string.IsNullOrEmpty(output) ? "" : $" {output}"; + } + + protected override string CompileInsertReturningAfterSource(SqlResult ctx) + { + return ""; + } + + protected override string CompileUpdateReturningAfterWhere(SqlResult ctx) + { + return ""; + } + + protected string CompileOutputReturning(SqlResult ctx) + { + var columns = ctx.Query + .GetComponents("returning", EngineCode) + .Select(x => CompileOutputColumn(ctx, x)) + .ToList(); + + return columns.Any() ? $"OUTPUT {string.Join(", ", columns)}" : null; + } + + protected string CompileOutputColumn(SqlResult ctx, AbstractColumn column) + { + if (column is Column normalColumn) + { + return $"INSERTED.{Wrap(normalColumn.Name)}"; + } + + return CompileColumn(ctx, column); + } + protected override SqlResult CompileSelectQuery(Query query) { if (!UseLegacyPagination || !query.HasOffset(EngineCode)) diff --git a/QueryBuilder/Query.Returning.cs b/QueryBuilder/Query.Returning.cs new file mode 100644 index 00000000..c2c3cb07 --- /dev/null +++ b/QueryBuilder/Query.Returning.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace SqlKata +{ + public partial class Query + { + public Query Returning(params string[] columns) + { + return Returning(columns.AsEnumerable()); + } + + public Query Returning(IEnumerable columns) + { + var columnsList = columns?.ToList(); + + if ((columnsList?.Count ?? 0) == 0) + { + throw new InvalidOperationException($"{nameof(columns)} cannot be null or empty"); + } + + columnsList = columnsList + .Select(x => Helper.ExpandExpression(x)) + .SelectMany(x => x) + .ToList(); + + foreach (var column in columnsList) + { + AddComponent("returning", new Column + { + Name = column + }); + } + + return this; + } + + public Query ReturningRaw(string sql, params object[] bindings) + { + return AddComponent("returning", new RawColumn + { + Expression = sql, + Bindings = bindings, + }); + } + } +}