term_guard/optimizer/
combiner.rs1use crate::optimizer::analyzer::{AggregationType, ConstraintAnalysis};
4use crate::prelude::TermError;
5use std::collections::{HashMap, HashSet};
6
7#[derive(Debug)]
9pub struct ConstraintGroup {
10 pub constraints: Vec<ConstraintAnalysis>,
12 pub combined_sql: String,
14 pub result_mapping: HashMap<String, String>,
16}
17
18#[derive(Debug)]
20pub struct QueryCombiner {
21 max_group_size: usize,
23}
24
25impl QueryCombiner {
26 pub fn new() -> Self {
28 Self {
29 max_group_size: 20, }
31 }
32
33 pub fn group_constraints(
35 &self,
36 analyses: Vec<ConstraintAnalysis>,
37 ) -> Result<Vec<ConstraintGroup>, TermError> {
38 let mut groups = Vec::new();
39 let mut processed = HashSet::new();
40
41 let by_table = self.group_by_table(analyses);
43
44 for (table, table_constraints) in by_table {
45 let combinable: Vec<_> = table_constraints
47 .iter()
48 .filter(|a| a.is_combinable && !processed.contains(&a.name))
49 .cloned()
50 .collect();
51
52 if !combinable.is_empty() {
53 let compatible_groups = self.find_compatible_groups(&combinable);
55
56 for group in compatible_groups {
57 let combined = self.combine_group(&table, group)?;
58 for constraint in &combined.constraints {
59 processed.insert(constraint.name.clone());
60 }
61 groups.push(combined);
62 }
63 }
64
65 for analysis in table_constraints {
67 if !analysis.is_combinable && !processed.contains(&analysis.name) {
68 processed.insert(analysis.name.clone());
69 let individual = self.create_individual_group(analysis)?;
70 groups.push(individual);
71 }
72 }
73 }
74
75 Ok(groups)
76 }
77
78 fn group_by_table(
80 &self,
81 analyses: Vec<ConstraintAnalysis>,
82 ) -> HashMap<String, Vec<ConstraintAnalysis>> {
83 let mut by_table: HashMap<String, Vec<ConstraintAnalysis>> = HashMap::new();
84
85 for analysis in analyses {
86 by_table
87 .entry(analysis.table_name.clone())
88 .or_default()
89 .push(analysis);
90 }
91
92 by_table
93 }
94
95 fn find_compatible_groups(
97 &self,
98 constraints: &[ConstraintAnalysis],
99 ) -> Vec<Vec<ConstraintAnalysis>> {
100 let mut groups = Vec::new();
101 let mut current_group = Vec::new();
102 let mut used_columns = HashSet::new();
103 let mut used_aggregations = HashSet::new();
104
105 for constraint in constraints {
106 let is_compatible = current_group.is_empty()
108 || (self.has_compatible_aggregations(&constraint.aggregations, &used_aggregations)
109 && !self.has_column_conflicts(&constraint.columns, &used_columns)
110 && current_group.len() < self.max_group_size);
111
112 if is_compatible {
113 for agg in &constraint.aggregations {
115 used_aggregations.insert(agg.clone());
116 }
117 for col in &constraint.columns {
118 used_columns.insert(col.clone());
119 }
120 current_group.push(constraint.clone());
121 } else {
122 if !current_group.is_empty() {
124 groups.push(current_group);
125 }
126 current_group = vec![constraint.clone()];
127 used_columns.clear();
128 used_aggregations.clear();
129 for agg in &constraint.aggregations {
130 used_aggregations.insert(agg.clone());
131 }
132 for col in &constraint.columns {
133 used_columns.insert(col.clone());
134 }
135 }
136 }
137
138 if !current_group.is_empty() {
139 groups.push(current_group);
140 }
141
142 groups
143 }
144
145 fn has_compatible_aggregations(
147 &self,
148 new_aggs: &[AggregationType],
149 existing_aggs: &HashSet<AggregationType>,
150 ) -> bool {
151 new_aggs.iter().all(|agg| {
153 existing_aggs.is_empty()
154 || existing_aggs.contains(agg)
155 || matches!(agg, AggregationType::Count) })
157 }
158
159 fn has_column_conflicts(&self, new_cols: &[String], existing_cols: &HashSet<String>) -> bool {
161 let overlap = new_cols
163 .iter()
164 .filter(|col| existing_cols.contains(*col))
165 .count();
166 overlap > new_cols.len() / 2
167 }
168
169 fn combine_group(
171 &self,
172 table: &str,
173 constraints: Vec<ConstraintAnalysis>,
174 ) -> Result<ConstraintGroup, TermError> {
175 let mut select_parts = vec!["COUNT(*) as total_count".to_string()];
176 let mut result_mapping = HashMap::new();
177
178 for constraint in &constraints {
180 if constraint.aggregations.contains(&AggregationType::Count) {
181 result_mapping.insert(
182 format!("{}_total", constraint.name),
183 "total_count".to_string(),
184 );
185 }
186 }
187
188 for (i, constraint) in constraints.iter().enumerate() {
190 for (j, agg) in constraint.aggregations.iter().enumerate() {
191 if matches!(agg, AggregationType::Count) {
192 continue; }
194
195 let col_name = if constraint.columns.is_empty() {
196 "*".to_string()
197 } else {
198 constraint.columns[0].clone() };
200
201 let agg_sql = agg_to_sql(agg);
202 let alias = format!("{}_{i}_{agg_sql}_{j}", constraint.name);
203 let sql_expr = match agg {
204 AggregationType::CountDistinct => {
205 format!("COUNT(DISTINCT {col_name}) as {alias}")
206 }
207 AggregationType::Sum => format!("SUM({col_name}) as {alias}"),
208 AggregationType::Avg => format!("AVG({col_name}) as {alias}"),
209 AggregationType::Min => format!("MIN({col_name}) as {alias}"),
210 AggregationType::Max => format!("MAX({col_name}) as {alias}"),
211 AggregationType::StdDev => format!("STDDEV({col_name}) as {alias}"),
212 AggregationType::Variance => format!("VARIANCE({col_name}) as {alias}"),
213 _ => continue,
214 };
215
216 select_parts.push(sql_expr);
217 let agg_sql = agg_to_sql(agg);
218 result_mapping.insert(format!("{}_{agg_sql}", constraint.name), alias);
219 }
220 }
221
222 let select_clause = select_parts.join(", ");
223 let combined_sql = format!("SELECT {select_clause} FROM {table}");
224
225 Ok(ConstraintGroup {
226 constraints,
227 combined_sql,
228 result_mapping,
229 })
230 }
231
232 fn create_individual_group(
234 &self,
235 analysis: ConstraintAnalysis,
236 ) -> Result<ConstraintGroup, TermError> {
237 let result_mapping = HashMap::new();
239
240 Ok(ConstraintGroup {
241 constraints: vec![analysis],
242 combined_sql: String::new(), result_mapping,
244 })
245 }
246
247 pub fn set_max_group_size(&mut self, size: usize) {
249 self.max_group_size = size;
250 }
251}
252
253fn agg_to_sql(agg: &AggregationType) -> &'static str {
255 match agg {
256 AggregationType::Count => "count",
257 AggregationType::CountDistinct => "count_distinct",
258 AggregationType::Sum => "sum",
259 AggregationType::Avg => "avg",
260 AggregationType::Min => "min",
261 AggregationType::Max => "max",
262 AggregationType::StdDev => "stddev",
263 AggregationType::Variance => "variance",
264 }
265}
266
267impl Default for QueryCombiner {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[cfg(test)]
275mod tests {
276 use super::*;
277 #[test]
280 fn test_combiner_creation() {
281 let combiner = QueryCombiner::new();
282 assert_eq!(combiner.max_group_size, 20);
283 }
284
285 }