synaptic_splitters/
html_header.rs1use std::collections::HashMap;
2
3use serde_json::Value;
4use synaptic_retrieval::Document;
5
6use crate::TextSplitter;
7
8pub struct HtmlHeaderTextSplitter {
13 headers_to_split_on: Vec<(String, String)>,
15}
16
17impl HtmlHeaderTextSplitter {
18 pub fn new(headers_to_split_on: Vec<(String, String)>) -> Self {
22 Self {
23 headers_to_split_on,
24 }
25 }
26
27 pub fn default_headers() -> Self {
29 Self::new(vec![
30 ("h1".to_string(), "Header 1".to_string()),
31 ("h2".to_string(), "Header 2".to_string()),
32 ("h3".to_string(), "Header 3".to_string()),
33 ])
34 }
35
36 pub fn split_html(&self, text: &str) -> Vec<Document> {
38 let mut documents = Vec::new();
39 let mut current_headers: HashMap<String, String> = HashMap::new();
40 let mut current_content = String::new();
41 let mut doc_index = 0;
42
43 let header_levels: HashMap<String, usize> = self
46 .headers_to_split_on
47 .iter()
48 .enumerate()
49 .map(|(i, (tag, _))| (tag.to_lowercase(), i))
50 .collect();
51
52 for line in text.lines() {
53 let trimmed = line.trim();
54
55 let mut matched = None;
57 for (tag, metadata_key) in &self.headers_to_split_on {
58 let open_tag = format!("<{}", tag.to_lowercase());
59 let trimmed_lower = trimmed.to_lowercase();
60
61 if trimmed_lower.starts_with(&open_tag) {
62 let header_text = extract_tag_content(trimmed, tag);
64 matched = Some((tag.clone(), metadata_key.clone(), header_text));
65 break;
66 }
67 }
68
69 if let Some((tag, metadata_key, header_text)) = matched {
70 let content = current_content.trim().to_string();
72 if !content.is_empty() {
73 let mut metadata: HashMap<String, Value> = current_headers
74 .iter()
75 .map(|(k, v)| (k.clone(), Value::String(v.clone())))
76 .collect();
77 metadata.insert("chunk_index".to_string(), Value::Number(doc_index.into()));
78 documents.push(Document::with_metadata(
79 format!("chunk-{doc_index}"),
80 content,
81 metadata,
82 ));
83 doc_index += 1;
84 }
85
86 let current_level = header_levels.get(&tag.to_lowercase()).copied().unwrap_or(0);
88 let keys_to_remove: Vec<String> = current_headers
89 .keys()
90 .filter(|k| {
91 self.headers_to_split_on
92 .iter()
93 .find(|(_, mk)| mk == *k)
94 .and_then(|(t, _)| header_levels.get(&t.to_lowercase()))
95 .map(|level| *level >= current_level)
96 .unwrap_or(false)
97 })
98 .cloned()
99 .collect();
100 for key in keys_to_remove {
101 current_headers.remove(&key);
102 }
103
104 current_headers.insert(metadata_key, header_text);
105 current_content.clear();
106 } else {
107 let stripped = strip_simple_tags(trimmed);
109 let stripped = stripped.trim();
110 if !stripped.is_empty() {
111 if !current_content.is_empty() {
112 current_content.push('\n');
113 }
114 current_content.push_str(stripped);
115 }
116 }
117 }
118
119 let content = current_content.trim().to_string();
121 if !content.is_empty() {
122 let mut metadata: HashMap<String, Value> = current_headers
123 .iter()
124 .map(|(k, v)| (k.clone(), Value::String(v.clone())))
125 .collect();
126 metadata.insert("chunk_index".to_string(), Value::Number(doc_index.into()));
127 documents.push(Document::with_metadata(
128 format!("chunk-{doc_index}"),
129 content,
130 metadata,
131 ));
132 }
133
134 documents
135 }
136}
137
138fn extract_tag_content(line: &str, tag: &str) -> String {
140 let close_tag = format!("</{}>", tag.to_lowercase());
141 if let Some(start) = line.find('>') {
143 let rest = &line[start + 1..];
144 let lower_rest = rest.to_lowercase();
146 if let Some(end) = lower_rest.find(&close_tag) {
147 return rest[..end].trim().to_string();
148 }
149 return rest.trim().to_string();
151 }
152 String::new()
153}
154
155fn strip_simple_tags(text: &str) -> String {
157 let mut result = String::new();
158 let mut in_tag = false;
159 for ch in text.chars() {
160 if ch == '<' {
161 in_tag = true;
162 } else if ch == '>' {
163 in_tag = false;
164 } else if !in_tag {
165 result.push(ch);
166 }
167 }
168 result
169}
170
171impl TextSplitter for HtmlHeaderTextSplitter {
172 fn split_text(&self, text: &str) -> Vec<String> {
173 self.split_html(text)
174 .into_iter()
175 .map(|d| d.content)
176 .collect()
177 }
178}