Skip to main content

rust_langgraph/channels/
binop.rs

1//! Binary operator aggregate channel implementation.
2//!
3//! A channel that reduces multiple writes using a binary operator function.
4
5use super::BaseChannel;
6use crate::errors::Result;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::sync::Arc;
10
11/// A channel that reduces values using a binary operator.
12///
13/// This channel maintains a single value and applies a reduction function
14/// when new values are written. This is useful for operations like summing,
15/// finding max/min, or any other associative binary operation.
16///
17/// # Example
18///
19/// ```rust
20/// use rust_langgraph::channels::{BaseChannel, BinaryOperatorAggregate};
21///
22/// // Sum reducer
23/// let mut channel = BinaryOperatorAggregate::new(0, |a, b| a + b);
24/// channel.update(vec![serde_json::json!(1), serde_json::json!(2), serde_json::json!(3)]).unwrap();
25/// assert_eq!(channel.get().unwrap(), Some(serde_json::json!(6)));
26/// ```
27pub struct BinaryOperatorAggregate<T, F>
28where
29    F: Fn(T, T) -> T + Send + Sync,
30{
31    value: T,
32    reducer: Arc<F>,
33}
34
35impl<T, F> BinaryOperatorAggregate<T, F>
36where
37    T: Clone,
38    F: Fn(T, T) -> T + Send + Sync + 'static,
39{
40    /// Create a new BinaryOperatorAggregate with an initial value and reducer
41    pub fn new(initial: T, reducer: F) -> Self {
42        Self {
43            value: initial,
44            reducer: Arc::new(reducer),
45        }
46    }
47
48    /// Get a reference to the current value
49    pub fn value(&self) -> &T {
50        &self.value
51    }
52}
53
54impl<T, F> Debug for BinaryOperatorAggregate<T, F>
55where
56    T: Debug,
57    F: Fn(T, T) -> T + Send + Sync,
58{
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("BinaryOperatorAggregate")
61            .field("value", &self.value)
62            .field("reducer", &"<function>")
63            .finish()
64    }
65}
66
67impl<T, F> BaseChannel for BinaryOperatorAggregate<T, F>
68where
69    T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
70    F: Fn(T, T) -> T + Send + Sync + 'static,
71{
72    fn get(&self) -> Result<Option<serde_json::Value>> {
73        Ok(Some(serde_json::to_value(&self.value)?))
74    }
75
76    fn update(&mut self, values: Vec<serde_json::Value>) -> Result<()> {
77        for value_json in values {
78            let new_value: T = serde_json::from_value(value_json)?;
79            self.value = (self.reducer)(self.value.clone(), new_value);
80        }
81        Ok(())
82    }
83
84    fn checkpoint(&self) -> Result<serde_json::Value> {
85        serde_json::to_value(&self.value).map_err(Into::into)
86    }
87
88    fn from_checkpoint(_data: serde_json::Value) -> Result<Box<dyn BaseChannel>> {
89        // Note: We can't fully restore without the reducer function
90        // This is a limitation of the type-erased channel system
91        // In practice, channels are created by the graph and checkpoints
92        // only restore the data, not the channel instances themselves
93        Err(crate::errors::Error::channel(
94            "BinaryOperatorAggregate cannot be restored from checkpoint alone - requires reducer function",
95        ))
96    }
97
98    fn type_name(&self) -> &'static str {
99        "BinaryOperatorAggregate"
100    }
101
102    fn is_empty(&self) -> bool {
103        false // Always has a value (at least the initial value)
104    }
105}
106
107// Common reducer implementations
108
109/// Create a sum reducer channel
110pub fn sum_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
111where
112    T: std::ops::Add<Output = T> + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
113{
114    BinaryOperatorAggregate::new(initial, |a, b| a + b)
115}
116
117/// Create a max reducer channel
118pub fn max_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
119where
120    T: Ord + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
121{
122    BinaryOperatorAggregate::new(initial, |a, b| a.max(b))
123}
124
125/// Create a min reducer channel
126pub fn min_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
127where
128    T: Ord + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
129{
130    BinaryOperatorAggregate::new(initial, |a, b| a.min(b))
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_binop_sum() {
139        let mut channel = BinaryOperatorAggregate::new(0, |a, b| a + b);
140        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(0)));
141
142        channel
143            .update(vec![
144                serde_json::json!(1),
145                serde_json::json!(2),
146                serde_json::json!(3),
147            ])
148            .unwrap();
149
150        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(6)));
151    }
152
153    #[test]
154    fn test_binop_max() {
155        let mut channel = max_channel(0);
156        channel
157            .update(vec![
158                serde_json::json!(5),
159                serde_json::json!(2),
160                serde_json::json!(8),
161                serde_json::json!(3),
162            ])
163            .unwrap();
164
165        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(8)));
166    }
167
168    #[test]
169    fn test_binop_min() {
170        let mut channel = min_channel(100);
171        channel
172            .update(vec![
173                serde_json::json!(50),
174                serde_json::json!(75),
175                serde_json::json!(25),
176            ])
177            .unwrap();
178
179        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(25)));
180    }
181
182    #[test]
183    fn test_binop_custom() {
184        // Product reducer
185        let mut channel = BinaryOperatorAggregate::new(1, |a: i32, b: i32| a * b);
186        channel
187            .update(vec![
188                serde_json::json!(2),
189                serde_json::json!(3),
190                serde_json::json!(4),
191            ])
192            .unwrap();
193
194        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(24)));
195    }
196
197    #[test]
198    fn test_binop_checkpoint() {
199        let mut channel = sum_channel(0);
200        channel.update(vec![serde_json::json!(10)]).unwrap();
201
202        let checkpoint = channel.checkpoint().unwrap();
203        assert_eq!(checkpoint, serde_json::json!(10));
204    }
205}