swiftide_integrations/redis/
persist.rs1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3
4use swiftide_core::{
5 Persist,
6 indexing::{IndexingStream, Node},
7};
8
9use super::Redis;
10
11#[async_trait]
12#[allow(dependency_on_unit_never_type_fallback)]
13impl Persist for Redis {
14 async fn setup(&self) -> Result<()> {
15 Ok(())
16 }
17
18 fn batch_size(&self) -> Option<usize> {
19 Some(self.batch_size)
20 }
21
22 async fn store(&self, node: Node) -> Result<Node> {
30 if let Some(mut cm) = self.lazy_connect().await {
31 redis::cmd("SET")
32 .arg(self.persist_key_for_node(&node)?)
33 .arg(self.persist_value_for_node(&node)?)
34 .query_async::<()>(&mut cm)
35 .await
36 .context("Error persisting to redis")?;
37
38 Ok(node)
39 } else {
40 anyhow::bail!("Failed to connect to Redis")
41 }
42 }
43
44 async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
52 if let Some(mut cm) = self.lazy_connect().await {
54 let args = nodes
55 .iter()
56 .map(|node| -> Result<Vec<String>> {
57 let key = self.persist_key_for_node(node)?;
58 let value = self.persist_value_for_node(node)?;
59
60 Ok(vec![key, value])
61 })
62 .collect::<Result<Vec<_>>>();
63
64 if args.is_err() {
65 return vec![Err(args.unwrap_err())].into();
66 }
67
68 let args = args.unwrap();
69
70 let result: Result<()> = redis::cmd("MSET")
71 .arg(args)
72 .query_async(&mut cm)
73 .await
74 .context("Error persisting to redis");
75
76 if result.is_ok() {
77 IndexingStream::iter(nodes.into_iter().map(Ok))
78 } else {
79 IndexingStream::iter([Err(result.unwrap_err())])
80 }
81 } else {
82 IndexingStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))])
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use futures_util::TryStreamExt;
91 use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner};
92
93 async fn start_redis() -> ContainerAsync<GenericImage> {
94 testcontainers::GenericImage::new("redis", "7.2.4")
95 .with_exposed_port(6379.into())
96 .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
97 "Ready to accept connections",
98 ))
99 .start()
100 .await
101 .expect("Redis started")
102 }
103
104 #[test_log::test(tokio::test)]
105 async fn test_redis_persist() {
106 let redis_container = start_redis().await;
107
108 let host = redis_container.get_host().await.unwrap();
109 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
110 let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
111 .unwrap()
112 .build()
113 .unwrap();
114
115 let node = Node::new("chunk");
116
117 redis.store(node.clone()).await.unwrap();
118 let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
119
120 assert_eq!(node, stored_node.unwrap());
121 }
122
123 #[test_log::test(tokio::test)]
125 async fn test_redis_batch_persist() {
126 let redis_container = start_redis().await;
127 let host = redis_container.get_host().await.unwrap();
128 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
129 let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
130 .unwrap()
131 .batch_size(20)
132 .build()
133 .unwrap();
134 let nodes = vec![Node::new("test"), Node::new("other")];
135
136 let stream = redis.batch_store(nodes).await;
137 let streamed_nodes: Vec<Node> = stream.try_collect().await.unwrap();
138
139 assert_eq!(streamed_nodes.len(), 2);
140
141 for node in streamed_nodes {
142 let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
143 assert_eq!(node, stored_node.unwrap());
144 }
145 }
146
147 #[test_log::test(tokio::test)]
148 async fn test_redis_custom_persist() {
149 let redis_container = start_redis().await;
150 let host = redis_container.get_host().await.unwrap();
151 let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
152 let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
153 .unwrap()
154 .persist_key_fn(|_node| Ok("test".to_string()))
155 .persist_value_fn(|_node| Ok("hello world".to_string()))
156 .build()
157 .unwrap();
158 let node = Node::default();
159
160 redis.store(node.clone()).await.unwrap();
161 let stored_node = redis.get_node(&node).await.unwrap();
162
163 assert_eq!(stored_node.unwrap(), "hello world");
164 assert_eq!(
165 redis.persist_key_for_node(&node).unwrap(),
166 "test".to_string()
167 );
168 }
169}