diff --git a/crates/core/common/src/context/exec.rs b/crates/core/common/src/context/exec.rs index 83be9ed8c..15f5ac658 100644 --- a/crates/core/common/src/context/exec.rs +++ b/crates/core/common/src/context/exec.rs @@ -1,12 +1,10 @@ use std::{ collections::{BTreeMap, BTreeSet, HashMap}, - pin::Pin, sync::{Arc, LazyLock}, - task::{Context, Poll}, }; use amp_data_store::DataStore; -use arrow::{array::ArrayRef, compute::concat_batches, datatypes::SchemaRef}; +use arrow::{array::ArrayRef, compute::concat_batches}; use datafusion::{ self, arrow::array::RecordBatch, @@ -15,11 +13,8 @@ use datafusion::{ datasource::{DefaultTableSource, TableType}, error::DataFusionError, execution::{ - RecordBatchStream, SendableRecordBatchStream, TaskContext, - cache::cache_manager::CacheManager, - config::SessionConfig, - disk_manager::DiskManager, - memory_pool::{MemoryPool, human_readable_size}, + SendableRecordBatchStream, TaskContext, cache::cache_manager::CacheManager, + config::SessionConfig, disk_manager::DiskManager, memory_pool::MemoryPool, object_store::ObjectStoreRegistry, }, logical_expr::{LogicalPlan, ScalarUDF, TableScan, expr::ScalarFunction}, @@ -32,7 +27,7 @@ use datafusion_tracing::{ InstrumentationOptions, instrument_with_info_spans, pretty_format_compact_batch, }; use datasets_common::network_id::NetworkId; -use futures::{Stream, TryStreamExt, stream}; +use futures::{TryStreamExt, stream}; use js_runtime::isolate_pool::IsolatePool; use regex::Regex; use tracing::field; @@ -90,6 +85,11 @@ impl ExecContext { &self.isolate_pool } + /// Returns the tiered memory pool for this query context. + pub fn memory_pool(&self) -> &Arc { + &self.tiered_memory_pool + } + /// Attaches a detached logical plan to this query context in a single /// traversal that, for each plan node: /// @@ -174,10 +174,7 @@ impl ExecContext { .await .map_err(ExecutePlanError::Execute)?; - Ok(PeakMemoryStream::wrap( - result, - self.tiered_memory_pool.clone(), - )) + Ok(result) } /// This will load the result set entirely in memory, so it should be used with caution. @@ -952,48 +949,6 @@ fn print_physical_plan(plan: &dyn ExecutionPlan) -> String { sanitize_parquet_paths(&plan_str) } -/// A stream wrapper that logs peak memory usage when dropped. -/// -/// Because `execute_plan` returns a lazy `SendableRecordBatchStream`, memory is only -/// allocated when the stream is consumed. This wrapper defers the peak memory log to -/// when the stream is dropped (i.e., after consumption or cancellation). -struct PeakMemoryStream { - inner: SendableRecordBatchStream, - pool: Arc, -} - -impl PeakMemoryStream { - fn wrap( - inner: SendableRecordBatchStream, - pool: Arc, - ) -> SendableRecordBatchStream { - Box::pin(Self { inner, pool }) - } -} - -impl Drop for PeakMemoryStream { - fn drop(&mut self) { - tracing::debug!( - peak_memory_mb = human_readable_size(self.pool.peak_reserved()), - "Query memory usage" - ); - } -} - -impl Stream for PeakMemoryStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.as_mut().poll_next(cx) - } -} - -impl RecordBatchStream for PeakMemoryStream { - fn schema(&self) -> SchemaRef { - self.inner.schema() - } -} - /// Creates an instrumentation rule that captures metrics and provides previews of data during execution. pub fn create_instrumentation_rule() -> Arc { let options_builder = InstrumentationOptions::builder() diff --git a/crates/services/server/src/flight.rs b/crates/services/server/src/flight.rs index fa1eaa458..3c7640130 100644 --- a/crates/services/server/src/flight.rs +++ b/crates/services/server/src/flight.rs @@ -47,6 +47,7 @@ use common::{ datasets_cache::{DatasetsCache, GetDatasetError}, detached_logical_plan::{AttachPlanError, DetachedLogicalPlan}, exec_env::ExecEnv, + memory_pool::TieredMemoryPool, plan_visitors::plan_has_block_num_udf, sql::{ResolveFunctionReferencesError, ResolveTableReferencesError, resolve_table_references}, sql_str::SqlStr, @@ -223,6 +224,7 @@ impl Service { ); } + let memory_pool = ctx.memory_pool().clone(); let record_batches = ctx .execute_plan(plan, true) .await @@ -241,6 +243,7 @@ impl Service { metrics, query_start_time, dataset_labels, + Some(memory_pool), )) } else { Ok(stream) @@ -302,6 +305,7 @@ impl Service { metrics, query_start_time, dataset_labels, + None, )) } else { Ok(stream) @@ -678,6 +682,7 @@ fn track_query_metrics( metrics: &Arc, start_time: std::time::Instant, dataset_labels: Vec, + memory_pool: Option>, ) -> QueryResultStream { let metrics = metrics.clone(); @@ -709,7 +714,7 @@ fn track_query_metrics( let err_msg = e.to_string(); for dataset in &dataset_labels { metrics.record_query_error(&err_msg, dataset); - metrics.record_query_execution(duration, total_rows, total_bytes, dataset); + metrics.record_query_execution(duration, total_rows, total_bytes, dataset, memory_pool.as_ref()); } yield Err(e); @@ -721,7 +726,7 @@ fn track_query_metrics( // Stream completed successfully, record metrics once per dataset let duration = start_time.elapsed().as_millis() as f64; for dataset in &dataset_labels { - metrics.record_query_execution(duration, total_rows, total_bytes, dataset); + metrics.record_query_execution(duration, total_rows, total_bytes, dataset, memory_pool.as_ref()); } }; diff --git a/crates/services/server/src/metrics.rs b/crates/services/server/src/metrics.rs index 437cadf8f..a7fb14e32 100644 --- a/crates/services/server/src/metrics.rs +++ b/crates/services/server/src/metrics.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use common::memory_pool::TieredMemoryPool; +use datafusion::execution::memory_pool::human_readable_size; use datasets_common::hash_reference::HashReference; use monitoring::telemetry; @@ -135,6 +139,7 @@ impl MetricsRegistry { rows_returned: u64, bytes_egress: u64, dataset: &HashReference, + memory_pool: Option<&Arc>, ) { let labels = dataset_kvs(dataset); self.query_count.inc_with_kvs(&labels); @@ -144,6 +149,14 @@ impl MetricsRegistry { .inc_by_with_kvs(rows_returned, &labels); self.query_bytes_egress .inc_by_with_kvs(bytes_egress, &labels); + if let Some(pool) = memory_pool { + let peak = pool.peak_reserved() as u64; + self.query_memory_peak_bytes.record(peak); + tracing::debug!( + peak_memory = human_readable_size(peak as usize), + "Query memory usage" + ); + } } /// Record query error @@ -156,11 +169,6 @@ impl MetricsRegistry { self.query_errors.inc_with_kvs(&labels); } - /// Record query memory usage - pub fn record_query_memory(&self, peak_bytes: u64) { - self.query_memory_peak_bytes.record(peak_bytes); - } - /// Record streaming microbatch size and throughput pub fn record_streaming_batch( &self,