term_guard/optimizer/
combiner.rs

1//! Query combination strategies for optimization.
2
3use crate::optimizer::analyzer::{AggregationType, ConstraintAnalysis};
4use crate::prelude::TermError;
5use std::collections::{HashMap, HashSet};
6
7/// A group of constraints that can be executed together.
8#[derive(Debug)]
9pub struct ConstraintGroup {
10    /// The constraints in this group
11    pub constraints: Vec<ConstraintAnalysis>,
12    /// The combined SQL query for this group
13    pub combined_sql: String,
14    /// Mapping of result columns to constraint names
15    pub result_mapping: HashMap<String, String>,
16}
17
18/// Combines compatible constraints into optimized query groups.
19#[derive(Debug)]
20pub struct QueryCombiner {
21    /// Maximum number of constraints to combine in a single query
22    max_group_size: usize,
23}
24
25impl QueryCombiner {
26    /// Creates a new query combiner.
27    pub fn new() -> Self {
28        Self {
29            max_group_size: 20, // Reasonable default to avoid overly complex queries
30        }
31    }
32
33    /// Groups constraints by optimization strategy.
34    pub fn group_constraints(
35        &self,
36        analyses: Vec<ConstraintAnalysis>,
37    ) -> Result<Vec<ConstraintGroup>, TermError> {
38        let mut groups = Vec::new();
39        let mut processed = HashSet::new();
40
41        // Group by table first
42        let by_table = self.group_by_table(analyses);
43
44        for (table, table_constraints) in by_table {
45            // Within each table, group combinable constraints
46            let combinable: Vec<_> = table_constraints
47                .iter()
48                .filter(|a| a.is_combinable && !processed.contains(&a.name))
49                .cloned()
50                .collect();
51
52            if !combinable.is_empty() {
53                // Create groups of compatible constraints
54                let compatible_groups = self.find_compatible_groups(&combinable);
55
56                for group in compatible_groups {
57                    let combined = self.combine_group(&table, group)?;
58                    for constraint in &combined.constraints {
59                        processed.insert(constraint.name.clone());
60                    }
61                    groups.push(combined);
62                }
63            }
64
65            // Handle non-combinable constraints individually
66            for analysis in table_constraints {
67                if !analysis.is_combinable && !processed.contains(&analysis.name) {
68                    processed.insert(analysis.name.clone());
69                    let individual = self.create_individual_group(analysis)?;
70                    groups.push(individual);
71                }
72            }
73        }
74
75        Ok(groups)
76    }
77
78    /// Groups constraints by table name.
79    fn group_by_table(
80        &self,
81        analyses: Vec<ConstraintAnalysis>,
82    ) -> HashMap<String, Vec<ConstraintAnalysis>> {
83        let mut by_table: HashMap<String, Vec<ConstraintAnalysis>> = HashMap::new();
84
85        for analysis in analyses {
86            by_table
87                .entry(analysis.table_name.clone())
88                .or_default()
89                .push(analysis);
90        }
91
92        by_table
93    }
94
95    /// Finds groups of compatible constraints.
96    fn find_compatible_groups(
97        &self,
98        constraints: &[ConstraintAnalysis],
99    ) -> Vec<Vec<ConstraintAnalysis>> {
100        let mut groups = Vec::new();
101        let mut current_group = Vec::new();
102        let mut used_columns = HashSet::new();
103        let mut used_aggregations = HashSet::new();
104
105        for constraint in constraints {
106            // Check if this constraint is compatible with the current group
107            let is_compatible = current_group.is_empty()
108                || (self.has_compatible_aggregations(&constraint.aggregations, &used_aggregations)
109                    && !self.has_column_conflicts(&constraint.columns, &used_columns)
110                    && current_group.len() < self.max_group_size);
111
112            if is_compatible {
113                // Add to current group
114                for agg in &constraint.aggregations {
115                    used_aggregations.insert(agg.clone());
116                }
117                for col in &constraint.columns {
118                    used_columns.insert(col.clone());
119                }
120                current_group.push(constraint.clone());
121            } else {
122                // Start a new group
123                if !current_group.is_empty() {
124                    groups.push(current_group);
125                }
126                current_group = vec![constraint.clone()];
127                used_columns.clear();
128                used_aggregations.clear();
129                for agg in &constraint.aggregations {
130                    used_aggregations.insert(agg.clone());
131                }
132                for col in &constraint.columns {
133                    used_columns.insert(col.clone());
134                }
135            }
136        }
137
138        if !current_group.is_empty() {
139            groups.push(current_group);
140        }
141
142        groups
143    }
144
145    /// Checks if aggregations are compatible.
146    fn has_compatible_aggregations(
147        &self,
148        new_aggs: &[AggregationType],
149        existing_aggs: &HashSet<AggregationType>,
150    ) -> bool {
151        // Simple compatibility check - can be made more sophisticated
152        new_aggs.iter().all(|agg| {
153            existing_aggs.is_empty()
154                || existing_aggs.contains(agg)
155                || matches!(agg, AggregationType::Count) // COUNT is always compatible
156        })
157    }
158
159    /// Checks for column conflicts.
160    fn has_column_conflicts(&self, new_cols: &[String], existing_cols: &HashSet<String>) -> bool {
161        // For now, no conflicts if columns don't overlap too much
162        let overlap = new_cols
163            .iter()
164            .filter(|col| existing_cols.contains(*col))
165            .count();
166        overlap > new_cols.len() / 2
167    }
168
169    /// Combines a group of constraints into a single optimized query.
170    fn combine_group(
171        &self,
172        table: &str,
173        constraints: Vec<ConstraintAnalysis>,
174    ) -> Result<ConstraintGroup, TermError> {
175        let mut select_parts = vec!["COUNT(*) as total_count".to_string()];
176        let mut result_mapping = HashMap::new();
177
178        // Add total_count mapping for all constraints that need it
179        for constraint in &constraints {
180            if constraint.aggregations.contains(&AggregationType::Count) {
181                result_mapping.insert(
182                    format!("{}_total", constraint.name),
183                    "total_count".to_string(),
184                );
185            }
186        }
187
188        // Build SELECT clause with all needed aggregations
189        for (i, constraint) in constraints.iter().enumerate() {
190            for (j, agg) in constraint.aggregations.iter().enumerate() {
191                if matches!(agg, AggregationType::Count) {
192                    continue; // Already handled with total_count
193                }
194
195                let col_name = if constraint.columns.is_empty() {
196                    "*".to_string()
197                } else {
198                    constraint.columns[0].clone() // Simplified
199                };
200
201                let agg_sql = agg_to_sql(agg);
202                let alias = format!("{}_{i}_{agg_sql}_{j}", constraint.name);
203                let sql_expr = match agg {
204                    AggregationType::CountDistinct => {
205                        format!("COUNT(DISTINCT {col_name}) as {alias}")
206                    }
207                    AggregationType::Sum => format!("SUM({col_name}) as {alias}"),
208                    AggregationType::Avg => format!("AVG({col_name}) as {alias}"),
209                    AggregationType::Min => format!("MIN({col_name}) as {alias}"),
210                    AggregationType::Max => format!("MAX({col_name}) as {alias}"),
211                    AggregationType::StdDev => format!("STDDEV({col_name}) as {alias}"),
212                    AggregationType::Variance => format!("VARIANCE({col_name}) as {alias}"),
213                    _ => continue,
214                };
215
216                select_parts.push(sql_expr);
217                let agg_sql = agg_to_sql(agg);
218                result_mapping.insert(format!("{}_{agg_sql}", constraint.name), alias);
219            }
220        }
221
222        let select_clause = select_parts.join(", ");
223        let combined_sql = format!("SELECT {select_clause} FROM {table}");
224
225        Ok(ConstraintGroup {
226            constraints,
227            combined_sql,
228            result_mapping,
229        })
230    }
231
232    /// Creates a group for a single non-combinable constraint.
233    fn create_individual_group(
234        &self,
235        analysis: ConstraintAnalysis,
236    ) -> Result<ConstraintGroup, TermError> {
237        // For non-combinable constraints, we'll let them execute their own queries
238        let result_mapping = HashMap::new();
239
240        Ok(ConstraintGroup {
241            constraints: vec![analysis],
242            combined_sql: String::new(), // Will use constraint's own SQL
243            result_mapping,
244        })
245    }
246
247    /// Sets the maximum group size.
248    pub fn set_max_group_size(&mut self, size: usize) {
249        self.max_group_size = size;
250    }
251}
252
253/// Converts aggregation type to SQL function name.
254fn agg_to_sql(agg: &AggregationType) -> &'static str {
255    match agg {
256        AggregationType::Count => "count",
257        AggregationType::CountDistinct => "count_distinct",
258        AggregationType::Sum => "sum",
259        AggregationType::Avg => "avg",
260        AggregationType::Min => "min",
261        AggregationType::Max => "max",
262        AggregationType::StdDev => "stddev",
263        AggregationType::Variance => "variance",
264    }
265}
266
267impl Default for QueryCombiner {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273// TODO: Fix tests once Completeness constraint is made public
274#[cfg(test)]
275mod tests {
276    use super::*;
277    // use crate::constraints::completeness::Completeness;
278
279    #[test]
280    fn test_combiner_creation() {
281        let combiner = QueryCombiner::new();
282        assert_eq!(combiner.max_group_size, 20);
283    }
284
285    // TODO: Re-enable once Completeness is made public
286    // #[test]
287    // fn test_group_by_table() {
288    //     let combiner = QueryCombiner::new();
289    //
290    //     let analyses = vec![
291    //         ConstraintAnalysis {
292    //             name: "c1".to_string(),
293    //             constraint: Arc::new(Completeness::new("col1")),
294    //             table_name: "data".to_string(),
295    //             aggregations: vec![AggregationType::Count],
296    //             columns: vec!["col1".to_string()],
297    //             has_predicates: false,
298    //             is_combinable: true,
299    //         },
300    //         ConstraintAnalysis {
301    //             name: "c2".to_string(),
302    //             constraint: Arc::new(Completeness::new("col2")),
303    //             table_name: "data".to_string(),
304    //             aggregations: vec![AggregationType::Count],
305    //             columns: vec!["col2".to_string()],
306    //             has_predicates: false,
307    //             is_combinable: true,
308    //         },
309    //     ];
310    //
311    //     let by_table = combiner.group_by_table(analyses);
312    //     assert_eq!(by_table.len(), 1);
313    //     assert_eq!(by_table.get("data").unwrap().len(), 2);
314    // }
315}