1use chrono::{DateTime, Duration, Utc};
13use shape_value::ValueWord;
14use std::collections::HashMap;
15
16use shape_ast::error::Result;
17#[derive(Debug, Clone)]
19pub enum WindowType {
20 Tumbling { size: Duration },
22 Sliding { size: Duration, slide: Duration },
24 Session { gap: Duration },
26 Count { size: usize },
28 Cumulative,
30}
31
32impl WindowType {
33 pub fn tumbling(size: Duration) -> Self {
35 WindowType::Tumbling { size }
36 }
37
38 pub fn sliding(size: Duration, slide: Duration) -> Self {
40 WindowType::Sliding { size, slide }
41 }
42
43 pub fn session(gap: Duration) -> Self {
45 WindowType::Session { gap }
46 }
47
48 pub fn count(size: usize) -> Self {
50 WindowType::Count { size }
51 }
52
53 pub fn cumulative() -> Self {
55 WindowType::Cumulative
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct WindowDataPoint {
62 pub timestamp: DateTime<Utc>,
63 pub fields: HashMap<String, ValueWord>,
64}
65
66#[derive(Debug, Clone)]
68pub struct WindowResult {
69 pub start: DateTime<Utc>,
71 pub end: DateTime<Utc>,
73 pub count: usize,
75 pub aggregates: HashMap<String, f64>,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum AggregateFunction {
82 Sum,
83 Avg,
84 Min,
85 Max,
86 Count,
87 First,
88 Last,
89 StdDev,
90 Variance,
91}
92
93#[derive(Debug, Clone)]
95pub struct AggregateSpec {
96 pub field: String,
97 pub function: AggregateFunction,
98 pub output_name: String,
99}
100
101#[derive(Debug)]
103struct WindowState {
104 start: DateTime<Utc>,
105 data: Vec<WindowDataPoint>,
106 last_timestamp: Option<DateTime<Utc>>,
107}
108
109pub struct WindowManager {
111 window_type: WindowType,
113 aggregates: Vec<AggregateSpec>,
115 active_windows: Vec<WindowState>,
117 current_window: Option<WindowState>,
119 current_count: usize,
121 cumulative_data: Vec<WindowDataPoint>,
123 completed_windows: Vec<WindowResult>,
125}
126
127impl WindowManager {
128 pub fn new(window_type: WindowType) -> Self {
130 Self {
131 window_type,
132 aggregates: Vec::new(),
133 active_windows: Vec::new(),
134 current_window: None,
135 current_count: 0,
136 cumulative_data: Vec::new(),
137 completed_windows: Vec::new(),
138 }
139 }
140
141 pub fn aggregate(
143 &mut self,
144 field: &str,
145 function: AggregateFunction,
146 output_name: &str,
147 ) -> &mut Self {
148 self.aggregates.push(AggregateSpec {
149 field: field.to_string(),
150 function,
151 output_name: output_name.to_string(),
152 });
153 self
154 }
155
156 pub fn process(
158 &mut self,
159 timestamp: DateTime<Utc>,
160 fields: HashMap<String, ValueWord>,
161 ) -> Result<()> {
162 let data_point = WindowDataPoint { timestamp, fields };
163
164 match &self.window_type {
165 WindowType::Tumbling { size } => {
166 self.process_tumbling(&data_point, *size)?;
167 }
168 WindowType::Sliding { size, slide } => {
169 self.process_sliding(&data_point, *size, *slide)?;
170 }
171 WindowType::Session { gap } => {
172 self.process_session(&data_point, *gap)?;
173 }
174 WindowType::Count { size } => {
175 self.process_count(&data_point, *size)?;
176 }
177 WindowType::Cumulative => {
178 self.process_cumulative(&data_point)?;
179 }
180 }
181
182 Ok(())
183 }
184
185 fn process_tumbling(&mut self, data_point: &WindowDataPoint, size: Duration) -> Result<()> {
187 let window_start = self.align_to_window(data_point.timestamp, size);
188
189 let should_close = self
191 .current_window
192 .as_ref()
193 .map(|w| data_point.timestamp >= w.start + size)
194 .unwrap_or(false);
195
196 if should_close {
197 if let Some(window) = self.current_window.take() {
199 let result = self.compute_window_result(&window)?;
200 self.completed_windows.push(result);
201 }
202 }
203
204 match &mut self.current_window {
206 Some(window) => {
207 window.data.push(data_point.clone());
208 window.last_timestamp = Some(data_point.timestamp);
209 }
210 None => {
211 self.current_window = Some(WindowState {
212 start: window_start,
213 data: vec![data_point.clone()],
214 last_timestamp: Some(data_point.timestamp),
215 });
216 }
217 }
218
219 Ok(())
220 }
221
222 fn process_sliding(
224 &mut self,
225 data_point: &WindowDataPoint,
226 size: Duration,
227 slide: Duration,
228 ) -> Result<()> {
229 let ts = data_point.timestamp;
231
232 let window_start = self.align_to_window(ts, slide);
234
235 let needs_new_window = self.active_windows.is_empty()
237 || self
238 .active_windows
239 .last()
240 .map(|w| ts >= w.start + slide)
241 .unwrap_or(true);
242
243 if needs_new_window {
244 self.active_windows.push(WindowState {
245 start: window_start,
246 data: Vec::new(),
247 last_timestamp: None,
248 });
249 }
250
251 for window in &mut self.active_windows {
253 if ts >= window.start && ts < window.start + size {
254 window.data.push(data_point.clone());
255 window.last_timestamp = Some(ts);
256 }
257 }
258
259 let mut closed_indices = Vec::new();
261 for (i, window) in self.active_windows.iter().enumerate() {
262 if ts >= window.start + size {
263 let result = self.compute_window_result(window)?;
264 self.completed_windows.push(result);
265 closed_indices.push(i);
266 }
267 }
268
269 for i in closed_indices.into_iter().rev() {
271 self.active_windows.remove(i);
272 }
273
274 Ok(())
275 }
276
277 fn process_session(&mut self, data_point: &WindowDataPoint, gap: Duration) -> Result<()> {
279 let should_close = self
281 .current_window
282 .as_ref()
283 .and_then(|w| w.last_timestamp)
284 .map(|last_ts| data_point.timestamp - last_ts > gap)
285 .unwrap_or(false);
286
287 if should_close {
288 if let Some(window) = self.current_window.take() {
289 let result = self.compute_window_result(&window)?;
290 self.completed_windows.push(result);
291 }
292 }
293
294 match &mut self.current_window {
296 Some(window) => {
297 window.data.push(data_point.clone());
298 window.last_timestamp = Some(data_point.timestamp);
299 }
300 None => {
301 self.current_window = Some(WindowState {
302 start: data_point.timestamp,
303 data: vec![data_point.clone()],
304 last_timestamp: Some(data_point.timestamp),
305 });
306 }
307 }
308
309 Ok(())
310 }
311
312 fn process_count(&mut self, data_point: &WindowDataPoint, size: usize) -> Result<()> {
314 if self.current_window.is_none() {
315 self.current_window = Some(WindowState {
316 start: data_point.timestamp,
317 data: Vec::new(),
318 last_timestamp: None,
319 });
320 }
321
322 if let Some(window) = &mut self.current_window {
324 window.data.push(data_point.clone());
325 window.last_timestamp = Some(data_point.timestamp);
326 }
327 self.current_count += 1;
328
329 if self.current_count >= size {
331 if let Some(window) = self.current_window.take() {
332 let result = self.compute_window_result(&window)?;
333 self.completed_windows.push(result);
334 }
335 self.current_count = 0;
336 }
337
338 Ok(())
339 }
340
341 fn process_cumulative(&mut self, data_point: &WindowDataPoint) -> Result<()> {
343 self.cumulative_data.push(data_point.clone());
344
345 let start = self
347 .cumulative_data
348 .first()
349 .map(|d| d.timestamp)
350 .unwrap_or(data_point.timestamp);
351 let end = data_point.timestamp;
352
353 let window = WindowState {
354 start,
355 data: self.cumulative_data.clone(),
356 last_timestamp: Some(end),
357 };
358
359 let result = self.compute_window_result(&window)?;
360 self.completed_windows.push(result);
361
362 Ok(())
363 }
364
365 fn align_to_window(&self, ts: DateTime<Utc>, size: Duration) -> DateTime<Utc> {
367 let epoch = DateTime::UNIX_EPOCH;
368 let since_epoch = ts - epoch;
369 let size_millis = size.num_milliseconds();
370
371 if size_millis == 0 {
372 return ts;
373 }
374
375 let aligned_millis = (since_epoch.num_milliseconds() / size_millis) * size_millis;
376 epoch + Duration::milliseconds(aligned_millis)
377 }
378
379 fn compute_window_result(&self, window: &WindowState) -> Result<WindowResult> {
381 let mut aggregates = HashMap::new();
382
383 for spec in &self.aggregates {
384 let values: Vec<f64> = window
385 .data
386 .iter()
387 .filter_map(|d| d.fields.get(&spec.field).and_then(|v| v.as_f64()))
388 .collect();
389
390 let result = self.compute_aggregate(&values, spec.function)?;
391 aggregates.insert(spec.output_name.clone(), result);
392 }
393
394 let end = window.last_timestamp.unwrap_or(window.start);
395
396 Ok(WindowResult {
397 start: window.start,
398 end,
399 count: window.data.len(),
400 aggregates,
401 })
402 }
403
404 fn compute_aggregate(&self, values: &[f64], function: AggregateFunction) -> Result<f64> {
406 if values.is_empty() {
407 return Ok(f64::NAN);
408 }
409
410 Ok(match function {
411 AggregateFunction::Sum => values.iter().sum(),
412 AggregateFunction::Avg => values.iter().sum::<f64>() / values.len() as f64,
413 AggregateFunction::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
414 AggregateFunction::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
415 AggregateFunction::Count => values.len() as f64,
416 AggregateFunction::First => values.first().copied().unwrap_or(f64::NAN),
417 AggregateFunction::Last => values.last().copied().unwrap_or(f64::NAN),
418 AggregateFunction::StdDev => {
419 let mean = values.iter().sum::<f64>() / values.len() as f64;
420 let variance =
421 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
422 variance.sqrt()
423 }
424 AggregateFunction::Variance => {
425 let mean = values.iter().sum::<f64>() / values.len() as f64;
426 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64
427 }
428 })
429 }
430
431 pub fn take_completed(&mut self) -> Vec<WindowResult> {
433 std::mem::take(&mut self.completed_windows)
434 }
435
436 pub fn flush(&mut self) -> Result<Vec<WindowResult>> {
438 if let Some(ref window) = self.current_window {
440 let result = self.compute_window_result(window)?;
441 self.completed_windows.push(result);
442 }
443
444 for window in &self.active_windows {
445 let result = self.compute_window_result(window)?;
446 self.completed_windows.push(result);
447 }
448
449 self.current_window = None;
450 self.active_windows.clear();
451
452 Ok(self.take_completed())
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn make_data_point(
461 timestamp: DateTime<Utc>,
462 value: f64,
463 ) -> (DateTime<Utc>, HashMap<String, ValueWord>) {
464 let mut fields = HashMap::new();
465 fields.insert("value".to_string(), ValueWord::from_f64(value));
466 (timestamp, fields)
467 }
468
469 #[test]
470 fn test_tumbling_window() {
471 let mut manager = WindowManager::new(WindowType::tumbling(Duration::seconds(10)));
472 manager.aggregate("value", AggregateFunction::Sum, "sum");
473 manager.aggregate("value", AggregateFunction::Avg, "avg");
474
475 let base = DateTime::from_timestamp(1000000000, 0).unwrap(); for i in 0..5 {
480 let (ts, fields) = make_data_point(base + Duration::seconds(i), 10.0);
481 manager.process(ts, fields).unwrap();
482 }
483
484 assert!(
486 manager.take_completed().is_empty(),
487 "Expected no completed windows within first window"
488 );
489
490 let (ts, fields) = make_data_point(base + Duration::seconds(15), 20.0);
492 manager.process(ts, fields).unwrap();
493
494 let completed = manager.take_completed();
495 assert_eq!(completed.len(), 1, "Expected exactly 1 completed window");
496 assert_eq!(completed[0].count, 5, "Expected 5 data points in window");
497 assert_eq!(completed[0].aggregates.get("sum"), Some(&50.0));
498 assert_eq!(completed[0].aggregates.get("avg"), Some(&10.0));
499 }
500
501 #[test]
502 fn test_count_window() {
503 let mut manager = WindowManager::new(WindowType::count(3));
504 manager.aggregate("value", AggregateFunction::Sum, "sum");
505
506 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
507
508 for i in 0..3 {
509 let (ts, fields) = make_data_point(base + Duration::seconds(i as i64), (i + 1) as f64);
510 manager.process(ts, fields).unwrap();
511 }
512
513 let completed = manager.take_completed();
514 assert_eq!(completed.len(), 1);
515 assert_eq!(completed[0].count, 3);
516 assert_eq!(completed[0].aggregates.get("sum"), Some(&6.0)); }
518
519 #[test]
520 fn test_session_window() {
521 let mut manager = WindowManager::new(WindowType::session(Duration::seconds(5)));
522 manager.aggregate("value", AggregateFunction::Count, "count");
523
524 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
525
526 for i in 0..3 {
528 let (ts, fields) = make_data_point(base + Duration::seconds(i), 1.0);
529 manager.process(ts, fields).unwrap();
530 }
531
532 let (ts, fields) = make_data_point(base + Duration::seconds(10), 1.0);
534 manager.process(ts, fields).unwrap();
535
536 let completed = manager.take_completed();
537 assert_eq!(completed.len(), 1); assert_eq!(completed[0].count, 3);
539 }
540
541 #[test]
542 fn test_aggregate_functions() {
543 let mut manager = WindowManager::new(WindowType::count(5));
544 manager.aggregate("value", AggregateFunction::Min, "min");
545 manager.aggregate("value", AggregateFunction::Max, "max");
546 manager.aggregate("value", AggregateFunction::StdDev, "std");
547
548 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
549 let values = [1.0, 2.0, 3.0, 4.0, 5.0];
550
551 for (i, v) in values.iter().enumerate() {
552 let (ts, fields) = make_data_point(base + Duration::seconds(i as i64), *v);
553 manager.process(ts, fields).unwrap();
554 }
555
556 let completed = manager.take_completed();
557 assert_eq!(completed.len(), 1);
558 assert_eq!(completed[0].aggregates.get("min"), Some(&1.0));
559 assert_eq!(completed[0].aggregates.get("max"), Some(&5.0));
560 let std = completed[0].aggregates.get("std").unwrap();
562 assert!((std - 1.414).abs() < 0.01);
563 }
564
565 #[test]
566 fn test_flush() {
567 let mut manager = WindowManager::new(WindowType::tumbling(Duration::seconds(10)));
568 manager.aggregate("value", AggregateFunction::Sum, "sum");
569
570 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
571 let (ts, fields) = make_data_point(base, 42.0);
572 manager.process(ts, fields).unwrap();
573
574 let results = manager.flush().unwrap();
576 assert_eq!(results.len(), 1);
577 assert_eq!(results[0].aggregates.get("sum"), Some(&42.0));
578 }
579}