Skip to main content

sh_layer3/document_loaders/
csv.rs

1//! # CSV Document Loader
2//!
3//! CSV 文件加载器。
4
5use crate::document_loaders::{DocumentLoader, LoadOptions};
6use crate::retriever_engine::Document;
7use crate::types::Layer3Result;
8use async_trait::async_trait;
9use std::path::PathBuf;
10
11/// CSV Loader 实现
12#[allow(dead_code)]
13pub struct CsvLoader {
14    #[allow(dead_code)]
15    options: LoadOptions,
16    /// 分隔符
17    delimiter: char,
18    /// 是否有表头
19    has_header: bool,
20}
21
22impl CsvLoader {
23    pub fn new() -> Self {
24        Self {
25            options: LoadOptions::default(),
26            delimiter: ',',
27            has_header: true,
28        }
29    }
30
31    pub fn with_delimiter(delimiter: char) -> Self {
32        Self {
33            options: LoadOptions::default(),
34            delimiter,
35            has_header: true,
36        }
37    }
38}
39
40impl Default for CsvLoader {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[async_trait]
47impl DocumentLoader for CsvLoader {
48    async fn load(&self, path: PathBuf) -> Layer3Result<Document> {
49        let content = tokio::fs::read_to_string(&path).await?;
50        Ok(Document::new(content).with_source(path.to_string_lossy().to_string()))
51    }
52
53    async fn load_and_split(&self, path: PathBuf) -> Layer3Result<Vec<Document>> {
54        let content = tokio::fs::read_to_string(&path).await?;
55        let lines: Vec<&str> = content.lines().collect();
56
57        if lines.is_empty() {
58            return Ok(Vec::new());
59        }
60
61        // 解析表头(如果有)
62        let header_line = if self.has_header { lines[0] } else { "" };
63        let headers: Vec<&str> = header_line.split(self.delimiter).collect();
64
65        let start_idx = if self.has_header { 1 } else { 0 };
66        let mut documents = Vec::new();
67
68        for (i, line) in lines.iter().enumerate().skip(start_idx) {
69            if line.trim().is_empty() {
70                continue;
71            }
72
73            let values: Vec<&str> = line.split(self.delimiter).collect();
74            let mut content_parts = Vec::new();
75
76            // 如果有表头,使用键值对格式
77            for (j, value) in values.iter().enumerate() {
78                if j < headers.len() {
79                    content_parts.push(format!("{}: {}", headers[j], value));
80                } else {
81                    content_parts.push(value.to_string());
82                }
83            }
84
85            documents.push(Document::new(content_parts.join(", ")).with_source(format!(
86                "{}#{}",
87                path.to_string_lossy(),
88                i
89            )));
90        }
91
92        Ok(documents)
93    }
94
95    fn supports(&self, path: &std::path::Path) -> bool {
96        path.extension()
97            .and_then(|e| e.to_str())
98            .map(|e| e == "csv" || e == "tsv")
99            .unwrap_or(false)
100    }
101
102    fn extensions(&self) -> &[&str] {
103        &["csv", "tsv"]
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_csv_loader_extensions() {
113        let loader = CsvLoader::new();
114        assert!(loader.extensions().contains(&"csv"));
115    }
116}