swiftide_integrations/redis/
message_history.rs1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{MessageHistory, chat_completion::ChatMessage, indexing::Chunk};
4
5use super::Redis;
6
7#[async_trait]
8impl<T: Chunk> MessageHistory for Redis<T> {
9 async fn history(&self) -> Result<Vec<ChatMessage>> {
10 if let Some(mut cm) = self.lazy_connect().await {
11 let messages: Vec<String> = redis::cmd("LRANGE")
12 .arg(&self.message_history_key)
13 .arg(0)
14 .arg(-1)
15 .query_async(&mut cm)
16 .await
17 .context("Error fetching message history")?;
18 let chat_messages: Result<Vec<ChatMessage>> = messages
19 .into_iter()
20 .map(|msg| serde_json::from_str(&msg).context("Error deserializing message"))
21 .collect();
22 chat_messages
23 } else {
24 anyhow::bail!("Failed to connect to Redis")
25 }
26 }
27
28 async fn push_owned(&self, item: ChatMessage) -> Result<()> {
29 if let Some(mut cm) = self.lazy_connect().await {
30 redis::cmd("RPUSH")
31 .arg(&self.message_history_key)
32 .arg(serde_json::to_string(&item)?)
33 .query_async::<()>(&mut cm)
34 .await
35 .context("Error pushing to message history")?;
36 Ok(())
37 } else {
38 anyhow::bail!("Failed to connect to Redis")
39 }
40 }
41
42 async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
43 if let Some(mut cm) = self.lazy_connect().await {
44 let _ = redis::cmd("DEL")
46 .arg(&self.message_history_key)
47 .query_async::<()>(&mut cm)
48 .await;
49
50 redis::cmd("RPUSH")
51 .arg(&self.message_history_key)
52 .arg(
53 items
54 .iter()
55 .map(serde_json::to_string)
56 .collect::<Result<Vec<_>, _>>()?,
57 )
58 .query_async::<()>(&mut cm)
59 .await
60 .context("Error pushing to message history")?;
61 Ok(())
62 } else {
63 anyhow::bail!("Failed to connect to Redis")
64 }
65 }
66
67 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
68 if let Some(mut cm) = self.lazy_connect().await {
69 let _ = redis::cmd("DEL")
71 .arg(&self.message_history_key)
72 .query_async::<()>(&mut cm)
73 .await;
74
75 if items.is_empty() {
76 return Ok(());
78 }
79
80 redis::cmd("RPUSH")
81 .arg(&self.message_history_key)
82 .arg(
83 items
84 .iter()
85 .map(serde_json::to_string)
86 .collect::<Result<Vec<_>, _>>()?,
87 )
88 .query_async::<()>(&mut cm)
89 .await
90 .context("Error pushing to message history")?;
91 Ok(())
92 } else {
93 anyhow::bail!("Failed to connect to Redis")
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner as _};
101
102 use super::*;
103
104 async fn start_redis() -> (String, ContainerAsync<GenericImage>) {
105 let redis_container = testcontainers::GenericImage::new("redis", "7.2.4")
106 .with_exposed_port(6379.into())
107 .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
108 "Ready to accept connections",
109 ))
110 .start()
111 .await
112 .expect("Redis started");
113
114 let host = redis_container.get_host().await.unwrap();
115 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
116
117 let url = format!("redis://{host}:{port}/");
118
119 (url, redis_container)
120 }
121
122 #[tokio::test]
123 async fn test_no_messages_yet() {
124 let (url, _container) = start_redis().await;
125 let redis = Redis::try_from_url(url, "tests").unwrap();
126
127 let messages = redis.history().await.unwrap();
128 assert!(
129 messages.is_empty(),
130 "Expected history to be empty for new Redis key"
131 );
132 }
133
134 #[tokio::test]
135 async fn test_adding_and_next_completions() {
136 let (url, _container) = start_redis().await;
137 let redis = Redis::try_from_url(url, "tests").unwrap();
138
139 let m1 = ChatMessage::System("System test".to_string());
140 let m2 = ChatMessage::User("User test".to_string());
141
142 redis.push_owned(m1.clone()).await.unwrap();
143 redis.push_owned(m2.clone()).await.unwrap();
144
145 let hist = redis.history().await.unwrap();
146 assert_eq!(
147 hist,
148 vec![m1.clone(), m2.clone()],
149 "History should match what's pushed"
150 );
151
152 let hist2 = redis.history().await.unwrap();
153 assert_eq!(
154 hist2,
155 vec![m1, m2],
156 "History should be unchanged on repeated call"
157 );
158 }
159
160 #[tokio::test]
161 async fn test_overwrite_history() {
162 let (url, _container) = start_redis().await;
163 let redis = Redis::try_from_url(url, "tests").unwrap();
164
165 redis.overwrite(vec![]).await.unwrap();
167
168 let m1 = ChatMessage::System("First".to_string());
169 let m2 = ChatMessage::User("Second".to_string());
170 redis.push_owned(m1.clone()).await.unwrap();
171 redis.push_owned(m2.clone()).await.unwrap();
172
173 let m3 = ChatMessage::Assistant(Some("Overwritten".to_string()), None);
174 redis.overwrite(vec![m3.clone()]).await.unwrap();
175
176 let hist = redis.history().await.unwrap();
177 assert_eq!(
178 hist,
179 vec![m3],
180 "History should only contain the overwritten message"
181 );
182 }
183
184 #[tokio::test]
185 async fn test_extend() {
186 let (url, _container) = start_redis().await;
187 let redis = Redis::try_from_url(url, "tests").unwrap();
188
189 let m1 = ChatMessage::System("First".to_string());
190 let m2 = ChatMessage::User("Second".to_string());
191 redis
192 .extend_owned(vec![m1.clone(), m2.clone()])
193 .await
194 .unwrap();
195
196 let hist = redis.history().await.unwrap();
197 assert_eq!(
198 hist,
199 vec![m1, m2],
200 "History should only contain the overwritten message"
201 );
202 }
203}