diff --git a/Cargo.lock b/Cargo.lock index e4fd2307b9f..4bd2006c54e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10444,6 +10444,17 @@ dependencies = [ "vortex-duckdb", ] +[[package]] +name = "vortex-tensor" +version = "0.1.0" +dependencies = [ + "itertools 0.14.0", + "num-traits", + "prost 0.14.3", + "rstest", + "vortex", +] + [[package]] name = "vortex-test-e2e" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5bbc2ebab02..f5c28544599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "vortex-io", "vortex-proto", "vortex-array", + "vortex-tensor", "vortex-btrblocks", "vortex-layout", "vortex-scan", @@ -271,6 +272,7 @@ vortex-scan = { version = "0.1.0", path = "./vortex-scan", default-features = fa vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-features = false } vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } +vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml new file mode 100644 index 00000000000..6f4fe4511af --- /dev/null +++ b/vortex-tensor/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "vortex-tensor" +authors = { workspace = true } +categories = { workspace = true } +description = "Vortex tensor extension type" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +vortex = { workspace = true } + +itertools = { workspace = true } +num-traits = { workspace = true } +prost = { workspace = true } + +[dev-dependencies] +rstest = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock new file mode 100644 index 00000000000..3046784af96 --- /dev/null +++ b/vortex-tensor/public-api.lock @@ -0,0 +1,121 @@ +pub mod vortex_tensor + +pub mod vortex_tensor::scalar_fns + +pub mod vortex_tensor::scalar_fns::cosine_similarity + +pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_tensor::FixedShapeTensor + +impl core::clone::Clone for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::clone(&self) -> vortex_tensor::FixedShapeTensor + +impl core::cmp::Eq for vortex_tensor::FixedShapeTensor + +impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::eq(&self, other: &vortex_tensor::FixedShapeTensor) -> bool + +impl core::default::Default for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::default() -> vortex_tensor::FixedShapeTensor + +impl core::fmt::Debug for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensor + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::FixedShapeTensor + +pub type vortex_tensor::FixedShapeTensor::Metadata = vortex_tensor::FixedShapeTensorMetadata + +pub type vortex_tensor::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_tensor::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_tensor::FixedShapeTensor::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_tensor::FixedShapeTensor::unpack_native<'a>(&self, _metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensor::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> + +pub struct vortex_tensor::FixedShapeTensorMetadata + +impl vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::dim_names(&self) -> core::option::Option<&[alloc::string::String]> + +pub fn vortex_tensor::FixedShapeTensorMetadata::logical_shape(&self) -> &[usize] + +pub fn vortex_tensor::FixedShapeTensorMetadata::ndim(&self) -> usize + +pub fn vortex_tensor::FixedShapeTensorMetadata::new(shape: alloc::vec::Vec) -> Self + +pub fn vortex_tensor::FixedShapeTensorMetadata::permutation(&self) -> core::option::Option<&[usize]> + +pub fn vortex_tensor::FixedShapeTensorMetadata::physical_shape(&self) -> impl core::iter::traits::iterator::Iterator + '_ + +pub fn vortex_tensor::FixedShapeTensorMetadata::strides(&self) -> impl core::iter::traits::iterator::Iterator + '_ + +pub fn vortex_tensor::FixedShapeTensorMetadata::with_dim_names(self, names: alloc::vec::Vec) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensorMetadata::with_permutation(self, permutation: alloc::vec::Vec) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::clone(&self) -> vortex_tensor::FixedShapeTensorMetadata + +impl core::cmp::Eq for vortex_tensor::FixedShapeTensorMetadata + +impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::eq(&self, other: &vortex_tensor::FixedShapeTensorMetadata) -> bool + +impl core::fmt::Debug for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensorMetadata diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs new file mode 100644 index 00000000000..ab18826c6b6 --- /dev/null +++ b/vortex-tensor/src/lib.rs @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tensor extension type. + +mod metadata; +pub use metadata::FixedShapeTensorMetadata; + +mod proto; +mod vtable; + +pub mod scalar_fns; + +/// The VTable for the Tensor extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct FixedShapeTensor; diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs new file mode 100644 index 00000000000..fb46c67d213 --- /dev/null +++ b/vortex-tensor/src/metadata.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; + +use itertools::Either; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; + +/// Metadata for a `FixedShapeTensor` extension type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FixedShapeTensorMetadata { + /// The logical shape of the tensor. + /// + /// `logical_shape[i]` is the size of the `i`-th logical dimension. When a `permutation` is + /// present, the physical shape (i.e., the row-major memory layout) is derived as + /// `physical_shape[permutation[i]] = logical_shape[i]`. + /// + /// May be empty (0D scalar tensor) or contain dimensions of size 0 (degenerate tensor). + logical_shape: Vec, + + /// Optional names for each logical dimension. Each name corresponds to an entry in + /// `logical_shape`. + /// + /// If names exist, there must be an equal number of names to logical dimensions. + dim_names: Option>, + + /// The permutation of the tensor's dimensions. `permutation[i]` is the physical dimension + /// index that logical dimension `i` maps to. + /// + /// If this is `None`, then the logical and physical layouts are identical, equivalent to + /// the identity permutation `[0, 1, ..., N-1]`. + permutation: Option>, +} + +impl FixedShapeTensorMetadata { + /// Creates a new [`FixedShapeTensorMetadata`] with the given logical `shape`. + /// + /// Use [`with_dim_names`](Self::with_dim_names) and + /// [`with_permutation`](Self::with_permutation) to further configure the metadata. + pub fn new(shape: Vec) -> Self { + Self { + logical_shape: shape, + dim_names: None, + permutation: None, + } + } + + /// Sets the dimension names for this tensor. An empty vec is normalized to `None` since a + /// 0-dimensional tensor has no dimensions to name. + /// + /// The number of names must match the number of logical dimensions. + pub fn with_dim_names(mut self, names: Vec) -> VortexResult { + if !names.is_empty() { + vortex_ensure_eq!( + names.len(), + self.logical_shape.len(), + "dim_names length ({}) must match logical_shape length ({})", + names.len(), + self.logical_shape.len() + ); + self.dim_names = Some(names); + } + + Ok(self) + } + + /// Sets the permutation for this tensor. An empty vec is normalized to `None` since a + /// 0-dimensional tensor has no dimensions to permute. + /// + /// The permutation must be a valid permutation of `[0, 1, ..., N-1]` where `N` is the + /// number of logical dimensions. + pub fn with_permutation(mut self, permutation: Vec) -> VortexResult { + if !permutation.is_empty() { + vortex_ensure_eq!( + permutation.len(), + self.logical_shape.len(), + "permutation length ({}) must match logical_shape length ({})", + permutation.len(), + self.logical_shape.len() + ); + + // Verify this is actually a permutation of [0..N). + let mut seen = vec![false; permutation.len()]; + for &p in &permutation { + vortex_ensure!( + p < permutation.len(), + "permutation index {p} is out of range for {} dimensions", + permutation.len() + ); + vortex_ensure!(!seen[p], "permutation contains duplicate index {p}"); + seen[p] = true; + } + + self.permutation = Some(permutation); + } + + Ok(self) + } + + /// Returns the number of dimensions (rank) of the tensor. + pub fn ndim(&self) -> usize { + self.logical_shape.len() + } + + /// Returns the logical dimensions of the tensor as a slice. + pub fn logical_shape(&self) -> &[usize] { + &self.logical_shape + } + + /// Returns the dimension names, if set. + pub fn dim_names(&self) -> Option<&[String]> { + self.dim_names.as_deref() + } + + /// Returns the permutation, if set. + pub fn permutation(&self) -> Option<&[usize]> { + self.permutation.as_deref() + } + + /// Returns an iterator over the physical shape of the tensor. + /// + /// The physical shape describes the row-major memory layout. It is derived from the logical + /// shape by placing each logical dimension's size at its physical position: + /// `physical_shape[permutation[i]] = logical_shape[i]`. + /// + /// When no permutation is present, the physical shape is identical to the logical shape. + pub fn physical_shape(&self) -> impl Iterator + '_ { + let ndim = self.logical_shape.len(); + let permutation = self.permutation.as_deref(); + + match permutation { + None => Either::Left(self.logical_shape.iter().copied()), + Some(perm) => Either::Right( + (0..ndim).map(move |p| self.logical_shape[Self::inverse_perm(perm, p)]), + ), + } + } + + /// Returns an iterator over the strides for each logical dimension of the tensor. + /// + /// The stride for a logical dimension is the number of elements to skip in the flat backing + /// array in order to move one step along that logical dimension. + /// + /// When a permutation is present, the physical memory is laid out in row-major order over the + /// physical dimensions (the logical dimensions reordered by the permutation), so the strides + /// are derived from that physical layout. + pub fn strides(&self) -> impl Iterator + '_ { + let ndim = self.logical_shape.len(); + let permutation = self.permutation.as_deref(); + + match permutation { + None => Either::Left( + (0..ndim).map(|i| self.logical_shape[i + 1..].iter().product::()), + ), + Some(permutation) => { + Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation))) + } + } + } + + /// Computes the stride for logical dimension `i` given a `permutation`. + /// + /// The stride is the product of `logical_shape[j]` for all logical dimensions `j` whose + /// physical position (`perm[j]`) comes after the physical position of dimension `i` + /// (`perm[i]`). + fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize { + let phys = perm[i]; + + // Each call scans the full permutation, making `strides()` O(ndim^2) overall. Tensor rank + // is typically small, so avoiding a Vec allocation is a net win. + perm.iter() + .enumerate() + .filter(|&(_, &p)| p > phys) + .map(|(l, _)| self.logical_shape[l]) + .product::() + } + + /// Returns the logical dimension index that maps to physical position `p`. This is the + /// inverse of the permutation: if `perm[i] == p`, returns `i`. + /// + /// Each call is a linear scan of `perm`, making callers that invoke this for every physical + /// position O(ndim^2) overall. Tensor rank is typically small (2–5), so avoiding a Vec + /// allocation for the full inverse permutation is a net win. + fn inverse_perm(perm: &[usize], p: usize) -> usize { + perm.iter() + .position(|&pi| pi == p) + .vortex_expect("permutation must contain every physical position exactly once") + } +} + +impl fmt::Display for FixedShapeTensorMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Tensor(")?; + + match &self.dim_names { + Some(names) => { + for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{name}: {dim}")?; + } + } + None => { + for (i, dim) in self.logical_shape.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{dim}")?; + } + } + } + + if let Some(perm) = &self.permutation { + for (i, p) in perm.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{p}")?; + } + write!(f, "]")?; + } + + write!(f, ")") + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + /// Reference implementation that computes permuted strides in an explicit, step-by-step way. + /// + /// 1. Build the physical shape: `physical_shape[perm[i]] = logical_shape[i]`. + /// 2. Compute row-major strides over the physical shape. + /// 3. Map back to logical: `logical_stride[i] = physical_strides[perm[i]]`. + fn slow_strides(shape: &[usize], perm: &[usize]) -> Vec { + let ndim = shape.len(); + + // Derive the physical shape from the logical shape and the permutation. + let mut physical_shape = vec![0usize; ndim]; + for l in 0..ndim { + physical_shape[perm[l]] = shape[l]; + } + + // Compute row-major strides over the physical shape. + let mut physical_strides = vec![1usize; ndim]; + for i in (0..ndim.saturating_sub(1)).rev() { + physical_strides[i] = physical_strides[i + 1] * physical_shape[i + 1]; + } + + // Map physical strides back to logical dimension order. + (0..ndim).map(|l| physical_strides[perm[l]]).collect() + } + + // -- Row-major strides (no permutation) -- + + #[rstest] + #[case::scalar_0d(vec![], vec![])] + #[case::vector_1d(vec![5], vec![1])] + #[case::matrix_2d(vec![3, 4], vec![4, 1])] + #[case::tensor_3d(vec![2, 3, 4], vec![12, 4, 1])] + #[case::zero_dim( vec![3, 0, 4], vec![0, 4, 1])] + fn strides_row_major(#[case] shape: Vec, #[case] expected: Vec) { + let m = FixedShapeTensorMetadata::new(shape); + assert_eq!(m.strides().collect::>(), expected); + } + + // -- Permuted strides -- + // + // Each case is checked against the expected value and cross-validated against the + // `slow_strides` reference implementation. + + #[rstest] + // 2D transpose: physical shape = [4, 3]. + #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![1, 3])] + // 3D: physical shape = [3, 4, 2]. + #[case::perm_3d_201( vec![2, 3, 4], vec![2, 0, 1], vec![1, 8, 2])] + // 3D with zero-sized dimension: physical shape = [4, 3, 0]. + #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![0, 1, 0])] + fn strides_permuted( + #[case] shape: Vec, + #[case] perm: Vec, + #[case] expected: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; + let actual: Vec = m.strides().collect(); + assert_eq!(actual, expected); + assert_eq!(actual, slow_strides(&shape, &perm)); + Ok(()) + } + + #[test] + fn strides_identity_permutation_matches_row_major() -> VortexResult<()> { + let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + let identity = + FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?; + assert_eq!( + row_major.strides().collect::>(), + identity.strides().collect::>(), + ); + Ok(()) + } + + /// Cross-validates the fast `permuted_stride` against the reference `slow_strides` across a + /// broader set of shapes and permutations. + #[rstest] + #[case::perm_3d_120(vec![2, 3, 4], vec![1, 2, 0])] + #[case::perm_3d_021(vec![2, 3, 4], vec![0, 2, 1])] + #[case::identity_3d(vec![2, 3, 4], vec![0, 1, 2])] + #[case::zero_lead( vec![0, 3, 4], vec![2, 0, 1])] + #[case::rev_4d( vec![2, 3, 4, 5], vec![3, 2, 1, 0])] + #[case::swap_4d( vec![2, 3, 4, 5], vec![1, 0, 3, 2])] + #[case::half_4d( vec![2, 3, 4, 5], vec![2, 3, 0, 1])] + fn strides_match_slow_reference( + #[case] shape: Vec, + #[case] perm: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; + assert_eq!(m.strides().collect::>(), slow_strides(&shape, &perm)); + Ok(()) + } + + // -- Physical shape -- + + #[test] + fn physical_shape_no_permutation() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.physical_shape().collect::>(), vec![2, 3, 4]); + } + + #[rstest] + // Logical [3, 4] with perm [1, 0] → physical [4, 3]. + #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![4, 3])] + // Logical [2, 3, 4] with perm [2, 0, 1] → physical [3, 4, 2]. + #[case::perm_3d( vec![2, 3, 4], vec![2, 0, 1], vec![3, 4, 2])] + // Identity: physical = logical. + #[case::identity( vec![2, 3, 4], vec![0, 1, 2], vec![2, 3, 4])] + // Logical [3, 0, 4] with perm [1, 2, 0] → physical [4, 3, 0]. + #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![4, 3, 0])] + fn physical_shape_permuted( + #[case] shape: Vec, + #[case] perm: Vec, + #[case] expected: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape).with_permutation(perm)?; + assert_eq!(m.physical_shape().collect::>(), expected); + Ok(()) + } + + #[test] + fn dim_names_wrong_length() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]); + assert!(result.is_err()); + } + + #[test] + fn permutation_wrong_length() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0]); + assert!(result.is_err()); + } + + #[test] + fn permutation_out_of_range() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 5]); + assert!(result.is_err()); + } + + #[test] + fn permutation_duplicate_index() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 0]); + assert!(result.is_err()); + } +} diff --git a/vortex-tensor/src/proto.rs b/vortex-tensor/src/proto.rs new file mode 100644 index 00000000000..f454531fcca --- /dev/null +++ b/vortex-tensor/src/proto.rs @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Protobuf serialization for [`FixedShapeTensorMetadata`]. + +use prost::Message; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_err; + +use crate::FixedShapeTensorMetadata; + +/// Protobuf representation of [`FixedShapeTensorMetadata`]. +/// +/// Protobuf does not distinguish between an absent repeated field and an empty one (both will +/// deserialize as an empty `Vec`). This is fine because the semantic meaning is unambiguous: +/// +/// - `logical_shape` empty: 0-dimensional (scalar) tensor. +/// - `dim_names` empty: no dimension names (`None`). +/// - `permutation` empty: no permutation, i.e., identity layout (`None`). +#[derive(Clone, PartialEq, Message)] +struct FixedShapeTensorMetadataProto { + /// The size of each logical dimension. Empty for a 0-dimensional scalar tensor. + #[prost(uint32, repeated, tag = "1")] + logical_shape: Vec, + + /// Optional human-readable names for each logical dimension. When present, must have the + /// same length as `logical_shape`. Empty means no names are set. + #[prost(string, repeated, tag = "2")] + dim_names: Vec, + + /// Optional dimension permutation mapping logical to physical indices. When present, must + /// be a permutation of `[0, 1, ..., N-1]`. Empty means identity (row-major) layout. + #[prost(uint32, repeated, tag = "3")] + permutation: Vec, +} + +/// Serializes [`FixedShapeTensorMetadata`] to protobuf bytes. +pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { + let logical_shape = metadata + .logical_shape() + .iter() + .map(|&d| u32::try_from(d).vortex_expect("dimension size exceeds u32")) + .collect(); + + let dim_names = metadata.dim_names().map(|n| n.to_vec()).unwrap_or_default(); + + let permutation = metadata + .permutation() + .map(|p| { + p.iter() + .map(|&i| u32::try_from(i).vortex_expect("permutation index exceeds u32")) + .collect() + }) + .unwrap_or_default(); + + let proto = FixedShapeTensorMetadataProto { + logical_shape, + dim_names, + permutation, + }; + proto.encode_to_vec() +} + +/// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes. +/// +/// For 0-dimensional tensors, all three repeated fields are empty, which correctly produces a +/// metadata with an empty shape and no names or permutation. +pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { + let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; + + let logical_shape = proto + .logical_shape + .into_iter() + .map(|d| d as usize) + .collect(); + let mut m = FixedShapeTensorMetadata::new(logical_shape); + + // Note that this is fine for 0 dimensions since if we do not have any dimensions, we cannot + // have any names or permutations. + if !proto.dim_names.is_empty() { + m = m.with_dim_names(proto.dim_names)?; + } + if !proto.permutation.is_empty() { + let permutation = proto.permutation.into_iter().map(|i| i as usize).collect(); + m = m.with_permutation(permutation)?; + } + + Ok(m) +} diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs new file mode 100644 index 00000000000..1746e6a5a75 --- /dev/null +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -0,0 +1,460 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Cosine similarity expression for [`FixedShapeTensor`](crate::FixedShapeTensor) arrays. + +use std::fmt::Formatter; + +use num_traits::Float; +use vortex::array::ArrayRef; +use vortex::array::DynArray; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::ConstantVTable; +use vortex::array::arrays::ExtensionVTable; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::match_each_float_ptype; +use vortex::dtype::DType; +use vortex::dtype::NativePType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::scalar_fn::Arity; +use vortex::scalar_fn::ChildName; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ExecutionArgs; +use vortex::scalar_fn::ScalarFnId; +use vortex::scalar_fn::ScalarFnVTable; + +/// Cosine similarity between two [`FixedShapeTensor`] columns. +/// +/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor. The +/// shape and permutation do not affect the result because cosine similarity only depends on the +/// element values, not their logical arrangement. +/// +/// Both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a float +/// element type (`f32` or `f64`). The output is a primitive column of the same float type. +/// +/// [`FixedShapeTensor`]: crate::FixedShapeTensor +#[derive(Clone)] +pub struct CosineSimilarity; + +impl ScalarFnVTable for CosineSimilarity { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.cosine_similarity") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("lhs"), + 1 => ChildName::from("rhs"), + _ => unreachable!("CosineSimilarity must have exactly two children"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "cosine_similarity(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + debug_assert_eq!(arg_dtypes.len(), 2); + + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; + + // Both must have the same dtype (ignoring top-level nullability). + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + + // We don't need to look at rhs anymore since we know lhs and rhs are equal. + + // Both inputs must be extension types. + let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { + vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") + })?; + + // Extract the element dtype from the storage FixedSizeList. + let element_dtype = lhs_ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "cosine_similarity storage dtype must be a FixedSizeList, got {}", + lhs_ext.storage_dtype() + ) + })?; + + // Element dtype must be a non-nullable float primitive. + vortex_ensure!( + element_dtype.is_float(), + "cosine_similarity element dtype must be a float primitive, got {element_dtype}" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "cosine_similarity element dtype must be non-nullable" + ); + + let ptype = element_dtype.as_ptype(); + let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); + Ok(DType::Primitive(ptype, nullability)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + let lhs = args.get(0)?; + let rhs = args.get(1)?; + let row_count = args.row_count(); + + // Get list size from the dtype. Both sides should have the same dtype. + let ext = lhs.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "cosine_similarity input must be an extension type, got {}", + lhs.dtype() + ) + })?; + let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { + vortex_bail!("expected FixedSizeList storage dtype"); + }; + let list_size = *list_size as usize; + + // Extract the storage array from each extension input. We pass the storage (FSL) rather + // than the extension array to avoid canonicalizing the extension wrapper. + let lhs_storage = extension_storage(&lhs)?; + let rhs_storage = extension_storage(&rhs)?; + + // Extract the flat primitive elements from each tensor column. When an input is a + // `ConstantArray` (e.g., a literal query vector), we materialize only a single row + // instead of expanding it to the full row count. + let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?; + let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?; + + match_each_float_ptype!(lhs_elems.ptype(), |T| { + let lhs_slice = lhs_elems.as_slice::(); + let rhs_slice = rhs_elems.as_slice::(); + + let result: PrimitiveArray = (0..row_count) + .map(|i| { + let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size]; + let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size]; + cosine_similarity_row(a, b) + }) + .collect(); + + Ok(result.into_array()) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // The result is null if either input tensor is null. + let lhs_validity = expression.child(0).validity()?; + let rhs_validity = expression.child(1).validity()?; + + Ok(Some(vortex::expr::and(lhs_validity, rhs_validity))) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + // Canonicalization of the storage arrays can fail. + true + } +} + +/// Extracts the storage array from an extension array without canonicalizing. +fn extension_storage(array: &ArrayRef) -> VortexResult { + let ext = array + .as_opt::() + .ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?; + Ok(ext.storage().clone()) +} + +/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). +/// +/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is +/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` +/// where `stride` is `list_size` for a full array and `0` for a constant. +fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, +) -> VortexResult<(PrimitiveArray, usize)> { + if let Some(constant) = storage.as_opt::() { + // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a + // huge amount of data. + let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); + let fsl = single.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok((elems, 0)) + } else { + // Otherwise we have to fully expand all of the data. + let fsl = storage.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok((elems, list_size)) + } +} + +// TODO(connor): We should try to use a more performant library instead of doing this ourselves. +/// Computes cosine similarity between two equal-length float slices. +/// +/// Returns `dot(a, b) / (||a|| * ||b||)`. When either vector has zero norm, this naturally +/// produces `NaN` via `0.0 / 0.0`, matching standard floating-point semantics. +fn cosine_similarity_row(a: &[T], b: &[T]) -> T { + let mut dot = T::zero(); + let mut norm_a = T::zero(); + let mut norm_b = T::zero(); + for i in 0..a.len() { + dot = dot + a[i] * b[i]; + norm_a = norm_a + a[i] * a[i]; + norm_b = norm_b + b[i] * b[i]; + } + dot / (norm_a.sqrt() * norm_b.sqrt()) +} + +#[cfg(test)] +mod tests { + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::ToCanonical; + use vortex::array::arrays::ConstantArray; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::ScalarFnArray; + use vortex::array::validity::Validity; + use vortex::buffer::Buffer; + use vortex::dtype::DType; + use vortex::dtype::Nullability; + use vortex::dtype::extension::ExtDType; + use vortex::error::VortexResult; + use vortex::scalar::Scalar; + use vortex::scalar_fn::EmptyOptions; + use vortex::scalar_fn::ScalarFn; + + use crate::FixedShapeTensor; + use crate::FixedShapeTensorMetadata; + use crate::scalar_fns::cosine_similarity::CosineSimilarity; + + /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// + /// The number of rows is inferred from the total element count divided by the product of the + /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. + fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); + let row_count = elements.len() / list_size as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. + fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let prim = result.to_primitive(); + Ok(prim.as_slice::().to_vec()) + } + + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` + /// value, with support for NaN (NaN == NaN is considered equal). + #[track_caller] + fn assert_close(actual: &[f64], expected: &[f64]) { + assert_eq!( + actual.len(), + expected.len(), + "length mismatch: got {} elements, expected {}", + actual.len(), + expected.len() + ); + + for (i, (a, e)) in actual.iter().zip(expected).enumerate() { + if a.is_nan() && e.is_nan() { + continue; + } + assert!( + (a - e).abs() < 1e-10, + "element {i}: got {a}, expected {e} (diff = {})", + (a - e).abs() + ); + } + } + + #[test] + fn unit_vectors_1d() -> VortexResult<()> { + let lhs = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // Tensor 1 + 0.0, 1.0, 0.0, // Tensor 2 + ], + )?; + let rhs = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // Tensor 1 + 1.0, 0.0, 0.0, // Tensor 2 + ], + )?; + + // Row 0: identical → 1.0, row 1: orthogonal → 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + use rstest::rstest; + + /// Single-row cosine similarity for various vector pairs. + #[rstest] + // Antiparallel → -1.0. + #[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])] + // dot=24, both magnitudes=5 → 24/25 = 0.96. + #[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])] + // Zero vector → 0/0 → NaN. + #[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])] + fn single_row( + #[case] shape: &[usize], + #[case] lhs_elems: &[f64], + #[case] rhs_elems: &[f64], + #[case] expected: &[f64], + ) -> VortexResult<()> { + let lhs = tensor_array(shape, lhs_elems)?; + let rhs = tensor_array(shape, rhs_elems)?; + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, expected); + Ok(()) + } + + /// Self-similarity across various tensor shapes should always produce 1.0. + #[rstest] + // 2x3 matrix, flattened to 6 elements. + #[case::matrix_2d( + &[2, 3], + &[ + 1.0, 0.0, 0.0, // row 0 + 0.0, 0.0, 0.0, // row 1 + ], + )] + // 2x2x2 tensor, 8 elements. + #[case::tensor_3d(&[2, 2, 2], &[1.0; 8])] + fn self_similarity(#[case] shape: &[usize], #[case] elements: &[f64]) -> VortexResult<()> { + let lhs = tensor_array(shape, elements)?; + let rhs = tensor_array(shape, elements)?; + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + Ok(()) + } + + #[test] + fn scalar_0d() -> VortexResult<()> { + // 0-dimensional tensor: each "tensor" is a single scalar value. + let lhs = tensor_array(&[], &[5.0, 3.0])?; + let rhs = tensor_array(&[], &[5.0, -3.0])?; + + // Same sign → 1.0, opposite sign → -1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]); + Ok(()) + } + + #[test] + fn many_rows() -> VortexResult<()> { + // 5 tensors of shape [4] compared against themselves → all 1.0. + let lhs = tensor_array( + &[4], + &[ + 1.0, 2.0, 3.0, 4.0, // tensor 0 + 0.0, 1.0, 0.0, 0.0, // tensor 1 + 5.0, 0.0, 5.0, 0.0, // tensor 2 + 1.0, 1.0, 1.0, 1.0, // tensor 3 + 0.0, 0.0, 0.0, 7.0, // tensor 4 + ], + )?; + let rhs = lhs.clone(); + + assert_close( + &eval_cosine_similarity(lhs, rhs, 5)?, + &[1.0, 1.0, 1.0, 1.0, 1.0], + ); + Ok(()) + } + + /// Builds an extension array whose storage is a [`ConstantArray`], representing a single + /// query tensor broadcast to `len` rows. + fn constant_tensor_array( + shape: &[usize], + elements: &[f64], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + // Build the FSL storage scalar from individual element scalars. + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + // Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies. + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + #[test] + fn constant_query_vector() -> VortexResult<()> { + // Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0]. + let data = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // tensor 0 + 0.0, 1.0, 0.0, // tensor 1 + 0.0, 0.0, 1.0, // tensor 2 + 1.0, 0.0, 0.0, // tensor 3 + ], + )?; + let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?; + + // Only tensor 0 is aligned with the query. + assert_close( + &eval_cosine_similarity(data, query, 4)?, + &[1.0, 0.0, 0.0, 1.0], + ); + Ok(()) + } +} diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs new file mode 100644 index 00000000000..2797589e03f --- /dev/null +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod cosine_similarity; diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs new file mode 100644 index 00000000000..ecec816516a --- /dev/null +++ b/vortex-tensor/src/vtable.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::dtype::DType; +use vortex::dtype::extension::ExtId; +use vortex::dtype::extension::ExtVTable; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::scalar::ScalarValue; + +use crate::FixedShapeTensor; +use crate::FixedShapeTensorMetadata; +use crate::proto; + +impl ExtVTable for FixedShapeTensor { + type Metadata = FixedShapeTensorMetadata; + + // TODO(connor): This is just a placeholder for now!!! + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new_ref("vortex.fixed_shape_tensor") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + Ok(proto::serialize(metadata)) + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + proto::deserialize(metadata) + } + + fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else { + vortex_bail!( + "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}" + ); + }; + + // Note that these constraints may be relaxed in the future. + vortex_ensure!( + element_dtype.is_primitive(), + "FixedShapeTensor element dtype must be primitive, got {element_dtype} \ + (may change in the future)" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "FixedShapeTensor element dtype must be non-nullable (may change in the future)" + ); + + let element_count: usize = metadata.logical_shape().iter().product(); + vortex_ensure_eq!( + element_count, + *list_size as usize, + "FixedShapeTensor logical shape product ({element_count}) does not match \ + FixedSizeList size ({list_size})" + ); + + Ok(()) + } + + fn unpack_native<'a>( + &self, + _metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + // TODO(connor): This is just a placeholder. However, even if we have a dedicated native + // type for a singular tensor, we do not need to validate anything as any backing memory + // should be valid for a given tensor. + Ok(storage_value) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex::dtype::extension::ExtVTable; + use vortex::error::VortexResult; + + use crate::FixedShapeTensor; + use crate::FixedShapeTensorMetadata; + + /// Serializes and deserializes the given metadata through protobuf, asserting equality. + fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> { + let vtable = FixedShapeTensor; + let bytes = vtable.serialize_metadata(metadata)?; + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(&deserialized, metadata); + Ok(()) + } + + #[rstest] + #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { + assert_roundtrip(&metadata) + } + + #[rstest] + #[case::with_permutation( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_permutation(vec![2, 0, 1]) + )] + #[case::with_dim_names( + FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()]) + )] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) + .and_then(|m| m.with_permutation(vec![1, 2, 0])) + )] + fn roundtrip_with_options( + #[case] metadata: VortexResult, + ) -> VortexResult<()> { + assert_roundtrip(&metadata?) + } +}