From d27370cc1db892a50fb4dd8f7678bbdae38db2ac Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 10:20:53 -0500 Subject: [PATCH 01/14] add vortex-tensor crate Signed-off-by: Connor Tsui --- Cargo.lock | 9 ++ Cargo.toml | 2 + vortex-tensor/Cargo.toml | 22 ++++ vortex-tensor/src/lib.rs | 31 +++++ vortex-tensor/src/metadata.rs | 225 ++++++++++++++++++++++++++++++++++ vortex-tensor/src/vtable.rs | 43 +++++++ 6 files changed, 332 insertions(+) create mode 100644 vortex-tensor/Cargo.toml create mode 100644 vortex-tensor/src/lib.rs create mode 100644 vortex-tensor/src/metadata.rs create mode 100644 vortex-tensor/src/vtable.rs diff --git a/Cargo.lock b/Cargo.lock index e4fd2307b9f..f42d19d37c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10444,6 +10444,15 @@ dependencies = [ "vortex-duckdb", ] +[[package]] +name = "vortex-tensor" +version = "0.1.0" +dependencies = [ + "itertools 0.14.0", + "vortex-array", + "vortex-error", +] + [[package]] name = "vortex-test-e2e" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5bbc2ebab02..f5c28544599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "vortex-io", "vortex-proto", "vortex-array", + "vortex-tensor", "vortex-btrblocks", "vortex-layout", "vortex-scan", @@ -271,6 +272,7 @@ vortex-scan = { version = "0.1.0", path = "./vortex-scan", default-features = fa vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-features = false } vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } +vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml new file mode 100644 index 00000000000..cbb7665860a --- /dev/null +++ b/vortex-tensor/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "vortex-tensor" +authors = { workspace = true } +categories = { workspace = true } +description = "Vortex tensor extension type" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = "README.md" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +itertools = { workspace = true } +vortex-array = { workspace = true } +vortex-error = { workspace = true } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs new file mode 100644 index 00000000000..8a9f893e201 --- /dev/null +++ b/vortex-tensor/src/lib.rs @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tensor extension type. + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_error::VortexResult; + +mod metadata; +pub use metadata::FixedShapeTensorMetadata; + +mod vtable; + +/// The VTable for the Tensor extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct FixedShapeTensor; + +impl FixedShapeTensor { + /// Creates a new [`Tensor`] extension type. + /// + /// TODO docs. + pub fn new( + metadata: FixedShapeTensorMetadata, + dtype: DType, + ) -> VortexResult> { + // TODO verify that the dtype matches the metadata. + + ExtDType::try_new(metadata, dtype) + } +} diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs new file mode 100644 index 00000000000..780b3cf4ca4 --- /dev/null +++ b/vortex-tensor/src/metadata.rs @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; + +use itertools::Either; + +/// Metadata for a [`Tensor`] extension type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FixedShapeTensorMetadata { + /// The shape of the tensor. + /// + /// The shape is always defined over row-major storage. May be empty (0D scalar tensor) or + /// contain dimensions of size 0 (degenerate tensor). + shape: Vec, + + /// Optional names for each dimension. Each name corresponds to a dimension in the `shape`. + /// + /// If names exist, there must be an equal number of names to dimensions. + dim_names: Option>, + + /// The permutation of the tensor's dimensions, mapping each logical dimension to its + /// corresponding physical dimension: `permutation[logical] = physical`. + /// + /// If this is `None`, then the logical and physical layout are equal, and the permutation is + /// in-order `[0, 1, ..., N-1]`. + permutation: Option>, +} + +impl FixedShapeTensorMetadata { + /// Creates a new [`FixedShapeTensorMetadata`] with the given `shape`. + /// + /// The shape defines the logical dimensions in row-major order. Use + /// [`with_dim_names`][Self::with_dim_names] and [`with_permutation`][Self::with_permutation] + /// to further configure the metadata. + pub fn new(shape: Vec) -> Self { + Self { + shape, + dim_names: None, + permutation: None, + } + } + + /// Sets the dimension names for this tensor. + pub fn with_dim_names(mut self, names: Vec) -> Self { + self.dim_names = Some(names); + self + } + + /// Sets the permutation for this tensor. + pub fn with_permutation(mut self, permutation: Vec) -> Self { + self.permutation = Some(permutation); + self + } + + /// Returns the dimensions of the tensor as a slice. + pub fn dimensions(&self) -> &[usize] { + &self.shape + } + + /// Returns an iterator over the strides for each logical dimension of the tensor. + /// + /// The stride for each dimension is the number of elements to skip in the flat backing array + /// in order to move one step along that dimension. + /// + /// When a permutation is present, the physical memory is laid out in row-major order over the + /// permuted dimensions, so the strides are computed accordingly. + pub fn strides(&self) -> impl Iterator + '_ { + let ndim = self.shape.len(); + let permutation = self.permutation.as_deref(); + + match permutation { + None => Either::Left((0..ndim).map(|i| self.shape[i + 1..].iter().product::())), + Some(permutation) => { + Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation))) + } + } + } + + /// Computes the stride for logical dimension `i` given a `permutation`. + /// + /// The stride is the product of `shape[l]` for all logical dimensions `l` whose physical + /// position comes after `perm[i]`. + fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize { + let phys = perm[i]; + + // Note that this is O(n^2), but since the number of dimensions is likely low and doing this + // avoids allocations, this is usually much faster. + perm.iter() + .enumerate() + .filter(|&(_, &p)| p > phys) + .map(|(l, _)| self.shape[l]) + .product::() + } +} + +impl fmt::Display for FixedShapeTensorMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.dim_names { + Some(names) => { + write!(f, "Tensor(")?; + for (i, (dim, name)) in self.shape.iter().zip(names.iter()).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{name}: {dim}")?; + } + write!(f, ")") + } + None => { + write!(f, "Tensor(")?; + for (i, dim) in self.shape.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{dim}")?; + } + write!(f, ")") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn strides_1d() { + let m = FixedShapeTensorMetadata::new(vec![5]); + assert_eq!(m.strides().collect::>(), vec![1]); + } + + #[test] + fn strides_2d_row_major() { + // Logical shape [3, 4], no permutation. + // Physical shape is [3, 4], physical strides are [4, 1]. + let m = FixedShapeTensorMetadata::new(vec![3, 4]); + assert_eq!(m.strides().collect::>(), vec![4, 1]); + } + + #[test] + fn strides_3d_row_major() { + // Logical shape [2, 3, 4], no permutation. + // Physical shape is [2, 3, 4], physical strides are [12, 4, 1]. + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.strides().collect::>(), vec![12, 4, 1]); + } + + #[test] + fn strides_2d_transposed() { + // Logical shape [3, 4], permutation [1, 0]. + // perm[logical] = physical: logical 0 -> phys 1, logical 1 -> phys 0. + // Physical shape (arrange logical sizes by phys pos): + // phys 0 <- logical 1 (size 4) + // phys 1 <- logical 0 (size 3) + // => physical shape [4, 3], physical strides [3, 1]. + // Logical strides = phys_stride[perm[i]]: + // logical 0 -> phys 1 -> stride 1 + // logical 1 -> phys 0 -> stride 3 + let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0]); + assert_eq!(m.strides().collect::>(), vec![1, 3]); + } + + #[test] + fn strides_3d_permuted() { + // Logical shape [2, 3, 4], permutation [2, 0, 1]. + // perm[logical] = physical: logical 0 -> phys 2, logical 1 -> phys 0, logical 2 -> phys 1. + // Physical shape (arrange logical sizes by phys pos): + // phys 0 <- logical 1 (size 3) + // phys 1 <- logical 2 (size 4) + // phys 2 <- logical 0 (size 2) + // => physical shape [3, 4, 2], physical strides [8, 2, 1]. + // Logical strides = phys_stride[perm[i]]: + // logical 0 -> phys 2 -> stride 1 + // logical 1 -> phys 0 -> stride 8 + // logical 2 -> phys 1 -> stride 2 + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]); + assert_eq!(m.strides().collect::>(), vec![1, 8, 2]); + } + + #[test] + fn strides_0d_scalar() { + // A 0D tensor (scalar) has no dimensions and thus no strides. + // numel = 1 (empty product), not 0. + let m = FixedShapeTensorMetadata::new(vec![]); + assert_eq!(m.strides().collect::>(), Vec::::new()); + } + + #[test] + fn strides_zero_size_dimension() { + // Logical shape [3, 0, 4], no permutation. numel = 0. + // Physical strides are still well-defined products of trailing dimensions. + let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]); + assert_eq!(m.strides().collect::>(), vec![0, 4, 1]); + } + + #[test] + fn strides_zero_size_dimension_permuted() { + // Logical shape [3, 0, 4], permutation [1, 2, 0]. + // perm[logical] = physical: logical 0 -> phys 1, logical 1 -> phys 2, logical 2 -> phys 0. + // Physical shape (arrange logical sizes by phys pos): + // phys 0 <- logical 2 (size 4) + // phys 1 <- logical 0 (size 3) + // phys 2 <- logical 1 (size 0) + // => physical shape [4, 3, 0], physical strides [0, 0, 1]. + // Logical strides = phys_stride[perm[i]]: + // logical 0 -> phys 1 -> stride 0 + // logical 1 -> phys 2 -> stride 1 + // logical 2 -> phys 0 -> stride 0 + let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0]); + assert_eq!(m.strides().collect::>(), vec![0, 1, 0]); + } + + #[test] + fn strides_identity_permutation_matches_row_major() { + // An identity permutation should produce the same strides as no permutation. + let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + let identity = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2]); + assert_eq!( + row_major.strides().collect::>(), + identity.strides().collect::>(), + ); + } +} diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs new file mode 100644 index 00000000000..0d4c4d94d60 --- /dev/null +++ b/vortex-tensor/src/vtable.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; + +use crate::FixedShapeTensor; +use crate::FixedShapeTensorMetadata; + +impl ExtVTable for FixedShapeTensor { + type Metadata = FixedShapeTensorMetadata; + + // TODO(connor): This is just a placeholder for now!!! + type NativeValue<'a> = ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new_ref("vortex.tensor") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + todo!() + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + todo!() + } + + fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { + todo!() + } + + fn unpack_native<'a>( + &self, + metadata: &'a Self::Metadata, + storage_dtype: &'a DType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + todo!() + } +} From 18a28a5f7f4e379d052cc7e171d53add311c4e90 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 11:34:53 -0500 Subject: [PATCH 02/14] fix metadata logic Signed-off-by: Connor Tsui --- vortex-tensor/src/metadata.rs | 245 ++++++++++++++++++++++++---------- 1 file changed, 174 insertions(+), 71 deletions(-) diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index 780b3cf4ca4..8c108da9af3 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -4,38 +4,42 @@ use std::fmt; use itertools::Either; +use vortex_error::VortexExpect; -/// Metadata for a [`Tensor`] extension type. +/// Metadata for a `FixedShapeTensor` extension type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct FixedShapeTensorMetadata { - /// The shape of the tensor. + /// The logical shape of the tensor. /// - /// The shape is always defined over row-major storage. May be empty (0D scalar tensor) or - /// contain dimensions of size 0 (degenerate tensor). - shape: Vec, + /// `logical_shape[i]` is the size of the `i`-th logical dimension. When a `permutation` is + /// present, the physical shape (i.e., the row-major memory layout) is derived as + /// `physical_shape[permutation[i]] = logical_shape[i]`. + /// + /// May be empty (0D scalar tensor) or contain dimensions of size 0 (degenerate tensor). + logical_shape: Vec, - /// Optional names for each dimension. Each name corresponds to a dimension in the `shape`. + /// Optional names for each logical dimension. Each name corresponds to an entry in + /// `logical_shape`. /// - /// If names exist, there must be an equal number of names to dimensions. + /// If names exist, there must be an equal number of names to logical dimensions. dim_names: Option>, - /// The permutation of the tensor's dimensions, mapping each logical dimension to its - /// corresponding physical dimension: `permutation[logical] = physical`. + /// The permutation of the tensor's dimensions. `permutation[i]` is the physical dimension + /// index that logical dimension `i` maps to. /// - /// If this is `None`, then the logical and physical layout are equal, and the permutation is - /// in-order `[0, 1, ..., N-1]`. + /// If this is `None`, then the logical and physical layouts are identical, equivalent to + /// the identity permutation `[0, 1, ..., N-1]`. permutation: Option>, } impl FixedShapeTensorMetadata { - /// Creates a new [`FixedShapeTensorMetadata`] with the given `shape`. + /// Creates a new [`FixedShapeTensorMetadata`] with the given logical `shape`. /// - /// The shape defines the logical dimensions in row-major order. Use - /// [`with_dim_names`][Self::with_dim_names] and [`with_permutation`][Self::with_permutation] - /// to further configure the metadata. + /// Use [`with_dim_names`](Self::with_dim_names) and + /// [`with_permutation`](Self::with_permutation) to further configure the metadata. pub fn new(shape: Vec) -> Self { Self { - shape, + logical_shape: shape, dim_names: None, permutation: None, } @@ -53,24 +57,51 @@ impl FixedShapeTensorMetadata { self } - /// Returns the dimensions of the tensor as a slice. - pub fn dimensions(&self) -> &[usize] { - &self.shape + /// Returns the number of dimensions (rank) of the tensor. + pub fn ndim(&self) -> usize { + self.logical_shape.len() + } + + /// Returns the logical dimensions of the tensor as a slice. + pub fn logical_shape(&self) -> &[usize] { + &self.logical_shape + } + + /// Returns an iterator over the physical shape of the tensor. + /// + /// The physical shape describes the row-major memory layout. It is derived from the logical + /// shape by placing each logical dimension's size at its physical position: + /// `physical_shape[permutation[i]] = logical_shape[i]`. + /// + /// When no permutation is present, the physical shape is identical to the logical shape. + pub fn physical_shape(&self) -> impl Iterator + '_ { + let ndim = self.logical_shape.len(); + let permutation = self.permutation.as_deref(); + + match permutation { + None => Either::Left(self.logical_shape.iter().copied()), + Some(perm) => Either::Right( + (0..ndim).map(move |p| self.logical_shape[Self::inverse_perm(perm, p)]), + ), + } } /// Returns an iterator over the strides for each logical dimension of the tensor. /// - /// The stride for each dimension is the number of elements to skip in the flat backing array - /// in order to move one step along that dimension. + /// The stride for a logical dimension is the number of elements to skip in the flat backing + /// array in order to move one step along that logical dimension. /// /// When a permutation is present, the physical memory is laid out in row-major order over the - /// permuted dimensions, so the strides are computed accordingly. + /// physical dimensions (the logical dimensions reordered by the permutation), so the strides + /// are derived from that physical layout. pub fn strides(&self) -> impl Iterator + '_ { - let ndim = self.shape.len(); + let ndim = self.logical_shape.len(); let permutation = self.permutation.as_deref(); match permutation { - None => Either::Left((0..ndim).map(|i| self.shape[i + 1..].iter().product::())), + None => Either::Left( + (0..ndim).map(|i| self.logical_shape[i + 1..].iter().product::()), + ), Some(permutation) => { Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation))) } @@ -79,19 +110,32 @@ impl FixedShapeTensorMetadata { /// Computes the stride for logical dimension `i` given a `permutation`. /// - /// The stride is the product of `shape[l]` for all logical dimensions `l` whose physical - /// position comes after `perm[i]`. + /// The stride is the product of `logical_shape[j]` for all logical dimensions `j` whose + /// physical position (`perm[j]`) comes after the physical position of dimension `i` + /// (`perm[i]`). fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize { let phys = perm[i]; - // Note that this is O(n^2), but since the number of dimensions is likely low and doing this - // avoids allocations, this is usually much faster. + // Each call scans the full permutation, making `strides()` O(ndim^2) overall. Tensor rank + // is typically small (2–5), so avoiding a Vec allocation is a net win. perm.iter() .enumerate() .filter(|&(_, &p)| p > phys) - .map(|(l, _)| self.shape[l]) + .map(|(l, _)| self.logical_shape[l]) .product::() } + + /// Returns the logical dimension index that maps to physical position `p`. This is the + /// inverse of the permutation: if `perm[i] == p`, returns `i`. + /// + /// Each call is a linear scan of `perm`, making callers that invoke this for every physical + /// position O(ndim^2) overall. Tensor rank is typically small (2–5), so avoiding a Vec + /// allocation for the full inverse permutation is a net win. + fn inverse_perm(perm: &[usize], p: usize) -> usize { + perm.iter() + .position(|&pi| pi == p) + .vortex_expect("permutation must contain every physical position exactly once") + } } impl fmt::Display for FixedShapeTensorMetadata { @@ -99,7 +143,7 @@ impl fmt::Display for FixedShapeTensorMetadata { match &self.dim_names { Some(names) => { write!(f, "Tensor(")?; - for (i, (dim, name)) in self.shape.iter().zip(names.iter()).enumerate() { + for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() { if i > 0 { write!(f, ", ")?; } @@ -109,7 +153,7 @@ impl fmt::Display for FixedShapeTensorMetadata { } None => { write!(f, "Tensor(")?; - for (i, dim) in self.shape.iter().enumerate() { + for (i, dim) in self.logical_shape.iter().enumerate() { if i > 0 { write!(f, ", ")?; } @@ -125,6 +169,30 @@ impl fmt::Display for FixedShapeTensorMetadata { mod tests { use super::*; + /// Reference implementation that computes permuted strides in an explicit, step-by-step way. + /// + /// 1. Build the physical shape: `physical_shape[perm[i]] = logical_shape[i]`. + /// 2. Compute row-major strides over the physical shape. + /// 3. Map back to logical: `logical_stride[i] = physical_strides[perm[i]]`. + fn slow_strides(shape: &[usize], perm: &[usize]) -> Vec { + let ndim = shape.len(); + + // Derive the physical shape from the logical shape and the permutation. + let mut physical_shape = vec![0usize; ndim]; + for l in 0..ndim { + physical_shape[perm[l]] = shape[l]; + } + + // Compute row-major strides over the physical shape. + let mut physical_strides = vec![1usize; ndim]; + for i in (0..ndim.saturating_sub(1)).rev() { + physical_strides[i] = physical_strides[i + 1] * physical_shape[i + 1]; + } + + // Map physical strides back to logical dimension order. + (0..ndim).map(|l| physical_strides[perm[l]]).collect() + } + #[test] fn strides_1d() { let m = FixedShapeTensorMetadata::new(vec![5]); @@ -133,88 +201,54 @@ mod tests { #[test] fn strides_2d_row_major() { - // Logical shape [3, 4], no permutation. - // Physical shape is [3, 4], physical strides are [4, 1]. let m = FixedShapeTensorMetadata::new(vec![3, 4]); assert_eq!(m.strides().collect::>(), vec![4, 1]); } #[test] fn strides_3d_row_major() { - // Logical shape [2, 3, 4], no permutation. - // Physical shape is [2, 3, 4], physical strides are [12, 4, 1]. let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); assert_eq!(m.strides().collect::>(), vec![12, 4, 1]); } #[test] fn strides_2d_transposed() { - // Logical shape [3, 4], permutation [1, 0]. - // perm[logical] = physical: logical 0 -> phys 1, logical 1 -> phys 0. - // Physical shape (arrange logical sizes by phys pos): - // phys 0 <- logical 1 (size 4) - // phys 1 <- logical 0 (size 3) - // => physical shape [4, 3], physical strides [3, 1]. - // Logical strides = phys_stride[perm[i]]: - // logical 0 -> phys 1 -> stride 1 - // logical 1 -> phys 0 -> stride 3 + // Logical shape [3, 4] with perm [1, 0] (transpose). + // Physical shape = [4, 3], so logical strides = [1, 3]. let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0]); assert_eq!(m.strides().collect::>(), vec![1, 3]); } #[test] fn strides_3d_permuted() { - // Logical shape [2, 3, 4], permutation [2, 0, 1]. - // perm[logical] = physical: logical 0 -> phys 2, logical 1 -> phys 0, logical 2 -> phys 1. - // Physical shape (arrange logical sizes by phys pos): - // phys 0 <- logical 1 (size 3) - // phys 1 <- logical 2 (size 4) - // phys 2 <- logical 0 (size 2) - // => physical shape [3, 4, 2], physical strides [8, 2, 1]. - // Logical strides = phys_stride[perm[i]]: - // logical 0 -> phys 2 -> stride 1 - // logical 1 -> phys 0 -> stride 8 - // logical 2 -> phys 1 -> stride 2 + // Logical shape [2, 3, 4] with perm [2, 0, 1]. + // Physical shape = [3, 4, 2], so logical strides = [1, 8, 2]. let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]); assert_eq!(m.strides().collect::>(), vec![1, 8, 2]); } #[test] fn strides_0d_scalar() { - // A 0D tensor (scalar) has no dimensions and thus no strides. - // numel = 1 (empty product), not 0. let m = FixedShapeTensorMetadata::new(vec![]); assert_eq!(m.strides().collect::>(), Vec::::new()); } #[test] fn strides_zero_size_dimension() { - // Logical shape [3, 0, 4], no permutation. numel = 0. - // Physical strides are still well-defined products of trailing dimensions. let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]); assert_eq!(m.strides().collect::>(), vec![0, 4, 1]); } #[test] fn strides_zero_size_dimension_permuted() { - // Logical shape [3, 0, 4], permutation [1, 2, 0]. - // perm[logical] = physical: logical 0 -> phys 1, logical 1 -> phys 2, logical 2 -> phys 0. - // Physical shape (arrange logical sizes by phys pos): - // phys 0 <- logical 2 (size 4) - // phys 1 <- logical 0 (size 3) - // phys 2 <- logical 1 (size 0) - // => physical shape [4, 3, 0], physical strides [0, 0, 1]. - // Logical strides = phys_stride[perm[i]]: - // logical 0 -> phys 1 -> stride 0 - // logical 1 -> phys 2 -> stride 1 - // logical 2 -> phys 0 -> stride 0 + // Logical shape [3, 0, 4] with perm [1, 2, 0]. + // Physical shape = [4, 3, 0], so logical strides = [0, 1, 0]. let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0]); assert_eq!(m.strides().collect::>(), vec![0, 1, 0]); } #[test] fn strides_identity_permutation_matches_row_major() { - // An identity permutation should produce the same strides as no permutation. let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]); let identity = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2]); assert_eq!( @@ -222,4 +256,73 @@ mod tests { identity.strides().collect::>(), ); } + + #[test] + fn physical_shape_no_permutation() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.physical_shape().collect::>(), vec![2, 3, 4]); + } + + #[test] + fn physical_shape_2d_transposed() { + // Logical [3, 4] with perm [1, 0]: physical dim 0 gets logical dim 1's size (4), + // physical dim 1 gets logical dim 0's size (3). + let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0]); + assert_eq!(m.physical_shape().collect::>(), vec![4, 3]); + } + + #[test] + fn physical_shape_3d_permuted() { + // Logical [2, 3, 4] with perm [2, 0, 1]: logical 0 -> phys 2, logical 1 -> phys 0, + // logical 2 -> phys 1. So physical shape = [3, 4, 2]. + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]); + assert_eq!(m.physical_shape().collect::>(), vec![3, 4, 2]); + } + + #[test] + fn physical_shape_identity_permutation() { + let no_perm = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + let identity = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2]); + assert_eq!( + no_perm.physical_shape().collect::>(), + identity.physical_shape().collect::>(), + ); + } + + #[test] + fn physical_shape_zero_size_dimension() { + // Logical [3, 0, 4] with perm [1, 2, 0]: physical shape = [4, 3, 0]. + let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0]); + assert_eq!(m.physical_shape().collect::>(), vec![4, 3, 0]); + } + + /// Verifies that the fast `permuted_stride` matches the explicit reference `slow_strides` + /// across a variety of shapes and permutations. + #[test] + fn fast_strides_match_slow_reference() { + let cases: Vec<(Vec, Vec)> = vec![ + // 2D transpose. + (vec![3, 4], vec![1, 0]), + // 3D permutations. + (vec![2, 3, 4], vec![2, 0, 1]), + (vec![2, 3, 4], vec![1, 2, 0]), + (vec![2, 3, 4], vec![0, 2, 1]), + // 3D identity. + (vec![2, 3, 4], vec![0, 1, 2]), + // 3D with a zero-sized dimension. + (vec![3, 0, 4], vec![1, 2, 0]), + (vec![0, 3, 4], vec![2, 0, 1]), + // 4D permutations. + (vec![2, 3, 4, 5], vec![3, 2, 1, 0]), + (vec![2, 3, 4, 5], vec![1, 0, 3, 2]), + (vec![2, 3, 4, 5], vec![2, 3, 0, 1]), + ]; + + for (shape, perm) in &cases { + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone()); + let fast: Vec = m.strides().collect(); + let slow = slow_strides(shape, perm); + assert_eq!(fast, slow, "mismatch for shape={shape:?}, perm={perm:?}"); + } + } } From d175bcccee3d8ab82b86695bcff978f3a749c27c Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 12:08:08 -0500 Subject: [PATCH 03/14] use vortex import Signed-off-by: Connor Tsui --- Cargo.lock | 5 ++-- vortex-tensor/Cargo.toml | 6 ++-- vortex-tensor/src/lib.rs | 18 ------------ vortex-tensor/src/metadata.rs | 2 +- vortex-tensor/src/vtable.rs | 55 +++++++++++++++++++++++++++-------- 5 files changed, 51 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f42d19d37c1..42d4fcb4352 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10449,8 +10449,9 @@ name = "vortex-tensor" version = "0.1.0" dependencies = [ "itertools 0.14.0", - "vortex-array", - "vortex-error", + "serde", + "serde_json", + "vortex", ] [[package]] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index cbb7665860a..bdfbc2be534 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -17,6 +17,8 @@ version = { workspace = true } workspace = true [dependencies] +vortex = { workspace = true } + itertools = { workspace = true } -vortex-array = { workspace = true } -vortex-error = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 8a9f893e201..4c52ecda8f3 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -3,10 +3,6 @@ //! Tensor extension type. -use vortex_array::dtype::DType; -use vortex_array::dtype::extension::ExtDType; -use vortex_error::VortexResult; - mod metadata; pub use metadata::FixedShapeTensorMetadata; @@ -15,17 +11,3 @@ mod vtable; /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; - -impl FixedShapeTensor { - /// Creates a new [`Tensor`] extension type. - /// - /// TODO docs. - pub fn new( - metadata: FixedShapeTensorMetadata, - dtype: DType, - ) -> VortexResult> { - // TODO verify that the dtype matches the metadata. - - ExtDType::try_new(metadata, dtype) - } -} diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index 8c108da9af3..fa537dfb2e7 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -4,7 +4,7 @@ use std::fmt; use itertools::Either; -use vortex_error::VortexExpect; +use vortex::error::VortexExpect; /// Metadata for a `FixedShapeTensor` extension type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs index 0d4c4d94d60..fc9451573ed 100644 --- a/vortex-tensor/src/vtable.rs +++ b/vortex-tensor/src/vtable.rs @@ -1,11 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::dtype::DType; -use vortex_array::dtype::extension::ExtId; -use vortex_array::dtype::extension::ExtVTable; -use vortex_array::scalar::ScalarValue; -use vortex_error::VortexResult; +use vortex::dtype::DType; +use vortex::dtype::extension::ExtId; +use vortex::dtype::extension::ExtVTable; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::scalar::ScalarValue; use crate::FixedShapeTensor; use crate::FixedShapeTensorMetadata; @@ -14,30 +17,58 @@ impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; // TODO(connor): This is just a placeholder for now!!! - type NativeValue<'a> = ScalarValue; + type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { ExtId::new_ref("vortex.tensor") } - fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { todo!() } - fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { todo!() } fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { - todo!() + let DType::FixedSizeList(element_dtype, list_size, _nullability) = storage_dtype else { + vortex_bail!( + "FixedShapeTensor storage dtype must be a FixedSizeList, got {storage_dtype}" + ); + }; + + // Note that these constraints may be relaxed in the future. + vortex_ensure!( + element_dtype.is_primitive(), + "FixedShapeTensor element dtype must be primitive, got {element_dtype} \ + (may change in the future)" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "FixedShapeTensor element dtype must be non-nullable (may change in the future)" + ); + + let element_count: usize = metadata.logical_shape().iter().product(); + vortex_ensure_eq!( + element_count, + *list_size as usize, + "FixedShapeTensor logical shape product ({element_count}) does not match \ + FixedSizeList size ({list_size})" + ); + + Ok(()) } fn unpack_native<'a>( &self, - metadata: &'a Self::Metadata, - storage_dtype: &'a DType, + _metadata: &'a Self::Metadata, + _storage_dtype: &'a DType, storage_value: &'a ScalarValue, ) -> VortexResult> { - todo!() + // TODO(connor): This is just a placeholder. However, even if we have a dedicated native + // type for a singular tensor, we do not need to validate anything as any backing memory + // should be valid for a given tensor. + Ok(storage_value) } } From 43e7c6c00547acb8516cd140c28538bd7920a422 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 13:03:25 -0500 Subject: [PATCH 04/14] add proto for tensor metadata Signed-off-by: Connor Tsui --- Cargo.lock | 3 +- vortex-tensor/Cargo.toml | 3 +- vortex-tensor/src/lib.rs | 1 + vortex-tensor/src/metadata.rs | 10 +++++ vortex-tensor/src/proto.rs | 71 +++++++++++++++++++++++++++++++++++ vortex-tensor/src/vtable.rs | 62 +++++++++++++++++++++++++++--- 6 files changed, 141 insertions(+), 9 deletions(-) create mode 100644 vortex-tensor/src/proto.rs diff --git a/Cargo.lock b/Cargo.lock index 42d4fcb4352..5005e4e197d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10449,8 +10449,7 @@ name = "vortex-tensor" version = "0.1.0" dependencies = [ "itertools 0.14.0", - "serde", - "serde_json", + "prost 0.14.3", "vortex", ] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index bdfbc2be534..6d257c3e204 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -20,5 +20,4 @@ workspace = true vortex = { workspace = true } itertools = { workspace = true } -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } +prost = { workspace = true } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 4c52ecda8f3..f9dd9cebcf3 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,6 +6,7 @@ mod metadata; pub use metadata::FixedShapeTensorMetadata; +mod proto; mod vtable; /// The VTable for the Tensor extension type. diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index fa537dfb2e7..bd974f5479e 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -67,6 +67,16 @@ impl FixedShapeTensorMetadata { &self.logical_shape } + /// Returns the dimension names, if set. + pub fn dim_names(&self) -> Option<&[String]> { + self.dim_names.as_deref() + } + + /// Returns the permutation, if set. + pub fn permutation(&self) -> Option<&[usize]> { + self.permutation.as_deref() + } + /// Returns an iterator over the physical shape of the tensor. /// /// The physical shape describes the row-major memory layout. It is derived from the logical diff --git a/vortex-tensor/src/proto.rs b/vortex-tensor/src/proto.rs new file mode 100644 index 00000000000..f00e6d9e3fb --- /dev/null +++ b/vortex-tensor/src/proto.rs @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Protobuf serialization for [`FixedShapeTensorMetadata`]. + +use prost::Message; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_err; + +use crate::FixedShapeTensorMetadata; + +/// Protobuf representation of [`FixedShapeTensorMetadata`]. +#[derive(Clone, PartialEq, Message)] +struct FixedShapeTensorMetadataProto { + #[prost(uint32, repeated, tag = "1")] + logical_shape: Vec, + #[prost(string, repeated, tag = "2")] + dim_names: Vec, + #[prost(uint32, repeated, tag = "3")] + permutation: Vec, +} + +/// Serializes [`FixedShapeTensorMetadata`] to protobuf bytes. +pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { + let logical_shape = metadata + .logical_shape() + .iter() + .map(|&d| u32::try_from(d).vortex_expect("dimension size exceeds u32")) + .collect(); + + let dim_names = metadata.dim_names().map(|n| n.to_vec()).unwrap_or_default(); + + let permutation = metadata + .permutation() + .map(|p| { + p.iter() + .map(|&i| u32::try_from(i).vortex_expect("permutation index exceeds u32")) + .collect() + }) + .unwrap_or_default(); + + let proto = FixedShapeTensorMetadataProto { + logical_shape, + dim_names, + permutation, + }; + proto.encode_to_vec() +} + +/// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes. +pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { + let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; + + let logical_shape = proto + .logical_shape + .into_iter() + .map(|d| d as usize) + .collect(); + let mut m = FixedShapeTensorMetadata::new(logical_shape); + + if !proto.dim_names.is_empty() { + m = m.with_dim_names(proto.dim_names); + } + if !proto.permutation.is_empty() { + let permutation = proto.permutation.into_iter().map(|i| i as usize).collect(); + m = m.with_permutation(permutation); + } + + Ok(m) +} diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs index fc9451573ed..5d45450cb40 100644 --- a/vortex-tensor/src/vtable.rs +++ b/vortex-tensor/src/vtable.rs @@ -12,6 +12,7 @@ use vortex::scalar::ScalarValue; use crate::FixedShapeTensor; use crate::FixedShapeTensorMetadata; +use crate::proto; impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; @@ -20,15 +21,15 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new_ref("vortex.tensor") + ExtId::new_ref("vortex.fixedshapetensor") } - fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { - todo!() + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + Ok(proto::serialize(metadata)) } - fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { - todo!() + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + proto::deserialize(metadata) } fn validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &DType) -> VortexResult<()> { @@ -72,3 +73,54 @@ impl ExtVTable for FixedShapeTensor { Ok(storage_value) } } + +#[cfg(test)] +mod tests { + use vortex::dtype::extension::ExtVTable; + use vortex::error::VortexResult; + + use crate::FixedShapeTensor; + use crate::FixedShapeTensorMetadata; + + fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> { + let vtable = FixedShapeTensor; + let bytes = vtable.serialize_metadata(metadata)?; + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(&deserialized, metadata); + Ok(()) + } + + #[test] + fn roundtrip_shape_only() -> VortexResult<()> { + assert_roundtrip(&FixedShapeTensorMetadata::new(vec![2, 3, 4])) + } + + #[test] + fn roundtrip_with_permutation() -> VortexResult<()> { + assert_roundtrip( + &FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]), + ) + } + + #[test] + fn roundtrip_with_dim_names() -> VortexResult<()> { + assert_roundtrip( + &FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()]), + ) + } + + #[test] + fn roundtrip_all_fields() -> VortexResult<()> { + assert_roundtrip( + &FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) + .with_permutation(vec![1, 2, 0]), + ) + } + + #[test] + fn roundtrip_scalar_0d() -> VortexResult<()> { + assert_roundtrip(&FixedShapeTensorMetadata::new(vec![])) + } +} From 6537a8619ba65bda8e67dd0cae7113b38535c263 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 14:08:13 -0500 Subject: [PATCH 05/14] add a cosine similarity expression stub Signed-off-by: Connor Tsui --- vortex-tensor/src/cosine_similarity.rs | 86 ++++++++++++++++++++++++++ vortex-tensor/src/lib.rs | 3 + 2 files changed, 89 insertions(+) create mode 100644 vortex-tensor/src/cosine_similarity.rs diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs new file mode 100644 index 00000000000..4f9647f8c7c --- /dev/null +++ b/vortex-tensor/src/cosine_similarity.rs @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Cosine similarity expression for [`FixedShapeTensor`] arrays. + +use std::fmt::Formatter; + +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::dtype::DType; +use vortex::error::VortexResult; +use vortex::expr::Expression; +use vortex::scalar_fn::Arity; +use vortex::scalar_fn::ChildName; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ExecutionArgs; +use vortex::scalar_fn::ScalarFnId; +use vortex::scalar_fn::ScalarFnVTable; + +/// Cosine similarity between two [`FixedShapeTensor`] columns. +/// +/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor. The +/// shape and permutation do not affect the result because cosine similarity only depends on the +/// element values, not their logical arrangement. +/// +/// Both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a float +/// element type (`f32` or `f64`). The output is a primitive column of the same float type. +/// +/// [`FixedShapeTensor`]: crate::FixedShapeTensor +#[derive(Clone)] +pub struct CosineSimilarity; + +impl ScalarFnVTable for CosineSimilarity { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.cosine_similarity") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("lhs"), + 1 => ChildName::from("rhs"), + _ => unreachable!("CosineSimilarity has exactly two children"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "cosine_similarity(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { + todo!("return_dtype") + } + + fn execute( + &self, + _options: &Self::Options, + _args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + todo!("execute") + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + // TODO(connor): Is this correct since we need to canonicalize? + false + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index f9dd9cebcf3..35626b7a6bb 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,6 +6,9 @@ mod metadata; pub use metadata::FixedShapeTensorMetadata; +mod cosine_similarity; +pub use cosine_similarity::CosineSimilarity; + mod proto; mod vtable; From c43165f755ab7a492ad2f58764a8a245acbd8ec1 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 14:14:40 -0500 Subject: [PATCH 06/14] add validity reduction Signed-off-by: Connor Tsui --- vortex-tensor/src/cosine_similarity.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index 4f9647f8c7c..3f99f246af3 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -75,6 +75,18 @@ impl ScalarFnVTable for CosineSimilarity { todo!("execute") } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // The result is null if either input tensor is null. + let lhs_validity = expression.child(0).validity()?; + let rhs_validity = expression.child(1).validity()?; + + Ok(Some(vortex::expr::and(lhs_validity, rhs_validity))) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } From 4eba73e8f18581bcb584782ca6aa1221134eb330 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 14:32:17 -0500 Subject: [PATCH 07/14] implement return_dtype Signed-off-by: Connor Tsui --- vortex-tensor/src/cosine_similarity.rs | 50 ++++++++++++++++++++++++-- vortex-tensor/src/vtable.rs | 2 +- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index 3f99f246af3..8156204b229 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -8,7 +8,10 @@ use std::fmt::Formatter; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; use vortex::dtype::DType; +use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; use vortex::expr::Expression; use vortex::scalar_fn::Arity; use vortex::scalar_fn::ChildName; @@ -45,7 +48,7 @@ impl ScalarFnVTable for CosineSimilarity { match child_idx { 0 => ChildName::from("lhs"), 1 => ChildName::from("rhs"), - _ => unreachable!("CosineSimilarity has exactly two children"), + _ => unreachable!("CosineSimilarity must have exactly two children"), } } @@ -62,8 +65,49 @@ impl ScalarFnVTable for CosineSimilarity { write!(f, ")") } - fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult { - todo!("return_dtype") + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + debug_assert_eq!(arg_dtypes.len(), 2); + + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; + + // Both must have the same dtype (ignoring top-level nullability). + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + + // We don't need to look at rhs anymore since we know lhs and rhs are equal. + + // Both inputs must be extension types. + let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { + vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") + })?; + + // Extract the element dtype from the storage FixedSizeList. + let element_dtype = lhs_ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "cosine_similarity storage dtype must be a FixedSizeList, got {}", + lhs_ext.storage_dtype() + ) + })?; + + // Element dtype must be a non-nullable float primitive. + vortex_ensure!( + element_dtype.is_float(), + "cosine_similarity element dtype must be a float primitive, got {element_dtype}" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "cosine_similarity element dtype must be non-nullable" + ); + + let ptype = element_dtype.as_ptype(); + let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); + Ok(DType::Primitive(ptype, nullability)) } fn execute( diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs index 5d45450cb40..17f296cb526 100644 --- a/vortex-tensor/src/vtable.rs +++ b/vortex-tensor/src/vtable.rs @@ -21,7 +21,7 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new_ref("vortex.fixedshapetensor") + ExtId::new_ref("vortex.fixed_shape_tensor") } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { From 0c01516628b9aa0b4129988a72ef3e0a94a8acf9 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 15:09:24 -0500 Subject: [PATCH 08/14] implement execute Signed-off-by: Connor Tsui --- Cargo.lock | 1 + vortex-tensor/Cargo.toml | 1 + vortex-tensor/src/cosine_similarity.rs | 98 +++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5005e4e197d..7c71b24a18e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10449,6 +10449,7 @@ name = "vortex-tensor" version = "0.1.0" dependencies = [ "itertools 0.14.0", + "num-traits", "prost 0.14.3", "vortex", ] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 6d257c3e204..3bb47455167 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -20,4 +20,5 @@ workspace = true vortex = { workspace = true } itertools = { workspace = true } +num-traits = { workspace = true } prost = { workspace = true } diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index 8156204b229..4918519fe77 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -5,11 +5,21 @@ use std::fmt::Formatter; +use num_traits::Float; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::ToCanonical; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::ConstantVTable; +use vortex::array::arrays::ExtensionVTable; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::match_each_float_ptype; use vortex::dtype::DType; +use vortex::dtype::NativePType; use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_bail; use vortex::error::vortex_ensure; use vortex::error::vortex_err; use vortex::expr::Expression; @@ -113,10 +123,50 @@ impl ScalarFnVTable for CosineSimilarity { fn execute( &self, _options: &Self::Options, - _args: &dyn ExecutionArgs, + args: &dyn ExecutionArgs, _ctx: &mut ExecutionCtx, ) -> VortexResult { - todo!("execute") + let lhs = args.get(0)?; + let rhs = args.get(1)?; + let row_count = args.row_count(); + + // Get list size from the dtype. Both sides should have the same dtype. + let ext = lhs.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "cosine_similarity input must be an extension type, got {}", + lhs.dtype() + ) + })?; + let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { + vortex_bail!("expected FixedSizeList storage dtype"); + }; + let list_size = *list_size as usize; + + // Extract the storage array from each extension input. We pass the storage (FSL) rather + // than the extension array to avoid canonicalizing the extension wrapper. + let lhs_storage = extension_storage(&lhs)?; + let rhs_storage = extension_storage(&rhs)?; + + // Extract the flat primitive elements from each tensor column. When an input is a + // `ConstantArray` (e.g., a literal query vector), we materialize only a single row + // instead of expanding it to the full row count. + let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size); + let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size); + + match_each_float_ptype!(lhs_elems.ptype(), |T| { + let lhs_slice = lhs_elems.as_slice::(); + let rhs_slice = rhs_elems.as_slice::(); + + let result: PrimitiveArray = (0..row_count) + .map(|i| { + let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size]; + let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size]; + cosine_similarity_row(a, b) + }) + .collect(); + + Ok(result.into_array()) + }) } fn validity( @@ -140,3 +190,47 @@ impl ScalarFnVTable for CosineSimilarity { false } } + +/// Extracts the storage array from an extension array without canonicalizing. +fn extension_storage(array: &ArrayRef) -> VortexResult { + let ext = array + .as_opt::() + .ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?; + Ok(ext.storage().clone()) +} + +/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). +/// +/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is +/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` +/// where `stride` is `list_size` for a full array and `0` for a constant. +fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> (PrimitiveArray, usize) { + if let Some(constant) = storage.as_opt::() { + // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a + // huge amount of data. + let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); + let elems = single.to_fixed_size_list().elements().to_primitive(); + (elems, 0) + } else { + // Otherwise we have to fully expand all of the data. + let elems = storage.to_fixed_size_list().elements().to_primitive(); + (elems, list_size) + } +} + +// TODO(connor): We should try to use a more performant library instead of doing this ourselves. +/// Computes cosine similarity between two equal-length float slices. +/// +/// Returns `dot(a, b) / (||a|| * ||b||)`. When either vector has zero norm, this naturally +/// produces `NaN` via `0.0 / 0.0`, matching standard floating-point semantics. +fn cosine_similarity_row(a: &[T], b: &[T]) -> T { + let mut dot = T::zero(); + let mut norm_a = T::zero(); + let mut norm_b = T::zero(); + for i in 0..a.len() { + dot = dot + a[i] * b[i]; + norm_a = norm_a + a[i] * a[i]; + norm_b = norm_b + b[i] * b[i]; + } + dot / (norm_a.sqrt() * norm_b.sqrt()) +} From 57f397462af7f10bb9fabceaf18e4a93a6aa3b0e Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 15:35:43 -0500 Subject: [PATCH 09/14] add simple tests Signed-off-by: Connor Tsui --- vortex-tensor/src/cosine_similarity.rs | 242 +++++++++++++++++++++++++ 1 file changed, 242 insertions(+) diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index 4918519fe77..f3fa4fb9f23 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -234,3 +234,245 @@ fn cosine_similarity_row(a: &[T], b: &[T]) -> T { } dot / (norm_a.sqrt() * norm_b.sqrt()) } + +#[cfg(test)] +mod tests { + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::ToCanonical; + use vortex::array::arrays::ConstantArray; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::ScalarFnArray; + use vortex::array::validity::Validity; + use vortex::buffer::Buffer; + use vortex::dtype::DType; + use vortex::dtype::Nullability; + use vortex::dtype::extension::ExtDType; + use vortex::error::VortexResult; + use vortex::scalar::Scalar; + use vortex::scalar_fn::EmptyOptions; + use vortex::scalar_fn::ScalarFn; + + use crate::CosineSimilarity; + use crate::FixedShapeTensor; + use crate::FixedShapeTensorMetadata; + + /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// + /// The number of rows is inferred from the total element count divided by the product of the + /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. + fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); + let row_count = elements.len() / list_size as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. + fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let prim = result.to_primitive(); + Ok(prim.as_slice::().to_vec()) + } + + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` + /// value, with support for NaN (NaN == NaN is considered equal). + #[track_caller] + fn assert_close(actual: &[f64], expected: &[f64]) { + assert_eq!( + actual.len(), + expected.len(), + "length mismatch: got {} elements, expected {}", + actual.len(), + expected.len() + ); + + for (i, (a, e)) in actual.iter().zip(expected).enumerate() { + if a.is_nan() && e.is_nan() { + continue; + } + assert!( + (a - e).abs() < 1e-10, + "element {i}: got {a}, expected {e} (diff = {})", + (a - e).abs() + ); + } + } + + #[test] + fn unit_vectors_1d() -> VortexResult<()> { + let lhs = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // Tensor 1 + 0.0, 1.0, 0.0, // Tensor 2 + ], + )?; + let rhs = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // Tensor 1 + 1.0, 0.0, 0.0, // Tensor 2 + ], + )?; + + // Row 0: identical → 1.0, row 1: orthogonal → 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn opposite_vectors() -> VortexResult<()> { + let lhs = tensor_array(&[3], &[1.0, 0.0, 0.0])?; + let rhs = tensor_array(&[3], &[-1.0, 0.0, 0.0])?; + + // Antiparallel → -1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[-1.0]); + Ok(()) + } + + #[test] + fn non_unit_vectors() -> VortexResult<()> { + let lhs = tensor_array(&[2], &[3.0, 4.0])?; + let rhs = tensor_array(&[2], &[4.0, 3.0])?; + + // dot=24, both magnitudes=5 → 24/25 = 0.96. + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.96]); + Ok(()) + } + + #[test] + fn zero_norm_produces_nan() -> VortexResult<()> { + let lhs = tensor_array(&[2], &[0.0, 0.0])?; + let rhs = tensor_array(&[2], &[1.0, 0.0])?; + + // Zero vector → 0/0 → NaN. + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[f64::NAN]); + Ok(()) + } + + #[test] + fn matrix_2d() -> VortexResult<()> { + // Each tensor is a 2x3 matrix, flattened to 6 elements. + let lhs = tensor_array( + &[2, 3], + &[ + 1.0, 0.0, 0.0, // Row 1 + 0.0, 0.0, 0.0, // Row 2 + ], + )?; + let rhs = tensor_array( + &[2, 3], + &[ + 1.0, 0.0, 0.0, // Row 1 + 0.0, 0.0, 0.0, // Row 2 + ], + )?; + + // Identical flat buffers → 1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + Ok(()) + } + + #[test] + fn tensor_3d() -> VortexResult<()> { + // shape: [2, 2, 2] — 8 elements per tensor, all ones. + let lhs = tensor_array(&[2, 2, 2], &[1.0; 8])?; + let rhs = tensor_array(&[2, 2, 2], &[1.0; 8])?; + + // Identical → 1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + Ok(()) + } + + #[test] + fn scalar_0d() -> VortexResult<()> { + // 0-dimensional tensor: each "tensor" is a single scalar value. + let lhs = tensor_array(&[], &[5.0, 3.0])?; + let rhs = tensor_array(&[], &[5.0, -3.0])?; + + // Same sign → 1.0, opposite sign → -1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]); + Ok(()) + } + + #[test] + fn many_rows() -> VortexResult<()> { + // 5 tensors of shape [4] compared against themselves → all 1.0. + let lhs = tensor_array( + &[4], + &[ + 1.0, 2.0, 3.0, 4.0, // tensor 0 + 0.0, 1.0, 0.0, 0.0, // tensor 1 + 5.0, 0.0, 5.0, 0.0, // tensor 2 + 1.0, 1.0, 1.0, 1.0, // tensor 3 + 0.0, 0.0, 0.0, 7.0, // tensor 4 + ], + )?; + let rhs = lhs.clone(); + + assert_close( + &eval_cosine_similarity(lhs, rhs, 5)?, + &[1.0, 1.0, 1.0, 1.0, 1.0], + ); + Ok(()) + } + + /// Builds an extension array whose storage is a [`ConstantArray`], representing a single + /// query tensor broadcast to `len` rows. + fn constant_tensor_array( + shape: &[usize], + elements: &[f64], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + // Build the FSL storage scalar from individual element scalars. + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + // Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies. + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + #[test] + fn constant_query_vector() -> VortexResult<()> { + // Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0]. + let data = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // tensor 0 + 0.0, 1.0, 0.0, // tensor 1 + 0.0, 0.0, 1.0, // tensor 2 + 1.0, 0.0, 0.0, // tensor 3 + ], + )?; + let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?; + + // Only tensor 0 is aligned with the query. + assert_close( + &eval_cosine_similarity(data, query, 4)?, + &[1.0, 0.0, 0.0, 1.0], + ); + Ok(()) + } +} From 886a9045326f796cbc6fda305db9c0da7a33dc4c Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 16:36:51 -0500 Subject: [PATCH 10/14] clean up Signed-off-by: Connor Tsui --- vortex-tensor/src/metadata.rs | 18 +++++++++++++----- vortex-tensor/src/proto.rs | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index bd974f5479e..26a3639d202 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -45,15 +45,23 @@ impl FixedShapeTensorMetadata { } } - /// Sets the dimension names for this tensor. + /// Sets the dimension names for this tensor. An empty vec is normalized to `None` since a + /// 0-dimensional tensor has no dimensions to name. pub fn with_dim_names(mut self, names: Vec) -> Self { - self.dim_names = Some(names); + self.dim_names = if names.is_empty() { None } else { Some(names) }; + self } - /// Sets the permutation for this tensor. + /// Sets the permutation for this tensor. An empty vec is normalized to `None` since a + /// 0-dimensional tensor has no dimensions to permute. pub fn with_permutation(mut self, permutation: Vec) -> Self { - self.permutation = Some(permutation); + self.permutation = if permutation.is_empty() { + None + } else { + Some(permutation) + }; + self } @@ -127,7 +135,7 @@ impl FixedShapeTensorMetadata { let phys = perm[i]; // Each call scans the full permutation, making `strides()` O(ndim^2) overall. Tensor rank - // is typically small (2–5), so avoiding a Vec allocation is a net win. + // is typically small, so avoiding a Vec allocation is a net win. perm.iter() .enumerate() .filter(|&(_, &p)| p > phys) diff --git a/vortex-tensor/src/proto.rs b/vortex-tensor/src/proto.rs index f00e6d9e3fb..f10979a864d 100644 --- a/vortex-tensor/src/proto.rs +++ b/vortex-tensor/src/proto.rs @@ -11,12 +11,26 @@ use vortex::error::vortex_err; use crate::FixedShapeTensorMetadata; /// Protobuf representation of [`FixedShapeTensorMetadata`]. +/// +/// Protobuf does not distinguish between an absent repeated field and an empty one (both will +/// deserialize as an empty `Vec`). This is fine because the semantic meaning is unambiguous: +/// +/// - `logical_shape` empty: 0-dimensional (scalar) tensor. +/// - `dim_names` empty: no dimension names (`None`). +/// - `permutation` empty: no permutation, i.e., identity layout (`None`). #[derive(Clone, PartialEq, Message)] struct FixedShapeTensorMetadataProto { + /// The size of each logical dimension. Empty for a 0-dimensional scalar tensor. #[prost(uint32, repeated, tag = "1")] logical_shape: Vec, + + /// Optional human-readable names for each logical dimension. When present, must have the + /// same length as `logical_shape`. Empty means no names are set. #[prost(string, repeated, tag = "2")] dim_names: Vec, + + /// Optional dimension permutation mapping logical to physical indices. When present, must + /// be a permutation of `[0, 1, ..., N-1]`. Empty means identity (row-major) layout. #[prost(uint32, repeated, tag = "3")] permutation: Vec, } @@ -49,6 +63,9 @@ pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec { } /// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes. +/// +/// For 0-dimensional tensors, all three repeated fields are empty, which correctly produces a +/// metadata with an empty shape and no names or permutation. pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult { let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?; @@ -59,6 +76,8 @@ pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult Date: Thu, 5 Mar 2026 17:01:15 -0500 Subject: [PATCH 11/14] add validation in builders Signed-off-by: Connor Tsui --- vortex-tensor/src/metadata.rs | 127 ++++++++++++++++++++++++++-------- vortex-tensor/src/proto.rs | 4 +- vortex-tensor/src/vtable.rs | 8 +-- 3 files changed, 105 insertions(+), 34 deletions(-) diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index 26a3639d202..54ed660eff7 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -5,6 +5,9 @@ use std::fmt; use itertools::Either; use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; /// Metadata for a `FixedShapeTensor` extension type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -47,22 +50,54 @@ impl FixedShapeTensorMetadata { /// Sets the dimension names for this tensor. An empty vec is normalized to `None` since a /// 0-dimensional tensor has no dimensions to name. - pub fn with_dim_names(mut self, names: Vec) -> Self { - self.dim_names = if names.is_empty() { None } else { Some(names) }; + /// + /// The number of names must match the number of logical dimensions. + pub fn with_dim_names(mut self, names: Vec) -> VortexResult { + if !names.is_empty() { + vortex_ensure_eq!( + names.len(), + self.logical_shape.len(), + "dim_names length ({}) must match logical_shape length ({})", + names.len(), + self.logical_shape.len() + ); + self.dim_names = Some(names); + } - self + Ok(self) } /// Sets the permutation for this tensor. An empty vec is normalized to `None` since a /// 0-dimensional tensor has no dimensions to permute. - pub fn with_permutation(mut self, permutation: Vec) -> Self { - self.permutation = if permutation.is_empty() { - None - } else { - Some(permutation) - }; + /// + /// The permutation must be a valid permutation of `[0, 1, ..., N-1]` where `N` is the + /// number of logical dimensions. + pub fn with_permutation(mut self, permutation: Vec) -> VortexResult { + if !permutation.is_empty() { + vortex_ensure_eq!( + permutation.len(), + self.logical_shape.len(), + "permutation length ({}) must match logical_shape length ({})", + permutation.len(), + self.logical_shape.len() + ); + + // Verify this is actually a permutation of [0..N). + let mut seen = vec![false; permutation.len()]; + for &p in &permutation { + vortex_ensure!( + p < permutation.len(), + "permutation index {p} is out of range for {} dimensions", + permutation.len() + ); + vortex_ensure!(!seen[p], "permutation contains duplicate index {p}"); + seen[p] = true; + } + + self.permutation = Some(permutation); + } - self + Ok(self) } /// Returns the number of dimensions (rank) of the tensor. @@ -230,19 +265,21 @@ mod tests { } #[test] - fn strides_2d_transposed() { + fn strides_2d_transposed() -> VortexResult<()> { // Logical shape [3, 4] with perm [1, 0] (transpose). // Physical shape = [4, 3], so logical strides = [1, 3]. - let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0]); + let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0])?; assert_eq!(m.strides().collect::>(), vec![1, 3]); + Ok(()) } #[test] - fn strides_3d_permuted() { + fn strides_3d_permuted() -> VortexResult<()> { // Logical shape [2, 3, 4] with perm [2, 0, 1]. // Physical shape = [3, 4, 2], so logical strides = [1, 8, 2]. - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]); + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?; assert_eq!(m.strides().collect::>(), vec![1, 8, 2]); + Ok(()) } #[test] @@ -258,21 +295,24 @@ mod tests { } #[test] - fn strides_zero_size_dimension_permuted() { + fn strides_zero_size_dimension_permuted() -> VortexResult<()> { // Logical shape [3, 0, 4] with perm [1, 2, 0]. // Physical shape = [4, 3, 0], so logical strides = [0, 1, 0]. - let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0]); + let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0])?; assert_eq!(m.strides().collect::>(), vec![0, 1, 0]); + Ok(()) } #[test] - fn strides_identity_permutation_matches_row_major() { + fn strides_identity_permutation_matches_row_major() -> VortexResult<()> { let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]); - let identity = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2]); + let identity = + FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?; assert_eq!( row_major.strides().collect::>(), identity.strides().collect::>(), ); + Ok(()) } #[test] @@ -282,42 +322,47 @@ mod tests { } #[test] - fn physical_shape_2d_transposed() { + fn physical_shape_2d_transposed() -> VortexResult<()> { // Logical [3, 4] with perm [1, 0]: physical dim 0 gets logical dim 1's size (4), // physical dim 1 gets logical dim 0's size (3). - let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0]); + let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0])?; assert_eq!(m.physical_shape().collect::>(), vec![4, 3]); + Ok(()) } #[test] - fn physical_shape_3d_permuted() { + fn physical_shape_3d_permuted() -> VortexResult<()> { // Logical [2, 3, 4] with perm [2, 0, 1]: logical 0 -> phys 2, logical 1 -> phys 0, // logical 2 -> phys 1. So physical shape = [3, 4, 2]. - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]); + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?; assert_eq!(m.physical_shape().collect::>(), vec![3, 4, 2]); + Ok(()) } #[test] - fn physical_shape_identity_permutation() { + fn physical_shape_identity_permutation() -> VortexResult<()> { let no_perm = FixedShapeTensorMetadata::new(vec![2, 3, 4]); - let identity = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2]); + let identity = + FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?; assert_eq!( no_perm.physical_shape().collect::>(), identity.physical_shape().collect::>(), ); + Ok(()) } #[test] - fn physical_shape_zero_size_dimension() { + fn physical_shape_zero_size_dimension() -> VortexResult<()> { // Logical [3, 0, 4] with perm [1, 2, 0]: physical shape = [4, 3, 0]. - let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0]); + let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0])?; assert_eq!(m.physical_shape().collect::>(), vec![4, 3, 0]); + Ok(()) } /// Verifies that the fast `permuted_stride` matches the explicit reference `slow_strides` /// across a variety of shapes and permutations. #[test] - fn fast_strides_match_slow_reference() { + fn fast_strides_match_slow_reference() -> VortexResult<()> { let cases: Vec<(Vec, Vec)> = vec![ // 2D transpose. (vec![3, 4], vec![1, 0]), @@ -337,10 +382,36 @@ mod tests { ]; for (shape, perm) in &cases { - let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone()); + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; let fast: Vec = m.strides().collect(); let slow = slow_strides(shape, perm); assert_eq!(fast, slow, "mismatch for shape={shape:?}, perm={perm:?}"); } + + Ok(()) + } + + #[test] + fn dim_names_wrong_length() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]); + assert!(result.is_err()); + } + + #[test] + fn permutation_wrong_length() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0]); + assert!(result.is_err()); + } + + #[test] + fn permutation_out_of_range() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 5]); + assert!(result.is_err()); + } + + #[test] + fn permutation_duplicate_index() { + let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 0]); + assert!(result.is_err()); } } diff --git a/vortex-tensor/src/proto.rs b/vortex-tensor/src/proto.rs index f10979a864d..f454531fcca 100644 --- a/vortex-tensor/src/proto.rs +++ b/vortex-tensor/src/proto.rs @@ -79,11 +79,11 @@ pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult VortexResult<()> { assert_roundtrip( - &FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1]), + &FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?, ) } @@ -106,7 +106,7 @@ mod tests { fn roundtrip_with_dim_names() -> VortexResult<()> { assert_roundtrip( &FixedShapeTensorMetadata::new(vec![3, 4]) - .with_dim_names(vec!["rows".into(), "cols".into()]), + .with_dim_names(vec!["rows".into(), "cols".into()])?, ) } @@ -114,8 +114,8 @@ mod tests { fn roundtrip_all_fields() -> VortexResult<()> { assert_roundtrip( &FixedShapeTensorMetadata::new(vec![2, 3, 4]) - .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) - .with_permutation(vec![1, 2, 0]), + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?, ) } From b18aa81dd0cf8f640d7ff8c6a85b7bab9bc6d9c2 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 17:12:32 -0500 Subject: [PATCH 12/14] fix error handling Signed-off-by: Connor Tsui --- vortex-tensor/src/cosine_similarity.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index f3fa4fb9f23..6ff65e6c20b 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -7,9 +7,9 @@ use std::fmt::Formatter; use num_traits::Float; use vortex::array::ArrayRef; +use vortex::array::DynArray; use vortex::array::ExecutionCtx; use vortex::array::IntoArray; -use vortex::array::ToCanonical; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::ConstantVTable; use vortex::array::arrays::ExtensionVTable; @@ -150,8 +150,8 @@ impl ScalarFnVTable for CosineSimilarity { // Extract the flat primitive elements from each tensor column. When an input is a // `ConstantArray` (e.g., a literal query vector), we materialize only a single row // instead of expanding it to the full row count. - let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size); - let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size); + let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?; + let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?; match_each_float_ptype!(lhs_elems.ptype(), |T| { let lhs_slice = lhs_elems.as_slice::(); @@ -186,8 +186,8 @@ impl ScalarFnVTable for CosineSimilarity { } fn is_fallible(&self, _options: &Self::Options) -> bool { - // TODO(connor): Is this correct since we need to canonicalize? - false + // Canonicalization of the storage arrays can fail. + true } } @@ -204,17 +204,22 @@ fn extension_storage(array: &ArrayRef) -> VortexResult { /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is /// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` /// where `stride` is `list_size` for a full array and `0` for a constant. -fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> (PrimitiveArray, usize) { +fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, +) -> VortexResult<(PrimitiveArray, usize)> { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a // huge amount of data. let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let elems = single.to_fixed_size_list().elements().to_primitive(); - (elems, 0) + let fsl = single.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok((elems, 0)) } else { // Otherwise we have to fully expand all of the data. - let elems = storage.to_fixed_size_list().elements().to_primitive(); - (elems, list_size) + let fsl = storage.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok((elems, list_size)) } } From 0a5356b7534c38cbb5faf248234efba1be4b5642 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 5 Mar 2026 17:27:59 -0500 Subject: [PATCH 13/14] better tests Signed-off-by: Connor Tsui --- Cargo.lock | 1 + vortex-tensor/Cargo.toml | 5 +- vortex-tensor/src/cosine_similarity.rs | 91 ++++------- vortex-tensor/src/metadata.rs | 210 ++++++++++--------------- vortex-tensor/src/vtable.rs | 55 +++---- 5 files changed, 150 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c71b24a18e..4bd2006c54e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10451,6 +10451,7 @@ dependencies = [ "itertools 0.14.0", "num-traits", "prost 0.14.3", + "rstest", "vortex", ] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 3bb47455167..6f4fe4511af 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -8,7 +8,7 @@ homepage = { workspace = true } include = { workspace = true } keywords = { workspace = true } license = { workspace = true } -readme = "README.md" +readme = { workspace = true } repository = { workspace = true } rust-version = { workspace = true } version = { workspace = true } @@ -22,3 +22,6 @@ vortex = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } + +[dev-dependencies] +rstest = { workspace = true } diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/cosine_similarity.rs index 6ff65e6c20b..6871256ca4f 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/cosine_similarity.rs @@ -335,66 +335,43 @@ mod tests { Ok(()) } - #[test] - fn opposite_vectors() -> VortexResult<()> { - let lhs = tensor_array(&[3], &[1.0, 0.0, 0.0])?; - let rhs = tensor_array(&[3], &[-1.0, 0.0, 0.0])?; - - // Antiparallel → -1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[-1.0]); - Ok(()) - } - - #[test] - fn non_unit_vectors() -> VortexResult<()> { - let lhs = tensor_array(&[2], &[3.0, 4.0])?; - let rhs = tensor_array(&[2], &[4.0, 3.0])?; - - // dot=24, both magnitudes=5 → 24/25 = 0.96. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.96]); - Ok(()) - } - - #[test] - fn zero_norm_produces_nan() -> VortexResult<()> { - let lhs = tensor_array(&[2], &[0.0, 0.0])?; - let rhs = tensor_array(&[2], &[1.0, 0.0])?; - - // Zero vector → 0/0 → NaN. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[f64::NAN]); - Ok(()) - } - - #[test] - fn matrix_2d() -> VortexResult<()> { - // Each tensor is a 2x3 matrix, flattened to 6 elements. - let lhs = tensor_array( - &[2, 3], - &[ - 1.0, 0.0, 0.0, // Row 1 - 0.0, 0.0, 0.0, // Row 2 - ], - )?; - let rhs = tensor_array( - &[2, 3], - &[ - 1.0, 0.0, 0.0, // Row 1 - 0.0, 0.0, 0.0, // Row 2 - ], - )?; - - // Identical flat buffers → 1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + use rstest::rstest; + + /// Single-row cosine similarity for various vector pairs. + #[rstest] + // Antiparallel → -1.0. + #[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])] + // dot=24, both magnitudes=5 → 24/25 = 0.96. + #[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])] + // Zero vector → 0/0 → NaN. + #[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])] + fn single_row( + #[case] shape: &[usize], + #[case] lhs_elems: &[f64], + #[case] rhs_elems: &[f64], + #[case] expected: &[f64], + ) -> VortexResult<()> { + let lhs = tensor_array(shape, lhs_elems)?; + let rhs = tensor_array(shape, rhs_elems)?; + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, expected); Ok(()) } - #[test] - fn tensor_3d() -> VortexResult<()> { - // shape: [2, 2, 2] — 8 elements per tensor, all ones. - let lhs = tensor_array(&[2, 2, 2], &[1.0; 8])?; - let rhs = tensor_array(&[2, 2, 2], &[1.0; 8])?; - - // Identical → 1.0. + /// Self-similarity across various tensor shapes should always produce 1.0. + #[rstest] + // 2x3 matrix, flattened to 6 elements. + #[case::matrix_2d( + &[2, 3], + &[ + 1.0, 0.0, 0.0, // row 0 + 0.0, 0.0, 0.0, // row 1 + ], + )] + // 2x2x2 tensor, 8 elements. + #[case::tensor_3d(&[2, 2, 2], &[1.0; 8])] + fn self_similarity(#[case] shape: &[usize], #[case] elements: &[f64]) -> VortexResult<()> { + let lhs = tensor_array(shape, elements)?; + let rhs = tensor_array(shape, elements)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); Ok(()) } diff --git a/vortex-tensor/src/metadata.rs b/vortex-tensor/src/metadata.rs index 54ed660eff7..fb46c67d213 100644 --- a/vortex-tensor/src/metadata.rs +++ b/vortex-tensor/src/metadata.rs @@ -193,33 +193,45 @@ impl FixedShapeTensorMetadata { impl fmt::Display for FixedShapeTensorMetadata { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Tensor(")?; + match &self.dim_names { Some(names) => { - write!(f, "Tensor(")?; for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{name}: {dim}")?; } - write!(f, ")") } None => { - write!(f, "Tensor(")?; for (i, dim) in self.logical_shape.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{dim}")?; } - write!(f, ")") } } + + if let Some(perm) = &self.permutation { + for (i, p) in perm.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{p}")?; + } + write!(f, "]")?; + } + + write!(f, ")") } } #[cfg(test)] mod tests { + use rstest::rstest; + use super::*; /// Reference implementation that computes permuted strides in an explicit, step-by-step way. @@ -246,60 +258,40 @@ mod tests { (0..ndim).map(|l| physical_strides[perm[l]]).collect() } - #[test] - fn strides_1d() { - let m = FixedShapeTensorMetadata::new(vec![5]); - assert_eq!(m.strides().collect::>(), vec![1]); - } - - #[test] - fn strides_2d_row_major() { - let m = FixedShapeTensorMetadata::new(vec![3, 4]); - assert_eq!(m.strides().collect::>(), vec![4, 1]); - } - - #[test] - fn strides_3d_row_major() { - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); - assert_eq!(m.strides().collect::>(), vec![12, 4, 1]); - } - - #[test] - fn strides_2d_transposed() -> VortexResult<()> { - // Logical shape [3, 4] with perm [1, 0] (transpose). - // Physical shape = [4, 3], so logical strides = [1, 3]. - let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0])?; - assert_eq!(m.strides().collect::>(), vec![1, 3]); - Ok(()) - } - - #[test] - fn strides_3d_permuted() -> VortexResult<()> { - // Logical shape [2, 3, 4] with perm [2, 0, 1]. - // Physical shape = [3, 4, 2], so logical strides = [1, 8, 2]. - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?; - assert_eq!(m.strides().collect::>(), vec![1, 8, 2]); - Ok(()) - } - - #[test] - fn strides_0d_scalar() { - let m = FixedShapeTensorMetadata::new(vec![]); - assert_eq!(m.strides().collect::>(), Vec::::new()); - } - - #[test] - fn strides_zero_size_dimension() { - let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]); - assert_eq!(m.strides().collect::>(), vec![0, 4, 1]); - } - - #[test] - fn strides_zero_size_dimension_permuted() -> VortexResult<()> { - // Logical shape [3, 0, 4] with perm [1, 2, 0]. - // Physical shape = [4, 3, 0], so logical strides = [0, 1, 0]. - let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0])?; - assert_eq!(m.strides().collect::>(), vec![0, 1, 0]); + // -- Row-major strides (no permutation) -- + + #[rstest] + #[case::scalar_0d(vec![], vec![])] + #[case::vector_1d(vec![5], vec![1])] + #[case::matrix_2d(vec![3, 4], vec![4, 1])] + #[case::tensor_3d(vec![2, 3, 4], vec![12, 4, 1])] + #[case::zero_dim( vec![3, 0, 4], vec![0, 4, 1])] + fn strides_row_major(#[case] shape: Vec, #[case] expected: Vec) { + let m = FixedShapeTensorMetadata::new(shape); + assert_eq!(m.strides().collect::>(), expected); + } + + // -- Permuted strides -- + // + // Each case is checked against the expected value and cross-validated against the + // `slow_strides` reference implementation. + + #[rstest] + // 2D transpose: physical shape = [4, 3]. + #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![1, 3])] + // 3D: physical shape = [3, 4, 2]. + #[case::perm_3d_201( vec![2, 3, 4], vec![2, 0, 1], vec![1, 8, 2])] + // 3D with zero-sized dimension: physical shape = [4, 3, 0]. + #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![0, 1, 0])] + fn strides_permuted( + #[case] shape: Vec, + #[case] perm: Vec, + #[case] expected: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; + let actual: Vec = m.strides().collect(); + assert_eq!(actual, expected); + assert_eq!(actual, slow_strides(&shape, &perm)); Ok(()) } @@ -315,79 +307,49 @@ mod tests { Ok(()) } - #[test] - fn physical_shape_no_permutation() { - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); - assert_eq!(m.physical_shape().collect::>(), vec![2, 3, 4]); - } - - #[test] - fn physical_shape_2d_transposed() -> VortexResult<()> { - // Logical [3, 4] with perm [1, 0]: physical dim 0 gets logical dim 1's size (4), - // physical dim 1 gets logical dim 0's size (3). - let m = FixedShapeTensorMetadata::new(vec![3, 4]).with_permutation(vec![1, 0])?; - assert_eq!(m.physical_shape().collect::>(), vec![4, 3]); - Ok(()) - } - - #[test] - fn physical_shape_3d_permuted() -> VortexResult<()> { - // Logical [2, 3, 4] with perm [2, 0, 1]: logical 0 -> phys 2, logical 1 -> phys 0, - // logical 2 -> phys 1. So physical shape = [3, 4, 2]. - let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?; - assert_eq!(m.physical_shape().collect::>(), vec![3, 4, 2]); + /// Cross-validates the fast `permuted_stride` against the reference `slow_strides` across a + /// broader set of shapes and permutations. + #[rstest] + #[case::perm_3d_120(vec![2, 3, 4], vec![1, 2, 0])] + #[case::perm_3d_021(vec![2, 3, 4], vec![0, 2, 1])] + #[case::identity_3d(vec![2, 3, 4], vec![0, 1, 2])] + #[case::zero_lead( vec![0, 3, 4], vec![2, 0, 1])] + #[case::rev_4d( vec![2, 3, 4, 5], vec![3, 2, 1, 0])] + #[case::swap_4d( vec![2, 3, 4, 5], vec![1, 0, 3, 2])] + #[case::half_4d( vec![2, 3, 4, 5], vec![2, 3, 0, 1])] + fn strides_match_slow_reference( + #[case] shape: Vec, + #[case] perm: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; + assert_eq!(m.strides().collect::>(), slow_strides(&shape, &perm)); Ok(()) } - #[test] - fn physical_shape_identity_permutation() -> VortexResult<()> { - let no_perm = FixedShapeTensorMetadata::new(vec![2, 3, 4]); - let identity = - FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?; - assert_eq!( - no_perm.physical_shape().collect::>(), - identity.physical_shape().collect::>(), - ); - Ok(()) - } + // -- Physical shape -- #[test] - fn physical_shape_zero_size_dimension() -> VortexResult<()> { - // Logical [3, 0, 4] with perm [1, 2, 0]: physical shape = [4, 3, 0]. - let m = FixedShapeTensorMetadata::new(vec![3, 0, 4]).with_permutation(vec![1, 2, 0])?; - assert_eq!(m.physical_shape().collect::>(), vec![4, 3, 0]); - Ok(()) + fn physical_shape_no_permutation() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.physical_shape().collect::>(), vec![2, 3, 4]); } - /// Verifies that the fast `permuted_stride` matches the explicit reference `slow_strides` - /// across a variety of shapes and permutations. - #[test] - fn fast_strides_match_slow_reference() -> VortexResult<()> { - let cases: Vec<(Vec, Vec)> = vec![ - // 2D transpose. - (vec![3, 4], vec![1, 0]), - // 3D permutations. - (vec![2, 3, 4], vec![2, 0, 1]), - (vec![2, 3, 4], vec![1, 2, 0]), - (vec![2, 3, 4], vec![0, 2, 1]), - // 3D identity. - (vec![2, 3, 4], vec![0, 1, 2]), - // 3D with a zero-sized dimension. - (vec![3, 0, 4], vec![1, 2, 0]), - (vec![0, 3, 4], vec![2, 0, 1]), - // 4D permutations. - (vec![2, 3, 4, 5], vec![3, 2, 1, 0]), - (vec![2, 3, 4, 5], vec![1, 0, 3, 2]), - (vec![2, 3, 4, 5], vec![2, 3, 0, 1]), - ]; - - for (shape, perm) in &cases { - let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?; - let fast: Vec = m.strides().collect(); - let slow = slow_strides(shape, perm); - assert_eq!(fast, slow, "mismatch for shape={shape:?}, perm={perm:?}"); - } - + #[rstest] + // Logical [3, 4] with perm [1, 0] → physical [4, 3]. + #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![4, 3])] + // Logical [2, 3, 4] with perm [2, 0, 1] → physical [3, 4, 2]. + #[case::perm_3d( vec![2, 3, 4], vec![2, 0, 1], vec![3, 4, 2])] + // Identity: physical = logical. + #[case::identity( vec![2, 3, 4], vec![0, 1, 2], vec![2, 3, 4])] + // Logical [3, 0, 4] with perm [1, 2, 0] → physical [4, 3, 0]. + #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![4, 3, 0])] + fn physical_shape_permuted( + #[case] shape: Vec, + #[case] perm: Vec, + #[case] expected: Vec, + ) -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(shape).with_permutation(perm)?; + assert_eq!(m.physical_shape().collect::>(), expected); Ok(()) } diff --git a/vortex-tensor/src/vtable.rs b/vortex-tensor/src/vtable.rs index 70c119d2c77..ecec816516a 100644 --- a/vortex-tensor/src/vtable.rs +++ b/vortex-tensor/src/vtable.rs @@ -76,12 +76,14 @@ impl ExtVTable for FixedShapeTensor { #[cfg(test)] mod tests { + use rstest::rstest; use vortex::dtype::extension::ExtVTable; use vortex::error::VortexResult; use crate::FixedShapeTensor; use crate::FixedShapeTensorMetadata; + /// Serializes and deserializes the given metadata through protobuf, asserting equality. fn assert_roundtrip(metadata: &FixedShapeTensorMetadata) -> VortexResult<()> { let vtable = FixedShapeTensor; let bytes = vtable.serialize_metadata(metadata)?; @@ -90,37 +92,30 @@ mod tests { Ok(()) } - #[test] - fn roundtrip_shape_only() -> VortexResult<()> { - assert_roundtrip(&FixedShapeTensorMetadata::new(vec![2, 3, 4])) + #[rstest] + #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + fn roundtrip_simple(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { + assert_roundtrip(&metadata) } - #[test] - fn roundtrip_with_permutation() -> VortexResult<()> { - assert_roundtrip( - &FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![2, 0, 1])?, - ) - } - - #[test] - fn roundtrip_with_dim_names() -> VortexResult<()> { - assert_roundtrip( - &FixedShapeTensorMetadata::new(vec![3, 4]) - .with_dim_names(vec!["rows".into(), "cols".into()])?, - ) - } - - #[test] - fn roundtrip_all_fields() -> VortexResult<()> { - assert_roundtrip( - &FixedShapeTensorMetadata::new(vec![2, 3, 4]) - .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? - .with_permutation(vec![1, 2, 0])?, - ) - } - - #[test] - fn roundtrip_scalar_0d() -> VortexResult<()> { - assert_roundtrip(&FixedShapeTensorMetadata::new(vec![])) + #[rstest] + #[case::with_permutation( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_permutation(vec![2, 0, 1]) + )] + #[case::with_dim_names( + FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()]) + )] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) + .and_then(|m| m.with_permutation(vec![1, 2, 0])) + )] + fn roundtrip_with_options( + #[case] metadata: VortexResult, + ) -> VortexResult<()> { + assert_roundtrip(&metadata?) } } From 3da1787c48847c3197f406256c9d97b6b3cfd714 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 6 Mar 2026 13:19:28 -0500 Subject: [PATCH 14/14] move some stuff around Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 121 ++++++++++++++++++ vortex-tensor/src/lib.rs | 5 +- .../src/{ => scalar_fns}/cosine_similarity.rs | 4 +- vortex-tensor/src/scalar_fns/mod.rs | 4 + 4 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 vortex-tensor/public-api.lock rename vortex-tensor/src/{ => scalar_fns}/cosine_similarity.rs (99%) create mode 100644 vortex-tensor/src/scalar_fns/mod.rs diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock new file mode 100644 index 00000000000..3046784af96 --- /dev/null +++ b/vortex-tensor/public-api.lock @@ -0,0 +1,121 @@ +pub mod vortex_tensor + +pub mod vortex_tensor::scalar_fns + +pub mod vortex_tensor::scalar_fns::cosine_similarity + +pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_tensor::FixedShapeTensor + +impl core::clone::Clone for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::clone(&self) -> vortex_tensor::FixedShapeTensor + +impl core::cmp::Eq for vortex_tensor::FixedShapeTensor + +impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::eq(&self, other: &vortex_tensor::FixedShapeTensor) -> bool + +impl core::default::Default for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::default() -> vortex_tensor::FixedShapeTensor + +impl core::fmt::Debug for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::FixedShapeTensor + +pub fn vortex_tensor::FixedShapeTensor::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensor + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::FixedShapeTensor + +pub type vortex_tensor::FixedShapeTensor::Metadata = vortex_tensor::FixedShapeTensorMetadata + +pub type vortex_tensor::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_tensor::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_tensor::FixedShapeTensor::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_tensor::FixedShapeTensor::unpack_native<'a>(&self, _metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensor::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()> + +pub struct vortex_tensor::FixedShapeTensorMetadata + +impl vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::dim_names(&self) -> core::option::Option<&[alloc::string::String]> + +pub fn vortex_tensor::FixedShapeTensorMetadata::logical_shape(&self) -> &[usize] + +pub fn vortex_tensor::FixedShapeTensorMetadata::ndim(&self) -> usize + +pub fn vortex_tensor::FixedShapeTensorMetadata::new(shape: alloc::vec::Vec) -> Self + +pub fn vortex_tensor::FixedShapeTensorMetadata::permutation(&self) -> core::option::Option<&[usize]> + +pub fn vortex_tensor::FixedShapeTensorMetadata::physical_shape(&self) -> impl core::iter::traits::iterator::Iterator + '_ + +pub fn vortex_tensor::FixedShapeTensorMetadata::strides(&self) -> impl core::iter::traits::iterator::Iterator + '_ + +pub fn vortex_tensor::FixedShapeTensorMetadata::with_dim_names(self, names: alloc::vec::Vec) -> vortex_error::VortexResult + +pub fn vortex_tensor::FixedShapeTensorMetadata::with_permutation(self, permutation: alloc::vec::Vec) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::clone(&self) -> vortex_tensor::FixedShapeTensorMetadata + +impl core::cmp::Eq for vortex_tensor::FixedShapeTensorMetadata + +impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::eq(&self, other: &vortex_tensor::FixedShapeTensorMetadata) -> bool + +impl core::fmt::Debug for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::FixedShapeTensorMetadata + +pub fn vortex_tensor::FixedShapeTensorMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensorMetadata diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 35626b7a6bb..ab18826c6b6 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,12 +6,11 @@ mod metadata; pub use metadata::FixedShapeTensorMetadata; -mod cosine_similarity; -pub use cosine_similarity::CosineSimilarity; - mod proto; mod vtable; +pub mod scalar_fns; + /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; diff --git a/vortex-tensor/src/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs similarity index 99% rename from vortex-tensor/src/cosine_similarity.rs rename to vortex-tensor/src/scalar_fns/cosine_similarity.rs index 6871256ca4f..1746e6a5a75 100644 --- a/vortex-tensor/src/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Cosine similarity expression for [`FixedShapeTensor`] arrays. +//! Cosine similarity expression for [`FixedShapeTensor`](crate::FixedShapeTensor) arrays. use std::fmt::Formatter; @@ -259,9 +259,9 @@ mod tests { use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; - use crate::CosineSimilarity; use crate::FixedShapeTensor; use crate::FixedShapeTensorMetadata; + use crate::scalar_fns::cosine_similarity::CosineSimilarity; /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. /// diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs new file mode 100644 index 00000000000..2797589e03f --- /dev/null +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod cosine_similarity;