diff --git a/include/xtensor/misc/xmanipulation.hpp b/include/xtensor/misc/xmanipulation.hpp index 0118329bc..e20095f3e 100644 --- a/include/xtensor/misc/xmanipulation.hpp +++ b/include/xtensor/misc/xmanipulation.hpp @@ -954,12 +954,13 @@ namespace xt inline auto roll(E&& e, std::ptrdiff_t shift) { auto cpy = empty_like(e); - auto flat_size = std::accumulate( - cpy.shape().begin(), - cpy.shape().end(), - 1L, - std::multiplies() - ); + + if (cpy.size() == 0) + { + return cpy; + } + + const auto flat_size = static_cast(cpy.size()); while (shift < 0) { shift += flat_size; @@ -1059,9 +1060,14 @@ namespace xt XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension."); } - std::size_t saxis = normalize_axis(dim, axis); + if (cpy.size() == 0) + { + return cpy; + } + std::size_t saxis = normalize_axis(dim, axis); const auto axis_dim = static_cast(shape[saxis]); + while (shift < 0) { shift += axis_dim; diff --git a/test/test_xmanipulation.cpp b/test/test_xmanipulation.cpp index f105a5a7d..bbc2663a2 100644 --- a/test/test_xmanipulation.cpp +++ b/test/test_xmanipulation.cpp @@ -514,6 +514,15 @@ namespace xt xarray expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}}; ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3)); + + xarray empty_1d = xt::xarray::from_shape({0}); + EXPECT_EQ(xt::roll(empty_1d, 5).shape(), empty_1d.shape()); + + xarray partial_empty = xt::xarray::from_shape({3, 0}); + EXPECT_EQ(xt::roll(partial_empty, 1, 1).shape(), partial_empty.shape()); + + xarray mixed_empty = xt::xarray::from_shape({3, 0, 4}); + EXPECT_EQ(xt::roll(mixed_empty, 1, 0).shape(), mixed_empty.shape()); } TEST(xmanipulation, repeat_all_elements_of_axis_0_of_int_array_2_times)