Skip to main content

scouter_dataframe/parquet/
utils.rs

1use crate::error::DataFrameError;
2use arrow::array::AsArray;
3use arrow::array::{BooleanBuilder, StringArray};
4use arrow::datatypes::DataType;
5use arrow::datatypes::UInt32Type;
6use arrow_array::types::Float64Type;
7use arrow_array::types::TimestampNanosecondType;
8use arrow_array::Array;
9use arrow_array::RecordBatch;
10use arrow_array::StringViewArray;
11use chrono::{DateTime, TimeZone, Utc};
12use datafusion::error::{DataFusionError, Result};
13use datafusion::logical_expr::ScalarFunctionArgs;
14use datafusion::logical_expr::{
15    ColumnarValue, Expr, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, Volatility,
16};
17use datafusion::prelude::DataFrame;
18use datafusion::scalar::ScalarValue;
19use deltalake::logstore::{
20    default_logstore, logstore_factories, LogStore, LogStoreFactory, ObjectStoreRef, StorageConfig,
21};
22use deltalake::DeltaResult;
23use scouter_types::{BinnedMetric, BinnedMetricStats, BinnedMetrics};
24use std::sync::Arc;
25use tracing::{debug, error, instrument};
26use url::Url;
27/// Now that we have at least 2 metric types that calculate avg, lower_bound, and upper_bound as part of their stats,
28/// it makes sense to implement a generic trait that we can use.
29pub struct ParquetHelper {}
30
31impl ParquetHelper {
32    #[instrument(skip_all)]
33    pub fn extract_feature_array(batch: &RecordBatch) -> Result<&StringViewArray, DataFrameError> {
34        let feature_array = batch
35            .column_by_name("feature")
36            .ok_or_else(|| {
37                error!("Missing 'feature' field in RecordBatch");
38                DataFrameError::MissingFieldError("feature")
39            })?
40            .as_string_view_opt()
41            .ok_or_else(|| {
42                error!("Failed to downcast 'feature' field to StringViewArray");
43                DataFrameError::DowncastError("StringViewArray")
44            })?;
45        Ok(feature_array)
46    }
47
48    #[instrument(skip_all)]
49    pub fn extract_created_at(batch: &RecordBatch) -> Result<Vec<DateTime<Utc>>, DataFrameError> {
50        let created_at_list = batch
51            .column_by_name("created_at")
52            .ok_or_else(|| {
53                error!("Missing 'created_at' field in RecordBatch");
54                DataFrameError::MissingFieldError("created_at")
55            })?
56            .as_list_opt::<i32>()
57            .ok_or_else(|| {
58                error!("Failed to downcast 'created_at' field to ListArray");
59                DataFrameError::DowncastError("ListArray")
60            })?;
61
62        let created_at_array = created_at_list.value(0);
63        Ok(created_at_array
64            .as_primitive::<TimestampNanosecondType>()
65            .iter()
66            .filter_map(|ts| ts.map(|t| Utc.timestamp_nanos(t)))
67            .collect())
68    }
69}
70pub struct BinnedMetricsExtractor {}
71
72impl BinnedMetricsExtractor {
73    #[instrument(skip_all)]
74    fn extract_stats(batch: &RecordBatch) -> Result<Vec<BinnedMetricStats>, DataFrameError> {
75        let stats_list = batch
76            .column_by_name("stats")
77            .ok_or_else(|| {
78                error!("Missing 'stats' field in RecordBatch");
79                DataFrameError::MissingFieldError("stats")
80            })?
81            .as_list_opt::<i32>()
82            .ok_or_else(|| {
83                error!("Failed to downcast 'stats' field to ListArray");
84                DataFrameError::DowncastError("ListArray")
85            })?
86            .value(0);
87
88        let stats_structs = stats_list.as_struct_opt().ok_or_else(|| {
89            error!("Failed to downcast 'stats' field to StructArray");
90            DataFrameError::DowncastError("StructArray")
91        })?;
92
93        let avg_array = stats_structs
94            .column_by_name("avg")
95            .ok_or_else(|| DataFrameError::MissingFieldError("avg"))
96            .inspect_err(|e| error!("Failed to get 'avg' field from stats: {:?}", e))?
97            .as_primitive_opt::<Float64Type>()
98            .ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
99
100        let lower_bound_array = stats_structs
101            .column_by_name("lower_bound")
102            .ok_or_else(|| DataFrameError::MissingFieldError("lower_bound"))
103            .inspect_err(|e| error!("Failed to get 'lower_bound' field from stats: {:?}", e))?
104            .as_primitive_opt::<Float64Type>()
105            .ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
106
107        let upper_bound_array = stats_structs
108            .column_by_name("upper_bound")
109            .ok_or_else(|| DataFrameError::MissingFieldError("upper_bound"))
110            .inspect_err(|e| error!("Failed to get 'upper_bound' field from stats: {:?}", e))?
111            .as_primitive_opt::<Float64Type>()
112            .ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
113
114        Ok((0..stats_structs.len())
115            .map(|i| BinnedMetricStats {
116                avg: avg_array.value(i),
117                lower_bound: lower_bound_array.value(i),
118                upper_bound: upper_bound_array.value(i),
119            })
120            .collect())
121    }
122
123    #[instrument(skip_all)]
124    fn process_metric_record_batch(batch: &RecordBatch) -> Result<BinnedMetric, DataFrameError> {
125        debug!("Processing metric record batch");
126
127        let metric_column = batch.column_by_name("metric").ok_or_else(|| {
128            error!("Missing 'metric' field in RecordBatch");
129            DataFrameError::MissingFieldError("metric")
130        })?;
131
132        // Handle both Dictionary and plain string types
133        let metric_name = if let Some(dict_array) = metric_column.as_dictionary_opt::<UInt32Type>()
134        {
135            // Dictionary-encoded string (e.g., from GenAI task_id)
136            let values = dict_array.values();
137            let string_values = values.as_string_opt::<i32>().ok_or_else(|| {
138                error!("Failed to downcast dictionary values to StringArray");
139                DataFrameError::DowncastError("StringArray")
140            })?;
141            let key = dict_array.key(0).ok_or_else(|| {
142                error!("Failed to get key from dictionary array");
143                DataFrameError::MissingFieldError("dictionary key")
144            })?;
145            string_values.value(key).to_string()
146        } else if let Some(string_view_array) = metric_column.as_string_view_opt() {
147            // StringView type
148            string_view_array.value(0).to_string()
149        } else if let Some(string_array) = metric_column.as_string_opt::<i32>() {
150            // Plain string type
151            string_array.value(0).to_string()
152        } else {
153            error!("Failed to downcast 'metric' field to any supported string type");
154            return Err(DataFrameError::DowncastError("String type"));
155        };
156
157        let created_at_list = ParquetHelper::extract_created_at(batch)?;
158        let stats = Self::extract_stats(batch)?;
159
160        Ok(BinnedMetric {
161            metric: metric_name,
162            created_at: created_at_list,
163            stats,
164        })
165    }
166
167    /// Convert a DataFrame to BinnedMetrics.
168    ///
169    /// # Arguments
170    /// * `df` - The DataFrame to convert
171    ///
172    /// # Returns
173    /// * `BinnedMetrics` - The converted BinnedMetrics
174    #[instrument(skip_all)]
175    pub async fn dataframe_to_binned_metrics(
176        df: DataFrame,
177    ) -> Result<BinnedMetrics, DataFrameError> {
178        debug!("Converting DataFrame to binned metrics");
179
180        let batches = df.collect().await?;
181
182        let metrics: Vec<BinnedMetric> = batches
183            .iter()
184            .map(Self::process_metric_record_batch)
185            .collect::<Result<Vec<_>, _>>()
186            .inspect_err(|e| {
187                error!("Failed to process metric record batch: {:?}", e);
188            })?;
189
190        Ok(BinnedMetrics::from_vec(metrics))
191    }
192}
193
194pub(crate) struct PassthroughLogStoreFactory;
195
196impl LogStoreFactory for PassthroughLogStoreFactory {
197    fn with_options(
198        &self,
199        prefixed_store: ObjectStoreRef,
200        root_store: ObjectStoreRef,
201        location: &Url,
202        options: &StorageConfig,
203    ) -> DeltaResult<Arc<dyn LogStore>> {
204        // For az:// URLs, object_store's ObjectStoreScheme::parse uses strip_bucket()
205        // which assumes az://account/container/blob-path format. Scouter uses
206        // az://container/blob-path (container in host, subpath in URL path).
207        // strip_bucket() finds no second path segment → returns "" → delta-rs
208        // applies no PrefixStore for Azure. Manually apply the correct prefix here.
209        //
210        // For gs://, s3://, s3a://, abfs://, abfss:// — delta-rs correctly derives
211        // the subpath prefix from url.path() and applies PrefixStore via decorate_prefix.
212        // Do not re-wrap those: use the already-prefixed `prefixed_store` as-is.
213        let store = if location.scheme() == "az" {
214            let subpath = location.path().trim_start_matches('/');
215            if subpath.is_empty() {
216                prefixed_store
217            } else {
218                let prefix = object_store::path::Path::from(subpath);
219                Arc::new(object_store::prefix::PrefixStore::new(
220                    root_store.clone(),
221                    prefix,
222                )) as ObjectStoreRef
223            }
224        } else {
225            prefixed_store
226        };
227        Ok(default_logstore(store, root_store, location, options))
228    }
229}
230
231pub(crate) fn register_cloud_logstore_factories() {
232    let factories = logstore_factories();
233    let factory = Arc::new(PassthroughLogStoreFactory) as Arc<dyn LogStoreFactory>;
234    for scheme in ["gs", "s3", "s3a", "az", "abfs", "abfss"] {
235        let key = Url::parse(&format!("{}://", scheme)).expect("scheme is a valid URL prefix");
236        if !factories.contains_key(&key) {
237            factories.insert(key, factory.clone());
238        }
239    }
240}
241
242/// DataFusion 52 scalar UDF for attribute-pattern matching on `search_blob`.
243///
244/// `match_attr(search_blob, '%key=value%')` → `Boolean`
245///
246/// The pattern argument is a pre-normalized LIKE string produced by `normalize_attr_filter`:
247/// it wraps the inner substring in `%...%`, so `match_attr` strips the outer `%` characters
248/// and performs a `.contains(inner)` check — semantically identical to `LIKE '%inner%'`
249/// but with zero regex compilation overhead and native `Utf8View` support.
250///
251/// **Accepted types for `search_blob` (first arg):**
252/// - `Utf8View` — the canonical storage type written by `TraceSpanBatchBuilder`
253/// - `Utf8` — the normalized form returned by DataFusion after some plan transformations
254///
255/// **Pattern (second arg):**
256/// - Must always be a `Utf8` scalar literal (i.e. `lit("...")`). Array patterns are rejected.
257///
258/// Register once on the `SessionContext`:
259/// ```rust
260/// ctx.register_udf(create_attr_match_udf());
261/// ```
262///
263/// Use in the DataFrame API via `match_attr_expr`:
264/// ```rust
265/// df = df.filter(match_attr_expr(col("search_blob"), lit("%svc=auth%")))?;
266/// ```
267/// `DynHash` (required by `ScalarUDFImpl`) is satisfied by `Hash + PartialEq + Eq`.
268/// Identity is name-based — two `AttrMatchUdf` instances with the same name are equal.
269#[derive(Debug)]
270struct AttrMatchUdf {
271    signature: Signature,
272}
273
274impl PartialEq for AttrMatchUdf {
275    fn eq(&self, _other: &Self) -> bool {
276        true // singleton UDF; all instances are equivalent
277    }
278}
279
280impl Eq for AttrMatchUdf {}
281
282impl std::hash::Hash for AttrMatchUdf {
283    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
284        self.name().hash(state);
285    }
286}
287
288impl AttrMatchUdf {
289    fn new() -> Self {
290        Self {
291            // Accept both Utf8View (Delta Lake read path) and Utf8 (post-cast path),
292            // plus a Utf8 literal pattern. one_of covers both schema variants cleanly.
293            signature: Signature::one_of(
294                vec![
295                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
296                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
297                ],
298                Volatility::Immutable,
299            ),
300        }
301    }
302}
303
304impl ScalarUDFImpl for AttrMatchUdf {
305    fn as_any(&self) -> &dyn std::any::Any {
306        self
307    }
308
309    fn name(&self) -> &str {
310        "match_attr"
311    }
312
313    fn signature(&self) -> &Signature {
314        &self.signature
315    }
316
317    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
318        Ok(DataType::Boolean)
319    }
320
321    /// Vectorized execution: match each `search_blob` value against a constant pattern.
322    ///
323    /// Pattern is always a scalar literal — DataFusion folds constant expressions before
324    /// dispatch, so the substring lookup is compiled exactly once per batch.
325    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
326        let args_slice = args.args;
327        let batch_size = args.number_rows;
328
329        // ── Pattern (second arg) — scalar literal only ───────────────────────
330        let pattern_str = match &args_slice[1] {
331            ColumnarValue::Scalar(ScalarValue::Utf8(Some(p)))
332            | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p.clone(),
333            _ => {
334                return Err(DataFusionError::Execution(
335                    "match_attr: second arg must be a non-null Utf8 scalar literal".into(),
336                ))
337            }
338        };
339
340        // Strip the '%...%' LIKE wrappers produced by normalize_attr_filter.
341        // LIKE '%inner%'  ≡  .contains("inner")  for substring matching.
342        let inner = pattern_str.trim_matches('%');
343
344        // ── Search blob (first arg) ───────────────────────────────────────────
345        match &args_slice[0] {
346            // Scalar fold path — constant propagation without allocating an array.
347            ColumnarValue::Scalar(s) => {
348                let matched = match s {
349                    ScalarValue::Utf8(Some(v))
350                    | ScalarValue::LargeUtf8(Some(v))
351                    | ScalarValue::Utf8View(Some(v)) => v.contains(inner),
352                    _ => false,
353                };
354                Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(matched))))
355            }
356
357            // Array path — vectorized substring scan.
358            ColumnarValue::Array(arr) => {
359                let mut builder = BooleanBuilder::with_capacity(batch_size);
360
361                if arr.data_type() == &DataType::Utf8View {
362                    // Zero-copy: StringViewArray::value() returns &str into inline or heap buffer.
363                    let view_arr = arr
364                        .as_any()
365                        .downcast_ref::<arrow_array::StringViewArray>()
366                        .ok_or_else(|| {
367                            DataFusionError::Execution(
368                                "match_attr: expected StringViewArray for search_blob".into(),
369                            )
370                        })?;
371                    for i in 0..arr.len() {
372                        if view_arr.is_null(i) {
373                            builder.append_null();
374                        } else {
375                            builder.append_value(view_arr.value(i).contains(inner));
376                        }
377                    }
378                } else {
379                    // Utf8 / LargeUtf8 — normalize via Arrow cast (zero-copy reinterpret).
380                    let cast_arr =
381                        arrow::compute::cast(arr.as_ref(), &DataType::Utf8).map_err(|e| {
382                            DataFusionError::Execution(format!(
383                                "match_attr: cast to Utf8 failed: {e}"
384                            ))
385                        })?;
386                    let str_arr =
387                        cast_arr
388                            .as_any()
389                            .downcast_ref::<StringArray>()
390                            .ok_or_else(|| {
391                                DataFusionError::Execution(
392                                    "match_attr: downcast to StringArray failed".into(),
393                                )
394                            })?;
395                    for i in 0..arr.len() {
396                        if str_arr.is_null(i) {
397                            builder.append_null();
398                        } else {
399                            builder.append_value(str_arr.value(i).contains(inner));
400                        }
401                    }
402                }
403
404                Ok(ColumnarValue::Array(Arc::new(builder.finish())))
405            }
406        }
407    }
408}
409
410/// Create the `match_attr` [`ScalarUDF`] using the DataFusion 52 `ScalarUDFImpl` API.
411///
412/// Register with a [`SessionContext`] once during initialization:
413/// ```rust
414/// ctx.register_udf(create_attr_match_udf());
415/// ```
416pub fn create_attr_match_udf() -> ScalarUDF {
417    ScalarUDF::from(AttrMatchUdf::new())
418}
419
420/// Build a DataFusion [`Expr`] that calls `match_attr(search_blob, pattern)`.
421///
422/// Drop-in replacement for `col(blob).like(lit(pattern))` in any DataFrame
423/// `.filter()`, `when()`, or aggregate context.  Handles `Utf8View` natively
424/// without an intermediate cast allocation.
425///
426/// # Example
427/// ```rust
428/// // Attribute filter in a query pipeline:
429/// let cond = match_attr_expr(col("search_blob"), lit("%key=value%"));
430/// df = df.filter(cond)?;
431///
432/// // Aggregate HAVING equivalent — fold into MAX for single-pass scan:
433/// let attr_agg = max(datafusion::logical_expr::cast(
434///     match_attr_expr(col("search_blob"), lit("%key=value%")),
435///     arrow::datatypes::DataType::Int64,
436/// )).alias("attr_match");
437/// ```
438pub fn match_attr_expr(search_blob: Expr, pattern: Expr) -> Expr {
439    create_attr_match_udf().call(vec![search_blob, pattern])
440}