diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 35cc261e4d43..2f1f89d2e559 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -526,11 +526,39 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_method("vm.builtin.shape_of", [](Tensor data) -> ffi::Shape { return data.Shape(); }) + .def_method("vm.builtin.shape_of", + [](ffi::Any any) -> ffi::Shape { + if (auto opt_tensor = any.try_cast()) { + return opt_tensor.value().Shape(); + } else if (auto opt_dltensor = any.try_cast()) { + DLTensor* ptr = opt_dltensor.value(); + return ffi::Shape(ptr->shape, ptr->shape + ptr->ndim); + } else { + TVM_FFI_THROW(TypeError) + << "vm.builtin.shape_of expects a Tensor or DLTensor*, but get " + << any.GetTypeKey(); + } + }) .def("vm.builtin.copy", [](ffi::Any a) -> ffi::Any { return a; }) - .def( - "vm.builtin.reshape", - [](Tensor data, ffi::Shape new_shape) { return data.CreateView(new_shape, data->dtype); }) + .def("vm.builtin.reshape", + [](ffi::Any any, ffi::Shape new_shape) { + if (auto opt_tensor = any.try_cast()) { + Tensor data = opt_tensor.value(); + return data.CreateView(new_shape, data->dtype); + } else if (auto opt_dltensor = any.try_cast()) { + DLTensor* ptr = opt_dltensor.value(); + auto tmp = std::make_unique(); + tmp->dl_tensor = *ptr; + tmp->manager_ctx = nullptr; + tmp->deleter = nullptr; + Tensor data = Tensor::FromDLPack(tmp.release()); + return data.CreateView(new_shape, data->dtype); + } else { + TVM_FFI_THROW(TypeError) + << "vm.builtin.reshape expects a Tensor or DLTensor*, but get " + << any.GetTypeKey(); + } + }) .def("vm.builtin.null_value", []() -> std::nullptr_t { return nullptr; }) .def_packed("vm.builtin.to_device", [](ffi::PackedArgs args, ffi::Any* rv) { Tensor data = args[0].cast();