Skip to main content

sh_layer3/document_loaders/
pdf.rs

1//! # PDF Document Loader
2//!
3//! PDF 文件加载器:解析 PDF 文件并提取文本内容和元数据。
4//!
5//! ## 功能
6//!
7//! - 基于页面的文本提取
8//! - 元数据提取(标题、作者、创建日期等)
9//! - 支持异步加载
10//! - 完整错误处理
11
12use crate::document_loaders::{DocumentLoader, LoadOptions};
13use crate::retriever_engine::Document;
14use crate::types::Layer3Error;
15use crate::types::Layer3Result;
16use async_trait::async_trait;
17use lopdf::Document as PdfDoc;
18use std::collections::HashMap;
19use std::path::PathBuf;
20use tracing::{debug, warn};
21
22/// PDF Loader 实现
23pub struct PdfLoader {
24    #[allow(dead_code)]
25    options: LoadOptions,
26}
27
28impl PdfLoader {
29    pub fn new() -> Self {
30        Self {
31            options: LoadOptions::default(),
32        }
33    }
34
35    pub fn with_options(options: LoadOptions) -> Self {
36        Self { options }
37    }
38
39    /// 从 PDF 提取元数据
40    fn extract_metadata(&self, pdf: &PdfDoc) -> HashMap<String, serde_json::Value> {
41        let mut metadata = HashMap::new();
42
43        // Helper to extract string from PDF dictionary
44        fn get_string_from_dict(dict: &lopdf::Dictionary, key: &[u8]) -> Option<String> {
45            let obj = dict.get(key).ok()?;
46            if let lopdf::Object::String(bytes, _) = obj {
47                PdfLoader::decode_pdf_string(bytes).ok()
48            } else {
49                None
50            }
51        }
52
53        // 提取文档信息
54        if let Ok(trailer) = pdf.trailer.get(b"Info") {
55            if let Ok(info_ref) = trailer.as_reference() {
56                if let Ok(lopdf::Object::Dictionary(dict)) = pdf.get_object(info_ref) {
57                    // 标题
58                    if let Some(title) = get_string_from_dict(dict, b"Title") {
59                        metadata.insert("title".to_string(), serde_json::json!(title));
60                    }
61
62                    // 作者
63                    if let Some(author) = get_string_from_dict(dict, b"Author") {
64                        metadata.insert("author".to_string(), serde_json::json!(author));
65                    }
66
67                    // 主题
68                    if let Some(subject) = get_string_from_dict(dict, b"Subject") {
69                        metadata.insert("subject".to_string(), serde_json::json!(subject));
70                    }
71
72                    // 创建者
73                    if let Some(creator) = get_string_from_dict(dict, b"Creator") {
74                        metadata.insert("creator".to_string(), serde_json::json!(creator));
75                    }
76
77                    // 生产者
78                    if let Some(producer) = get_string_from_dict(dict, b"Producer") {
79                        metadata.insert("producer".to_string(), serde_json::json!(producer));
80                    }
81
82                    // 创建日期
83                    if let Some(creation_date) = get_string_from_dict(dict, b"CreationDate") {
84                        metadata.insert(
85                            "creation_date".to_string(),
86                            serde_json::json!(creation_date),
87                        );
88                    }
89
90                    // 修改日期
91                    if let Some(mod_date) = get_string_from_dict(dict, b"ModDate") {
92                        metadata
93                            .insert("modification_date".to_string(), serde_json::json!(mod_date));
94                    }
95                }
96            }
97        }
98
99        // 页数
100        let page_count = pdf.get_pages().len();
101        metadata.insert("page_count".to_string(), serde_json::json!(page_count));
102
103        metadata
104    }
105
106    /// 解码 PDF 字符串(处理编码)
107    fn decode_pdf_string(bytes: &[u8]) -> Layer3Result<String> {
108        // 尝试 UTF-8
109        if let Ok(s) = std::str::from_utf8(bytes) {
110            return Ok(s.to_string());
111        }
112
113        // 尝试 Latin-1 (ISO-8859-1)
114        let decoded: String = bytes.iter().map(|&b| b as char).collect();
115        Ok(decoded)
116    }
117
118    /// 从单个页面提取文本
119    fn extract_page_text(pdf: &PdfDoc, page_id: (u32, u16)) -> Layer3Result<String> {
120        let mut text = String::new();
121
122        if let Ok(lopdf::Object::Dictionary(dict)) = pdf.get_object(page_id) {
123            if let Ok(contents) = dict.get(b"Contents") {
124                match contents {
125                    lopdf::Object::Reference(ref_id) => {
126                        if let Ok(lopdf::Object::Stream(stream_obj)) = pdf.get_object(*ref_id) {
127                            if let Ok(content) = stream_obj.decompressed_content() {
128                                text.push_str(&Self::parse_content_stream(&content));
129                            }
130                        }
131                    }
132                    lopdf::Object::Array(arr) => {
133                        for obj in arr {
134                            if let lopdf::Object::Reference(ref_id) = obj {
135                                if let Ok(lopdf::Object::Stream(stream_obj)) =
136                                    pdf.get_object(*ref_id)
137                                {
138                                    if let Ok(content) = stream_obj.decompressed_content() {
139                                        text.push_str(&Self::parse_content_stream(&content));
140                                    }
141                                }
142                            }
143                        }
144                    }
145                    _ => {}
146                }
147            }
148        }
149
150        Ok(text)
151    }
152
153    /// 解析 PDF 内容流,提取文本
154    fn parse_content_stream(content: &[u8]) -> String {
155        let mut text = String::new();
156        let content_str = String::from_utf8_lossy(content);
157
158        // 简单的文本提取:查找 Tj 和 TJ 操作符
159        let mut current_text = String::new();
160        let mut in_string = false;
161        let mut escape_next = false;
162
163        for ch in content_str.chars() {
164            if escape_next {
165                current_text.push(ch);
166                escape_next = false;
167                continue;
168            }
169
170            match ch {
171                '\\' if in_string => {
172                    escape_next = true;
173                }
174                '(' => {
175                    if !in_string {
176                        in_string = true;
177                        current_text.clear();
178                    } else {
179                        current_text.push(ch);
180                    }
181                }
182                ')' => {
183                    if in_string {
184                        in_string = false;
185                        if !current_text.is_empty() {
186                            // 过滤控制字符
187                            let cleaned: String = current_text
188                                .chars()
189                                .filter(|c| {
190                                    c.is_alphabetic()
191                                        || c.is_numeric()
192                                        || c.is_whitespace()
193                                        || *c == '-'
194                                        || *c == '.'
195                                        || *c == ','
196                                })
197                                .collect();
198                            if !cleaned.trim().is_empty() {
199                                text.push_str(&cleaned);
200                                text.push(' ');
201                            }
202                        }
203                    } else {
204                        current_text.push(ch);
205                    }
206                }
207                _ => {
208                    if in_string {
209                        current_text.push(ch);
210                    }
211                }
212            }
213        }
214
215        // 清理多余空格
216        let cleaned: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
217
218        cleaned
219    }
220
221    /// 从 PDF 提取所有页面文本
222    fn extract_all_text(&self, pdf: &PdfDoc) -> Layer3Result<Vec<(usize, String)>> {
223        let pages = pdf.get_pages();
224        let mut result = Vec::new();
225
226        for (page_num, page_id) in pages.iter() {
227            match Self::extract_page_text(pdf, *page_id) {
228                Ok(page_text) => {
229                    if !page_text.trim().is_empty() {
230                        result.push((*page_num as usize, page_text));
231                    }
232                }
233                Err(e) => {
234                    warn!("Failed to extract text from page {}: {}", page_num, e);
235                }
236            }
237        }
238
239        Ok(result)
240    }
241}
242
243impl Default for PdfLoader {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249#[async_trait]
250impl DocumentLoader for PdfLoader {
251    async fn load(&self, path: PathBuf) -> Layer3Result<Document> {
252        debug!("Loading PDF file: {:?}", path);
253
254        // 读取 PDF 文件
255        let pdf = PdfDoc::load(&path).map_err(|e| {
256            Layer3Error::PersistenceError(format!(
257                "Failed to load PDF file '{}': {}",
258                path.display(),
259                e
260            ))
261        })?;
262
263        // 提取元数据
264        let metadata = self.extract_metadata(&pdf);
265
266        // 提取所有文本
267        let pages = self.extract_all_text(&pdf)?;
268        let full_text: String = pages
269            .iter()
270            .map(|(_, text)| text.as_str())
271            .collect::<Vec<_>>()
272            .join("\n\n");
273
274        // 创建文档
275        let mut doc = Document::new(full_text).with_source(path.to_string_lossy().to_string());
276
277        // 添加元数据
278        doc.metadata = metadata;
279
280        Ok(doc)
281    }
282
283    async fn load_and_split(&self, path: PathBuf) -> Layer3Result<Vec<Document>> {
284        debug!("Loading and splitting PDF file: {:?}", path);
285
286        // 读取 PDF 文件
287        let pdf = PdfDoc::load(&path).map_err(|e| {
288            Layer3Error::PersistenceError(format!(
289                "Failed to load PDF file '{}': {}",
290                path.display(),
291                e
292            ))
293        })?;
294
295        // 提取元数据
296        let base_metadata = self.extract_metadata(&pdf);
297
298        // 提取所有页面文本
299        let pages = self.extract_all_text(&pdf)?;
300
301        if pages.is_empty() {
302            // 如果没有提取到文本,返回空数组
303            return Ok(Vec::new());
304        }
305
306        // 为每个页面创建一个文档
307        let source = path.to_string_lossy().to_string();
308        let documents: Vec<Document> = pages
309            .into_iter()
310            .map(|(page_num, text)| {
311                let mut metadata = base_metadata.clone();
312                metadata.insert("page".to_string(), serde_json::json!(page_num));
313                metadata.insert(
314                    "total_pages".to_string(),
315                    serde_json::json!(pdf.get_pages().len()),
316                );
317
318                Document {
319                    id: None,
320                    content: text,
321                    metadata,
322                    source: Some(source.clone()),
323                }
324            })
325            .collect();
326
327        Ok(documents)
328    }
329
330    fn supports(&self, path: &std::path::Path) -> bool {
331        path.extension()
332            .and_then(|e| e.to_str())
333            .map(|e| e.to_lowercase() == "pdf")
334            .unwrap_or(false)
335    }
336
337    fn extensions(&self) -> &[&str] {
338        &["pdf"]
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use lopdf::Dictionary;
346    use lopdf::Object as PdfObject;
347    use lopdf::Stream;
348    use tempfile::NamedTempFile;
349
350    fn create_minimal_pdf() -> NamedTempFile {
351        // Create a PDF programmatically using lopdf
352        let mut pdf = lopdf::Document::new();
353
354        // Object 1: Catalog
355        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
356            ("Type", PdfObject::Name("Catalog".as_bytes().to_vec())),
357            ("Pages", PdfObject::Reference((2, 0))),
358        ])));
359
360        // Object 2: Pages
361        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
362            ("Type", PdfObject::Name("Pages".as_bytes().to_vec())),
363            ("Kids", PdfObject::Array(vec![PdfObject::Reference((3, 0))])),
364            ("Count", PdfObject::Integer(1)),
365        ])));
366
367        // Object 3: Page
368        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
369            ("Type", PdfObject::Name("Page".as_bytes().to_vec())),
370            ("Parent", PdfObject::Reference((2, 0))),
371            (
372                "MediaBox",
373                PdfObject::Array(vec![
374                    PdfObject::Integer(0),
375                    PdfObject::Integer(0),
376                    PdfObject::Integer(612),
377                    PdfObject::Integer(792),
378                ]),
379            ),
380            ("Contents", PdfObject::Reference((4, 0))),
381        ])));
382
383        // Object 4: Content stream with text
384        let content = b"BT /F1 12 Tf 100 700 Td (Hello World) Tj ET";
385        pdf.add_object(PdfObject::Stream(Stream::new(
386            Dictionary::from_iter([("Length", PdfObject::Integer(content.len() as i64))]),
387            content.to_vec(),
388        )));
389
390        // Save to temp file
391        let file = NamedTempFile::with_suffix(".pdf").unwrap();
392        pdf.save(file.path()).expect("Failed to save PDF");
393        file
394    }
395
396    #[test]
397    fn test_pdf_loader_extensions() {
398        let loader = PdfLoader::new();
399        assert!(loader.extensions().contains(&"pdf"));
400    }
401
402    #[test]
403    fn test_pdf_loader_supports() {
404        let loader = PdfLoader::new();
405        assert!(loader.supports(std::path::Path::new("test.pdf")));
406        assert!(loader.supports(std::path::Path::new("test.PDF")));
407        assert!(!loader.supports(std::path::Path::new("test.txt")));
408    }
409
410    #[tokio::test]
411    async fn test_pdf_loader_load() {
412        let loader = PdfLoader::new();
413        let pdf_file = create_minimal_pdf();
414
415        let result = loader.load(pdf_file.path().to_path_buf()).await;
416        if let Err(ref err) = result {
417            eprintln!("Error loading PDF: {:?}", err);
418        }
419        assert!(result.is_ok(), "PDF should load successfully");
420
421        let doc = result.unwrap();
422        assert!(doc.source.is_some());
423        // Should have page_count metadata even if no text extracted
424        assert!(doc.metadata.contains_key("page_count"));
425    }
426
427    #[tokio::test]
428    async fn test_pdf_loader_load_and_split() {
429        let loader = PdfLoader::new();
430        let pdf_file = create_minimal_pdf();
431
432        let result = loader.load_and_split(pdf_file.path().to_path_buf()).await;
433        if let Err(ref err) = result {
434            eprintln!("Error loading PDF: {:?}", err);
435        }
436        assert!(result.is_ok(), "PDF should load successfully");
437
438        let docs = result.unwrap();
439        // Minimal PDF has 1 page, so should return 1 doc (or 0 if no text)
440        assert!(docs.len() <= 1);
441    }
442
443    #[test]
444    fn test_decode_pdf_string_utf8() {
445        let bytes = b"Hello World";
446        let result = PdfLoader::decode_pdf_string(bytes);
447        assert!(result.is_ok());
448        assert_eq!(result.unwrap(), "Hello World");
449    }
450
451    #[test]
452    fn test_decode_pdf_string_latin1() {
453        // Latin-1 编码的 "Café"
454        let bytes = vec![b'C', b'a', b'f', 0xE9]; // é in Latin-1
455        let result = PdfLoader::decode_pdf_string(&bytes);
456        assert!(result.is_ok());
457    }
458}