swiftide_integrations/redis/
persist.rs

1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3
4use serde::Serialize;
5use swiftide_core::{
6    Persist,
7    indexing::{Chunk, IndexingStream, Node},
8};
9
10use super::Redis;
11
12#[async_trait]
13#[allow(dependency_on_unit_never_type_fallback)]
14impl<T: Chunk + Serialize> Persist for Redis<T> {
15    type Input = T;
16    type Output = T;
17    async fn setup(&self) -> Result<()> {
18        Ok(())
19    }
20
21    fn batch_size(&self) -> Option<usize> {
22        Some(self.batch_size)
23    }
24
25    /// Stores a node in Redis using the SET command.
26    ///
27    /// By default nodes are stored with the path and hash as key and the node serialized as JSON as
28    /// value.
29    ///
30    /// You can customize the key and value used for storing nodes by setting the `persist_key_fn`
31    /// and `persist_value_fn` fields.
32    async fn store(&self, node: Node<T>) -> Result<Node<T>> {
33        if let Some(mut cm) = self.lazy_connect().await {
34            redis::cmd("SET")
35                .arg(self.persist_key_for_node(&node)?)
36                .arg(self.persist_value_for_node(&node)?)
37                .query_async::<()>(&mut cm)
38                .await
39                .context("Error persisting to redis")?;
40
41            Ok(node)
42        } else {
43            anyhow::bail!("Failed to connect to Redis")
44        }
45    }
46
47    /// Stores a batch of nodes in Redis using the MSET command.
48    ///
49    /// By default nodes are stored with the path and hash as key and the node serialized as JSON as
50    /// value.
51    ///
52    /// You can customize the key and value used for storing nodes by setting the `persist_key_fn`
53    /// and `persist_value_fn` fields.
54    async fn batch_store(&self, nodes: Vec<Node<T>>) -> IndexingStream<T> {
55        // use mset for batch store
56        if let Some(mut cm) = self.lazy_connect().await {
57            let args = nodes
58                .iter()
59                .map(|node| -> Result<Vec<String>> {
60                    let key = self.persist_key_for_node(node)?;
61                    let value = self.persist_value_for_node(node)?;
62
63                    Ok(vec![key, value])
64                })
65                .collect::<Result<Vec<_>>>();
66
67            if args.is_err() {
68                return vec![Err(args.unwrap_err())].into();
69            }
70
71            let args = args.unwrap();
72
73            let result: Result<()> = redis::cmd("MSET")
74                .arg(args)
75                .query_async(&mut cm)
76                .await
77                .context("Error persisting to redis");
78
79            if let Err(e) = result {
80                IndexingStream::iter([Err(e)])
81            } else {
82                IndexingStream::iter(nodes.into_iter().map(Ok))
83            }
84        } else {
85            IndexingStream::iter([Err(anyhow::anyhow!("Failed to connect to Redis"))])
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use futures_util::TryStreamExt;
94    use swiftide_core::indexing::TextNode;
95    use testcontainers::{ContainerAsync, GenericImage, runners::AsyncRunner};
96
97    async fn start_redis() -> ContainerAsync<GenericImage> {
98        testcontainers::GenericImage::new("redis", "7.2.4")
99            .with_exposed_port(6379.into())
100            .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
101                "Ready to accept connections",
102            ))
103            .start()
104            .await
105            .expect("Redis started")
106    }
107
108    #[test_log::test(tokio::test)]
109    async fn test_redis_persist() {
110        let redis_container = start_redis().await;
111
112        let host = redis_container.get_host().await.unwrap();
113        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
114        let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
115            .unwrap()
116            .build()
117            .unwrap();
118
119        let node = TextNode::new("chunk");
120
121        redis.store(node.clone()).await.unwrap();
122        let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
123
124        assert_eq!(node, stored_node.unwrap());
125    }
126
127    // test batch store
128    #[test_log::test(tokio::test)]
129    async fn test_redis_batch_persist() {
130        let redis_container = start_redis().await;
131        let host = redis_container.get_host().await.unwrap();
132        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
133        let redis = Redis::try_build_from_url(format!("redis://{host}:{port}"))
134            .unwrap()
135            .batch_size(20)
136            .build()
137            .unwrap();
138        let nodes = vec![TextNode::new("test"), TextNode::new("other")];
139
140        let stream = redis.batch_store(nodes).await;
141        let streamed_nodes: Vec<TextNode> = stream.try_collect().await.unwrap();
142
143        assert_eq!(streamed_nodes.len(), 2);
144
145        for node in streamed_nodes {
146            let stored_node = serde_json::from_str(&redis.get_node(&node).await.unwrap().unwrap());
147            assert_eq!(node, stored_node.unwrap());
148        }
149    }
150
151    #[test_log::test(tokio::test)]
152    async fn test_redis_custom_persist() {
153        let redis_container = start_redis().await;
154        let host = redis_container.get_host().await.unwrap();
155        let port = redis_container.get_host_port_ipv4(6379).await.unwrap();
156        let redis = Redis::<String>::try_build_from_url(format!("redis://{host}:{port}"))
157            .unwrap()
158            .persist_key_fn(|_node| Ok("test".to_string()))
159            .persist_value_fn(|_node| Ok("hello world".to_string()))
160            .build()
161            .unwrap();
162        let node = Node::default();
163
164        redis.store(node.clone()).await.unwrap();
165        let stored_node = redis.get_node(&node).await.unwrap();
166
167        assert_eq!(stored_node.unwrap(), "hello world");
168        assert_eq!(
169            redis.persist_key_for_node(&node).unwrap(),
170            "test".to_string()
171        );
172    }
173}