swiftide_integrations/lancedb/
persist.rs1use std::sync::Arc;
2
3use anyhow::Context as _;
4use anyhow::Result;
5use arrow_array::Array;
6use arrow_array::FixedSizeListArray;
7use arrow_array::GenericByteArray;
8use arrow_array::RecordBatch;
9use arrow_array::RecordBatchIterator;
10use arrow_array::types::Float32Type;
11use arrow_array::types::UInt8Type;
12use arrow_array::types::Utf8Type;
13use async_trait::async_trait;
14use swiftide_core::Persist;
15use swiftide_core::indexing::IndexingStream;
16use swiftide_core::indexing::TextNode;
17
18use super::FieldConfig;
19use super::LanceDB;
20
21#[async_trait]
22impl Persist for LanceDB {
23 type Input = String;
24 type Output = String;
25
26 #[tracing::instrument(skip_all)]
27 async fn setup(&self) -> Result<()> {
28 let conn = self.get_connection().await?;
29 let schema = self.schema.clone();
30
31 if let Err(err) = conn.open_table(&self.table_name).execute().await {
32 if matches!(err, lancedb::Error::TableNotFound { .. }) {
33 conn.create_empty_table(&self.table_name, schema)
34 .execute()
35 .await
36 .map(|_| ())
37 .map_err(anyhow::Error::from)?;
38 } else {
39 return Err(err.into());
40 }
41 }
42
43 Ok(())
44 }
45
46 #[tracing::instrument(skip_all)]
47 async fn store(&self, node: TextNode) -> Result<TextNode> {
48 let mut nodes = vec![node; 1];
49 self.store_nodes(&nodes).await?;
50
51 let node = nodes.swap_remove(0);
52
53 Ok(node)
54 }
55
56 #[tracing::instrument(skip_all)]
57 async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> {
58 self.store_nodes(&nodes).await.map(|()| nodes).into()
59 }
60
61 fn batch_size(&self) -> Option<usize> {
62 Some(self.batch_size)
63 }
64}
65
66impl LanceDB {
67 async fn store_nodes(&self, nodes: &[TextNode]) -> Result<()> {
68 let schema = self.schema.clone();
69
70 let batches = self.extract_arrow_batches_from_nodes(nodes)?;
71
72 let data = RecordBatchIterator::new(
73 vec![
74 RecordBatch::try_new(schema.clone(), batches)
75 .context("Could not create batches")?,
76 ]
77 .into_iter()
78 .map(Ok),
79 schema.clone(),
80 );
81
82 let conn = self.get_connection().await?;
83 let table = conn.open_table(&self.table_name).execute().await?;
84 let mut merge_insert = table.merge_insert(&["id"]);
85
86 merge_insert
87 .when_matched_update_all(None)
88 .when_not_matched_insert_all();
89
90 merge_insert.execute(Box::new(data)).await?;
91
92 Ok(())
93 }
94
95 fn extract_arrow_batches_from_nodes(
96 &self,
97 nodes: &[TextNode],
98 ) -> core::result::Result<Vec<Arc<dyn Array>>, anyhow::Error> {
99 let fields = self.fields.as_slice();
100 let mut batches: Vec<Arc<dyn Array>> = Vec::with_capacity(fields.len());
101
102 for field in fields {
103 match field {
104 FieldConfig::Vector(config) => {
105 let mut row = Vec::with_capacity(nodes.len());
106 let vector_size = config
107 .vector_size
108 .or(self.vector_size)
109 .context("Expected vector size to be set for field")?;
110
111 for node in nodes {
112 let data = node
113 .vectors
114 .as_ref()
115 .and_then(|v| v.get(&config.embedded_field))
117 .map(|v| v.iter().map(|f| Some(*f)));
118
119 row.push(data);
120 }
121 batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::<
122 Float32Type,
123 _,
124 _,
125 >(row, vector_size)));
126 }
127 FieldConfig::Metadata(config) => {
128 let mut row = Vec::with_capacity(nodes.len());
129
130 for node in nodes {
131 let data = node
132 .metadata
133 .get(&config.original_field)
134 .and_then(|v| v.as_str());
136
137 row.push(data);
138 }
139 batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row)));
140 }
141 FieldConfig::Chunk => {
142 let mut row = Vec::with_capacity(nodes.len());
143
144 for node in nodes {
145 let data = Some(node.chunk.as_str());
146 row.push(data);
147 }
148 batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row)));
149 }
150 FieldConfig::ID => {
151 let mut row = Vec::with_capacity(nodes.len());
152 for node in nodes {
153 let data = Some(node.id().as_bytes().map(Some));
154 row.push(data);
155 }
156 batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::<
157 UInt8Type,
158 _,
159 _,
160 >(row, 16)));
161 }
162 }
163 }
164 Ok(batches)
165 }
166}
167
168#[cfg(test)]
169mod test {
170 use swiftide_core::{Persist as _, indexing::EmbeddedField};
171 use temp_dir::TempDir;
172
173 use super::*;
174
175 async fn setup() -> (TempDir, LanceDB) {
176 let tempdir = TempDir::new().unwrap();
177 let lancedb = LanceDB::builder()
178 .uri(tempdir.child("lancedb").to_str().unwrap())
179 .vector_size(384)
180 .with_metadata("filter")
181 .with_vector(EmbeddedField::Combined)
182 .table_name("swiftide_test")
183 .build()
184 .unwrap();
185 lancedb.setup().await.unwrap();
186
187 (tempdir, lancedb)
188 }
189
190 #[tokio::test]
191 async fn test_no_error_when_table_exists() {
192 let (_guard, lancedb) = setup().await;
193
194 lancedb
195 .setup()
196 .await
197 .expect("Should not error if table exists");
198 }
199}