term_guard/analyzers/basic/
mean.rs1use async_trait::async_trait;
4use datafusion::prelude::*;
5use serde::{Deserialize, Serialize};
6use tracing::instrument;
7
8use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
9
10use crate::core::current_validation_context;
11#[derive(Debug, Clone)]
37pub struct MeanAnalyzer {
38 column: String,
40}
41
42impl MeanAnalyzer {
43 pub fn new(column: impl Into<String>) -> Self {
45 Self {
46 column: column.into(),
47 }
48 }
49
50 pub fn column(&self) -> &str {
52 &self.column
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MeanState {
59 pub sum: f64,
61 pub count: u64,
63}
64
65impl MeanState {
66 pub fn mean(&self) -> Option<f64> {
68 if self.count == 0 {
69 None
70 } else {
71 Some(self.sum / self.count as f64)
72 }
73 }
74}
75
76impl AnalyzerState for MeanState {
77 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
78 let sum = states.iter().map(|s| s.sum).sum();
79 let count = states.iter().map(|s| s.count).sum();
80
81 Ok(MeanState { sum, count })
82 }
83
84 fn is_empty(&self) -> bool {
85 self.count == 0
86 }
87}
88
89#[async_trait]
90impl Analyzer for MeanAnalyzer {
91 type State = MeanState;
92 type Metric = MetricValue;
93
94 #[instrument(skip(ctx), fields(analyzer = "mean", column = %self.column))]
95 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
96 let validation_ctx = current_validation_context();
100
101 let table_name = validation_ctx.table_name();
102
103 let sql = format!(
104 "SELECT SUM({0}) as sum, COUNT({0}) as count FROM {table_name}",
105 self.column
106 );
107
108 let df = ctx.sql(&sql).await?;
110 let batches = df.collect().await?;
111
112 let (sum, count) = if let Some(batch) = batches.first() {
114 if batch.num_rows() > 0 {
115 let sum = if batch.column(0).is_null(0) {
117 0.0
118 } else {
119 let sum_array = batch
120 .column(0)
121 .as_any()
122 .downcast_ref::<arrow::array::Float64Array>()
123 .ok_or_else(|| {
124 AnalyzerError::invalid_data("Expected Float64 array for sum")
125 })?;
126 sum_array.value(0)
127 };
128
129 let count_array = batch
130 .column(1)
131 .as_any()
132 .downcast_ref::<arrow::array::Int64Array>()
133 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for count"))?;
134 let count = count_array.value(0) as u64;
135
136 (sum, count)
137 } else {
138 (0.0, 0)
139 }
140 } else {
141 (0.0, 0)
142 };
143
144 Ok(MeanState { sum, count })
145 }
146
147 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
148 match state.mean() {
149 Some(mean) => Ok(MetricValue::Double(mean)),
150 None => Err(AnalyzerError::NoData),
151 }
152 }
153
154 fn name(&self) -> &str {
155 "mean"
156 }
157
158 fn description(&self) -> &str {
159 "Computes the average value of a numeric column"
160 }
161
162 fn metric_key(&self) -> String {
163 format!("{}.{}", self.name(), self.column)
164 }
165
166 fn columns(&self) -> Vec<&str> {
167 vec![&self.column]
168 }
169}