diff --git a/vortex-array/src/arrays/struct_/array.rs b/vortex-array/src/arrays/struct_/array.rs index 4cbeb38bbbb..798d9a8507b 100644 --- a/vortex-array/src/arrays/struct_/array.rs +++ b/vortex-array/src/arrays/struct_/array.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use std::iter::once; +use std::ops::Not; use std::sync::Arc; use vortex_dtype::DType; @@ -18,10 +19,23 @@ use vortex_error::vortex_err; use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::builtins::ArrayBuiltins; +use crate::compute::mask; use crate::stats::ArrayStats; use crate::validity::Validity; use crate::vtable::ValidityHelper; +/// Metadata for StructArray serialization. +#[derive(Clone, prost::Message)] +pub struct StructMetadata { + /// If true, child validity is a superset of struct validity (validity was pushed down). + /// For nullable children, their validity already includes struct nulls. For non-nullable + /// children, we apply struct validity on field read. If false (default), no guarantee + /// about relationship - must intersect validities on read. + #[prost(bool, tag = "1", default = false)] + pub(super) validity_pushed_down: bool, +} + /// A struct array that stores multiple named fields as columns, similar to a database row. /// /// This mirrors the Apache Arrow Struct array encoding and provides a columnar representation @@ -147,6 +161,9 @@ pub struct StructArray { pub(super) fields: Arc<[ArrayRef]>, pub(super) validity: Validity, pub(super) stats_set: ArrayStats, + /// true = child validity is a superset of struct validity (validity was pushed down) + /// false = default, no guarantee about relationship + pub(super) validity_pushed_down: bool, } pub struct StructArrayParts { @@ -157,12 +174,78 @@ pub struct StructArrayParts { } impl StructArray { - /// Return the struct fields without the validity of the struct applied + /// Note this field may not have the validity of the parent struct applied. + /// Should use `masked_fields` instead, unless you know what you are doing. pub fn unmasked_fields(&self) -> &Arc<[ArrayRef]> { &self.fields } - /// Return the struct field without the validity of the struct applied + pub fn masked_fields(&self) -> VortexResult> { + if !self.dtype.is_nullable() { + // fields need not be masked + return Ok(self.fields.to_vec()); + } + + if self.has_validity_pushed_down() { + self.fields + .iter() + .cloned() + .map(|f| { + if f.dtype().is_nullable() { + Ok(f.into_array()) + } else { + let validity = self.validity().to_array(self.len); + f.mask(validity) + } + }) + .collect::>>() + } else { + // Apply struct validity to all fields + let struct_validity = self.validity().to_array(self.len); + self.fields + .iter() + .map(move |f| f.clone().mask(struct_validity.clone())) + .collect::>>() + } + } + + /// Return the struct field with name `name` with the struct validity already applied. + /// If the struct has no field with that `name` an error is returned. + pub fn field_by_name(&self, name: impl AsRef) -> VortexResult { + let name = name.as_ref(); + self.field_by_name_opt(name)?.ok_or_else(|| { + vortex_err!( + "Field {name} not found in struct array with names {:?}", + self.names() + ) + }) + } + + /// Return the struct field with name `name` with the struct validity already applied. + /// If the struct has no field with that `name` Ok(None) is returned. + pub fn field_by_name_opt(&self, name: impl AsRef) -> VortexResult> { + let name = name.as_ref(); + self.struct_fields() + .find(name) + .map(|idx| { + let field = self.fields[idx].clone(); + // Non-nullable struct: return field as-is (no struct validity to apply) + if !self.dtype.is_nullable() { + return Ok(field); + } + // Non-nullable field: always apply struct validity (even with validity_pushed_down, + // since we can't push validity to non-nullable fields) + if !field.dtype().is_nullable() { + return field.mask(self.validity().to_array(self.len)); + } + // Nullable field: return as-is (validity is either in the field or was pushed down) + Ok(field) + }) + .transpose() + } + + /// Note this field may not have the validity of the parent struct applied. + /// Should use `field_by_name` instead. pub fn unmasked_field_by_name(&self, name: impl AsRef) -> VortexResult<&ArrayRef> { let name = name.as_ref(); self.unmasked_field_by_name_opt(name).ok_or_else(|| { @@ -173,7 +256,8 @@ impl StructArray { }) } - /// Return the struct field without the validity of the struct applied + /// Note this field may not have the validity of the parent struct applied. + /// Should use `field_by_name_opt` instead. pub fn unmasked_field_by_name_opt(&self, name: impl AsRef) -> Option<&ArrayRef> { let name = name.as_ref(); self.struct_fields().find(name).map(|idx| &self.fields[idx]) @@ -286,6 +370,7 @@ impl StructArray { fields, validity, stats_set: Default::default(), + validity_pushed_down: false, } } @@ -473,4 +558,137 @@ impl StructArray { Self::try_new_with_dtype(children, new_fields, self.len, self.validity.clone()) } + + /// Returns whether validity has been pushed down into children. + /// + /// When true, child validity is a superset of struct validity (children include + /// the struct's nulls baked in). This is an optimization that allows readers to + /// skip combining struct+child validity when extracting fields. + pub fn has_validity_pushed_down(&self) -> bool { + #[cfg(debug_assertions)] + if self.validity_pushed_down { + self.validate_validity_pushed_down() + .vortex_expect("validity_pushed_down invariant violated"); + } + self.validity_pushed_down + } + + /// Checks that the validity_pushed_down invariant holds. + /// + /// When `validity_pushed_down` is true, for every nullable child field, + /// the child's validity must be a superset of the struct's validity. + /// That is, wherever the struct is invalid (null), the child must also be invalid. + fn validate_validity_pushed_down(&self) -> VortexResult<()> { + if !self.validity_pushed_down { + return Ok(()); + } + + let struct_validity = self.validity_mask()?; + + for (idx, field) in self.fields.iter().enumerate() { + // Only check nullable children - non-nullable children cannot have validity pushed down + if !field.dtype().is_nullable() { + continue; + } + + let child_validity = field.validity_mask()?; + + // Check invariant: struct_invalid => child_invalid + // Equivalently: (!struct_validity) & child_validity should be all-false + // If struct is invalid (false) but child is valid (true), that's a violation + let violation = &(!&struct_validity) & &child_validity; + if !violation.all_false() { + vortex_bail!( + "validity_pushed_down invariant violated for field {}: \ + struct has nulls at positions where child is valid", + idx + ); + } + } + + Ok(()) + } + + /// Set the validity_pushed_down flag. + /// + /// For non-nullable structs, this is a no-op (flag stays false) since there's no validity to + /// push down + /// + /// For nullable structs, setting this to true indicates that child validity + /// is a superset of struct validity (children include struct's nulls). + /// + /// # Safety + /// + /// If set all non-nullable field must have their nullability be a superset of the struct + /// validity + pub unsafe fn with_validity_pushed_down(mut self, validity_pushed_down: bool) -> Self { + // For non-nullable structs, the flag is meaningless - keep it false + if !self.dtype.is_nullable() { + return self; + } + self.validity_pushed_down = validity_pushed_down; + + #[cfg(debug_assertions)] + if validity_pushed_down { + self.validate_validity_pushed_down() + .vortex_expect("validity_pushed_down invariant violated"); + } + + self + } + + /// Push struct validity down into each child field. + /// + /// For nullable structs with non-trivial validity, this applies the validity + /// mask to each **nullable** child field, making child validity a superset + /// of parent validity. Non-nullable children are left unchanged to preserve + /// their dtype. + /// + /// The struct validity is **preserved** (DType never changes). The + /// `validity_pushed_down` flag indicates that nullable children already include + /// the parent's nulls, so readers can skip combining validities for those fields. + /// + /// For non-nullable structs or trivial validity, this is essentially a no-op. + pub fn compact(&self) -> VortexResult { + // For non-nullable structs, nothing to push down + if !self.dtype.is_nullable() { + return Ok(self.clone()); + } + + // If validity is trivial (AllValid), nothing to push down + // but mark as pushed down since children trivially include parent validity + if self.validity.all_valid(self.len)? { + // # Safety no validity to push down + return Ok(unsafe { self.clone().with_validity_pushed_down(true) }); + } + + // Get the validity mask - mask() expects true = set to null, so we invert the validity + let validity_mask = self.validity_mask()?.not(); + + // Apply mask only to nullable children - non-nullable children cannot have their + // dtype changed, so we leave them alone + let new_fields: Vec = self + .unmasked_fields() + .iter() + .map(|field| { + if field.dtype().is_nullable() { + mask(field.as_ref(), &validity_mask) + } else { + Ok(field.clone()) + } + }) + .collect::>()?; + + // Create new struct with same validity but updated children + // # Safety mask of struct validity applied to each child. + Ok(unsafe { + StructArray::try_new( + self.names().clone(), + new_fields, + self.len(), + self.validity.clone(), // Keep original validity + )? + .with_validity_pushed_down(true) + }) + } } diff --git a/vortex-array/src/arrays/struct_/mod.rs b/vortex-array/src/arrays/struct_/mod.rs index 85123d8053a..f7ff8c07b1e 100644 --- a/vortex-array/src/arrays/struct_/mod.rs +++ b/vortex-array/src/arrays/struct_/mod.rs @@ -4,6 +4,7 @@ mod array; pub use array::StructArray; pub use array::StructArrayParts; +pub use array::StructMetadata; mod compute; mod vtable; diff --git a/vortex-array/src/arrays/struct_/tests.rs b/vortex-array/src/arrays/struct_/tests.rs index 71cba8261f0..8e78f581d7b 100644 --- a/vortex-array/src/arrays/struct_/tests.rs +++ b/vortex-array/src/arrays/struct_/tests.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::ops::Not; + +use rstest::rstest; use vortex_buffer::buffer; use vortex_dtype::DType; use vortex_dtype::FieldName; @@ -150,3 +153,96 @@ fn test_uncompressed_size_in_bytes() -> VortexResult<()> { assert_eq!(uncompressed_size, Some(4000)); Ok(()) } + +// Field validity must include struct null positions for pushed_down invariant. +// STRUCT: nulls at positions 1, 3 +// FIELD: nulls at positions 1, 2, 3 (superset of struct nulls) +const STRUCT_VALIDITY: [bool; 4] = [true, false, true, false]; +const FIELD_VALIDITY: [bool; 4] = [true, false, false, false]; + +fn field_validity(nullable: bool) -> Validity { + if nullable { + Validity::Array(BoolArray::from_iter(FIELD_VALIDITY).into_array()) + } else { + Validity::NonNullable + } +} + +fn struct_validity(nullable: bool) -> Validity { + if nullable { + Validity::Array(BoolArray::from_iter(STRUCT_VALIDITY).into_array()) + } else { + Validity::NonNullable + } +} + +#[rstest] +#[case::both_non_nullable_struct_non_nullable(false, false, false)] +#[case::both_non_nullable_struct_nullable(false, false, true)] +#[case::field_a_nullable_struct_non_nullable(true, false, false)] +#[case::field_a_nullable_struct_nullable(true, false, true)] +#[case::field_b_nullable_struct_non_nullable(false, true, false)] +#[case::field_b_nullable_struct_nullable(false, true, true)] +#[case::both_nullable_struct_non_nullable(true, true, false)] +#[case::both_nullable_struct_nullable(true, true, true)] +fn test_masked_fields( + #[case] field_a_nullable: bool, + #[case] field_b_nullable: bool, + #[case] struct_nullable: bool, + #[values(false, true)] should_compact: bool, +) -> VortexResult<()> { + let field_a_val = field_validity(field_a_nullable); + let field_b_val = field_validity(field_b_nullable); + let struct_val = struct_validity(struct_nullable); + let field_a = PrimitiveArray::new(buffer![10i32, 20, 30, 40], field_a_val.clone()); + let field_b = PrimitiveArray::new(buffer![100i32, 200, 300, 400], field_b_val.clone()); + + let mut struct_array = StructArray::try_new( + FieldNames::from(["a", "b"]), + vec![field_a.into_array(), field_b.into_array()], + 4, + struct_val.clone(), + )?; + + let before_dtype = struct_array.dtype().clone(); + + if should_compact { + struct_array = struct_array.compact()?; + } + + assert_eq!(struct_array.dtype(), &before_dtype); + if struct_val.nullability().is_nullable() { + assert_eq!(struct_array.has_validity_pushed_down(), should_compact); + } + + assert_eq!(struct_array.validity()?, struct_val); + + let combined_a = + if field_a_val.nullability().is_nullable() || struct_val.nullability().is_nullable() { + field_a_val.mask(&struct_val.to_mask(struct_array.len()).not()) + } else { + Validity::NonNullable + }; + + let combined_b = + if field_b_val.nullability().is_nullable() || struct_val.nullability().is_nullable() { + field_b_val.mask(&struct_val.to_mask(struct_array.len()).not()) + } else { + Validity::NonNullable + }; + + // Test masked_fields + let masked = struct_array.masked_fields()?; + assert_eq!(masked.len(), 2); + assert_eq!(masked[0].validity()?, combined_a); + assert_eq!(masked[1].validity()?, combined_b); + + // Test field_by_name + let field_a_by_name = struct_array.field_by_name("a")?; + assert_eq!(field_a_by_name.validity()?, combined_a); + + let field_b_by_name = struct_array.field_by_name("b")?; + assert_eq!(field_b_by_name.validity()?, combined_b); + + Ok(()) +} diff --git a/vortex-array/src/arrays/struct_/vtable/mod.rs b/vortex-array/src/arrays/struct_/vtable/mod.rs index abc95c2d518..4b9928776ba 100644 --- a/vortex-array/src/arrays/struct_/vtable/mod.rs +++ b/vortex-array/src/arrays/struct_/vtable/mod.rs @@ -13,10 +13,13 @@ use vortex_error::vortex_ensure; use crate::ArrayRef; use crate::Canonical; -use crate::EmptyMetadata; +use crate::DeserializeMetadata; use crate::ExecutionCtx; use crate::IntoArray; +use crate::ProstMetadata; +use crate::SerializeMetadata; use crate::arrays::struct_::StructArray; +use crate::arrays::struct_::StructMetadata; use crate::arrays::struct_::vtable::rules::PARENT_RULES; use crate::buffer::BufferHandle; use crate::serde::ArrayChildren; @@ -40,7 +43,7 @@ vtable!(Struct); impl VTable for StructVTable { type Array = StructArray; - type Metadata = EmptyMetadata; + type Metadata = ProstMetadata; type ArrayVTable = Self; type OperationsVTable = Self; @@ -52,22 +55,25 @@ impl VTable for StructVTable { Self::ID } - fn metadata(_array: &StructArray) -> VortexResult { - Ok(EmptyMetadata) + fn metadata(array: &StructArray) -> VortexResult { + Ok(ProstMetadata(StructMetadata { + validity_pushed_down: array.validity_pushed_down, + })) } - fn serialize(_metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(vec![])) + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) } - fn deserialize(_buffer: &[u8]) -> VortexResult { - Ok(EmptyMetadata) + fn deserialize(buffer: &[u8]) -> VortexResult { + let metadata = ::deserialize(buffer)?; + Ok(ProstMetadata(metadata)) } fn build( dtype: &DType, len: usize, - _metadata: &Self::Metadata, + metadata: &Self::Metadata, _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { @@ -99,7 +105,14 @@ impl VTable for StructVTable { }) .try_collect()?; - StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity) + // # Safety + // + // The file was written with fields pushed down, if the file was corrupted this + // will result in invalid results, but no unsoundness. + Ok(unsafe { + StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity)? + .with_validity_pushed_down(metadata.validity_pushed_down) + }) } fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { diff --git a/vortex-array/src/compute/zip.rs b/vortex-array/src/compute/zip.rs index 49ad2b19a1d..e6f27e7539c 100644 --- a/vortex-array/src/compute/zip.rs +++ b/vortex-array/src/compute/zip.rs @@ -347,7 +347,7 @@ mod tests { let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap(); insta::assert_snapshot!(wrapped_result.display_tree(), @r" root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%) - metadata: EmptyMetadata + metadata: StructMetadata { validity_pushed_down: false } nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid] metadata: EmptyMetadata buffer (align=1): 29 B (1.75%) diff --git a/vortex-array/src/expr/exprs/get_item.rs b/vortex-array/src/expr/exprs/get_item.rs index e72b8036a0c..c3ced54a70e 100644 --- a/vortex-array/src/expr/exprs/get_item.rs +++ b/vortex-array/src/expr/exprs/get_item.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Formatter; -use std::ops::Not; use prost::Message; use vortex_dtype::DType; @@ -16,7 +15,6 @@ use vortex_proto::expr as pb; use crate::arrays::StructArray; use crate::builtins::ExprBuiltins; -use crate::compute::mask; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -111,13 +109,7 @@ impl VTable for GetItem { .pop() .vortex_expect("missing input for GetItem expression") .execute::(args.ctx)?; - let field = input.unmasked_field_by_name(field_name).cloned()?; - - match input.dtype().nullability() { - Nullability::NonNullable => Ok(field), - Nullability::Nullable => mask(&field, &input.validity_mask()?.not()), - }? - .execute(args.ctx) + input.field_by_name(field_name)?.execute(args.ctx) } fn reduce( diff --git a/vortex-layout/src/layouts/compact.rs b/vortex-layout/src/layouts/compact.rs index 6f55174f391..9f6becbb26c 100644 --- a/vortex-layout/src/layouts/compact.rs +++ b/vortex-layout/src/layouts/compact.rs @@ -123,19 +123,27 @@ impl CompactCompressor { .into_array() } Canonical::Struct(struct_array) => { + // Push validity into children before compressing + let compacted = struct_array.compact()?; + // recurse - let fields = struct_array + let fields = compacted .unmasked_fields() .iter() .map(|field| self.compress(field)) .collect::>>()?; - StructArray::try_new( - struct_array.names().clone(), - fields, - struct_array.len(), - struct_array.validity().clone(), - )? + // # Safety + // compressing the fields must not affect the validity those fields. + unsafe { + StructArray::try_new( + compacted.names().clone(), + fields, + compacted.len(), + compacted.validity().clone(), + )? + .with_validity_pushed_down(compacted.has_validity_pushed_down()) + } .into_array() } Canonical::List(listview) => { diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index 1e256272e6b..8bfe8ea14ed 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -166,11 +166,7 @@ impl StructReader { let mut partitioned = partition( expr.clone(), self.dtype(), - annotate_scope_access( - self.dtype() - .as_struct_fields_opt() - .vortex_expect("We know it's a struct DType"), - ), + annotate_scope_access(self.dtype().as_struct_fields()), ) .vortex_expect("We should not fail to partition expression over struct fields"); diff --git a/vortex-layout/src/layouts/table.rs b/vortex-layout/src/layouts/table.rs index 13ad7f8f51d..4a7f758b828 100644 --- a/vortex-layout/src/layouts/table.rs +++ b/vortex-layout/src/layouts/table.rs @@ -246,17 +246,19 @@ impl LayoutStrategy for TableStrategy { let columns_vec_stream = stream.map(move |chunk| { let (sequence_id, chunk) = chunk?; let mut sequence_pointer = sequence_id.descend(); - let struct_chunk = chunk.to_struct(); + // Push struct validity into children before decomposing + let compacted = chunk.to_struct().compact()?; let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new(); + // Validity column is still written (DType unchanged, struct still nullable) if is_nullable { columns.push(( sequence_pointer.advance(), - chunk.validity_mask()?.into_array(), + compacted.validity_mask()?.into_array(), )); } columns.extend( - struct_chunk + compacted .unmasked_fields() .iter() .map(|field| (sequence_pointer.advance(), field.to_array())),