sqlite_graphrag/
extraction.rs1use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub struct ExtractedUrl {
22 pub url: String,
23 pub start: usize,
24 pub end: usize,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub struct ExtractedEntity {
31 pub name: String,
32 pub entity_type: String,
33 pub start: usize,
34 pub end: usize,
35}
36
37#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
42pub struct ExtractionResult {
43 pub entities: Vec<ExtractedEntity>,
44 pub urls: Vec<ExtractedUrl>,
45 pub elapsed_ms: u64,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum GlinerVariant {
54 Fp32,
55 Int8,
56}
57
58impl GlinerVariant {
59 pub fn as_filename(self) -> &'static str {
60 match self {
61 Self::Fp32 => "model.onnx",
62 Self::Int8 => "model_int8.onnx",
63 }
64 }
65 pub fn display_size(self) -> &'static str {
66 match self {
67 Self::Fp32 => "1.1 GB",
68 Self::Int8 => "349 MB",
69 }
70 }
71}
72
73pub trait Extractor: Send + Sync {
76 fn name(&self) -> &'static str;
77 fn extract(&self, body: &str) -> Result<ExtractionResult, crate::errors::AppError>;
78}
79
80pub struct RegexExtractor;
83
84impl Extractor for RegexExtractor {
85 fn name(&self) -> &'static str {
86 "regex"
87 }
88 fn extract(&self, body: &str) -> Result<ExtractionResult, crate::errors::AppError> {
89 Ok(ExtractionResult {
90 entities: Vec::new(),
91 urls: extract_urls(body),
92 elapsed_ms: 0,
93 })
94 }
95}
96
97pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
100 let mut out = Vec::new();
101 let mut cursor = 0usize;
102 while cursor < body.len() {
103 let hay = &body[cursor..];
104 let http_at = hay.find("http://");
106 let https_at = hay.find("https://");
107 let (rel_start, scheme_len) = match (http_at, https_at) {
108 (Some(a), Some(b)) => {
109 if a <= b {
110 (a, 7)
111 } else {
112 (b, 8)
113 }
114 }
115 (Some(a), None) => (a, 7),
116 (None, Some(b)) => (b, 8),
117 (None, None) => break,
118 };
119 let abs_start = cursor + rel_start;
120 let after_scheme = abs_start + scheme_len;
121 let mut end = after_scheme;
122 for (i, c) in body[after_scheme..].char_indices() {
123 if c.is_whitespace() || matches!(c, ')' | ']' | '}' | '"' | '\'' | '<') {
124 end = after_scheme + i;
125 break;
126 }
127 end = after_scheme + i + c.len_utf8();
128 }
129 out.push(ExtractedUrl {
130 url: body[abs_start..end].to_string(),
131 start: abs_start,
132 end,
133 });
134 cursor = end;
135 }
136 out
137}
138
139pub fn extract_graph_auto(
143 body: &str,
144 _paths: &crate::paths::AppPaths,
145 _gliner_variant: GlinerVariant,
146) -> Result<ExtractionResult, crate::errors::AppError> {
147 let start = std::time::Instant::now();
148 let urls = extract_urls(body);
149 Ok(ExtractionResult {
150 entities: Vec::new(),
151 urls,
152 elapsed_ms: start.elapsed().as_millis() as u64,
153 })
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn extract_urls_finds_http_and_https() {
162 let body = "see https://example.com/foo and http://bar.baz/qux end";
163 let urls = extract_urls(body);
164 assert_eq!(urls.len(), 2, "got {urls:?} for body {body:?}");
165 assert_eq!(urls[0].url, "https://example.com/foo");
166 assert_eq!(urls[1].url, "http://bar.baz/qux");
167 }
168
169 #[test]
170 fn extract_urls_handles_trailing_punctuation() {
171 let body = "see https://example.com/foo).";
172 let urls = extract_urls(body);
173 assert_eq!(urls.len(), 1);
174 assert_eq!(urls[0].url, "https://example.com/foo");
175 }
176
177 #[test]
178 fn extract_urls_empty_body() {
179 assert!(extract_urls("").is_empty());
180 }
181
182 #[test]
183 fn gliner_variant_size_strings() {
184 assert_eq!(GlinerVariant::Fp32.display_size(), "1.1 GB");
185 assert_eq!(GlinerVariant::Int8.display_size(), "349 MB");
186 }
187
188 #[test]
189 fn regex_extractor_returns_only_urls() {
190 let result = RegexExtractor.extract("see https://example.com").unwrap();
191 assert_eq!(result.entities.len(), 0);
192 assert_eq!(result.urls.len(), 1);
193 }
194}