stygian_graph/application/
extraction.rs1use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38use tracing::{debug, info, warn};
39
40use crate::domain::error::{ProviderError, Result, StygianError};
41use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
42
43#[derive(Debug, Clone)]
45pub struct ExtractionConfig {
46 pub max_content_chars: usize,
49 pub validate_output: bool,
52}
53
54impl Default for ExtractionConfig {
55 fn default() -> Self {
56 Self {
57 max_content_chars: 64_000,
58 validate_output: true,
59 }
60 }
61}
62
63pub struct LlmExtractionService {
79 providers: Vec<Arc<dyn AIProvider>>,
81 config: ExtractionConfig,
82}
83
84impl LlmExtractionService {
85 #[must_use]
100 pub fn new(providers: Vec<Arc<dyn AIProvider>>, config: ExtractionConfig) -> Self {
101 Self { providers, config }
102 }
103
104 fn resolve_content(input: &ServiceInput) -> &str {
110 input
111 .params
112 .get("content")
113 .and_then(Value::as_str)
114 .unwrap_or(&input.url)
115 }
116
117 fn truncate_content<'a>(&self, content: &'a str) -> &'a str {
119 if content.len() <= self.config.max_content_chars {
120 content
121 } else {
122 warn!(
123 limit = self.config.max_content_chars,
124 actual = content.len(),
125 "Content truncated for LLM extraction"
126 );
127 &content[..self.config.max_content_chars]
128 }
129 }
130
131 fn resolve_schema(input: &ServiceInput) -> Result<Value> {
133 input.params.get("schema").cloned().ok_or_else(|| {
134 StygianError::Provider(ProviderError::ApiError(
135 "LlmExtractionService requires 'schema' in ServiceInput.params".to_string(),
136 ))
137 })
138 }
139
140 fn validate_output(output: &Value) -> Result<()> {
142 if output.is_object() || output.is_array() {
143 Ok(())
144 } else {
145 Err(StygianError::Provider(ProviderError::ApiError(format!(
146 "Provider returned non-object output: {output}"
147 ))))
148 }
149 }
150}
151
152#[async_trait]
153impl ScrapingService for LlmExtractionService {
154 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
180 if self.providers.is_empty() {
181 return Err(StygianError::Provider(ProviderError::ApiError(
182 "No AI providers configured in LlmExtractionService".to_string(),
183 )));
184 }
185
186 let schema = Self::resolve_schema(&input)?;
187 let raw_content = Self::resolve_content(&input);
188 let content = self.truncate_content(raw_content).to_string();
189
190 let start = std::time::Instant::now();
191 let mut last_error: Option<StygianError> = None;
192
193 for provider in &self.providers {
194 debug!(provider = provider.name(), "Attempting LLM extraction");
195
196 match provider.extract(content.clone(), schema.clone()).await {
197 Ok(extracted) => {
198 if self.config.validate_output
199 && let Err(e) = Self::validate_output(&extracted)
200 {
201 warn!(
202 provider = provider.name(),
203 error = %e,
204 "Provider returned invalid output, trying next"
205 );
206 last_error = Some(e);
207 continue;
208 }
209
210 let elapsed = start.elapsed();
211 info!(
212 provider = provider.name(),
213 elapsed_ms = elapsed.as_millis(),
214 "LLM extraction succeeded"
215 );
216
217 return Ok(ServiceOutput {
218 data: extracted.to_string(),
219 metadata: json!({
220 "provider": provider.name(),
221 "elapsed_ms": elapsed.as_millis(),
222 "content_chars": content.len(),
223 }),
224 });
225 }
226 Err(e) => {
227 warn!(
228 provider = provider.name(),
229 error = %e,
230 "Provider failed, trying next in chain"
231 );
232 last_error = Some(e);
233 }
234 }
235 }
236
237 Err(last_error.unwrap_or_else(|| {
239 StygianError::Provider(ProviderError::ApiError(
240 "All AI providers in fallback chain failed".to_string(),
241 ))
242 }))
243 }
244
245 fn name(&self) -> &'static str {
246 "llm-extraction"
247 }
248}
249
250#[cfg(test)]
251#[allow(
252 clippy::unwrap_used,
253 clippy::indexing_slicing,
254 clippy::needless_pass_by_value
255)]
256mod tests {
257 use super::*;
258 use crate::ports::ProviderCapabilities;
259 use futures::stream::{self, BoxStream};
260 use serde_json::json;
261
262 struct AlwaysSucceed {
265 response: Value,
266 }
267
268 #[async_trait]
269 impl AIProvider for AlwaysSucceed {
270 async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
271 Ok(self.response.clone())
272 }
273
274 async fn stream_extract(
275 &self,
276 _content: String,
277 _schema: Value,
278 ) -> Result<BoxStream<'static, Result<Value>>> {
279 Ok(Box::pin(stream::once(async { Ok(json!({})) })))
280 }
281
282 fn capabilities(&self) -> ProviderCapabilities {
283 ProviderCapabilities::default()
284 }
285
286 fn name(&self) -> &'static str {
287 "mock-succeed"
288 }
289 }
290
291 struct AlwaysFail;
292
293 #[async_trait]
294 impl AIProvider for AlwaysFail {
295 async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
296 Err(StygianError::Provider(ProviderError::ApiError(
297 "mock failure".to_string(),
298 )))
299 }
300
301 async fn stream_extract(
302 &self,
303 _content: String,
304 _schema: Value,
305 ) -> Result<BoxStream<'static, Result<Value>>> {
306 Err(StygianError::Provider(ProviderError::ApiError(
307 "mock failure".to_string(),
308 )))
309 }
310
311 fn capabilities(&self) -> ProviderCapabilities {
312 ProviderCapabilities::default()
313 }
314
315 fn name(&self) -> &'static str {
316 "mock-fail"
317 }
318 }
319
320 fn make_input(schema: Value) -> ServiceInput {
321 ServiceInput {
322 url: "<h1>Hello</h1>".to_string(),
323 params: json!({ "schema": schema }),
324 }
325 }
326
327 #[tokio::test]
328 async fn test_service_name() {
329 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
330 assert_eq!(svc.name(), "llm-extraction");
331 }
332
333 #[tokio::test]
334 async fn test_no_providers_returns_error() {
335 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
336 let err = svc.execute(make_input(json!({}))).await.unwrap_err();
337 assert!(err.to_string().contains("No AI providers"));
338 }
339
340 #[tokio::test]
341 async fn test_missing_schema_returns_error() {
342 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
343 response: json!({}),
344 })];
345 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
346 let input = ServiceInput {
347 url: "some content".to_string(),
348 params: json!({}), };
350 let err = svc.execute(input).await.unwrap_err();
351 assert!(err.to_string().contains("schema"));
352 }
353
354 #[tokio::test]
355 async fn test_single_succeeding_provider() {
356 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
357 response: json!({"title": "Hello"}),
358 })];
359 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
360 let output = svc.execute(make_input(json!({}))).await.unwrap();
361 assert_eq!(
362 output.metadata["provider"].as_str().unwrap(),
363 "mock-succeed"
364 );
365 let data: Value = serde_json::from_str(&output.data).unwrap();
366 assert_eq!(data["title"].as_str().unwrap(), "Hello");
367 }
368
369 #[tokio::test]
370 async fn test_fallback_to_second_provider() {
371 let providers: Vec<Arc<dyn AIProvider>> = vec![
372 Arc::new(AlwaysFail),
373 Arc::new(AlwaysSucceed {
374 response: json!({"score": 42}),
375 }),
376 ];
377 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
378 let output = svc.execute(make_input(json!({}))).await.unwrap();
379 assert_eq!(
380 output.metadata["provider"].as_str().unwrap(),
381 "mock-succeed"
382 );
383 }
384
385 #[tokio::test]
386 async fn test_all_providers_fail() {
387 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysFail), Arc::new(AlwaysFail)];
388 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
389 let err = svc.execute(make_input(json!({}))).await.unwrap_err();
390 assert!(err.to_string().contains("mock failure"));
391 }
392
393 #[tokio::test]
394 async fn test_content_from_params_overrides_url() {
395 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
396 response: json!({"ok": true}),
397 })];
398 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
399 let input = ServiceInput {
400 url: "should-not-be-used".to_string(),
401 params: json!({
402 "schema": {"type": "object"},
403 "content": "actual content here"
404 }),
405 };
406 let output = svc.execute(input).await.unwrap();
407 assert_eq!(output.metadata["content_chars"].as_u64().unwrap(), 19);
409 }
410
411 #[test]
412 fn test_truncate_content_short() {
413 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
414 let s = "hello";
415 assert_eq!(svc.truncate_content(s), s);
416 }
417
418 #[test]
419 fn test_truncate_content_long() {
420 let svc = LlmExtractionService::new(
421 vec![],
422 ExtractionConfig {
423 max_content_chars: 5,
424 ..Default::default()
425 },
426 );
427 assert_eq!(svc.truncate_content("hello world"), "hello");
428 }
429
430 #[test]
431 fn test_validate_output_object_ok() {
432 assert!(LlmExtractionService::validate_output(&json!({"k": "v"})).is_ok());
433 }
434
435 #[test]
436 fn test_validate_output_array_ok() {
437 assert!(LlmExtractionService::validate_output(&json!([1, 2, 3])).is_ok());
438 }
439
440 #[test]
441 fn test_validate_output_scalar_err() {
442 assert!(LlmExtractionService::validate_output(&json!("just a string")).is_err());
443 }
444}