streamweave_transformers/moving-average/
moving_average_transformer.rs

1//! Builder and configuration for the MovingAverageTransformer.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5use streamweave::TransformerConfig;
6use streamweave_error::ErrorStrategy;
7use streamweave_stateful::InMemoryStateStore;
8
9/// State for the moving average calculation.
10///
11/// Maintains a sliding window of recent values.
12#[derive(Debug, Clone)]
13pub struct MovingAverageState {
14  /// The sliding window of values.
15  pub window: VecDeque<f64>,
16  /// Maximum window size.
17  pub window_size: usize,
18}
19
20impl MovingAverageState {
21  /// Creates a new state with the specified window size.
22  pub fn new(window_size: usize) -> Self {
23    Self {
24      window: VecDeque::with_capacity(window_size),
25      window_size,
26    }
27  }
28
29  /// Adds a value to the window, removing the oldest if at capacity.
30  pub fn add_value(&mut self, value: f64) {
31    if self.window.len() >= self.window_size {
32      self.window.pop_front();
33    }
34    self.window.push_back(value);
35  }
36
37  /// Calculates the current average.
38  pub fn average(&self) -> f64 {
39    if self.window.is_empty() {
40      return 0.0;
41    }
42    let sum: f64 = self.window.iter().sum();
43    sum / self.window.len() as f64
44  }
45}
46
47/// A stateful transformer that calculates a moving average over a sliding window.
48///
49/// # Example
50///
51/// ```rust
52/// use streamweave::transformers::moving_average::MovingAverageTransformer;
53/// use streamweave::transformer::Transformer;
54/// use futures::StreamExt;
55///
56/// # async fn example() {
57/// let mut transformer = MovingAverageTransformer::new(3); // 3-item window
58/// let input = Box::pin(futures::stream::iter(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
59/// let output = transformer.transform(input);
60/// let results: Vec<f64> = output.collect().await;
61/// // Window: [1] -> avg 1.0
62/// // Window: [1,2] -> avg 1.5
63/// // Window: [1,2,3] -> avg 2.0
64/// // Window: [2,3,4] -> avg 3.0
65/// // Window: [3,4,5] -> avg 4.0
66/// assert_eq!(results, vec![1.0, 1.5, 2.0, 3.0, 4.0]);
67/// # }
68/// ```
69#[derive(Debug)]
70pub struct MovingAverageTransformer {
71  /// Configuration for the transformer.
72  pub(crate) config: TransformerConfig<f64>,
73  /// State store for maintaining the window (wrapped in Arc for sharing).
74  pub(crate) state_store: Arc<InMemoryStateStore<MovingAverageState>>,
75  /// Window size for the moving average.
76  pub(crate) window_size: usize,
77}
78
79impl Clone for MovingAverageTransformer {
80  fn clone(&self) -> Self {
81    Self {
82      config: self.config.clone(),
83      state_store: Arc::clone(&self.state_store),
84      window_size: self.window_size,
85    }
86  }
87}
88
89impl MovingAverageTransformer {
90  /// Creates a new MovingAverageTransformer with the specified window size.
91  ///
92  /// # Arguments
93  ///
94  /// * `window_size` - The number of recent items to include in the average.
95  ///
96  /// # Panics
97  ///
98  /// Panics if window_size is 0.
99  pub fn new(window_size: usize) -> Self {
100    assert!(window_size > 0, "Window size must be greater than 0");
101    Self {
102      config: TransformerConfig::default(),
103      state_store: Arc::new(InMemoryStateStore::new(MovingAverageState::new(
104        window_size,
105      ))),
106      window_size,
107    }
108  }
109
110  /// Sets the name for this transformer.
111  pub fn with_name(mut self, name: String) -> Self {
112    self.config.name = Some(name);
113    self
114  }
115
116  /// Sets the error strategy for this transformer.
117  pub fn with_error_strategy(mut self, strategy: ErrorStrategy<f64>) -> Self {
118    self.config.error_strategy = strategy;
119    self
120  }
121
122  /// Returns the window size.
123  pub fn window_size(&self) -> usize {
124    self.window_size
125  }
126}
127
128#[cfg(test)]
129mod tests {
130  use super::*;
131
132  #[test]
133  fn test_moving_average_state_new() {
134    let state = MovingAverageState::new(5);
135    assert_eq!(state.window_size, 5);
136    assert!(state.window.is_empty());
137  }
138
139  #[test]
140  fn test_moving_average_state_add_value() {
141    let mut state = MovingAverageState::new(3);
142    state.add_value(10.0);
143    assert_eq!(state.window.len(), 1);
144    assert_eq!(state.window[0], 10.0);
145  }
146
147  #[test]
148  fn test_moving_average_state_add_value_overflow() {
149    let mut state = MovingAverageState::new(2);
150    state.add_value(1.0);
151    state.add_value(2.0);
152    state.add_value(3.0); // Should remove 1.0
153    assert_eq!(state.window.len(), 2);
154    assert_eq!(state.window[0], 2.0);
155    assert_eq!(state.window[1], 3.0);
156  }
157
158  #[test]
159  fn test_moving_average_state_average_empty() {
160    let state = MovingAverageState::new(3);
161    assert_eq!(state.average(), 0.0);
162  }
163
164  #[test]
165  fn test_moving_average_state_average_single() {
166    let mut state = MovingAverageState::new(3);
167    state.add_value(10.0);
168    assert_eq!(state.average(), 10.0);
169  }
170
171  #[test]
172  fn test_moving_average_state_average_multiple() {
173    let mut state = MovingAverageState::new(3);
174    state.add_value(10.0);
175    state.add_value(20.0);
176    state.add_value(30.0);
177    assert_eq!(state.average(), 20.0);
178  }
179
180  #[test]
181  fn test_moving_average_state_clone() {
182    let mut state1 = MovingAverageState::new(3);
183    state1.add_value(5.0);
184    state1.add_value(10.0);
185
186    let state2 = state1.clone();
187    assert_eq!(state1.window, state2.window);
188    assert_eq!(state1.window_size, state2.window_size);
189    assert_eq!(state1.average(), state2.average());
190  }
191
192  #[test]
193  fn test_moving_average_transformer_new() {
194    let transformer = MovingAverageTransformer::new(5);
195    assert_eq!(transformer.window_size(), 5);
196  }
197
198  #[test]
199  #[should_panic(expected = "Window size must be greater than 0")]
200  fn test_moving_average_transformer_new_zero_panics() {
201    let _ = MovingAverageTransformer::new(0);
202  }
203
204  #[test]
205  fn test_moving_average_transformer_with_name() {
206    let transformer = MovingAverageTransformer::new(3).with_name("test_moving_avg".to_string());
207    assert_eq!(transformer.config.name, Some("test_moving_avg".to_string()));
208  }
209
210  #[test]
211  fn test_moving_average_transformer_with_error_strategy() {
212    let transformer =
213      MovingAverageTransformer::new(3).with_error_strategy(ErrorStrategy::<f64>::Skip);
214    assert!(matches!(
215      transformer.config.error_strategy,
216      ErrorStrategy::Skip
217    ));
218  }
219
220  #[test]
221  fn test_moving_average_transformer_clone() {
222    let transformer1 = MovingAverageTransformer::new(5);
223    let transformer2 = transformer1.clone();
224
225    assert_eq!(transformer1.window_size(), transformer2.window_size());
226    assert_eq!(transformer1.window_size, transformer2.window_size);
227  }
228
229  #[test]
230  fn test_moving_average_transformer_chaining() {
231    let transformer = MovingAverageTransformer::new(7)
232      .with_error_strategy(ErrorStrategy::<f64>::Retry(3))
233      .with_name("chained_moving_avg".to_string());
234
235    assert_eq!(transformer.window_size(), 7);
236    assert!(matches!(
237      transformer.config.error_strategy,
238      ErrorStrategy::Retry(3)
239    ));
240    assert_eq!(
241      transformer.config.name,
242      Some("chained_moving_avg".to_string())
243    );
244  }
245}