Skip to main content

rust_langgraph/channels/
last_value.rs

1//! LastValue channel implementation.
2//!
3//! A channel that stores only the most recent value written to it.
4
5use super::BaseChannel;
6use crate::errors::Result;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::marker::PhantomData;
10
11/// A channel that stores the last written value.
12///
13/// When multiple values are written in a single step, only the last
14/// one is kept. This is the most common channel type.
15///
16/// # Example
17///
18/// ```rust
19/// use rust_langgraph::channels::{BaseChannel, LastValue};
20///
21/// let mut channel = LastValue::<i32>::new();
22/// channel.update(vec![serde_json::json!(1), serde_json::json!(2)]).unwrap();
23/// assert_eq!(channel.get().unwrap(), Some(serde_json::json!(2)));
24/// ```
25#[derive(Debug, Clone)]
26pub struct LastValue<T> {
27    value: Option<T>,
28    _phantom: PhantomData<T>,
29}
30
31impl<T> LastValue<T> {
32    /// Create a new empty LastValue channel
33    pub fn new() -> Self {
34        Self {
35            value: None,
36            _phantom: PhantomData,
37        }
38    }
39
40    /// Create a LastValue channel with an initial value
41    pub fn with_value(value: T) -> Self {
42        Self {
43            value: Some(value),
44            _phantom: PhantomData,
45        }
46    }
47}
48
49impl<T> Default for LastValue<T> {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl<T> BaseChannel for LastValue<T>
56where
57    T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
58{
59    fn get(&self) -> Result<Option<serde_json::Value>> {
60        match &self.value {
61            Some(v) => Ok(Some(serde_json::to_value(v)?)),
62            None => Ok(None),
63        }
64    }
65
66    fn update(&mut self, values: Vec<serde_json::Value>) -> Result<()> {
67        if let Some(last) = values.last() {
68            self.value = Some(serde_json::from_value(last.clone())?);
69        }
70        Ok(())
71    }
72
73    fn checkpoint(&self) -> Result<serde_json::Value> {
74        match &self.value {
75            Some(v) => serde_json::to_value(v).map_err(Into::into),
76            None => Ok(serde_json::Value::Null),
77        }
78    }
79
80    fn from_checkpoint(data: serde_json::Value) -> Result<Box<dyn BaseChannel>> {
81        if data.is_null() {
82            Ok(Box::new(Self::new()))
83        } else {
84            let value: T = serde_json::from_value(data)?;
85            Ok(Box::new(Self::with_value(value)))
86        }
87    }
88
89    fn type_name(&self) -> &'static str {
90        "LastValue"
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_last_value_basic() {
100        let mut channel = LastValue::<i32>::new();
101        assert!(channel.get().unwrap().is_none());
102
103        channel.update(vec![serde_json::json!(42)]).unwrap();
104        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(42)));
105    }
106
107    #[test]
108    fn test_last_value_multiple_writes() {
109        let mut channel = LastValue::<i32>::new();
110        channel
111            .update(vec![
112                serde_json::json!(1),
113                serde_json::json!(2),
114                serde_json::json!(3),
115            ])
116            .unwrap();
117
118        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(3)));
119    }
120
121    #[test]
122    fn test_last_value_checkpoint() {
123        let mut channel = LastValue::<String>::new();
124        channel.update(vec![serde_json::json!("hello")]).unwrap();
125
126        let checkpoint = channel.checkpoint().unwrap();
127        assert_eq!(checkpoint, serde_json::json!("hello"));
128
129        let restored = LastValue::<String>::from_checkpoint(checkpoint).unwrap();
130        assert_eq!(
131            restored.get().unwrap(),
132            Some(serde_json::json!("hello"))
133        );
134    }
135
136    #[test]
137    fn test_last_value_with_value() {
138        let channel = LastValue::with_value(100);
139        assert_eq!(channel.get().unwrap(), Some(serde_json::json!(100)));
140    }
141}