swiftide_integrations/parquet/
loader.rs

1use anyhow::{Context as _, Result};
2use arrow_array::StringArray;
3use fs_err::tokio::File;
4use futures_util::StreamExt as _;
5use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};
6use swiftide_core::{
7    Loader,
8    indexing::{IndexingStream, TextNode},
9};
10use tokio::runtime::Handle;
11
12use super::Parquet;
13
14impl Loader for Parquet {
15    type Output = String;
16
17    fn into_stream(self) -> IndexingStream<String> {
18        let mut builder = tokio::task::block_in_place(|| {
19            Handle::current().block_on(async {
20                let file = File::open(self.path).await.expect("Failed to open file");
21
22                ParquetRecordBatchStreamBuilder::new(file)
23                    .await
24                    .context("Failed to load builder")
25                    .unwrap()
26                    .with_batch_size(self.batch_size)
27            })
28        });
29
30        let file_metadata = builder.metadata().file_metadata().clone();
31        dbg!(file_metadata.schema_descr().columns());
32        let column_idx = file_metadata
33            .schema()
34            .get_fields()
35            .iter()
36            .enumerate()
37            .find_map(|(pos, column)| {
38                if self.column_name == column.name() {
39                    Some(pos)
40                } else {
41                    None
42                }
43            })
44            .unwrap_or_else(|| panic!("Column {} not found in dataset", &self.column_name));
45
46        let mask = ProjectionMask::roots(file_metadata.schema_descr(), [column_idx]);
47        builder = builder.with_projection(mask);
48
49        let stream = builder.build().expect("Failed to build parquet builder");
50
51        let swiftide_stream = stream.flat_map_unordered(None, move |result_batch| {
52            let Ok(batch) = result_batch else {
53                let new_result: Result<TextNode> = Err(anyhow::anyhow!(result_batch.unwrap_err()));
54
55                return vec![new_result].into();
56            };
57            assert!(batch.num_columns() == 1, "Number of columns _must_ be 1");
58
59            let node_values = batch
60                .column(0) // Should only have one column at this point
61                .as_any()
62                .downcast_ref::<StringArray>()
63                .unwrap()
64                .into_iter()
65                .flatten()
66                .map(TextNode::from)
67                .map(Ok)
68                .collect::<Vec<_>>();
69
70            IndexingStream::iter(node_values)
71        });
72
73        swiftide_stream.boxed().into()
74
75        // let mask = ProjectionMask::
76    }
77
78    fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String> {
79        self.into_stream()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use std::path::PathBuf;
86
87    use futures_util::TryStreamExt as _;
88
89    use super::*;
90
91    #[test_log::test(tokio::test(flavor = "multi_thread"))]
92    async fn test_parquet_loader() {
93        let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
94        path.push("src/parquet/test.parquet");
95        dbg!(&path);
96
97        let loader = Parquet::builder()
98            .path(path)
99            .column_name("chunk")
100            .build()
101            .unwrap();
102
103        let result = loader.into_stream().try_collect::<Vec<_>>().await.unwrap();
104
105        let expected = [TextNode::new("hello"), TextNode::new("world")];
106        assert_eq!(result, expected);
107    }
108}