swiftide_integrations/redis/
message_history.rs

1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{MessageHistory, chat_completion::ChatMessage, indexing::Chunk};
4
5use super::Redis;
6
7#[async_trait]
8impl<T: Chunk> MessageHistory for Redis<T> {
9    async fn history(&self) -> Result<Vec<ChatMessage>> {
10        if let Some(mut cm) = self.lazy_connect().await {
11            let messages: Vec<String> = redis::cmd("LRANGE")
12                .arg(&self.message_history_key)
13                .arg(0)
14                .arg(-1)
15                .query_async(&mut cm)
16                .await
17                .context("Error fetching message history")?;
18            let chat_messages: Result<Vec<ChatMessage>> = messages
19                .into_iter()
20                .map(|msg| serde_json::from_str(&msg).context("Error deserializing message"))
21                .collect();
22            chat_messages
23        } else {
24            anyhow::bail!("Failed to connect to Redis")
25        }
26    }
27
28    async fn push_owned(&self, item: ChatMessage) -> Result<()> {
29        if let Some(mut cm) = self.lazy_connect().await {
30            redis::cmd("RPUSH")
31                .arg(&self.message_history_key)
32                .arg(serde_json::to_string(&item)?)
33                .query_async::<()>(&mut cm)
34                .await
35                .context("Error pushing to message history")?;
36            Ok(())
37        } else {
38            anyhow::bail!("Failed to connect to Redis")
39        }
40    }
41
42    async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
43        if let Some(mut cm) = self.lazy_connect().await {
44            // If it does not exist yet, we can just push the items
45            let _ = redis::cmd("DEL")
46                .arg(&self.message_history_key)
47                .query_async::<()>(&mut cm)
48                .await;
49
50            redis::cmd("RPUSH")
51                .arg(&self.message_history_key)
52                .arg(
53                    items
54                        .iter()
55                        .map(serde_json::to_string)
56                        .collect::<Result<Vec<_>, _>>()?,
57                )
58                .query_async::<()>(&mut cm)
59                .await
60                .context("Error pushing to message history")?;
61            Ok(())
62        } else {
63            anyhow::bail!("Failed to connect to Redis")
64        }
65    }
66
67    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
68        if let Some(mut cm) = self.lazy_connect().await {
69            // If it does not exist yet, we can just push the items
70            let _ = redis::cmd("DEL")
71                .arg(&self.message_history_key)
72                .query_async::<()>(&mut cm)
73                .await;
74
75            if items.is_empty() {
76                // If we are overwriting with an empty history, we can just return
77                return Ok(());
78            }
79
80            redis::cmd("RPUSH")
81                .arg(&self.message_history_key)
82                .arg(
83                    items
84                        .iter()
85                        .map(serde_json::to_string)
86                        .collect::<Result<Vec<_>, _>>()?,
87                )
88                .query_async::<()>(&mut cm)
89                .await
90                .context("Error pushing to message history")?;
91            Ok(())
92        } else {
93            anyhow::bail!("Failed to connect to Redis")
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner as _};
101
102    use super::*;
103
104    async fn start_redis() -> (String, ContainerAsync<GenericImage>) {
105        let redis_container = testcontainers::GenericImage::new("redis", "7.2.4")
106            .with_exposed_port(6379.into())
107            .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
108                "Ready to accept connections",
109            ))
110            .start()
111            .await
112            .expect("Redis started");
113
114        let host = redis_container.get_host().await.unwrap();
115        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
116
117        let url = format!("redis://{host}:{port}/");
118
119        (url, redis_container)
120    }
121
122    #[tokio::test]
123    async fn test_no_messages_yet() {
124        let (url, _container) = start_redis().await;
125        let redis = Redis::try_from_url(url, "tests").unwrap();
126
127        let messages = redis.history().await.unwrap();
128        assert!(
129            messages.is_empty(),
130            "Expected history to be empty for new Redis key"
131        );
132    }
133
134    #[tokio::test]
135    async fn test_adding_and_next_completions() {
136        let (url, _container) = start_redis().await;
137        let redis = Redis::try_from_url(url, "tests").unwrap();
138
139        let m1 = ChatMessage::System("System test".to_string());
140        let m2 = ChatMessage::User("User test".to_string());
141
142        redis.push_owned(m1.clone()).await.unwrap();
143        redis.push_owned(m2.clone()).await.unwrap();
144
145        let hist = redis.history().await.unwrap();
146        assert_eq!(
147            hist,
148            vec![m1.clone(), m2.clone()],
149            "History should match what's pushed"
150        );
151
152        let hist2 = redis.history().await.unwrap();
153        assert_eq!(
154            hist2,
155            vec![m1, m2],
156            "History should be unchanged on repeated call"
157        );
158    }
159
160    #[tokio::test]
161    async fn test_overwrite_history() {
162        let (url, _container) = start_redis().await;
163        let redis = Redis::try_from_url(url, "tests").unwrap();
164
165        // Check that overwrite on empty also works
166        redis.overwrite(vec![]).await.unwrap();
167
168        let m1 = ChatMessage::System("First".to_string());
169        let m2 = ChatMessage::User("Second".to_string());
170        redis.push_owned(m1.clone()).await.unwrap();
171        redis.push_owned(m2.clone()).await.unwrap();
172
173        let m3 = ChatMessage::Assistant(Some("Overwritten".to_string()), None);
174        redis.overwrite(vec![m3.clone()]).await.unwrap();
175
176        let hist = redis.history().await.unwrap();
177        assert_eq!(
178            hist,
179            vec![m3],
180            "History should only contain the overwritten message"
181        );
182    }
183
184    #[tokio::test]
185    async fn test_extend() {
186        let (url, _container) = start_redis().await;
187        let redis = Redis::try_from_url(url, "tests").unwrap();
188
189        let m1 = ChatMessage::System("First".to_string());
190        let m2 = ChatMessage::User("Second".to_string());
191        redis
192            .extend_owned(vec![m1.clone(), m2.clone()])
193            .await
194            .unwrap();
195
196        let hist = redis.history().await.unwrap();
197        assert_eq!(
198            hist,
199            vec![m1, m2],
200            "History should only contain the overwritten message"
201        );
202    }
203}