rust_langgraph/channels/
binop.rs1use super::BaseChannel;
6use crate::errors::Result;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::sync::Arc;
10
11pub struct BinaryOperatorAggregate<T, F>
28where
29 F: Fn(T, T) -> T + Send + Sync,
30{
31 value: T,
32 reducer: Arc<F>,
33}
34
35impl<T, F> BinaryOperatorAggregate<T, F>
36where
37 T: Clone,
38 F: Fn(T, T) -> T + Send + Sync + 'static,
39{
40 pub fn new(initial: T, reducer: F) -> Self {
42 Self {
43 value: initial,
44 reducer: Arc::new(reducer),
45 }
46 }
47
48 pub fn value(&self) -> &T {
50 &self.value
51 }
52}
53
54impl<T, F> Debug for BinaryOperatorAggregate<T, F>
55where
56 T: Debug,
57 F: Fn(T, T) -> T + Send + Sync,
58{
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("BinaryOperatorAggregate")
61 .field("value", &self.value)
62 .field("reducer", &"<function>")
63 .finish()
64 }
65}
66
67impl<T, F> BaseChannel for BinaryOperatorAggregate<T, F>
68where
69 T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
70 F: Fn(T, T) -> T + Send + Sync + 'static,
71{
72 fn get(&self) -> Result<Option<serde_json::Value>> {
73 Ok(Some(serde_json::to_value(&self.value)?))
74 }
75
76 fn update(&mut self, values: Vec<serde_json::Value>) -> Result<()> {
77 for value_json in values {
78 let new_value: T = serde_json::from_value(value_json)?;
79 self.value = (self.reducer)(self.value.clone(), new_value);
80 }
81 Ok(())
82 }
83
84 fn checkpoint(&self) -> Result<serde_json::Value> {
85 serde_json::to_value(&self.value).map_err(Into::into)
86 }
87
88 fn from_checkpoint(_data: serde_json::Value) -> Result<Box<dyn BaseChannel>> {
89 Err(crate::errors::Error::channel(
94 "BinaryOperatorAggregate cannot be restored from checkpoint alone - requires reducer function",
95 ))
96 }
97
98 fn type_name(&self) -> &'static str {
99 "BinaryOperatorAggregate"
100 }
101
102 fn is_empty(&self) -> bool {
103 false }
105}
106
107pub fn sum_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
111where
112 T: std::ops::Add<Output = T> + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
113{
114 BinaryOperatorAggregate::new(initial, |a, b| a + b)
115}
116
117pub fn max_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
119where
120 T: Ord + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
121{
122 BinaryOperatorAggregate::new(initial, |a, b| a.max(b))
123}
124
125pub fn min_channel<T>(initial: T) -> BinaryOperatorAggregate<T, impl Fn(T, T) -> T + Send + Sync>
127where
128 T: Ord + Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
129{
130 BinaryOperatorAggregate::new(initial, |a, b| a.min(b))
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_binop_sum() {
139 let mut channel = BinaryOperatorAggregate::new(0, |a, b| a + b);
140 assert_eq!(channel.get().unwrap(), Some(serde_json::json!(0)));
141
142 channel
143 .update(vec![
144 serde_json::json!(1),
145 serde_json::json!(2),
146 serde_json::json!(3),
147 ])
148 .unwrap();
149
150 assert_eq!(channel.get().unwrap(), Some(serde_json::json!(6)));
151 }
152
153 #[test]
154 fn test_binop_max() {
155 let mut channel = max_channel(0);
156 channel
157 .update(vec![
158 serde_json::json!(5),
159 serde_json::json!(2),
160 serde_json::json!(8),
161 serde_json::json!(3),
162 ])
163 .unwrap();
164
165 assert_eq!(channel.get().unwrap(), Some(serde_json::json!(8)));
166 }
167
168 #[test]
169 fn test_binop_min() {
170 let mut channel = min_channel(100);
171 channel
172 .update(vec![
173 serde_json::json!(50),
174 serde_json::json!(75),
175 serde_json::json!(25),
176 ])
177 .unwrap();
178
179 assert_eq!(channel.get().unwrap(), Some(serde_json::json!(25)));
180 }
181
182 #[test]
183 fn test_binop_custom() {
184 let mut channel = BinaryOperatorAggregate::new(1, |a: i32, b: i32| a * b);
186 channel
187 .update(vec![
188 serde_json::json!(2),
189 serde_json::json!(3),
190 serde_json::json!(4),
191 ])
192 .unwrap();
193
194 assert_eq!(channel.get().unwrap(), Some(serde_json::json!(24)));
195 }
196
197 #[test]
198 fn test_binop_checkpoint() {
199 let mut channel = sum_channel(0);
200 channel.update(vec![serde_json::json!(10)]).unwrap();
201
202 let checkpoint = channel.checkpoint().unwrap();
203 assert_eq!(checkpoint, serde_json::json!(10));
204 }
205}