swiftide_integrations/redis/
persist.rs1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3
4use serde::Serialize;
5use swiftide_core::{
6 Persist,
7 indexing::{Chunk, IndexingStream, Node},
8};
9
10use super::Redis;
11
12#[async_trait]
13#[allow(dependency_on_unit_never_type_fallback)]
14impl<T: Chunk + Serialize> Persist for Redis<T> {
15 type Input = T;
16 type Output = T;
17 async fn setup(&self) -> Result<()> {
18 Ok(())
19 }
20
21 fn batch_size(&self) -> Option<usize> {
22 Some(self.batch_size)
23 }
24
25 async fn store(&self, node: Node<T>) -> Result<Node<T>> {
33 if let Some(mut cm) = self.lazy_connect().await {
34 redis::cmd("SET")
35 .arg(self.persist_key_for_node(&node)?)
36 .arg(self.persist_value_for_node(&node)?)
37 .query_async::<()>(&mut cm)
38 .await
39 .context("Error persisting to redis")?;
40
41 Ok(node)
42 } else {
43 anyhow::bail!("Failed to connect to Redis")
44 }
45 }
46
47 async fn batch_store(&self, nodes: Vec<Node<T>>) -> IndexingStream<T> {
55 if let Some(mut cm) = self.lazy_connect().await {
57 let args = nodes
58 .iter()
59 .map(|node| -> Result<Vec<String>> {
60 let key = self.persist_key_for_node(node)?;
61 let value = self.persist_value_for_node(node)?;
62
63 Ok(vec![key, value])
64 })
65 .collect::<Result<Vec<_>>>();
66
67 if args.is_err() {
68 return vec![Err(args.unwrap_err())].into();
69 }
70
71 let args = args.unwrap();
72
73 let result: Result<()> = redis::cmd("MSET")
74 .arg(args)
75 .query_async(&mut cm)
76 .await
77 .context("Error persisting to redis");
78
79 if let Err(e) = result {
80 IndexingStream::iter([Err(e)])
81 } else {
82 IndexingStream::iter(nodes.into_iter().map(Ok))
83 }
84 } else {
85 IndexingStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))])
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use futures_util::TryStreamExt;
94 use swiftide_core::indexing::TextNode;
95 use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner};
96
97 async fn start_redis() -> ContainerAsync<GenericImage> {
98 testcontainers::GenericImage::new("redis", "7.2.4")
99 .with_exposed_port(6379.into())
100 .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
101 "Ready to accept connections",
102 ))
103 .start()
104 .await
105 .expect("Redis started")
106 }
107
108 #[test_log::test(tokio::test)]
109 async fn test_redis_persist() {
110 let redis_container = start_redis().await;
111
112 let host = redis_container.get_host().await.unwrap();
113 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
114 let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
115 .unwrap()
116 .build()
117 .unwrap();
118
119 let node = TextNode::new("chunk");
120
121 redis.store(node.clone()).await.unwrap();
122 let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
123
124 assert_eq!(node, stored_node.unwrap());
125 }
126
127 #[test_log::test(tokio::test)]
129 async fn test_redis_batch_persist() {
130 let redis_container = start_redis().await;
131 let host = redis_container.get_host().await.unwrap();
132 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
133 let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
134 .unwrap()
135 .batch_size(20)
136 .build()
137 .unwrap();
138 let nodes = vec![TextNode::new("test"), TextNode::new("other")];
139
140 let stream = redis.batch_store(nodes).await;
141 let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap();
142
143 assert_eq!(streamed_nodes.len(), 2);
144
145 for node in streamed_nodes {
146 let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
147 assert_eq!(node, stored_node.unwrap());
148 }
149 }
150
151 #[test_log::test(tokio::test)]
152 async fn test_redis_custom_persist() {
153 let redis_container = start_redis().await;
154 let host = redis_container.get_host().await.unwrap();
155 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
156 let redis = Redis::<String>::try_build_from_url(format!("redis://{host}:{port}"))
157 .unwrap()
158 .persist_key_fn(|_node| Ok("test".to_string()))
159 .persist_value_fn(|_node| Ok("hello world".to_string()))
160 .build()
161 .unwrap();
162 let node = Node::default();
163
164 redis.store(node.clone()).await.unwrap();
165 let stored_node = redis.get_node(&node).await.unwrap();
166
167 assert_eq!(stored_node.unwrap(), "hello world");
168 assert_eq!(
169 redis.persist_key_for_node(&node).unwrap(),
170 "test".to_string()
171 );
172 }
173}