term_guard/analyzers/basic/
sum.rs

1//! Sum analyzer for computing the total of numeric values.
2
3use async_trait::async_trait;
4use datafusion::prelude::*;
5use serde::{Deserialize, Serialize};
6use tracing::instrument;
7
8use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
9
10use crate::core::current_validation_context;
11/// Analyzer that computes the sum of values in a numeric column.
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use term_guard::analyzers::basic::SumAnalyzer;
17/// use datafusion::prelude::*;
18///
19/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20/// let ctx = SessionContext::new();
21/// // Register your data table
22///
23/// let analyzer = SumAnalyzer::new("amount");
24/// let state = analyzer.compute_state_from_data(&ctx).await?;
25/// let metric = analyzer.compute_metric_from_state(&state)?;
26///
27/// if let MetricValue::Double(sum) = metric {
28///     println!("Total amount: ${:.2}", sum);
29/// }
30/// # Ok(())
31/// # }
32/// ```
33#[derive(Debug, Clone)]
34pub struct SumAnalyzer {
35    /// The column to analyze.
36    column: String,
37}
38
39impl SumAnalyzer {
40    /// Creates a new sum analyzer for the specified column.
41    pub fn new(column: impl Into<String>) -> Self {
42        Self {
43            column: column.into(),
44        }
45    }
46
47    /// Returns the column being analyzed.
48    pub fn column(&self) -> &str {
49        &self.column
50    }
51}
52
53/// State for the sum analyzer.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SumState {
56    /// Sum of all non-null values.
57    pub sum: f64,
58    /// Whether any non-null values were found.
59    pub has_values: bool,
60}
61
62impl AnalyzerState for SumState {
63    fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
64        let sum = states.iter().map(|s| s.sum).sum();
65        let has_values = states.iter().any(|s| s.has_values);
66
67        Ok(SumState { sum, has_values })
68    }
69
70    fn is_empty(&self) -> bool {
71        !self.has_values
72    }
73}
74
75#[async_trait]
76impl Analyzer for SumAnalyzer {
77    type State = SumState;
78    type Metric = MetricValue;
79
80    #[instrument(skip(ctx), fields(analyzer = "sum", column = %self.column))]
81    async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
82        // Build SQL query to compute sum and check for non-null values
83        // Get the table name from the validation context
84
85        let validation_ctx = current_validation_context();
86
87        let table_name = validation_ctx.table_name();
88
89        let sql = format!(
90            "SELECT SUM({0}) as sum, COUNT({0}) as count FROM {table_name}",
91            self.column
92        );
93
94        // Execute query
95        let df = ctx.sql(&sql).await?;
96        let batches = df.collect().await?;
97
98        // Extract sum from result
99        let (sum, has_values) = if let Some(batch) = batches.first() {
100            if batch.num_rows() > 0 {
101                // Sum can be null if all values are null
102                let sum = if batch.column(0).is_null(0) {
103                    0.0
104                } else {
105                    // Try Float64 first, then Int64
106                    if let Some(arr) = batch
107                        .column(0)
108                        .as_any()
109                        .downcast_ref::<arrow::array::Float64Array>()
110                    {
111                        arr.value(0)
112                    } else if let Some(arr) = batch
113                        .column(0)
114                        .as_any()
115                        .downcast_ref::<arrow::array::Int64Array>()
116                    {
117                        arr.value(0) as f64
118                    } else {
119                        return Err(AnalyzerError::invalid_data(format!(
120                            "Expected numeric array for sum, got {:?}",
121                            batch.column(0).data_type()
122                        )));
123                    }
124                };
125
126                // Check if we had any non-null values
127                let count_array = batch
128                    .column(1)
129                    .as_any()
130                    .downcast_ref::<arrow::array::Int64Array>()
131                    .ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 array for count"))?;
132                let has_values = count_array.value(0) > 0;
133
134                (sum, has_values)
135            } else {
136                (0.0, false)
137            }
138        } else {
139            (0.0, false)
140        };
141
142        Ok(SumState { sum, has_values })
143    }
144
145    fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
146        if state.has_values {
147            Ok(MetricValue::Double(state.sum))
148        } else {
149            Err(AnalyzerError::NoData)
150        }
151    }
152
153    fn name(&self) -> &str {
154        "sum"
155    }
156
157    fn description(&self) -> &str {
158        "Computes the sum of values in a numeric column"
159    }
160
161    fn metric_key(&self) -> String {
162        format!("{}.{}", self.name(), self.column)
163    }
164
165    fn columns(&self) -> Vec<&str> {
166        vec![&self.column]
167    }
168}