Skip to main content

sh_layer3/document_loaders/
pdf.rs

1//! # PDF Document Loader
2//!
3//! PDF 文件加载器:解析 PDF 文件并提取文本内容和元数据。
4//!
5//! ## 功能
6//!
7//! - 基于页面的文本提取
8//! - 元数据提取(标题、作者、创建日期等)
9//! - 支持异步加载
10//! - 完整错误处理
11
12use crate::document_loaders::{DocumentLoader, LoadOptions};
13use crate::retriever_engine::Document;
14use crate::types::Layer3Error;
15use crate::types::Layer3Result;
16use async_trait::async_trait;
17use lopdf::Document as PdfDoc;
18use std::collections::HashMap;
19use std::path::PathBuf;
20use tracing::{debug, warn};
21
22/// PDF Loader 实现
23pub struct PdfLoader {
24    options: LoadOptions,
25}
26
27impl PdfLoader {
28    pub fn new() -> Self {
29        Self {
30            options: LoadOptions::default(),
31        }
32    }
33
34    pub fn with_options(options: LoadOptions) -> Self {
35        Self { options }
36    }
37
38    /// 从 PDF 提取元数据
39    fn extract_metadata(&self, pdf: &PdfDoc) -> HashMap<String, serde_json::Value> {
40        let mut metadata = HashMap::new();
41
42        // 提取文档信息
43        if let Ok(trailer) = pdf.trailer.get(b"Info") {
44            if let Ok(info_ref) = trailer.as_reference() {
45                if let Some(info_obj) = pdf.get_object(info_ref).ok() {
46                    if let lopdf::Object::Dictionary(dict) = info_obj {
47                        // 标题
48                        if let Some(title) = dict.get(b"Title").ok() {
49                            if let lopdf::Object::String(bytes, _) = title {
50                                if let Ok(title_str) = Self::decode_pdf_string(bytes) {
51                                    metadata
52                                        .insert("title".to_string(), serde_json::json!(title_str));
53                                }
54                            }
55                        }
56
57                        // 作者
58                        if let Some(author) = dict.get(b"Author").ok() {
59                            if let lopdf::Object::String(bytes, _) = author {
60                                if let Ok(author_str) = Self::decode_pdf_string(bytes) {
61                                    metadata.insert(
62                                        "author".to_string(),
63                                        serde_json::json!(author_str),
64                                    );
65                                }
66                            }
67                        }
68
69                        // 主题
70                        if let Some(subject) = dict.get(b"Subject").ok() {
71                            if let lopdf::Object::String(bytes, _) = subject {
72                                if let Ok(subject_str) = Self::decode_pdf_string(bytes) {
73                                    metadata.insert(
74                                        "subject".to_string(),
75                                        serde_json::json!(subject_str),
76                                    );
77                                }
78                            }
79                        }
80
81                        // 创建者
82                        if let Some(creator) = dict.get(b"Creator").ok() {
83                            if let lopdf::Object::String(bytes, _) = creator {
84                                if let Ok(creator_str) = Self::decode_pdf_string(bytes) {
85                                    metadata.insert(
86                                        "creator".to_string(),
87                                        serde_json::json!(creator_str),
88                                    );
89                                }
90                            }
91                        }
92
93                        // 生产者
94                        if let Some(producer) = dict.get(b"Producer").ok() {
95                            if let lopdf::Object::String(bytes, _) = producer {
96                                if let Ok(producer_str) = Self::decode_pdf_string(bytes) {
97                                    metadata.insert(
98                                        "producer".to_string(),
99                                        serde_json::json!(producer_str),
100                                    );
101                                }
102                            }
103                        }
104
105                        // 创建日期
106                        if let Some(creation_date) = dict.get(b"CreationDate").ok() {
107                            if let lopdf::Object::String(bytes, _) = creation_date {
108                                if let Ok(date_str) = Self::decode_pdf_string(bytes) {
109                                    metadata.insert(
110                                        "creation_date".to_string(),
111                                        serde_json::json!(date_str),
112                                    );
113                                }
114                            }
115                        }
116
117                        // 修改日期
118                        if let Some(mod_date) = dict.get(b"ModDate").ok() {
119                            if let lopdf::Object::String(bytes, _) = mod_date {
120                                if let Ok(date_str) = Self::decode_pdf_string(bytes) {
121                                    metadata.insert(
122                                        "modification_date".to_string(),
123                                        serde_json::json!(date_str),
124                                    );
125                                }
126                            }
127                        }
128                    }
129                }
130            }
131        }
132
133        // 页数
134        let page_count = pdf.get_pages().len();
135        metadata.insert("page_count".to_string(), serde_json::json!(page_count));
136
137        metadata
138    }
139
140    /// 解码 PDF 字符串(处理编码)
141    fn decode_pdf_string(bytes: &[u8]) -> Layer3Result<String> {
142        // 尝试 UTF-8
143        if let Ok(s) = std::str::from_utf8(bytes) {
144            return Ok(s.to_string());
145        }
146
147        // 尝试 Latin-1 (ISO-8859-1)
148        let decoded: String = bytes.iter().map(|&b| b as char).collect();
149        Ok(decoded)
150    }
151
152    /// 从单个页面提取文本
153    fn extract_page_text(pdf: &PdfDoc, page_id: (u32, u16)) -> Layer3Result<String> {
154        let mut text = String::new();
155
156        if let Ok(page) = pdf.get_object(page_id) {
157            if let lopdf::Object::Dictionary(dict) = page {
158                if let Some(contents) = dict.get(b"Contents").ok() {
159                    match contents {
160                        lopdf::Object::Reference(ref_id) => {
161                            if let Ok(stream) = pdf.get_object(*ref_id) {
162                                if let lopdf::Object::Stream(stream_obj) = stream {
163                                    if let Ok(content) = stream_obj.decompressed_content() {
164                                        text.push_str(&Self::parse_content_stream(&content));
165                                    }
166                                }
167                            }
168                        }
169                        lopdf::Object::Array(arr) => {
170                            for obj in arr {
171                                if let lopdf::Object::Reference(ref_id) = obj {
172                                    if let Ok(stream) = pdf.get_object(*ref_id) {
173                                        if let lopdf::Object::Stream(stream_obj) = stream {
174                                            if let Ok(content) = stream_obj.decompressed_content() {
175                                                text.push_str(&Self::parse_content_stream(
176                                                    &content,
177                                                ));
178                                            }
179                                        }
180                                    }
181                                }
182                            }
183                        }
184                        _ => {}
185                    }
186                }
187            }
188        }
189
190        Ok(text)
191    }
192
193    /// 解析 PDF 内容流,提取文本
194    fn parse_content_stream(content: &[u8]) -> String {
195        let mut text = String::new();
196        let content_str = String::from_utf8_lossy(content);
197
198        // 简单的文本提取:查找 Tj 和 TJ 操作符
199        let mut current_text = String::new();
200        let mut in_string = false;
201        let mut escape_next = false;
202
203        for ch in content_str.chars() {
204            if escape_next {
205                current_text.push(ch);
206                escape_next = false;
207                continue;
208            }
209
210            match ch {
211                '\\' if in_string => {
212                    escape_next = true;
213                }
214                '(' => {
215                    if !in_string {
216                        in_string = true;
217                        current_text.clear();
218                    } else {
219                        current_text.push(ch);
220                    }
221                }
222                ')' => {
223                    if in_string {
224                        in_string = false;
225                        if !current_text.is_empty() {
226                            // 过滤控制字符
227                            let cleaned: String = current_text
228                                .chars()
229                                .filter(|c| {
230                                    c.is_alphabetic()
231                                        || c.is_numeric()
232                                        || c.is_whitespace()
233                                        || *c == '-'
234                                        || *c == '.'
235                                        || *c == ','
236                                })
237                                .collect();
238                            if !cleaned.trim().is_empty() {
239                                text.push_str(&cleaned);
240                                text.push(' ');
241                            }
242                        }
243                    } else {
244                        current_text.push(ch);
245                    }
246                }
247                _ => {
248                    if in_string {
249                        current_text.push(ch);
250                    }
251                }
252            }
253        }
254
255        // 清理多余空格
256        let cleaned: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
257
258        cleaned
259    }
260
261    /// 从 PDF 提取所有页面文本
262    fn extract_all_text(&self, pdf: &PdfDoc) -> Layer3Result<Vec<(usize, String)>> {
263        let pages = pdf.get_pages();
264        let mut result = Vec::new();
265
266        for (page_num, page_id) in pages.iter() {
267            match Self::extract_page_text(pdf, *page_id) {
268                Ok(page_text) => {
269                    if !page_text.trim().is_empty() {
270                        result.push((*page_num as usize, page_text));
271                    }
272                }
273                Err(e) => {
274                    warn!("Failed to extract text from page {}: {}", page_num, e);
275                }
276            }
277        }
278
279        Ok(result)
280    }
281}
282
283impl Default for PdfLoader {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289#[async_trait]
290impl DocumentLoader for PdfLoader {
291    async fn load(&self, path: PathBuf) -> Layer3Result<Document> {
292        debug!("Loading PDF file: {:?}", path);
293
294        // 读取 PDF 文件
295        let pdf = PdfDoc::load(&path).map_err(|e| {
296            Layer3Error::PersistenceError(format!(
297                "Failed to load PDF file '{}': {}",
298                path.display(),
299                e
300            ))
301        })?;
302
303        // 提取元数据
304        let metadata = self.extract_metadata(&pdf);
305
306        // 提取所有文本
307        let pages = self.extract_all_text(&pdf)?;
308        let full_text: String = pages
309            .iter()
310            .map(|(_, text)| text.as_str())
311            .collect::<Vec<_>>()
312            .join("\n\n");
313
314        // 创建文档
315        let mut doc = Document::new(full_text).with_source(path.to_string_lossy().to_string());
316
317        // 添加元数据
318        doc.metadata = metadata;
319
320        Ok(doc)
321    }
322
323    async fn load_and_split(&self, path: PathBuf) -> Layer3Result<Vec<Document>> {
324        debug!("Loading and splitting PDF file: {:?}", path);
325
326        // 读取 PDF 文件
327        let pdf = PdfDoc::load(&path).map_err(|e| {
328            Layer3Error::PersistenceError(format!(
329                "Failed to load PDF file '{}': {}",
330                path.display(),
331                e
332            ))
333        })?;
334
335        // 提取元数据
336        let base_metadata = self.extract_metadata(&pdf);
337
338        // 提取所有页面文本
339        let pages = self.extract_all_text(&pdf)?;
340
341        if pages.is_empty() {
342            // 如果没有提取到文本,返回空数组
343            return Ok(Vec::new());
344        }
345
346        // 为每个页面创建一个文档
347        let source = path.to_string_lossy().to_string();
348        let documents: Vec<Document> = pages
349            .into_iter()
350            .map(|(page_num, text)| {
351                let mut metadata = base_metadata.clone();
352                metadata.insert("page".to_string(), serde_json::json!(page_num));
353                metadata.insert(
354                    "total_pages".to_string(),
355                    serde_json::json!(pdf.get_pages().len()),
356                );
357
358                Document {
359                    id: None,
360                    content: text,
361                    metadata,
362                    source: Some(source.clone()),
363                }
364            })
365            .collect();
366
367        Ok(documents)
368    }
369
370    fn supports(&self, path: &std::path::Path) -> bool {
371        path.extension()
372            .and_then(|e| e.to_str())
373            .map(|e| e.to_lowercase() == "pdf")
374            .unwrap_or(false)
375    }
376
377    fn extensions(&self) -> &[&str] {
378        &["pdf"]
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use lopdf::Dictionary;
386    use lopdf::Object as PdfObject;
387    use lopdf::Stream;
388    use tempfile::NamedTempFile;
389
390    fn create_minimal_pdf() -> NamedTempFile {
391        // Create a PDF programmatically using lopdf
392        let mut pdf = lopdf::Document::new();
393
394        // Object 1: Catalog
395        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
396            ("Type", PdfObject::Name("Catalog".as_bytes().to_vec())),
397            ("Pages", PdfObject::Reference((2, 0))),
398        ])));
399
400        // Object 2: Pages
401        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
402            ("Type", PdfObject::Name("Pages".as_bytes().to_vec())),
403            ("Kids", PdfObject::Array(vec![PdfObject::Reference((3, 0))])),
404            ("Count", PdfObject::Integer(1)),
405        ])));
406
407        // Object 3: Page
408        pdf.add_object(PdfObject::Dictionary(Dictionary::from_iter([
409            ("Type", PdfObject::Name("Page".as_bytes().to_vec())),
410            ("Parent", PdfObject::Reference((2, 0))),
411            (
412                "MediaBox",
413                PdfObject::Array(vec![
414                    PdfObject::Integer(0),
415                    PdfObject::Integer(0),
416                    PdfObject::Integer(612),
417                    PdfObject::Integer(792),
418                ]),
419            ),
420            ("Contents", PdfObject::Reference((4, 0))),
421        ])));
422
423        // Object 4: Content stream with text
424        let content = b"BT /F1 12 Tf 100 700 Td (Hello World) Tj ET";
425        pdf.add_object(PdfObject::Stream(Stream::new(
426            Dictionary::from_iter([("Length", PdfObject::Integer(content.len() as i64))]),
427            content.to_vec(),
428        )));
429
430        // Save to temp file
431        let file = NamedTempFile::with_suffix(".pdf").unwrap();
432        pdf.save(file.path()).expect("Failed to save PDF");
433        file
434    }
435
436    #[test]
437    fn test_pdf_loader_extensions() {
438        let loader = PdfLoader::new();
439        assert!(loader.extensions().contains(&"pdf"));
440    }
441
442    #[test]
443    fn test_pdf_loader_supports() {
444        let loader = PdfLoader::new();
445        assert!(loader.supports(std::path::Path::new("test.pdf")));
446        assert!(loader.supports(std::path::Path::new("test.PDF")));
447        assert!(!loader.supports(std::path::Path::new("test.txt")));
448    }
449
450    #[tokio::test]
451    async fn test_pdf_loader_load() {
452        let loader = PdfLoader::new();
453        let pdf_file = create_minimal_pdf();
454
455        let result = loader.load(pdf_file.path().to_path_buf()).await;
456        if let Err(ref err) = result {
457            eprintln!("Error loading PDF: {:?}", err);
458        }
459        assert!(result.is_ok(), "PDF should load successfully");
460
461        let doc = result.unwrap();
462        assert!(doc.source.is_some());
463        // Should have page_count metadata even if no text extracted
464        assert!(doc.metadata.contains_key("page_count"));
465    }
466
467    #[tokio::test]
468    async fn test_pdf_loader_load_and_split() {
469        let loader = PdfLoader::new();
470        let pdf_file = create_minimal_pdf();
471
472        let result = loader.load_and_split(pdf_file.path().to_path_buf()).await;
473        if let Err(ref err) = result {
474            eprintln!("Error loading PDF: {:?}", err);
475        }
476        assert!(result.is_ok(), "PDF should load successfully");
477
478        let docs = result.unwrap();
479        // Minimal PDF has 1 page, so should return 1 doc (or 0 if no text)
480        assert!(docs.len() <= 1);
481    }
482
483    #[test]
484    fn test_decode_pdf_string_utf8() {
485        let bytes = b"Hello World";
486        let result = PdfLoader::decode_pdf_string(bytes);
487        assert!(result.is_ok());
488        assert_eq!(result.unwrap(), "Hello World");
489    }
490
491    #[test]
492    fn test_decode_pdf_string_latin1() {
493        // Latin-1 编码的 "Café"
494        let bytes = vec![b'C', b'a', b'f', 0xE9]; // é in Latin-1
495        let result = PdfLoader::decode_pdf_string(&bytes);
496        assert!(result.is_ok());
497    }
498}