diff --git a/vortex-row/src/codec.rs b/vortex-row/src/codec.rs index 8468301e5b3..37295536d1d 100644 --- a/vortex-row/src/codec.rs +++ b/vortex-row/src/codec.rs @@ -45,6 +45,7 @@ use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; use vortex_array::dtype::half::f16; use vortex_array::match_each_native_ptype; +use vortex_array::validity::Validity; use vortex_buffer::ByteBufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -77,6 +78,33 @@ const fn encoded_size_for_fixed(value_bytes: u32) -> u32 { 1 + value_bytes } +/// Pre-resolved per-row validity for the row encoders. +/// +/// Encoders pattern-match on this once before their inner loop so the +/// no-nulls fast path avoids per-row `mask.value(i)` branches entirely, +/// and the nullable path holds the materialized mask exactly once. +pub(crate) enum ValidityKind { + /// Column statically has no nulls (`Validity::NonNullable` or `AllValid`); no mask + /// allocation needed. + AllValid, + /// Column may have nulls; the materialized per-row mask is included. + Mask(vortex_mask::Mask), +} + +/// Resolve a [`Validity`] into a [`ValidityKind`], materializing the mask only when +/// the column may actually have nulls. +#[inline] +pub(crate) fn resolve_validity( + validity: Validity, + len: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + Ok(match validity { + Validity::NonNullable | Validity::AllValid => ValidityKind::AllValid, + other => ValidityKind::Mask(other.execute_mask(len, ctx)?), + }) +} + /// Per-row width classification for a column. /// /// `Fixed(w)` means every row encodes to exactly `w` bytes (sentinel + value), regardless @@ -245,15 +273,21 @@ fn add_size_varbinview( sizes: &mut [u32], ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mask = arr.as_ref().validity()?.execute_mask(arr.len(), ctx)?; let views = arr.views(); - for (i, view) in views.iter().enumerate() { - let valid = mask.value(i); - if !valid { - sizes[i] += 1; // sentinel only - } else { - let len = view.len() as usize; - sizes[i] += encoded_size_for_varlen(len); + match resolve_validity(arr.as_ref().validity()?, arr.len(), ctx)? { + ValidityKind::AllValid => { + for (i, view) in views.iter().enumerate() { + sizes[i] += encoded_size_for_varlen(view.len() as usize); + } + } + ValidityKind::Mask(mask) => { + for (i, view) in views.iter().enumerate() { + if mask.value(i) { + sizes[i] += encoded_size_for_varlen(view.len() as usize); + } else { + sizes[i] += 1; // sentinel only + } + } } } Ok(()) @@ -336,23 +370,35 @@ fn encode_bool( out: &mut [u8], ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mask = arr.as_ref().validity()?.execute_mask(arr.len(), ctx)?; let bits = arr.clone().into_bit_buffer(); let non_null = field.non_null_sentinel(); - let null = field.null_sentinel(); let xor = if field.descending { 0xFF } else { 0x00 }; - for i in 0..bits.len() { - let pos = (row_offsets[i] + col_offset[i]) as usize; - if mask.value(i) { - out[pos] = non_null; - // false=0x01, true=0x02 so false < true; XOR for descending - let raw = if bits.value(i) { 0x02u8 } else { 0x01u8 }; - out[pos + 1] = raw ^ xor; - } else { - out[pos] = null; - out[pos + 1] = 0; + match resolve_validity(arr.as_ref().validity()?, arr.len(), ctx)? { + ValidityKind::AllValid => { + for i in 0..bits.len() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + out[pos] = non_null; + let raw = if bits.value(i) { 0x02u8 } else { 0x01u8 }; + out[pos + 1] = raw ^ xor; + col_offset[i] += BOOL_ENCODED_SIZE; + } + } + ValidityKind::Mask(mask) => { + let null = field.null_sentinel(); + for i in 0..bits.len() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + if mask.value(i) { + out[pos] = non_null; + // false=0x01, true=0x02 so false < true; XOR for descending + let raw = if bits.value(i) { 0x02u8 } else { 0x01u8 }; + out[pos + 1] = raw ^ xor; + } else { + out[pos] = null; + out[pos + 1] = 0; + } + col_offset[i] += BOOL_ENCODED_SIZE; + } } - col_offset[i] += BOOL_ENCODED_SIZE; } Ok(()) } @@ -379,24 +425,35 @@ fn encode_primitive_typed( out: &mut [u8], ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mask = arr.as_ref().validity()?.execute_mask(arr.len(), ctx)?; let slice: &[T] = arr.as_slice(); let non_null = field.non_null_sentinel(); - let null = field.null_sentinel(); let value_bytes = size_of::(); - for (i, &v) in slice.iter().enumerate() { - let pos = (row_offsets[i] + col_offset[i]) as usize; - if mask.value(i) { - out[pos] = non_null; - v.encode_to(&mut out[pos + 1..pos + 1 + value_bytes], field.descending); - } else { - out[pos] = null; - // Zero-fill the value bytes. - for b in &mut out[pos + 1..pos + 1 + value_bytes] { - *b = 0; + let stride = encoded_size_for_fixed(value_bytes as u32); + match resolve_validity(arr.as_ref().validity()?, arr.len(), ctx)? { + ValidityKind::AllValid => { + for (i, &v) in slice.iter().enumerate() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + out[pos] = non_null; + v.encode_to(&mut out[pos + 1..pos + 1 + value_bytes], field.descending); + col_offset[i] += stride; + } + } + ValidityKind::Mask(mask) => { + let null = field.null_sentinel(); + for (i, &v) in slice.iter().enumerate() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + if mask.value(i) { + out[pos] = non_null; + v.encode_to(&mut out[pos + 1..pos + 1 + value_bytes], field.descending); + } else { + out[pos] = null; + for b in &mut out[pos + 1..pos + 1 + value_bytes] { + *b = 0; + } + } + col_offset[i] += stride; } } - col_offset[i] += encoded_size_for_fixed(value_bytes as u32); } Ok(()) } @@ -471,24 +528,38 @@ fn encode_varbinview( out: &mut [u8], ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mask = arr.as_ref().validity()?.execute_mask(arr.len(), ctx)?; let non_null = field.non_null_sentinel(); - let null = field.null_sentinel(); - - arr.with_iterator(|iter| { - for (i, maybe) in iter.enumerate() { - let pos = (row_offsets[i] + col_offset[i]) as usize; - if !mask.value(i) { - out[pos] = null; - col_offset[i] += 1; - continue; - } - let bytes: &[u8] = maybe.unwrap_or(&[]); - out[pos] = non_null; - let written = encode_varlen_value(bytes, &mut out[pos + 1..], field.descending); - col_offset[i] += 1 + written; + let descending = field.descending; + match resolve_validity(arr.as_ref().validity()?, arr.len(), ctx)? { + ValidityKind::AllValid => { + arr.with_iterator(|iter| { + for (i, maybe) in iter.enumerate() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + let bytes: &[u8] = maybe.unwrap_or(&[]); + out[pos] = non_null; + let written = encode_varlen_value(bytes, &mut out[pos + 1..], descending); + col_offset[i] += 1 + written; + } + }); } - }); + ValidityKind::Mask(mask) => { + let null = field.null_sentinel(); + arr.with_iterator(|iter| { + for (i, maybe) in iter.enumerate() { + let pos = (row_offsets[i] + col_offset[i]) as usize; + if !mask.value(i) { + out[pos] = null; + col_offset[i] += 1; + continue; + } + let bytes: &[u8] = maybe.unwrap_or(&[]); + out[pos] = non_null; + let written = encode_varlen_value(bytes, &mut out[pos + 1..], descending); + col_offset[i] += 1 + written; + } + }); + } + } Ok(()) }