1use super::{CompressionType, DataSource};
4use crate::prelude::*;
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::Schema;
7use datafusion::datasource::file_format::csv::CsvFormat;
8use datafusion::datasource::listing::{
9 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
10};
11use datafusion::prelude::*;
12use std::sync::Arc;
13use tracing::{debug, info, instrument};
14
15#[derive(Debug, Clone)]
17pub struct CsvOptions {
18 pub has_header: bool,
20 pub delimiter: u8,
22 pub quote: u8,
24 pub escape: Option<u8>,
26 pub comment: Option<u8>,
28 pub schema: Option<Arc<Schema>>,
30 pub compression: CompressionType,
32 pub schema_infer_max_records: usize,
34}
35
36impl Default for CsvOptions {
37 fn default() -> Self {
38 Self {
39 has_header: true,
40 delimiter: b',',
41 quote: b'"',
42 escape: None,
43 comment: None,
44 schema: None,
45 compression: CompressionType::Auto,
46 schema_infer_max_records: 1000,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
76pub struct CsvSource {
77 paths: Vec<String>,
78 options: CsvOptions,
79 inferred_schema: Option<Arc<Schema>>,
80}
81
82impl CsvSource {
83 pub fn new(path: impl Into<String>) -> Result<Self> {
85 Ok(Self {
86 paths: vec![path.into()],
87 options: CsvOptions::default(),
88 inferred_schema: None,
89 })
90 }
91
92 pub fn with_options(path: impl Into<String>, options: CsvOptions) -> Result<Self> {
94 Ok(Self {
95 paths: vec![path.into()],
96 options,
97 inferred_schema: None,
98 })
99 }
100
101 pub fn from_paths(paths: Vec<String>) -> Result<Self> {
103 if paths.is_empty() {
104 return Err(TermError::Configuration(
105 "At least one path must be provided".to_string(),
106 ));
107 }
108 Ok(Self {
109 paths,
110 options: CsvOptions::default(),
111 inferred_schema: None,
112 })
113 }
114
115 pub async fn from_glob(pattern: impl Into<String>) -> Result<Self> {
117 let patterns = vec![pattern.into()];
118 let paths = super::expand_globs(&patterns).await?;
119 Self::from_paths(paths)
120 }
121
122 pub async fn from_globs(patterns: Vec<String>) -> Result<Self> {
124 let paths = super::expand_globs(&patterns).await?;
125 Self::from_paths(paths)
126 }
127
128 pub fn with_custom_options(mut self, options: CsvOptions) -> Self {
130 self.options = options;
131 self
132 }
133
134 #[instrument(skip(self))]
136 #[allow(dead_code)]
137 async fn infer_schema(&mut self) -> Result<Arc<Schema>> {
138 if let Some(schema) = &self.options.schema {
139 return Ok(schema.clone());
140 }
141
142 if let Some(schema) = &self.inferred_schema {
143 return Ok(schema.clone());
144 }
145
146 let _format = CsvFormat::default()
148 .with_has_header(self.options.has_header)
149 .with_delimiter(self.options.delimiter)
150 .with_quote(self.options.quote)
151 .with_escape(self.options.escape)
152 .with_comment(self.options.comment)
153 .with_schema_infer_max_rec(self.options.schema_infer_max_records);
154
155 let ctx = SessionContext::new();
157
158 let first_path = &self.paths[0];
160 let schema = if first_path.ends_with(".csv") {
161 let csv_options = CsvReadOptions::new()
163 .has_header(self.options.has_header)
164 .delimiter(self.options.delimiter)
165 .schema_infer_max_records(self.options.schema_infer_max_records);
166
167 let df = ctx.read_csv(first_path, csv_options).await?;
168 df.schema().inner().clone()
169 } else {
170 let first_path_obj = std::path::Path::new(first_path);
172 let dir_path = first_path_obj
173 .parent()
174 .ok_or_else(|| TermError::Configuration("Invalid file path".to_string()))?;
175 let dir_path_str = dir_path.to_str().ok_or_else(|| {
176 TermError::Configuration("Path contains invalid UTF-8".to_string())
177 })?;
178 let table_path = ListingTableUrl::parse(dir_path_str)?;
179
180 let extension = if first_path.ends_with(".tsv") {
181 ".tsv"
182 } else if first_path.ends_with(".txt") {
183 ".txt"
184 } else {
185 ".csv"
186 };
187
188 let format = CsvFormat::default()
189 .with_has_header(self.options.has_header)
190 .with_delimiter(self.options.delimiter)
191 .with_quote(self.options.quote)
192 .with_escape(self.options.escape)
193 .with_comment(self.options.comment)
194 .with_schema_infer_max_rec(self.options.schema_infer_max_records);
195
196 let listing_options =
197 ListingOptions::new(Arc::new(format)).with_file_extension(extension);
198
199 let config = ListingTableConfig::new(table_path).with_listing_options(listing_options);
200 let table = ListingTable::try_new(config)?;
201
202 let timestamp = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .map_err(|e| TermError::Internal(format!("Failed to get system time: {e}")))?
206 .as_nanos();
207 let temp_table_name = format!("_temp_schema_inference_{timestamp}");
208 ctx.register_table(&temp_table_name, Arc::new(table))?;
209
210 let df = ctx.table(&temp_table_name).await?;
211 let schema = df.schema().inner().clone();
212
213 ctx.deregister_table(&temp_table_name)?;
215
216 schema
217 };
218
219 self.inferred_schema = Some(schema.clone());
220 Ok(schema)
221 }
222}
223
224#[async_trait]
225impl DataSource for CsvSource {
226 #[instrument(skip(self, ctx, telemetry), fields(
227 table.name = %table_name,
228 source.type = "csv",
229 source.files = self.paths.len(),
230 csv.delimiter = %self.options.delimiter as char,
231 csv.has_header = self.options.has_header
232 ))]
233 async fn register_with_telemetry(
234 &self,
235 ctx: &SessionContext,
236 table_name: &str,
237 telemetry: Option<&Arc<TermTelemetry>>,
238 ) -> Result<()> {
239 info!(
240 table.name = %table_name,
241 source.type = "csv",
242 source.paths = ?self.paths,
243 csv.delimiter = %self.options.delimiter as char,
244 csv.has_header = self.options.has_header,
245 csv.compression = ?self.options.compression,
246 "Registering CSV data source"
247 );
248
249 let mut _datasource_span = if let Some(tel) = telemetry {
251 tel.start_datasource_span("csv", table_name)
252 } else {
253 TermSpan::noop()
254 };
255 let mut format = CsvFormat::default()
257 .with_has_header(self.options.has_header)
258 .with_delimiter(self.options.delimiter)
259 .with_quote(self.options.quote);
260
261 if let Some(escape) = self.options.escape {
262 format = format.with_escape(Some(escape));
263 }
264 if let Some(comment) = self.options.comment {
265 format = format.with_comment(Some(comment));
266 }
267
268 if self.paths.len() == 1 {
270 let path = &self.paths[0];
271
272 if path.ends_with(".csv") {
274 let mut csv_options = CsvReadOptions::new()
275 .has_header(self.options.has_header)
276 .delimiter(self.options.delimiter)
277 .quote(self.options.quote)
278 .schema_infer_max_records(self.options.schema_infer_max_records);
279
280 if let Some(escape) = self.options.escape {
281 csv_options = csv_options.escape(escape);
282 }
283 if let Some(comment) = self.options.comment {
284 csv_options = csv_options.comment(comment);
285 }
286 if let Some(schema) = &self.options.schema {
287 csv_options = csv_options.schema(schema);
288 }
289
290 ctx.register_csv(table_name, path, csv_options).await?;
291 } else {
292 let table_path = ListingTableUrl::parse(path)?;
294
295 let extension = if path.ends_with(".tsv") {
297 ".tsv".to_string()
298 } else if path.ends_with(".txt") {
299 ".txt".to_string()
300 } else {
301 std::path::Path::new(path)
303 .extension()
304 .and_then(|ext| ext.to_str())
305 .map(|ext| format!(".{ext}"))
306 .unwrap_or_else(|| ".csv".to_string())
307 };
308
309 let listing_options =
310 ListingOptions::new(Arc::new(format)).with_file_extension(&extension);
311
312 let config = if let Some(schema) = &self.options.schema {
314 ListingTableConfig::new(table_path)
315 .with_listing_options(listing_options)
316 .with_schema(schema.clone())
317 } else {
318 ListingTableConfig::new(table_path)
319 .with_listing_options(listing_options)
320 .infer_schema(&ctx.state())
321 .await?
322 };
323
324 let table = ListingTable::try_new(config)?;
325 ctx.register_table(table_name, Arc::new(table))?;
326 }
327 } else {
328 let first_path = std::path::Path::new(&self.paths[0]);
331 let dir_path = first_path
332 .parent()
333 .ok_or_else(|| TermError::Configuration("Invalid file path".to_string()))?;
334 let dir_path_str = dir_path.to_str().ok_or_else(|| {
335 TermError::Configuration("Path contains invalid UTF-8".to_string())
336 })?;
337 let table_path = ListingTableUrl::parse(dir_path_str)?;
338
339 let extension = if self.paths[0].ends_with(".tsv") {
341 ".tsv"
342 } else if self.paths[0].ends_with(".txt") {
343 ".txt"
344 } else {
345 ".csv"
346 };
347
348 let listing_options =
349 ListingOptions::new(Arc::new(format)).with_file_extension(extension);
350
351 let schema = if let Some(schema) = &self.options.schema {
353 schema.clone()
354 } else {
355 let mut source_clone = self.clone();
357 source_clone.infer_schema().await?
358 };
359
360 let config = ListingTableConfig::new(table_path)
361 .with_listing_options(listing_options)
362 .with_schema(schema);
363
364 let table = ListingTable::try_new(config)?;
365 ctx.register_table(table_name, Arc::new(table))?;
366 }
367
368 debug!(
369 table.name = %table_name,
370 source.type = "csv",
371 source.files = self.paths.len(),
372 "CSV data source registered successfully"
373 );
374
375 Ok(())
376 }
377
378 fn schema(&self) -> Option<&Arc<Schema>> {
379 self.options
380 .schema
381 .as_ref()
382 .or(self.inferred_schema.as_ref())
383 }
384
385 fn description(&self) -> String {
386 if self.paths.len() == 1 {
387 let path = &self.paths[0];
388 format!("CSV file: {path}")
389 } else {
390 let count = self.paths.len();
391 format!("CSV files: {count} files")
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use std::io::Write;
400 use tempfile::NamedTempFile;
401
402 async fn create_test_csv() -> NamedTempFile {
403 let mut file = NamedTempFile::with_suffix(".csv").unwrap();
404 writeln!(file, "id,name,age").unwrap();
405 writeln!(file, "1,Alice,30").unwrap();
406 writeln!(file, "2,Bob,25").unwrap();
407 writeln!(file, "3,Charlie,35").unwrap();
408 file.flush().unwrap();
409 file
410 }
411
412 #[tokio::test]
413 async fn test_csv_source_single_file() {
414 let file = create_test_csv().await;
415 let source = CsvSource::new(file.path().to_str().unwrap()).unwrap();
416
417 assert_eq!(source.paths.len(), 1);
418 assert!(source.description().contains("CSV file"));
419 }
420
421 #[tokio::test]
422 async fn test_csv_source_with_options() {
423 let file = create_test_csv().await;
424 let options = CsvOptions {
425 delimiter: b'\t',
426 has_header: false,
427 ..Default::default()
428 };
429
430 let source = CsvSource::with_options(file.path().to_str().unwrap(), options).unwrap();
431 assert_eq!(source.options.delimiter, b'\t');
432 assert!(!source.options.has_header);
433 }
434
435 #[tokio::test]
436 async fn test_csv_source_multiple_files() {
437 let file1 = create_test_csv().await;
438 let file2 = create_test_csv().await;
439
440 let paths = vec![
441 file1.path().to_str().unwrap().to_string(),
442 file2.path().to_str().unwrap().to_string(),
443 ];
444
445 let source = CsvSource::from_paths(paths).unwrap();
446 assert_eq!(source.paths.len(), 2);
447 assert!(source.description().contains("2 files"));
448 }
449
450 #[tokio::test]
451 async fn test_csv_source_empty_paths() {
452 let result = CsvSource::from_paths(vec![]);
453 assert!(result.is_err());
454 }
455
456 #[tokio::test]
457 async fn test_csv_registration() {
458 let file = create_test_csv().await;
459 let source = CsvSource::new(file.path().to_str().unwrap()).unwrap();
460
461 let ctx = SessionContext::new();
462 source.register(&ctx, "test_table").await.unwrap();
463
464 let df = ctx
466 .sql("SELECT COUNT(*) as count FROM test_table")
467 .await
468 .unwrap();
469 let batches = df.collect().await.unwrap();
470 assert!(!batches.is_empty());
471 }
472}