Skip to main content

rust_langgraph/channels/
topic.rs

1//! Topic channel implementation.
2//!
3//! A channel that accumulates all written values as a sequence.
4
5use super::BaseChannel;
6use crate::errors::Result;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::marker::PhantomData;
10
11/// A channel that accumulates all written values.
12///
13/// Unlike LastValue which keeps only the most recent value, Topic
14/// appends all writes to a list. This is useful for collecting
15/// multiple results or building up a history.
16///
17/// # Example
18///
19/// ```rust
20/// use rust_langgraph::channels::{BaseChannel, Topic};
21///
22/// let mut channel = Topic::<String>::new();
23/// channel.update(vec![serde_json::json!("first")]).unwrap();
24/// channel.update(vec![serde_json::json!("second")]).unwrap();
25///
26/// let values: Vec<String> = serde_json::from_value(
27///     channel.get().unwrap().unwrap()
28/// ).unwrap();
29/// assert_eq!(values, vec!["first", "second"]);
30/// ```
31#[derive(Debug, Clone)]
32pub struct Topic<T> {
33    values: Vec<T>,
34    _phantom: PhantomData<T>,
35}
36
37impl<T> Topic<T> {
38    /// Create a new empty Topic channel
39    pub fn new() -> Self {
40        Self {
41            values: Vec::new(),
42            _phantom: PhantomData,
43        }
44    }
45
46    /// Create a Topic with initial values
47    pub fn with_values(values: Vec<T>) -> Self {
48        Self {
49            values,
50            _phantom: PhantomData,
51        }
52    }
53
54    /// Get the number of accumulated values
55    pub fn len(&self) -> usize {
56        self.values.len()
57    }
58
59    /// Check if the topic is empty
60    pub fn is_empty(&self) -> bool {
61        self.values.is_empty()
62    }
63}
64
65impl<T> Default for Topic<T> {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl<T> BaseChannel for Topic<T>
72where
73    T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
74{
75    fn get(&self) -> Result<Option<serde_json::Value>> {
76        if self.values.is_empty() {
77            Ok(None)
78        } else {
79            Ok(Some(serde_json::to_value(&self.values)?))
80        }
81    }
82
83    fn update(&mut self, values: Vec<serde_json::Value>) -> Result<()> {
84        for value in values {
85            let typed_value: T = serde_json::from_value(value)?;
86            self.values.push(typed_value);
87        }
88        Ok(())
89    }
90
91    fn checkpoint(&self) -> Result<serde_json::Value> {
92        serde_json::to_value(&self.values).map_err(Into::into)
93    }
94
95    fn from_checkpoint(data: serde_json::Value) -> Result<Box<dyn BaseChannel>> {
96        let values: Vec<T> = serde_json::from_value(data)?;
97        Ok(Box::new(Self::with_values(values)))
98    }
99
100    fn type_name(&self) -> &'static str {
101        "Topic"
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_topic_basic() {
111        let mut channel = Topic::<i32>::new();
112        assert!(channel.get().unwrap().is_none());
113        assert_eq!(channel.len(), 0);
114
115        channel.update(vec![serde_json::json!(1)]).unwrap();
116        assert_eq!(channel.len(), 1);
117
118        let values: Vec<i32> = serde_json::from_value(channel.get().unwrap().unwrap()).unwrap();
119        assert_eq!(values, vec![1]);
120    }
121
122    #[test]
123    fn test_topic_accumulation() {
124        let mut channel = Topic::<String>::new();
125
126        channel.update(vec![serde_json::json!("first")]).unwrap();
127        channel.update(vec![serde_json::json!("second")]).unwrap();
128        channel
129            .update(vec![serde_json::json!("third"), serde_json::json!("fourth")])
130            .unwrap();
131
132        let values: Vec<String> =
133            serde_json::from_value(channel.get().unwrap().unwrap()).unwrap();
134        assert_eq!(values, vec!["first", "second", "third", "fourth"]);
135    }
136
137    #[test]
138    fn test_topic_checkpoint() {
139        let mut channel = Topic::<i32>::new();
140        channel
141            .update(vec![serde_json::json!(1), serde_json::json!(2)])
142            .unwrap();
143
144        let checkpoint = channel.checkpoint().unwrap();
145        let restored = Topic::<i32>::from_checkpoint(checkpoint).unwrap();
146
147        let values: Vec<i32> = serde_json::from_value(restored.get().unwrap().unwrap()).unwrap();
148        assert_eq!(values, vec![1, 2]);
149    }
150}