Skip to main content

sqlite_graphrag/extract/
composite_backend.rs

1//! Composite extraction backend (v1.0.75 — G21 orchestration)
2//!
3//! Runs multiple backends in parallel and merges their outputs.
4//! Used when the user requests `--extraction-backend both`.
5
6use super::{
7    BackendHealth, BackendKind, ExtractionBackend, ExtractionHints, ExtractionOutput, SharedBackend,
8};
9use crate::errors::AppError;
10use async_trait::async_trait;
11use std::time::Instant;
12
13pub struct CompositeBackend {
14    backends: Vec<SharedBackend>,
15}
16
17impl CompositeBackend {
18    pub fn new(backends: Vec<SharedBackend>) -> Self {
19        Self { backends }
20    }
21}
22
23#[async_trait]
24impl ExtractionBackend for CompositeBackend {
25    fn kind(&self) -> BackendKind {
26        BackendKind::Composite
27    }
28
29    fn model_name(&self) -> String {
30        self.backends
31            .iter()
32            .map(|b| b.model_name())
33            .collect::<Vec<_>>()
34            .join("+")
35    }
36
37    async fn extract(
38        &self,
39        content: &str,
40        hints: &ExtractionHints,
41    ) -> Result<ExtractionOutput, AppError> {
42        let start = Instant::now();
43        let mut merged = ExtractionOutput {
44            backend: self.kind().as_str().to_string(),
45            ..Default::default()
46        };
47        let mut first_embedding: Option<Vec<f32>> = None;
48        let mut any_error: Option<AppError> = None;
49
50        for backend in &self.backends {
51            match backend.extract(content, hints).await {
52                Ok(out) => {
53                    for entity in out.entities {
54                        if !merged.entities.iter().any(|e| e.name == entity.name) {
55                            merged.entities.push(entity);
56                        }
57                    }
58                    for rel in out.relationships {
59                        let exists = merged.relationships.iter().any(|r| {
60                            r.source == rel.source
61                                && r.target == rel.target
62                                && r.relation == rel.relation
63                        });
64                        if !exists {
65                            merged.relationships.push(rel);
66                        }
67                    }
68                    if first_embedding.is_none() && out.embedding.is_some() {
69                        first_embedding = out.embedding;
70                    }
71                }
72                Err(err) => {
73                    if any_error.is_none() {
74                        any_error = Some(err);
75                    }
76                }
77            }
78        }
79
80        merged.embedding = first_embedding;
81        merged.elapsed_ms = start.elapsed().as_millis() as u64;
82
83        if merged.entities.is_empty() && merged.relationships.is_empty() {
84            if let Some(err) = any_error {
85                return Err(err);
86            }
87        }
88        Ok(merged)
89    }
90
91    async fn health(&self) -> Result<BackendHealth, AppError> {
92        let mut healthy = true;
93        let mut messages = Vec::new();
94        for backend in &self.backends {
95            match backend.health().await {
96                Ok(h) => {
97                    if !h.healthy {
98                        healthy = false;
99                    }
100                    messages.push(format!(
101                        "{}:{}",
102                        h.kind.as_str(),
103                        if h.healthy { "ok" } else { "degraded" }
104                    ));
105                }
106                Err(err) => {
107                    healthy = false;
108                    messages.push(format!("err:{err}"));
109                }
110            }
111        }
112        Ok(BackendHealth {
113            kind: self.kind(),
114            healthy,
115            model_name: self.model_name(),
116            message: messages.join(" "),
117        })
118    }
119}
120
121/// Factory that builds the default backend for the current build configuration.
122pub fn default_backend() -> SharedBackend {
123    use std::sync::Arc;
124    Arc::new(super::llm_backend::LlmBackend::with_default_codex())
125}
126
127/// Factory that builds a backend from a CLI flag.
128pub fn backend_from_kind(kind: BackendKind) -> SharedBackend {
129    use std::sync::Arc;
130    match kind {
131        BackendKind::Llm => Arc::new(super::llm_backend::LlmBackend::with_default_codex()),
132        BackendKind::Embedding => Arc::new(super::embedding_backend::EmbeddingBackend::new()),
133        BackendKind::None => Arc::new(super::none_backend::NoneBackend::new()),
134        BackendKind::Composite => {
135            let llm: SharedBackend = Arc::new(super::llm_backend::LlmBackend::with_default_codex());
136            let embedding: SharedBackend =
137                Arc::new(super::embedding_backend::EmbeddingBackend::new());
138            Arc::new(CompositeBackend::new(vec![llm, embedding]))
139        }
140    }
141}