Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 10 additions & 55 deletions crates/core/common/src/context/exec.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
pin::Pin,
sync::{Arc, LazyLock},
task::{Context, Poll},
};

use amp_data_store::DataStore;
use arrow::{array::ArrayRef, compute::concat_batches, datatypes::SchemaRef};
use arrow::{array::ArrayRef, compute::concat_batches};
use datafusion::{
self,
arrow::array::RecordBatch,
Expand All @@ -15,11 +13,8 @@ use datafusion::{
datasource::{DefaultTableSource, TableType},
error::DataFusionError,
execution::{
RecordBatchStream, SendableRecordBatchStream, TaskContext,
cache::cache_manager::CacheManager,
config::SessionConfig,
disk_manager::DiskManager,
memory_pool::{MemoryPool, human_readable_size},
SendableRecordBatchStream, TaskContext, cache::cache_manager::CacheManager,
config::SessionConfig, disk_manager::DiskManager, memory_pool::MemoryPool,
object_store::ObjectStoreRegistry,
},
logical_expr::{LogicalPlan, ScalarUDF, TableScan, expr::ScalarFunction},
Expand All @@ -32,7 +27,7 @@ use datafusion_tracing::{
InstrumentationOptions, instrument_with_info_spans, pretty_format_compact_batch,
};
use datasets_common::network_id::NetworkId;
use futures::{Stream, TryStreamExt, stream};
use futures::{TryStreamExt, stream};
use js_runtime::isolate_pool::IsolatePool;
use regex::Regex;
use tracing::field;
Expand Down Expand Up @@ -90,6 +85,11 @@ impl ExecContext {
&self.isolate_pool
}

/// Returns the tiered memory pool for this query context.
pub fn memory_pool(&self) -> &Arc<TieredMemoryPool> {
&self.tiered_memory_pool
}

/// Attaches a detached logical plan to this query context in a single
/// traversal that, for each plan node:
///
Expand Down Expand Up @@ -174,10 +174,7 @@ impl ExecContext {
.await
.map_err(ExecutePlanError::Execute)?;

Ok(PeakMemoryStream::wrap(
result,
self.tiered_memory_pool.clone(),
))
Ok(result)
}

/// This will load the result set entirely in memory, so it should be used with caution.
Expand Down Expand Up @@ -952,48 +949,6 @@ fn print_physical_plan(plan: &dyn ExecutionPlan) -> String {
sanitize_parquet_paths(&plan_str)
}

/// A stream wrapper that logs peak memory usage when dropped.
///
/// Because `execute_plan` returns a lazy `SendableRecordBatchStream`, memory is only
/// allocated when the stream is consumed. This wrapper defers the peak memory log to
/// when the stream is dropped (i.e., after consumption or cancellation).
struct PeakMemoryStream {
inner: SendableRecordBatchStream,
pool: Arc<TieredMemoryPool>,
}

impl PeakMemoryStream {
fn wrap(
inner: SendableRecordBatchStream,
pool: Arc<TieredMemoryPool>,
) -> SendableRecordBatchStream {
Box::pin(Self { inner, pool })
}
}

impl Drop for PeakMemoryStream {
fn drop(&mut self) {
tracing::debug!(
peak_memory_mb = human_readable_size(self.pool.peak_reserved()),
"Query memory usage"
);
}
}

impl Stream for PeakMemoryStream {
type Item = Result<RecordBatch, DataFusionError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.as_mut().poll_next(cx)
}
}

impl RecordBatchStream for PeakMemoryStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}

/// Creates an instrumentation rule that captures metrics and provides previews of data during execution.
pub fn create_instrumentation_rule() -> Arc<dyn PhysicalOptimizerRule + Send + Sync> {
let options_builder = InstrumentationOptions::builder()
Expand Down
9 changes: 7 additions & 2 deletions crates/services/server/src/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use common::{
datasets_cache::{DatasetsCache, GetDatasetError},
detached_logical_plan::{AttachPlanError, DetachedLogicalPlan},
exec_env::ExecEnv,
memory_pool::TieredMemoryPool,
plan_visitors::plan_has_block_num_udf,
sql::{ResolveFunctionReferencesError, ResolveTableReferencesError, resolve_table_references},
sql_str::SqlStr,
Expand Down Expand Up @@ -223,6 +224,7 @@ impl Service {
);
}

let memory_pool = ctx.memory_pool().clone();
let record_batches = ctx
.execute_plan(plan, true)
.await
Expand All @@ -241,6 +243,7 @@ impl Service {
metrics,
query_start_time,
dataset_labels,
Some(memory_pool),
))
} else {
Ok(stream)
Expand Down Expand Up @@ -302,6 +305,7 @@ impl Service {
metrics,
query_start_time,
dataset_labels,
None,
))
} else {
Ok(stream)
Expand Down Expand Up @@ -678,6 +682,7 @@ fn track_query_metrics(
metrics: &Arc<MetricsRegistry>,
start_time: std::time::Instant,
dataset_labels: Vec<HashReference>,
memory_pool: Option<Arc<TieredMemoryPool>>,
) -> QueryResultStream {
let metrics = metrics.clone();

Expand Down Expand Up @@ -709,7 +714,7 @@ fn track_query_metrics(
let err_msg = e.to_string();
for dataset in &dataset_labels {
metrics.record_query_error(&err_msg, dataset);
metrics.record_query_execution(duration, total_rows, total_bytes, dataset);
metrics.record_query_execution(duration, total_rows, total_bytes, dataset, memory_pool.as_ref());
}

yield Err(e);
Expand All @@ -721,7 +726,7 @@ fn track_query_metrics(
// Stream completed successfully, record metrics once per dataset
let duration = start_time.elapsed().as_millis() as f64;
for dataset in &dataset_labels {
metrics.record_query_execution(duration, total_rows, total_bytes, dataset);
metrics.record_query_execution(duration, total_rows, total_bytes, dataset, memory_pool.as_ref());
}
};

Expand Down
18 changes: 13 additions & 5 deletions crates/services/server/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::sync::Arc;

use common::memory_pool::TieredMemoryPool;
use datafusion::execution::memory_pool::human_readable_size;
use datasets_common::hash_reference::HashReference;
use monitoring::telemetry;

Expand Down Expand Up @@ -135,6 +139,7 @@ impl MetricsRegistry {
rows_returned: u64,
bytes_egress: u64,
dataset: &HashReference,
memory_pool: Option<&Arc<TieredMemoryPool>>,
) {
let labels = dataset_kvs(dataset);
self.query_count.inc_with_kvs(&labels);
Expand All @@ -144,6 +149,14 @@ impl MetricsRegistry {
.inc_by_with_kvs(rows_returned, &labels);
self.query_bytes_egress
.inc_by_with_kvs(bytes_egress, &labels);
if let Some(pool) = memory_pool {
let peak = pool.peak_reserved() as u64;
self.query_memory_peak_bytes.record(peak);
tracing::debug!(
peak_memory = human_readable_size(peak as usize),
"Query memory usage"
);
}
}

/// Record query error
Expand All @@ -156,11 +169,6 @@ impl MetricsRegistry {
self.query_errors.inc_with_kvs(&labels);
}

/// Record query memory usage
pub fn record_query_memory(&self, peak_bytes: u64) {
self.query_memory_peak_bytes.record(peak_bytes);
}

/// Record streaming microbatch size and throughput
pub fn record_streaming_batch(
&self,
Expand Down
Loading