diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index d4eeb9b1dd4..3b18915eae5 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { ComputeGraph* compute_graph = static_cast(handle); const size_t num_inputs = compute_graph->inputs().size(); + const size_t num_outputs = compute_graph->outputs().size(); bool should_propagate_resize = false; #ifdef ET_EVENT_TRACER_ENABLED runtime::EventTracer* event_tracer = context.event_tracer(); @@ -770,14 +771,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { "ETVK_COPY_OUTPUTS", /* delegate_debug_id = */ -1); #endif // ET_EVENT_TRACER_ENABLED - for (size_t i = 0; i < compute_graph->outputs().size(); i++) { - const size_t o = i + num_inputs; + const size_t output_offset = args.size() - num_outputs; + for (size_t i = 0; i < num_outputs; i++) { + const size_t o = output_offset + i; const ValueRef oref = compute_graph->outputs()[i].value; if (compute_graph->val_is_tensor(oref)) { VK_CHECK_COND(args[o]->isTensor()); maybe_resize_output(compute_graph, i, args[o]->toTensor()); - // args holds inputs directly followed by outputs, so the i'th output - // for compute_graph corresponds to the o'th arg compute_graph->maybe_cast_and_copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index bb2df30a174..1ca56b152ea 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -452,14 +452,15 @@ ValueRef ComputeGraph::add_tensor( const utils::AxisMapLayout axis_map_layout) { ValueRef idx(static_cast(values_.size())); check_no_active_value_ptrs(); - values_.emplace_back(api::vTensor( - context(), - sizes, - dtype, - storage_type, - memory_layout, - false, - axis_map_layout)); + values_.emplace_back( + api::vTensor( + context(), + sizes, + dtype, + storage_type, + memory_layout, + false, + axis_map_layout)); if (shared_object_idx >= 0) { get_shared_object(shared_object_idx).add_user(this, idx); @@ -725,6 +726,9 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { } int32_t ComputeGraph::read_symint(const ValueRef idx) { + if (values_.at(idx).isInt()) { + return static_cast(values_.at(idx).toInt()); + } return get_symint(idx)->get(); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 5ce84dd705b..9935b9be51b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -573,6 +573,9 @@ class ComputeGraph final { if (value.isBool()) { return static_cast(value.toBool()); } + if (value.isSymInt()) { + return utils::safe_downcast(read_symint(idx)); + } VK_THROW("Cannot extract scalar from Value with type ", value.type()); }