1use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::{DataTable, DataValue};
13use crate::sql::recursive_parser::{OrderByColumn, SortDirection};
14
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
18struct PartitionKey(String);
19
20impl PartitionKey {
21 fn from_values(values: Vec<DataValue>) -> Self {
23 let key_parts: Vec<String> = values
25 .iter()
26 .map(|v| match v {
27 DataValue::String(s) => format!("S:{}", s),
28 DataValue::InternedString(s) => format!("S:{}", s),
29 DataValue::Integer(i) => format!("I:{}", i),
30 DataValue::Float(f) => format!("F:{}", f),
31 DataValue::Boolean(b) => format!("B:{}", b),
32 DataValue::DateTime(dt) => format!("D:{}", dt),
33 DataValue::Null => "N".to_string(),
34 })
35 .collect();
36 let key = key_parts.join("|");
37 PartitionKey(key)
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct OrderedPartition {
44 rows: Vec<usize>,
46
47 row_positions: HashMap<usize, usize>,
49}
50
51impl OrderedPartition {
52 fn new(mut rows: Vec<usize>) -> Self {
54 let row_positions: HashMap<usize, usize> = rows
56 .iter()
57 .enumerate()
58 .map(|(pos, &row_idx)| (row_idx, pos))
59 .collect();
60
61 Self {
62 rows,
63 row_positions,
64 }
65 }
66
67 pub fn get_row_at_offset(&self, current_row: usize, offset: i32) -> Option<usize> {
69 let current_pos = self.row_positions.get(¤t_row)?;
70 let target_pos = (*current_pos as i32) + offset;
71
72 if target_pos >= 0 && target_pos < self.rows.len() as i32 {
73 Some(self.rows[target_pos as usize])
74 } else {
75 None
76 }
77 }
78
79 pub fn get_position(&self, row_index: usize) -> Option<usize> {
81 self.row_positions.get(&row_index).copied()
82 }
83
84 pub fn first_row(&self) -> Option<usize> {
86 self.rows.first().copied()
87 }
88
89 pub fn last_row(&self) -> Option<usize> {
91 self.rows.last().copied()
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct WindowSpec {
98 pub partition_by: Vec<String>,
99 pub order_by: Vec<OrderByColumn>,
100}
101
102pub struct WindowContext {
104 source: Arc<DataView>,
106
107 partitions: BTreeMap<PartitionKey, OrderedPartition>,
109
110 row_to_partition: HashMap<usize, PartitionKey>,
112
113 spec: WindowSpec,
115}
116
117impl WindowContext {
118 pub fn new(
120 view: Arc<DataView>,
121 partition_by: Vec<String>,
122 order_by: Vec<OrderByColumn>,
123 ) -> Result<Self> {
124 let spec = WindowSpec {
125 partition_by: partition_by.clone(),
126 order_by: order_by.clone(),
127 };
128
129 if partition_by.is_empty() {
131 let single_partition = Self::create_single_partition(&view, &order_by)?;
132 let partition_key = PartitionKey::from_values(vec![]);
133
134 let mut row_to_partition = HashMap::new();
136 for &row_idx in &single_partition.rows {
137 row_to_partition.insert(row_idx, partition_key.clone());
138 }
139
140 let mut partitions = BTreeMap::new();
141 partitions.insert(partition_key, single_partition);
142
143 return Ok(Self {
144 source: view,
145 partitions,
146 row_to_partition,
147 spec,
148 });
149 }
150
151 let mut partition_map: BTreeMap<PartitionKey, Vec<usize>> = BTreeMap::new();
153 let mut row_to_partition = HashMap::new();
154
155 let source_table = view.source();
157 let partition_col_indices: Vec<usize> = partition_by
158 .iter()
159 .map(|col| {
160 source_table
161 .get_column_index(col)
162 .ok_or_else(|| anyhow!("Invalid partition column: {}", col))
163 })
164 .collect::<Result<Vec<_>>>()?;
165
166 for row_idx in view.get_visible_rows() {
168 let mut key_values = Vec::new();
170 for &col_idx in &partition_col_indices {
171 let value = source_table
172 .get_value(row_idx, col_idx)
173 .ok_or_else(|| anyhow!("Failed to get value for partition"))?
174 .clone();
175 key_values.push(value);
176 }
177 let key = PartitionKey::from_values(key_values);
178
179 partition_map.entry(key.clone()).or_default().push(row_idx);
181 row_to_partition.insert(row_idx, key);
182 }
183
184 let mut partitions = BTreeMap::new();
186 for (key, mut rows) in partition_map {
187 if !order_by.is_empty() {
189 Self::sort_rows(&mut rows, source_table, &order_by)?;
190 }
191
192 partitions.insert(key, OrderedPartition::new(rows));
193 }
194
195 Ok(Self {
196 source: view,
197 partitions,
198 row_to_partition,
199 spec,
200 })
201 }
202
203 fn create_single_partition(
205 view: &DataView,
206 order_by: &[OrderByColumn],
207 ) -> Result<OrderedPartition> {
208 let mut rows: Vec<usize> = view.get_visible_rows();
209
210 if !order_by.is_empty() {
211 Self::sort_rows(&mut rows, view.source(), order_by)?;
212 }
213
214 Ok(OrderedPartition::new(rows))
215 }
216
217 fn sort_rows(
219 rows: &mut Vec<usize>,
220 table: &DataTable,
221 order_by: &[OrderByColumn],
222 ) -> Result<()> {
223 let sort_cols: Vec<(usize, bool)> = order_by
225 .iter()
226 .map(|col| {
227 let idx = table
228 .get_column_index(&col.column)
229 .ok_or_else(|| anyhow!("Invalid ORDER BY column: {}", col.column))?;
230 let ascending = matches!(col.direction, SortDirection::Asc);
231 Ok((idx, ascending))
232 })
233 .collect::<Result<Vec<_>>>()?;
234
235 rows.sort_by(|&a, &b| {
237 for &(col_idx, ascending) in &sort_cols {
238 let val_a = table.get_value(a, col_idx);
239 let val_b = table.get_value(b, col_idx);
240
241 match (val_a, val_b) {
242 (None, None) => continue,
243 (None, Some(_)) => {
244 return if ascending {
245 std::cmp::Ordering::Less
246 } else {
247 std::cmp::Ordering::Greater
248 }
249 }
250 (Some(_), None) => {
251 return if ascending {
252 std::cmp::Ordering::Greater
253 } else {
254 std::cmp::Ordering::Less
255 }
256 }
257 (Some(v_a), Some(v_b)) => {
258 let ord = v_a.partial_cmp(&v_b).unwrap_or(std::cmp::Ordering::Equal);
260 if ord != std::cmp::Ordering::Equal {
261 return if ascending { ord } else { ord.reverse() };
262 }
263 }
264 }
265 }
266 std::cmp::Ordering::Equal
267 });
268
269 Ok(())
270 }
271
272 pub fn get_offset_value(
274 &self,
275 current_row: usize,
276 offset: i32,
277 column: &str,
278 ) -> Option<DataValue> {
279 let partition_key = self.row_to_partition.get(¤t_row)?;
281 let partition = self.partitions.get(partition_key)?;
282
283 let target_row = partition.get_row_at_offset(current_row, offset)?;
285
286 let source_table = self.source.source();
288 let col_idx = source_table.get_column_index(column)?;
289 source_table.get_value(target_row, col_idx).cloned()
290 }
291
292 pub fn get_row_number(&self, row_index: usize) -> usize {
294 if let Some(partition_key) = self.row_to_partition.get(&row_index) {
295 if let Some(partition) = self.partitions.get(partition_key) {
296 if let Some(position) = partition.get_position(row_index) {
297 return position + 1; }
299 }
300 }
301 0 }
303
304 pub fn get_first_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
306 let partition_key = self.row_to_partition.get(&row_index)?;
307 let partition = self.partitions.get(partition_key)?;
308 let first_row = partition.first_row()?;
309
310 let source_table = self.source.source();
311 let col_idx = source_table.get_column_index(column)?;
312 source_table.get_value(first_row, col_idx).cloned()
313 }
314
315 pub fn get_last_value(&self, row_index: usize, column: &str) -> Option<DataValue> {
317 let partition_key = self.row_to_partition.get(&row_index)?;
318 let partition = self.partitions.get(partition_key)?;
319 let last_row = partition.last_row()?;
320
321 let source_table = self.source.source();
322 let col_idx = source_table.get_column_index(column)?;
323 source_table.get_value(last_row, col_idx).cloned()
324 }
325
326 pub fn partition_count(&self) -> usize {
328 self.partitions.len()
329 }
330
331 pub fn has_partitions(&self) -> bool {
333 !self.spec.partition_by.is_empty()
334 }
335
336 pub fn get_partition_sum(&self, row_index: usize, column: &str) -> Option<DataValue> {
338 let partition_key = self.row_to_partition.get(&row_index)?;
339 let partition = self.partitions.get(partition_key)?;
340 let source_table = self.source.source();
341 let col_idx = source_table.get_column_index(column)?;
342
343 let mut sum = 0.0;
344 let mut has_float = false;
345 let mut has_value = false;
346
347 for &row_idx in &partition.rows {
349 if let Some(value) = source_table.get_value(row_idx, col_idx) {
350 match value {
351 DataValue::Integer(i) => {
352 sum += *i as f64;
353 has_value = true;
354 }
355 DataValue::Float(f) => {
356 sum += f;
357 has_float = true;
358 has_value = true;
359 }
360 DataValue::Null => {
361 }
363 _ => {
364 return Some(DataValue::Null);
366 }
367 }
368 }
369 }
370
371 if !has_value {
372 return Some(DataValue::Null);
373 }
374
375 if !has_float && sum.fract() == 0.0 && sum >= i64::MIN as f64 && sum <= i64::MAX as f64 {
377 Some(DataValue::Integer(sum as i64))
378 } else {
379 Some(DataValue::Float(sum))
380 }
381 }
382
383 pub fn get_partition_count(&self, row_index: usize, column: Option<&str>) -> Option<DataValue> {
385 let partition_key = self.row_to_partition.get(&row_index)?;
386 let partition = self.partitions.get(partition_key)?;
387
388 if let Some(col_name) = column {
389 let source_table = self.source.source();
391 let col_idx = source_table.get_column_index(col_name)?;
392
393 let count = partition
394 .rows
395 .iter()
396 .filter_map(|&row_idx| source_table.get_value(row_idx, col_idx))
397 .filter(|v| !matches!(v, DataValue::Null))
398 .count();
399
400 Some(DataValue::Integer(count as i64))
401 } else {
402 Some(DataValue::Integer(partition.rows.len() as i64))
404 }
405 }
406}