streamweave_transformers/sample/
sample_transformer.rs

1use std::marker::PhantomData;
2use streamweave::TransformerConfig;
3use streamweave_error::ErrorStrategy;
4
5/// A transformer that randomly samples items from the stream.
6///
7/// This transformer passes through each item with a given probability,
8/// effectively creating a random sample of the input stream.
9pub struct SampleTransformer<T>
10where
11  T: std::fmt::Debug + Clone + Send + Sync + 'static,
12{
13  /// The probability (0.0 to 1.0) that an item will be passed through.
14  pub probability: f64,
15  /// Configuration for the transformer, including error handling strategy.
16  pub config: TransformerConfig<T>,
17  /// Phantom data to track the type parameter.
18  pub _phantom: PhantomData<T>,
19}
20
21impl<T> SampleTransformer<T>
22where
23  T: std::fmt::Debug + Clone + Send + Sync + 'static,
24{
25  /// Creates a new `SampleTransformer` with the given probability.
26  ///
27  /// # Arguments
28  ///
29  /// * `probability` - The probability (0.0 to 1.0) that an item will be passed through.
30  ///
31  /// # Panics
32  ///
33  /// Panics if `probability` is not between 0.0 and 1.0 (inclusive).
34  pub fn new(probability: f64) -> Self {
35    assert!(
36      (0.0..=1.0).contains(&probability),
37      "Probability must be between 0 and 1"
38    );
39    Self {
40      probability,
41      config: TransformerConfig::default(),
42      _phantom: PhantomData,
43    }
44  }
45
46  /// Sets the error handling strategy for this transformer.
47  ///
48  /// # Arguments
49  ///
50  /// * `strategy` - The error handling strategy to use.
51  pub fn with_error_strategy(mut self, strategy: ErrorStrategy<T>) -> Self {
52    self.config.error_strategy = strategy;
53    self
54  }
55
56  /// Sets the name for this transformer.
57  ///
58  /// # Arguments
59  ///
60  /// * `name` - The name to assign to this transformer.
61  pub fn with_name(mut self, name: String) -> Self {
62    self.config.name = Some(name);
63    self
64  }
65}
66
67#[cfg(test)]
68mod tests {
69  use super::*;
70
71  #[test]
72  fn test_sample_transformer_new() {
73    let transformer = SampleTransformer::<i32>::new(0.5);
74    assert_eq!(transformer.probability, 0.5);
75  }
76
77  #[test]
78  fn test_sample_transformer_new_zero() {
79    let transformer = SampleTransformer::<i32>::new(0.0);
80    assert_eq!(transformer.probability, 0.0);
81  }
82
83  #[test]
84  fn test_sample_transformer_new_one() {
85    let transformer = SampleTransformer::<i32>::new(1.0);
86    assert_eq!(transformer.probability, 1.0);
87  }
88
89  #[test]
90  #[should_panic(expected = "Probability must be between 0 and 1")]
91  fn test_sample_transformer_new_invalid_negative() {
92    let _ = SampleTransformer::<i32>::new(-0.1);
93  }
94
95  #[test]
96  #[should_panic(expected = "Probability must be between 0 and 1")]
97  fn test_sample_transformer_new_invalid_above_one() {
98    let _ = SampleTransformer::<i32>::new(1.1);
99  }
100
101  #[test]
102  fn test_sample_transformer_with_error_strategy() {
103    let transformer =
104      SampleTransformer::<i32>::new(0.5).with_error_strategy(ErrorStrategy::<i32>::Skip);
105    assert!(matches!(
106      transformer.config.error_strategy,
107      ErrorStrategy::Skip
108    ));
109  }
110
111  #[test]
112  fn test_sample_transformer_with_name() {
113    let transformer = SampleTransformer::<i32>::new(0.5).with_name("test_sample".to_string());
114    assert_eq!(transformer.config.name, Some("test_sample".to_string()));
115  }
116
117  #[test]
118  fn test_sample_transformer_chaining() {
119    let transformer = SampleTransformer::<i32>::new(0.75)
120      .with_error_strategy(ErrorStrategy::<i32>::Retry(3))
121      .with_name("chained_sample".to_string());
122    assert_eq!(transformer.probability, 0.75);
123    assert!(matches!(
124      transformer.config.error_strategy,
125      ErrorStrategy::Retry(3)
126    ));
127    assert_eq!(transformer.config.name, Some("chained_sample".to_string()));
128  }
129}