term_guard/analyzers/basic/
sum.rs1use async_trait::async_trait;
4use datafusion::prelude::*;
5use serde::{Deserialize, Serialize};
6use tracing::instrument;
7
8use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
9
10use crate::core::current_validation_context;
11#[derive(Debug, Clone)]
34pub struct SumAnalyzer {
35 column: String,
37}
38
39impl SumAnalyzer {
40 pub fn new(column: impl Into<String>) -> Self {
42 Self {
43 column: column.into(),
44 }
45 }
46
47 pub fn column(&self) -> &str {
49 &self.column
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SumState {
56 pub sum: f64,
58 pub has_values: bool,
60}
61
62impl AnalyzerState for SumState {
63 fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
64 let sum = states.iter().map(|s| s.sum).sum();
65 let has_values = states.iter().any(|s| s.has_values);
66
67 Ok(SumState { sum, has_values })
68 }
69
70 fn is_empty(&self) -> bool {
71 !self.has_values
72 }
73}
74
75#[async_trait]
76impl Analyzer for SumAnalyzer {
77 type State = SumState;
78 type Metric = MetricValue;
79
80 #[instrument(skip(ctx), fields(analyzer = "sum", column = %self.column))]
81 async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
82 let validation_ctx = current_validation_context();
86
87 let table_name = validation_ctx.table_name();
88
89 let sql = format!(
90 "SELECT SUM({0}) as sum, COUNT({0}) as count FROM {table_name}",
91 self.column
92 );
93
94 let df = ctx.sql(&sql).await?;
96 let batches = df.collect().await?;
97
98 let (sum, has_values) = if let Some(batch) = batches.first() {
100 if batch.num_rows() > 0 {
101 let sum = if batch.column(0).is_null(0) {
103 0.0
104 } else {
105 if let Some(arr) = batch
107 .column(0)
108 .as_any()
109 .downcast_ref::<arrow::array::Float64Array>()
110 {
111 arr.value(0)
112 } else if let Some(arr) = batch
113 .column(0)
114 .as_any()
115 .downcast_ref::<arrow::array::Int64Array>()
116 {
117 arr.value(0) as f64
118 } else {
119 return Err(AnalyzerError::invalid_data(format!(
120 "Expected numeric array for sum, got {:?}",
121 batch.column(0).data_type()
122 )));
123 }
124 };
125
126 let count_array = batch
128 .column(1)
129 .as_any()
130 .downcast_ref::<arrow::array::Int64Array>()
131 .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for count"))?;
132 let has_values = count_array.value(0) > 0;
133
134 (sum, has_values)
135 } else {
136 (0.0, false)
137 }
138 } else {
139 (0.0, false)
140 };
141
142 Ok(SumState { sum, has_values })
143 }
144
145 fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
146 if state.has_values {
147 Ok(MetricValue::Double(state.sum))
148 } else {
149 Err(AnalyzerError::NoData)
150 }
151 }
152
153 fn name(&self) -> &str {
154 "sum"
155 }
156
157 fn description(&self) -> &str {
158 "Computes the sum of values in a numeric column"
159 }
160
161 fn metric_key(&self) -> String {
162 format!("{}.{}", self.name(), self.column)
163 }
164
165 fn columns(&self) -> Vec<&str> {
166 vec![&self.column]
167 }
168}