term_guard/sources/
csv.rs

1//! CSV file source implementation.
2
3use super::{CompressionType, DataSource};
4use crate::prelude::*;
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::Schema;
7use datafusion::datasource::file_format::csv::CsvFormat;
8use datafusion::datasource::listing::{
9    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
10};
11use datafusion::prelude::*;
12use std::sync::Arc;
13use tracing::{debug, info, instrument};
14
15/// Options for configuring CSV file reading.
16#[derive(Debug, Clone)]
17pub struct CsvOptions {
18    /// Whether the CSV file has a header row
19    pub has_header: bool,
20    /// Field delimiter (default: ',')
21    pub delimiter: u8,
22    /// Quote character (default: '"')
23    pub quote: u8,
24    /// Escape character (default: None)
25    pub escape: Option<u8>,
26    /// Comment prefix (lines starting with this are ignored)
27    pub comment: Option<u8>,
28    /// Schema to use (if None, will be inferred)
29    pub schema: Option<Arc<Schema>>,
30    /// Compression type (default: Auto)
31    pub compression: CompressionType,
32    /// Maximum records to read for schema inference
33    pub schema_infer_max_records: usize,
34}
35
36impl Default for CsvOptions {
37    fn default() -> Self {
38        Self {
39            has_header: true,
40            delimiter: b',',
41            quote: b'"',
42            escape: None,
43            comment: None,
44            schema: None,
45            compression: CompressionType::Auto,
46            schema_infer_max_records: 1000,
47        }
48    }
49}
50
51/// A CSV file data source with schema inference and compression support.
52///
53/// # Examples
54///
55/// ```rust,ignore
56/// use term_guard::sources::{CsvSource, CsvOptions};
57///
58/// # async fn example() -> Result<()> {
59/// // Simple CSV file
60/// let source = CsvSource::new("data/users.csv")?;
61///
62/// // CSV with custom options
63/// let options = CsvOptions {
64///     delimiter: b'\t',
65///     has_header: false,
66///     ..Default::default()
67/// };
68/// let source = CsvSource::with_options("data/users.tsv", options)?;
69///
70/// // Multiple files with glob pattern
71/// let source = CsvSource::from_glob("data/*.csv")?;
72/// # Ok(())
73/// # }
74/// ```
75#[derive(Debug, Clone)]
76pub struct CsvSource {
77    paths: Vec<String>,
78    options: CsvOptions,
79    inferred_schema: Option<Arc<Schema>>,
80}
81
82impl CsvSource {
83    /// Creates a new CSV source from a single file path.
84    pub fn new(path: impl Into<String>) -> Result<Self> {
85        Ok(Self {
86            paths: vec![path.into()],
87            options: CsvOptions::default(),
88            inferred_schema: None,
89        })
90    }
91
92    /// Creates a new CSV source with custom options.
93    pub fn with_options(path: impl Into<String>, options: CsvOptions) -> Result<Self> {
94        Ok(Self {
95            paths: vec![path.into()],
96            options,
97            inferred_schema: None,
98        })
99    }
100
101    /// Creates a CSV source from multiple file paths.
102    pub fn from_paths(paths: Vec<String>) -> Result<Self> {
103        if paths.is_empty() {
104            return Err(TermError::Configuration(
105                "At least one path must be provided".to_string(),
106            ));
107        }
108        Ok(Self {
109            paths,
110            options: CsvOptions::default(),
111            inferred_schema: None,
112        })
113    }
114
115    /// Creates a CSV source from a glob pattern.
116    pub async fn from_glob(pattern: impl Into<String>) -> Result<Self> {
117        let patterns = vec![pattern.into()];
118        let paths = super::expand_globs(&patterns).await?;
119        Self::from_paths(paths)
120    }
121
122    /// Creates a CSV source from multiple glob patterns.
123    pub async fn from_globs(patterns: Vec<String>) -> Result<Self> {
124        let paths = super::expand_globs(&patterns).await?;
125        Self::from_paths(paths)
126    }
127
128    /// Sets custom options for this CSV source.
129    pub fn with_custom_options(mut self, options: CsvOptions) -> Self {
130        self.options = options;
131        self
132    }
133
134    /// Infers schema from the CSV files.
135    #[instrument(skip(self))]
136    #[allow(dead_code)]
137    async fn infer_schema(&mut self) -> Result<Arc<Schema>> {
138        if let Some(schema) = &self.options.schema {
139            return Ok(schema.clone());
140        }
141
142        if let Some(schema) = &self.inferred_schema {
143            return Ok(schema.clone());
144        }
145
146        // Use DataFusion's CSV format for schema inference
147        let _format = CsvFormat::default()
148            .with_has_header(self.options.has_header)
149            .with_delimiter(self.options.delimiter)
150            .with_quote(self.options.quote)
151            .with_escape(self.options.escape)
152            .with_comment(self.options.comment)
153            .with_schema_infer_max_rec(self.options.schema_infer_max_records);
154
155        // Create a temporary context for schema inference
156        let ctx = SessionContext::new();
157
158        // For schema inference, we can use the first file path
159        let first_path = &self.paths[0];
160        let schema = if first_path.ends_with(".csv") {
161            // For .csv files, use read_csv
162            let csv_options = CsvReadOptions::new()
163                .has_header(self.options.has_header)
164                .delimiter(self.options.delimiter)
165                .schema_infer_max_records(self.options.schema_infer_max_records);
166
167            let df = ctx.read_csv(first_path, csv_options).await?;
168            df.schema().inner().clone()
169        } else {
170            // For non-.csv files (e.g., .tsv), use ListingTable
171            let first_path_obj = std::path::Path::new(first_path);
172            let dir_path = first_path_obj
173                .parent()
174                .ok_or_else(|| TermError::Configuration("Invalid file path".to_string()))?;
175            let dir_path_str = dir_path.to_str().ok_or_else(|| {
176                TermError::Configuration("Path contains invalid UTF-8".to_string())
177            })?;
178            let table_path = ListingTableUrl::parse(dir_path_str)?;
179
180            let extension = if first_path.ends_with(".tsv") {
181                ".tsv"
182            } else if first_path.ends_with(".txt") {
183                ".txt"
184            } else {
185                ".csv"
186            };
187
188            let format = CsvFormat::default()
189                .with_has_header(self.options.has_header)
190                .with_delimiter(self.options.delimiter)
191                .with_quote(self.options.quote)
192                .with_escape(self.options.escape)
193                .with_comment(self.options.comment)
194                .with_schema_infer_max_rec(self.options.schema_infer_max_records);
195
196            let listing_options =
197                ListingOptions::new(Arc::new(format)).with_file_extension(extension);
198
199            let config = ListingTableConfig::new(table_path).with_listing_options(listing_options);
200            let table = ListingTable::try_new(config)?;
201
202            // Register temporarily to infer schema
203            let timestamp = std::time::SystemTime::now()
204                .duration_since(std::time::UNIX_EPOCH)
205                .map_err(|e| TermError::Internal(format!("Failed to get system time: {e}")))?
206                .as_nanos();
207            let temp_table_name = format!("_temp_schema_inference_{timestamp}");
208            ctx.register_table(&temp_table_name, Arc::new(table))?;
209
210            let df = ctx.table(&temp_table_name).await?;
211            let schema = df.schema().inner().clone();
212
213            // Deregister the temporary table
214            ctx.deregister_table(&temp_table_name)?;
215
216            schema
217        };
218
219        self.inferred_schema = Some(schema.clone());
220        Ok(schema)
221    }
222}
223
224#[async_trait]
225impl DataSource for CsvSource {
226    #[instrument(skip(self, ctx, telemetry), fields(
227        table.name = %table_name,
228        source.type = "csv",
229        source.files = self.paths.len(),
230        csv.delimiter = %self.options.delimiter as char,
231        csv.has_header = self.options.has_header
232    ))]
233    async fn register_with_telemetry(
234        &self,
235        ctx: &SessionContext,
236        table_name: &str,
237        telemetry: Option<&Arc<TermTelemetry>>,
238    ) -> Result<()> {
239        info!(
240            table.name = %table_name,
241            source.type = "csv",
242            source.paths = ?self.paths,
243            csv.delimiter = %self.options.delimiter as char,
244            csv.has_header = self.options.has_header,
245            csv.compression = ?self.options.compression,
246            "Registering CSV data source"
247        );
248
249        // Create telemetry span for data source loading
250        let mut _datasource_span = if let Some(tel) = telemetry {
251            tel.start_datasource_span("csv", table_name)
252        } else {
253            TermSpan::noop()
254        };
255        // Create CSV format configuration
256        let mut format = CsvFormat::default()
257            .with_has_header(self.options.has_header)
258            .with_delimiter(self.options.delimiter)
259            .with_quote(self.options.quote);
260
261        if let Some(escape) = self.options.escape {
262            format = format.with_escape(Some(escape));
263        }
264        if let Some(comment) = self.options.comment {
265            format = format.with_comment(Some(comment));
266        }
267
268        // Handle single vs multiple paths
269        if self.paths.len() == 1 {
270            let path = &self.paths[0];
271
272            // For single files ending with .csv, use register_csv for simplicity
273            if path.ends_with(".csv") {
274                let mut csv_options = CsvReadOptions::new()
275                    .has_header(self.options.has_header)
276                    .delimiter(self.options.delimiter)
277                    .quote(self.options.quote)
278                    .schema_infer_max_records(self.options.schema_infer_max_records);
279
280                if let Some(escape) = self.options.escape {
281                    csv_options = csv_options.escape(escape);
282                }
283                if let Some(comment) = self.options.comment {
284                    csv_options = csv_options.comment(comment);
285                }
286                if let Some(schema) = &self.options.schema {
287                    csv_options = csv_options.schema(schema);
288                }
289
290                ctx.register_csv(table_name, path, csv_options).await?;
291            } else {
292                // Single non-.csv file (like .tsv) - use ListingTable with specific file
293                let table_path = ListingTableUrl::parse(path)?;
294
295                // Determine the file extension
296                let extension = if path.ends_with(".tsv") {
297                    ".tsv".to_string()
298                } else if path.ends_with(".txt") {
299                    ".txt".to_string()
300                } else {
301                    // Get the extension from the path
302                    std::path::Path::new(path)
303                        .extension()
304                        .and_then(|ext| ext.to_str())
305                        .map(|ext| format!(".{ext}"))
306                        .unwrap_or_else(|| ".csv".to_string())
307                };
308
309                let listing_options =
310                    ListingOptions::new(Arc::new(format)).with_file_extension(&extension);
311
312                // Infer schema if not provided
313                let config = if let Some(schema) = &self.options.schema {
314                    ListingTableConfig::new(table_path)
315                        .with_listing_options(listing_options)
316                        .with_schema(schema.clone())
317                } else {
318                    ListingTableConfig::new(table_path)
319                        .with_listing_options(listing_options)
320                        .infer_schema(&ctx.state())
321                        .await?
322                };
323
324                let table = ListingTable::try_new(config)?;
325                ctx.register_table(table_name, Arc::new(table))?;
326            }
327        } else {
328            // Multiple files - use ListingTable
329            // For multiple files, we need to use the directory path, not a specific file
330            let first_path = std::path::Path::new(&self.paths[0]);
331            let dir_path = first_path
332                .parent()
333                .ok_or_else(|| TermError::Configuration("Invalid file path".to_string()))?;
334            let dir_path_str = dir_path.to_str().ok_or_else(|| {
335                TermError::Configuration("Path contains invalid UTF-8".to_string())
336            })?;
337            let table_path = ListingTableUrl::parse(dir_path_str)?;
338
339            // Determine the file extension from the actual files
340            let extension = if self.paths[0].ends_with(".tsv") {
341                ".tsv"
342            } else if self.paths[0].ends_with(".txt") {
343                ".txt"
344            } else {
345                ".csv"
346            };
347
348            let listing_options =
349                ListingOptions::new(Arc::new(format)).with_file_extension(extension);
350
351            // Infer schema if not provided
352            let schema = if let Some(schema) = &self.options.schema {
353                schema.clone()
354            } else {
355                // Infer schema using a mutable clone
356                let mut source_clone = self.clone();
357                source_clone.infer_schema().await?
358            };
359
360            let config = ListingTableConfig::new(table_path)
361                .with_listing_options(listing_options)
362                .with_schema(schema);
363
364            let table = ListingTable::try_new(config)?;
365            ctx.register_table(table_name, Arc::new(table))?;
366        }
367
368        debug!(
369            table.name = %table_name,
370            source.type = "csv",
371            source.files = self.paths.len(),
372            "CSV data source registered successfully"
373        );
374
375        Ok(())
376    }
377
378    fn schema(&self) -> Option<&Arc<Schema>> {
379        self.options
380            .schema
381            .as_ref()
382            .or(self.inferred_schema.as_ref())
383    }
384
385    fn description(&self) -> String {
386        if self.paths.len() == 1 {
387            let path = &self.paths[0];
388            format!("CSV file: {path}")
389        } else {
390            let count = self.paths.len();
391            format!("CSV files: {count} files")
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use std::io::Write;
400    use tempfile::NamedTempFile;
401
402    async fn create_test_csv() -> NamedTempFile {
403        let mut file = NamedTempFile::with_suffix(".csv").unwrap();
404        writeln!(file, "id,name,age").unwrap();
405        writeln!(file, "1,Alice,30").unwrap();
406        writeln!(file, "2,Bob,25").unwrap();
407        writeln!(file, "3,Charlie,35").unwrap();
408        file.flush().unwrap();
409        file
410    }
411
412    #[tokio::test]
413    async fn test_csv_source_single_file() {
414        let file = create_test_csv().await;
415        let source = CsvSource::new(file.path().to_str().unwrap()).unwrap();
416
417        assert_eq!(source.paths.len(), 1);
418        assert!(source.description().contains("CSV file"));
419    }
420
421    #[tokio::test]
422    async fn test_csv_source_with_options() {
423        let file = create_test_csv().await;
424        let options = CsvOptions {
425            delimiter: b'\t',
426            has_header: false,
427            ..Default::default()
428        };
429
430        let source = CsvSource::with_options(file.path().to_str().unwrap(), options).unwrap();
431        assert_eq!(source.options.delimiter, b'\t');
432        assert!(!source.options.has_header);
433    }
434
435    #[tokio::test]
436    async fn test_csv_source_multiple_files() {
437        let file1 = create_test_csv().await;
438        let file2 = create_test_csv().await;
439
440        let paths = vec![
441            file1.path().to_str().unwrap().to_string(),
442            file2.path().to_str().unwrap().to_string(),
443        ];
444
445        let source = CsvSource::from_paths(paths).unwrap();
446        assert_eq!(source.paths.len(), 2);
447        assert!(source.description().contains("2 files"));
448    }
449
450    #[tokio::test]
451    async fn test_csv_source_empty_paths() {
452        let result = CsvSource::from_paths(vec![]);
453        assert!(result.is_err());
454    }
455
456    #[tokio::test]
457    async fn test_csv_registration() {
458        let file = create_test_csv().await;
459        let source = CsvSource::new(file.path().to_str().unwrap()).unwrap();
460
461        let ctx = SessionContext::new();
462        source.register(&ctx, "test_table").await.unwrap();
463
464        // Verify table is registered
465        let df = ctx
466            .sql("SELECT COUNT(*) as count FROM test_table")
467            .await
468            .unwrap();
469        let batches = df.collect().await.unwrap();
470        assert!(!batches.is_empty());
471    }
472}