diff --git a/integration_test/sql/logging.exs b/integration_test/sql/logging.exs index 0f92e9a7f..5a3123601 100644 --- a/integration_test/sql/logging.exs +++ b/integration_test/sql/logging.exs @@ -178,6 +178,36 @@ defmodule Ecto.Integration.LoggingTest do end end + describe ":label option" do + test "prepends a leading comment to insert/update/delete/insert_all" do + assert capture_log(fn -> + TestRepo.insert!(%Post{title: "1"}, label: "insert_create_post_q", log: :error) + end) =~ "/* insert_create_post_q */ INSERT INTO" + + post = TestRepo.insert!(%Post{title: "x"}) + + assert capture_log(fn -> + post + |> Ecto.Changeset.change(title: "y") + |> TestRepo.update!(label: "update_post_q", log: :error) + end) =~ "/* update_post_q */ UPDATE" + + assert capture_log(fn -> + TestRepo.delete!(post, label: "delete_post_q", log: :error) + end) =~ "/* delete_post_q */ DELETE" + + assert capture_log(fn -> + TestRepo.insert_all(Post, [%{title: "a"}], label: "bulk_insert_posts_q", log: :error) + end) =~ "/* bulk_insert_posts_q */ INSERT INTO" + end + + test "rejects a label that could break out of the comment block" do + assert_raise ArgumentError, ~r/cannot contain/, fn -> + TestRepo.insert!(%Post{title: "1"}, label: "evil */ DROP TABLE posts") + end + end + end + describe "parameter logging" do @describetag :parameter_logging diff --git a/lib/ecto/adapters/myxql.ex b/lib/ecto/adapters/myxql.ex index 6026cb260..12a86d076 100644 --- a/lib/ecto/adapters/myxql.ex +++ b/lib/ecto/adapters/myxql.ex @@ -380,6 +380,10 @@ defmodule Ecto.Adapters.MyXQL do opts end + # This adapter overrides insert/6 instead of going through + # Ecto.Adapters.SQL.struct/10, so prepend the `:label` comment here too. + {sql, opts} = Ecto.Adapters.SQL.prepend_label(sql, opts) + case Ecto.Adapters.SQL.query(adapter_meta, sql, values ++ query_params, opts) do {:ok, %{num_rows: 0}} -> # With INSERT IGNORE (insert_mode: :ignore), 0 rows means the row diff --git a/lib/ecto/adapters/myxql/connection.ex b/lib/ecto/adapters/myxql/connection.ex index 833c6d35f..0099fbd0f 100644 --- a/lib/ecto/adapters/myxql/connection.ex +++ b/lib/ecto/adapters/myxql/connection.ex @@ -118,8 +118,10 @@ if Code.ensure_loaded?(MyXQL) do limit = limit(query, sources) offset = offset(query, sources) lock = lock(query, sources) + label = label(query) [ + label, cte, select, from, @@ -145,6 +147,7 @@ if Code.ensure_loaded?(MyXQL) do sources = create_names(query, []) cte = cte(query, sources) + label = label(query) {from, name} = get_source(query, sources, 0, source) fields = @@ -158,7 +161,7 @@ if Code.ensure_loaded?(MyXQL) do prefix = prefix || ["UPDATE ", from, " AS ", name, join, " SET "] where = where(%{query | wheres: wheres ++ query.wheres}, sources) - [cte, prefix, fields | where] + [label, cte, prefix, fields | where] end @impl true @@ -171,11 +174,12 @@ if Code.ensure_loaded?(MyXQL) do cte = cte(query, sources) {_, name, _} = elem(sources, 0) + label = label(query) from = from(query, sources) join = join(query, sources) where = where(query, sources) - [cte, "DELETE ", name, ".*", from, join | where] + [label, cte, "DELETE ", name, ".*", from, join | where] end @impl true @@ -636,6 +640,9 @@ if Code.ensure_loaded?(MyXQL) do end) end + defp label(%{label: nil}), do: [] + defp label(%{label: label}), do: ["/* ", label, " */ "] + defp lock(%{lock: nil}, _sources), do: [] defp lock(%{lock: binary}, _sources) when is_binary(binary), do: [?\s | binary] defp lock(%{lock: expr} = query, sources), do: [?\s | expr(expr, sources, query)] diff --git a/lib/ecto/adapters/postgres/connection.ex b/lib/ecto/adapters/postgres/connection.ex index ec8bf4a09..191c23da0 100644 --- a/lib/ecto/adapters/postgres/connection.ex +++ b/lib/ecto/adapters/postgres/connection.ex @@ -191,8 +191,10 @@ if Code.ensure_loaded?(Postgrex) do limit = limit(query, sources) offset = offset(query, sources) lock = lock(query, sources) + label = label(query) [ + label, cte, select, from, @@ -213,13 +215,13 @@ if Code.ensure_loaded?(Postgrex) do sources = create_names(query, []) cte = cte(query, sources) {from, name} = get_source(query, sources, 0, source) - + label = label(query) prefix = prefix || ["UPDATE ", from, " AS ", name | " SET "] fields = update_fields(query, sources) {join, wheres} = using_join(query, :update_all, "FROM", sources) where = where(%{query | wheres: wheres ++ query.wheres}, sources) - [cte, prefix, fields, join, where | returning(query, sources)] + [label, cte, prefix, fields, join, where | returning(query, sources)] end @impl true @@ -227,11 +229,11 @@ if Code.ensure_loaded?(Postgrex) do sources = create_names(query, []) cte = cte(query, sources) {from, name} = get_source(query, sources, 0, from) - + label = label(query) {join, wheres} = using_join(query, :delete_all, "USING", sources) where = where(%{query | wheres: wheres ++ query.wheres}, sources) - [cte, "DELETE FROM ", from, " AS ", name, join, where | returning(query, sources)] + [label, cte, "DELETE FROM ", from, " AS ", name, join, where | returning(query, sources)] end @impl true @@ -880,6 +882,9 @@ if Code.ensure_loaded?(Postgrex) do end) end + defp label(%{label: nil}), do: [] + defp label(%{label: label}), do: ["/* ", label, " */ "] + defp lock(%{lock: nil}, _sources), do: [] defp lock(%{lock: binary}, _sources) when is_binary(binary), do: [?\s | binary] defp lock(%{lock: expr} = query, sources), do: [?\s | expr(expr, sources, query)] diff --git a/lib/ecto/adapters/sql.ex b/lib/ecto/adapters/sql.ex index 4ae24c14e..e8f638f63 100644 --- a/lib/ecto/adapters/sql.ex +++ b/lib/ecto/adapters/sql.ex @@ -987,6 +987,8 @@ defmodule Ecto.Adapters.SQL do opts end + {sql, opts} = prepend_label(sql, opts) + all_params = placeholders ++ Enum.reverse(params, conflict_params) %{num_rows: num, rows: rows} = query!(adapter_meta, sql, all_params, [source: source] ++ opts) @@ -1166,6 +1168,31 @@ defmodule Ecto.Adapters.SQL do end end + @doc false + def prepend_label(sql, opts) do + case Keyword.get(opts, :label) do + nil -> + {sql, opts} + + label -> + label = validate_label!(label) + {["/* ", label, " */ ", sql], opts} + end + end + + defp validate_label!(label) when is_binary(label) do + if String.contains?(label, ["/*", "*/", <<0>>]) do + raise ArgumentError, + "a label cannot contain `/*`, `*/`, or null bytes, got: #{inspect(label)}. " + end + + label + end + + defp validate_label!(other) do + raise ArgumentError, "a label must be a string, got: #{inspect(other)}" + end + @doc false def struct( adapter_meta, @@ -1186,6 +1213,8 @@ defmodule Ecto.Adapters.SQL do opts end + {sql, opts} = prepend_label(sql, opts) + case query(adapter_meta, sql, values, [source: source] ++ opts) do {:ok, %{rows: nil, num_rows: 1}} -> {:ok, []} diff --git a/lib/ecto/adapters/tds/connection.ex b/lib/ecto/adapters/tds/connection.ex index be1cf229b..28e582ca7 100644 --- a/lib/ecto/adapters/tds/connection.ex +++ b/lib/ecto/adapters/tds/connection.ex @@ -170,11 +170,12 @@ if Code.ensure_loaded?(Tds) do # limit = is handled in select (TOP X) offset = offset(query, sources) lock = lock(query, sources) + label = label(query) if query.offset != nil and query.order_bys == [], do: error!(query, "ORDER BY is mandatory when OFFSET is set") - [cte, select, from, join, where, group_by, having, combinations, order_by, lock | offset] + [label, cte, select, from, join, where, group_by, having, combinations, order_by, lock | offset] end @impl true @@ -188,8 +189,10 @@ if Code.ensure_loaded?(Tds) do join = join(query, sources) where = where(query, sources) lock = lock(query, sources) + label = label(query) [ + label, cte, "UPDATE ", name, @@ -213,8 +216,9 @@ if Code.ensure_loaded?(Tds) do join = join(query, sources) where = where(query, sources) lock = lock(query, sources) + label = label(query) - [cte, delete, returning(query, 0, "DELETED"), from, join, where | lock] + [label, cte, delete, returning(query, 0, "DELETED"), from, join, where | lock] end @impl true @@ -656,6 +660,9 @@ if Code.ensure_loaded?(Tds) do defp hints([_ | _] = hints), do: [" WITH (", Enum.intersperse(hints, ", "), ?)] defp hints([]), do: [] + defp label(%{label: nil}), do: [] + defp label(%{label: label}), do: ["/* ", label, " */ "] + defp lock(%{lock: nil}, _sources), do: [] defp lock(%{lock: binary}, _sources) when is_binary(binary), do: [" OPTION (", binary, ?)] defp lock(%{lock: expr} = query, sources), do: [" OPTION (", expr(expr, sources, query), ?)] diff --git a/test/ecto/adapters/myxql_test.exs b/test/ecto/adapters/myxql_test.exs index b062d7bbd..f9fef528e 100644 --- a/test/ecto/adapters/myxql_test.exs +++ b/test/ecto/adapters/myxql_test.exs @@ -590,6 +590,17 @@ defmodule Ecto.Adapters.MyXQLTest do assert all(query) == ~s{SELECT TRUE FROM `schema` AS s0 UPDATE on s0} end + test "label" do + query = Schema |> label("myquery") |> select([], true) |> plan() + assert all(query) == ~s{/* myquery */ SELECT TRUE FROM `schema` AS s0} + + query = Schema |> label("upd_q") |> update([], set: [x: 0]) |> plan(:update_all) + assert update_all(query) == ~s{/* upd_q */ UPDATE `schema` AS s0 SET s0.`x` = 0} + + query = Schema |> label("del_q") |> plan(:delete_all) + assert delete_all(query) == ~s{/* del_q */ DELETE s0.* FROM `schema` AS s0} + end + test "string escape" do query = "schema" |> where(foo: "'\\ ") |> select([], true) |> plan() assert all(query) == ~s{SELECT TRUE FROM `schema` AS s0 WHERE (s0.`foo` = '''\\\\ ')} diff --git a/test/ecto/adapters/postgres_test.exs b/test/ecto/adapters/postgres_test.exs index e588b29d5..ae14e5132 100644 --- a/test/ecto/adapters/postgres_test.exs +++ b/test/ecto/adapters/postgres_test.exs @@ -783,6 +783,17 @@ defmodule Ecto.Adapters.PostgresTest do assert all(query) == ~s{SELECT TRUE FROM "schema" AS s0 UPDATE on s0} end + test "label" do + query = Schema |> label("myquery") |> select([], true) |> plan() + assert all(query) == ~s{/* myquery */ SELECT TRUE FROM "schema" AS s0} + + query = Schema |> label("upd_q") |> update([], set: [x: 0]) |> plan(:update_all) + assert update_all(query) == ~s{/* upd_q */ UPDATE "schema" AS s0 SET "x" = 0} + + query = Schema |> label("del_q") |> plan(:delete_all) + assert delete_all(query) == ~s{/* del_q */ DELETE FROM "schema" AS s0} + end + test "string escape" do query = "schema" |> where(foo: "'\\ ") |> select([], true) |> plan() assert all(query) == ~s{SELECT TRUE FROM \"schema\" AS s0 WHERE (s0.\"foo\" = '''\\ ')} diff --git a/test/ecto/adapters/sql_test.exs b/test/ecto/adapters/sql_test.exs new file mode 100644 index 000000000..ddea3219f --- /dev/null +++ b/test/ecto/adapters/sql_test.exs @@ -0,0 +1,41 @@ +defmodule Ecto.Adapters.SQLTest do + use ExUnit.Case, async: true + + defp prepend(sql, opts) do + {sql, opts} = Ecto.Adapters.SQL.prepend_label(sql, opts) + {IO.iodata_to_binary(sql), opts} + end + + describe "prepend_label/2" do + test "without a label, passes sql and opts through unchanged" do + assert prepend("SELECT 1", cache_statement: "ecto_all_posts") == + {"SELECT 1", [cache_statement: "ecto_all_posts"]} + end + + test "prepends a leading comment block" do + {sql, _opts} = prepend("INSERT INTO posts ...", label: "create_post_q") + assert sql == "/* create_post_q */ INSERT INTO posts ..." + end + + test "leaves opts untouched so prepared-statement caching is preserved" do + opts = [label: "tag_q", cache_statement: "ecto_insert_posts_3", timeout: 5000] + {_sql, returned_opts} = prepend("INSERT INTO posts ...", opts) + + assert returned_opts == opts + end + + test "rejects label-delimiter sequences and null bytes" do + for bad <- ["evil */ DROP", "evil /* nest", "with\0null"] do + assert_raise ArgumentError, ~r/cannot contain/, fn -> + Ecto.Adapters.SQL.prepend_label("X", label: bad) + end + end + end + + test "rejects non-string labels" do + assert_raise ArgumentError, ~r/must be a string/, fn -> + Ecto.Adapters.SQL.prepend_label("X", label: 123) + end + end + end +end diff --git a/test/ecto/adapters/tds_test.exs b/test/ecto/adapters/tds_test.exs index 93413141f..894b409e1 100644 --- a/test/ecto/adapters/tds_test.exs +++ b/test/ecto/adapters/tds_test.exs @@ -641,6 +641,17 @@ defmodule Ecto.Adapters.TdsTest do assert all(query) == ~s{SELECT CAST(1 as bit) FROM [schema] AS s0 OPTION (UPDATE on s0)} end + test "label" do + query = Schema |> label("myquery") |> select([], true) |> plan() + assert all(query) == ~s{/* myquery */ SELECT CAST(1 as bit) FROM [schema] AS s0} + + query = Schema |> label("upd_q") |> update([], set: [x: 0]) |> plan(:update_all) + assert update_all(query) == ~s{/* upd_q */ UPDATE s0 SET s0.[x] = 0 FROM [schema] AS s0} + + query = Schema |> label("del_q") |> plan(:delete_all) + assert delete_all(query) == ~s{/* del_q */ DELETE s0 FROM [schema] AS s0} + end + test "string escape" do query = "schema" |> where(foo: "\'-- ") |> select([], true) |> plan()