swiftide_integrations/redis/
persist.rs

1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3
4use swiftide_core::{
5    Persist,
6    indexing::{IndexingStream, Node},
7};
8
9use super::Redis;
10
11#[async_trait]
12#[allow(dependency_on_unit_never_type_fallback)]
13impl Persist for Redis {
14    async fn setup(&self) -> Result<()> {
15        Ok(())
16    }
17
18    fn batch_size(&self) -> Option<usize> {
19        Some(self.batch_size)
20    }
21
22    /// Stores a node in Redis using the SET command.
23    ///
24    /// By default nodes are stored with the path and hash as key and the node serialized as JSON as
25    /// value.
26    ///
27    /// You can customize the key and value used for storing nodes by setting the `persist_key_fn`
28    /// and `persist_value_fn` fields.
29    async fn store(&self, node: Node) -> Result<Node> {
30        if let Some(mut cm) = self.lazy_connect().await {
31            redis::cmd("SET")
32                .arg(self.persist_key_for_node(&node)?)
33                .arg(self.persist_value_for_node(&node)?)
34                .query_async::<()>(&mut cm)
35                .await
36                .context("Error persisting to redis")?;
37
38            Ok(node)
39        } else {
40            anyhow::bail!("Failed to connect to Redis")
41        }
42    }
43
44    /// Stores a batch of nodes in Redis using the MSET command.
45    ///
46    /// By default nodes are stored with the path and hash as key and the node serialized as JSON as
47    /// value.
48    ///
49    /// You can customize the key and value used for storing nodes by setting the `persist_key_fn`
50    /// and `persist_value_fn` fields.
51    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
52        // use mset for batch store
53        if let Some(mut cm) = self.lazy_connect().await {
54            let args = nodes
55                .iter()
56                .map(|node| -> Result<Vec<String>> {
57                    let key = self.persist_key_for_node(node)?;
58                    let value = self.persist_value_for_node(node)?;
59
60                    Ok(vec![key, value])
61                })
62                .collect::<Result<Vec<_>>>();
63
64            if args.is_err() {
65                return vec![Err(args.unwrap_err())].into();
66            }
67
68            let args = args.unwrap();
69
70            let result: Result<()> = redis::cmd("MSET")
71                .arg(args)
72                .query_async(&mut cm)
73                .await
74                .context("Error persisting to redis");
75
76            if result.is_ok() {
77                IndexingStream::iter(nodes.into_iter().map(Ok))
78            } else {
79                IndexingStream::iter([Err(result.unwrap_err())])
80            }
81        } else {
82            IndexingStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))])
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use futures_util::TryStreamExt;
91    use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner};
92
93    async fn start_redis() -> ContainerAsync<GenericImage> {
94        testcontainers::GenericImage::new("redis", "7.2.4")
95            .with_exposed_port(6379.into())
96            .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
97                "Ready to accept connections",
98            ))
99            .start()
100            .await
101            .expect("Redis started")
102    }
103
104    #[test_log::test(tokio::test)]
105    async fn test_redis_persist() {
106        let redis_container = start_redis().await;
107
108        let host = redis_container.get_host().await.unwrap();
109        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
110        let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
111            .unwrap()
112            .build()
113            .unwrap();
114
115        let node = Node::new("chunk");
116
117        redis.store(node.clone()).await.unwrap();
118        let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
119
120        assert_eq!(node, stored_node.unwrap());
121    }
122
123    // test batch store
124    #[test_log::test(tokio::test)]
125    async fn test_redis_batch_persist() {
126        let redis_container = start_redis().await;
127        let host = redis_container.get_host().await.unwrap();
128        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
129        let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
130            .unwrap()
131            .batch_size(20)
132            .build()
133            .unwrap();
134        let nodes = vec![Node::new("test"), Node::new("other")];
135
136        let stream = redis.batch_store(nodes).await;
137        let streamed_nodes: Vec<Node> = stream.try_collect().await.unwrap();
138
139        assert_eq!(streamed_nodes.len(), 2);
140
141        for node in streamed_nodes {
142            let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
143            assert_eq!(node, stored_node.unwrap());
144        }
145    }
146
147    #[test_log::test(tokio::test)]
148    async fn test_redis_custom_persist() {
149        let redis_container = start_redis().await;
150        let host = redis_container.get_host().await.unwrap();
151        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
152        let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
153            .unwrap()
154            .persist_key_fn(|_node| Ok("test".to_string()))
155            .persist_value_fn(|_node| Ok("hello world".to_string()))
156            .build()
157            .unwrap();
158        let node = Node::default();
159
160        redis.store(node.clone()).await.unwrap();
161        let stored_node = redis.get_node(&node).await.unwrap();
162
163        assert_eq!(stored_node.unwrap(), "hello world");
164        assert_eq!(
165            redis.persist_key_for_node(&node).unwrap(),
166            "test".to_string()
167        );
168    }
169}