swiftide_integrations/kafka/
persist.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use async_trait::async_trait;
5
6use rdkafka::producer::FutureRecord;
7use swiftide_core::{
8    Persist,
9    indexing::{IndexingStream, Node},
10};
11
12use super::Kafka;
13
14#[async_trait]
15impl Persist for Kafka {
16    async fn setup(&self) -> Result<()> {
17        if self.topic_exists()? {
18            return Ok(());
19        }
20        if !self.create_topic_if_not_exists {
21            return Err(anyhow::anyhow!("Topic {} does not exist", self.topic));
22        }
23        self.create_topic().await?;
24        Ok(())
25    }
26
27    fn batch_size(&self) -> Option<usize> {
28        Some(self.batch_size)
29    }
30
31    async fn store(&self, node: Node) -> Result<Node> {
32        let (key, payload) = self.node_to_key_payload(&node)?;
33        self.producer()?
34            .send(
35                FutureRecord::to(&self.topic).key(&key).payload(&payload),
36                Duration::from_secs(0),
37            )
38            .await
39            .map_err(|(e, _)| anyhow::anyhow!("Failed to send node: {:?}", e))?;
40        Ok(node)
41    }
42
43    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
44        let producer = Arc::new(self.producer().expect("Failed to create producer"));
45
46        for node in &nodes {
47            match self.node_to_key_payload(node) {
48                Ok((key, payload)) => {
49                    if let Err(e) = producer
50                        .send(
51                            FutureRecord::to(&self.topic).payload(&payload).key(&key),
52                            Duration::from_secs(0),
53                        )
54                        .await
55                    {
56                        return vec![Err(anyhow::anyhow!("failed to send node: {:?}", e))].into();
57                    }
58                }
59                Err(e) => {
60                    return vec![Err(e)].into();
61                }
62            }
63        }
64
65        IndexingStream::iter(nodes.into_iter().map(Ok))
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use futures_util::TryStreamExt;
73    use rdkafka::ClientConfig;
74    use testcontainers::runners::AsyncRunner;
75    use testcontainers_modules::kafka::apache::{self};
76
77    #[test_log::test(tokio::test)]
78    async fn test_kafka_persist() {
79        static TOPIC_NAME: &str = "topic";
80
81        let kafka_node = apache::Kafka::default()
82            .start()
83            .await
84            .expect("failed to start kafka");
85        let bootstrap_servers = format!(
86            "127.0.0.1:{}",
87            kafka_node
88                .get_host_port_ipv4(apache::KAFKA_PORT)
89                .await
90                .expect("failed to get kafka port")
91        );
92
93        let mut client_config = ClientConfig::new();
94        client_config.set("bootstrap.servers", &bootstrap_servers);
95        let storage = Kafka::builder()
96            .client_config(client_config)
97            .topic(TOPIC_NAME)
98            .build()
99            .unwrap();
100
101        let node = Node::new("chunk");
102
103        storage.setup().await.unwrap();
104        storage.store(node.clone()).await.unwrap();
105    }
106
107    #[test_log::test(tokio::test)]
108    async fn test_kafka_batch_persist() {
109        static TOPIC_NAME: &str = "topic";
110
111        let kafka_node = apache::Kafka::default()
112            .start()
113            .await
114            .expect("failed to start kafka");
115        let bootstrap_servers = format!(
116            "127.0.0.1:{}",
117            kafka_node
118                .get_host_port_ipv4(apache::KAFKA_PORT)
119                .await
120                .expect("failed to get kafka port")
121        );
122
123        let mut client_config = ClientConfig::new();
124        client_config.set("bootstrap.servers", &bootstrap_servers);
125        let storage = Kafka::builder()
126            .client_config(client_config)
127            .topic(TOPIC_NAME)
128            .create_topic_if_not_exists(true)
129            .batch_size(2usize)
130            .build()
131            .unwrap();
132
133        let nodes = vec![Node::default(); 6];
134
135        storage.setup().await.unwrap();
136
137        let stream = storage.batch_store(nodes.clone()).await;
138
139        let result: Vec<Node> = stream.try_collect().await.unwrap();
140
141        assert_eq!(result.len(), 6);
142        assert_eq!(result[0], nodes[0]);
143        assert_eq!(result[1], nodes[1]);
144        assert_eq!(result[2], nodes[2]);
145        assert_eq!(result[3], nodes[3]);
146        assert_eq!(result[4], nodes[4]);
147        assert_eq!(result[5], nodes[5]);
148    }
149}