From 3070b7afc337f8e1e2f7a52a842290e30330b4ab Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 10 Mar 2026 10:00:41 -0700 Subject: [PATCH] [ET-VK] Add symint infrastructure to VulkanBackend and ComputeGraph Extend the Vulkan backend runtime infrastructure to better support symbolic integer (symint) arguments. This is a prerequisite for operators that need to handle dynamic shapes via symint values. Changes: - VulkanBackend.cpp: Compute output offset from end of args instead of assuming outputs follow inputs directly. Add scalar-to-tensor input handling so that Int/Bool EValues can populate tensor inputs. Support symint inputs provided as raw Int EValues (not just scalar tensors). Add symint output handling to write values back as tensor or Int EValue. - ComputeGraph.h: Add SymInt case to extract_scalar() so operators can transparently read symint values as scalars. - ComputeGraph.cpp: Add Int fallback in read_symint() so values stored as plain Int (rather than SymInt objects) can be read uniformly. Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 81 ++++++++++++++----- .../vulkan/runtime/graph/ComputeGraph.cpp | 20 +++-- backends/vulkan/runtime/graph/ComputeGraph.h | 3 + 3 files changed, 77 insertions(+), 27 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index d4eeb9b1dd4..a084ce9a146 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { ComputeGraph* compute_graph = static_cast(handle); const size_t num_inputs = compute_graph->inputs().size(); + const size_t num_outputs = compute_graph->outputs().size(); bool should_propagate_resize = false; #ifdef ET_EVENT_TRACER_ENABLED runtime::EventTracer* event_tracer = context.event_tracer(); @@ -690,22 +691,51 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { for (size_t i = 0; i < num_inputs; i++) { const ValueRef iref = compute_graph->inputs()[i].value; if (compute_graph->val_is_tensor(iref)) { - VK_CHECK_COND(args[i]->isTensor()); - bool was_resized = - maybe_resize_input(compute_graph, i, args[i]->toTensor()); - should_propagate_resize = should_propagate_resize || was_resized; - compute_graph->maybe_cast_and_copy_into_staging( - compute_graph->inputs()[i].staging, - args[i]->toTensor().const_data_ptr(), - args[i]->toTensor().numel(), - equivalent_scalar_type(args[i]->toTensor().scalar_type())); + if (args[i]->isTensor()) { + bool was_resized = + maybe_resize_input(compute_graph, i, args[i]->toTensor()); + should_propagate_resize = should_propagate_resize || was_resized; + compute_graph->maybe_cast_and_copy_into_staging( + compute_graph->inputs()[i].staging, + args[i]->toTensor().const_data_ptr(), + args[i]->toTensor().numel(), + equivalent_scalar_type(args[i]->toTensor().scalar_type())); + } else if (args[i]->isInt() || args[i]->isBool()) { + int64_t val = + args[i]->isInt() ? args[i]->toInt() : (args[i]->toBool() ? 1 : 0); + vkapi::ScalarType tensor_dtype = compute_graph->dtype_of(iref); + if (tensor_dtype == vkapi::kFloat) { + float fval = static_cast(val); + compute_graph->maybe_cast_and_copy_into_staging( + compute_graph->inputs()[i].staging, &fval, 1, vkapi::kFloat); + } else if (tensor_dtype == vkapi::kInt) { + int32_t ival = static_cast(val); + compute_graph->maybe_cast_and_copy_into_staging( + compute_graph->inputs()[i].staging, &ival, 1, vkapi::kInt); + } else { + compute_graph->maybe_cast_and_copy_into_staging( + compute_graph->inputs()[i].staging, &val, 1, vkapi::kLong); + } + } else { + VK_THROW( + "Tensor input[", + i, + "] has unsupported EValue tag ", + static_cast(args[i]->tag)); + } } else if (compute_graph->val_is_symint(iref)) { - VK_CHECK_COND( - args[i]->isTensor(), - "Cannot handle symint arg to graph that is not derived from a " - "scalar tensor at the moment."); - bool was_updated = maybe_update_scalar_tensor( - compute_graph, iref, args[i]->toTensor()); + bool was_updated = false; + if (args[i]->isTensor()) { + was_updated = maybe_update_scalar_tensor( + compute_graph, iref, args[i]->toTensor()); + } else if (args[i]->isInt()) { + const int32_t new_val = static_cast(args[i]->toInt()); + const int32_t cur_val = compute_graph->read_symint(iref); + if (new_val != cur_val) { + compute_graph->set_symint(iref, new_val); + was_updated = true; + } + } // Since symint inputs may impact tensor's sizes, trigger a resize if // any symbolic integer shapes are updated. should_propagate_resize = should_propagate_resize || was_updated; @@ -770,14 +800,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { "ETVK_COPY_OUTPUTS", /* delegate_debug_id = */ -1); #endif // ET_EVENT_TRACER_ENABLED - for (size_t i = 0; i < compute_graph->outputs().size(); i++) { - const size_t o = i + num_inputs; + const size_t output_offset = args.size() - num_outputs; + for (size_t i = 0; i < num_outputs; i++) { + const size_t o = output_offset + i; const ValueRef oref = compute_graph->outputs()[i].value; if (compute_graph->val_is_tensor(oref)) { VK_CHECK_COND(args[o]->isTensor()); maybe_resize_output(compute_graph, i, args[o]->toTensor()); - // args holds inputs directly followed by outputs, so the i'th output - // for compute_graph corresponds to the o'th arg compute_graph->maybe_cast_and_copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), @@ -789,6 +818,20 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { // returned as an output, no action is required. else if (compute_graph->val_is_tref(oref)) { continue; + } else if (compute_graph->val_is_symint(oref)) { + const int32_t symint_val = compute_graph->read_symint(oref); + if (args[o]->isTensor()) { + executorch::aten::Tensor& out_tensor = args[o]->toTensor(); + executorch::aten::ScalarType dtype = out_tensor.scalar_type(); + if (dtype == executorch::aten::ScalarType::Int) { + *out_tensor.mutable_data_ptr() = symint_val; + } else if (dtype == executorch::aten::ScalarType::Long) { + *out_tensor.mutable_data_ptr() = + static_cast(symint_val); + } + } else if (args[o]->isInt()) { + *args[o] = EValue(static_cast(symint_val)); + } } else { VK_THROW( "Could not handle output with type ", diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index bb2df30a174..1ca56b152ea 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -452,14 +452,15 @@ ValueRef ComputeGraph::add_tensor( const utils::AxisMapLayout axis_map_layout) { ValueRef idx(static_cast(values_.size())); check_no_active_value_ptrs(); - values_.emplace_back(api::vTensor( - context(), - sizes, - dtype, - storage_type, - memory_layout, - false, - axis_map_layout)); + values_.emplace_back( + api::vTensor( + context(), + sizes, + dtype, + storage_type, + memory_layout, + false, + axis_map_layout)); if (shared_object_idx >= 0) { get_shared_object(shared_object_idx).add_user(this, idx); @@ -725,6 +726,9 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { } int32_t ComputeGraph::read_symint(const ValueRef idx) { + if (values_.at(idx).isInt()) { + return static_cast(values_.at(idx).toInt()); + } return get_symint(idx)->get(); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 5ce84dd705b..9935b9be51b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -573,6 +573,9 @@ class ComputeGraph final { if (value.isBool()) { return static_cast(value.toBool()); } + if (value.isSymInt()) { + return utils::safe_downcast(read_symint(idx)); + } VK_THROW("Cannot extract scalar from Value with type ", value.type()); }