Skip to main content

smos_application/testkit/
providers.rs

1//! Provider test doubles: scripted LLM extractor, constant/recording
2//! embedders, and the dual-mode scripted NLI classifier.
3
4use std::sync::{Arc, Mutex};
5
6use smos_domain::NliResult;
7use smos_domain::chat::ToolCall;
8
9use crate::errors::ProviderError;
10use crate::ports::{EmbeddingProvider, LlmExtractor, NliClassifier, RerankProvider};
11use crate::types::RerankResult;
12
13/// LLM extractor that pops pre-scripted results in FIFO order and counts
14/// invocations. When the script is exhausted, subsequent calls return an empty
15/// `Vec` (mirroring a provider that simply finds no facts) rather than
16/// erroring, so tests that do not care about the Nth call still pass.
17///
18/// Records every `(content, tool_calls)` pair handed to `extract_facts` via
19/// [`ScriptedExtractor::inputs`] — the parity-shaped accessor the NLI/reranker
20/// fakes expose — so tests can assert on the exact input that reached the
21/// extractor (e.g. the `"User:/Assistant:"` role markup).
22pub struct ScriptedExtractor {
23    results: Mutex<Vec<Result<Vec<String>, ProviderError>>>,
24    calls: Mutex<u32>,
25    inputs: Mutex<Vec<(String, Vec<ToolCall>)>>,
26}
27
28impl ScriptedExtractor {
29    pub fn new(results: Vec<Result<Vec<String>, ProviderError>>) -> Self {
30        Self {
31            results: Mutex::new(results),
32            calls: Mutex::new(0),
33            inputs: Mutex::new(Vec::new()),
34        }
35    }
36
37    pub fn call_count(&self) -> u32 {
38        *self.calls.lock().unwrap()
39    }
40
41    /// Recorded `(content, tool_calls)` pairs in invocation order.
42    pub fn inputs(&self) -> Vec<(String, Vec<ToolCall>)> {
43        self.inputs.lock().unwrap().clone()
44    }
45}
46
47impl LlmExtractor for ScriptedExtractor {
48    async fn extract_facts(
49        &self,
50        content: &str,
51        tool_calls: &[ToolCall],
52    ) -> Result<Vec<String>, ProviderError> {
53        *self.calls.lock().unwrap() += 1;
54        self.inputs
55            .lock()
56            .unwrap()
57            .push((content.to_string(), tool_calls.to_vec()));
58        let mut guard = self.results.lock().unwrap();
59        if guard.is_empty() {
60            Ok(Vec::new())
61        } else {
62            guard.remove(0)
63        }
64    }
65}
66
67/// Embedding provider that always returns the same vector regardless of input.
68pub struct ConstantEmbedder(pub Vec<f32>);
69
70impl EmbeddingProvider for ConstantEmbedder {
71    async fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
72        Ok(Some(self.0.clone()))
73    }
74}
75
76/// Embedding provider that records every `embed` call and returns a
77/// deterministic content-derived vector unique to the input text. Used to
78/// verify the extraction pipeline hands distinct embeddings to distinct facts
79/// (so Layer 2 dedup makes the right call). `new` returns the double together
80/// with the shared call-log handle so the test body can assert on it.
81pub struct RecordingEmbedder {
82    calls: Arc<Mutex<Vec<String>>>,
83}
84
85impl RecordingEmbedder {
86    pub fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
87        let calls = Arc::new(Mutex::new(Vec::new()));
88        (
89            Self {
90                calls: calls.clone(),
91            },
92            calls,
93        )
94    }
95
96    fn vector_for(text: &str) -> Vec<f32> {
97        // Stable, content-derived 1024-dim one-hot-ish vector: hash the text
98        // into a single u64 and use it as the index of the single non-zero
99        // dimension. Distinct inputs land on distinct indices, so the cosine
100        // similarity across different hashes is 0.
101        let hash = text
102            .bytes()
103            .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
104        let mut vec = vec![0.0; 1024];
105        vec[(hash as usize) % 1024] = 1.0;
106        vec
107    }
108}
109
110impl EmbeddingProvider for RecordingEmbedder {
111    async fn embed(&self, text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
112        self.calls.lock().unwrap().push(text.to_string());
113        Ok(Some(Self::vector_for(text)))
114    }
115}
116
117/// Closure type used by the matcher variant of [`ScriptedNliClassifier`].
118type NliResolver = Box<dyn Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync>;
119
120/// Scripted NLI classifier with two modes:
121/// - [`ScriptedNliClassifier::new`] (FIFO): each call pops the next verdict
122///   from the queue. Use when the test controls call order.
123/// - [`ScriptedNliClassifier::matching`] (Match): each call dispatches to the
124///   supplied closure. Use when pending iteration order is not deterministic
125///   (`HashMap` order) and the test keys verdicts on the candidate text.
126///
127/// Both modes record every (premise, hypothesis) pair so tests can assert on
128/// the exact set of pairs the use case asked about.
129pub enum ScriptedNliClassifier {
130    Fifo {
131        verdicts: Mutex<Vec<Result<NliResult, ProviderError>>>,
132        calls: Mutex<Vec<(String, String)>>,
133    },
134    Match {
135        resolver: NliResolver,
136        calls: Mutex<Vec<(String, String)>>,
137    },
138}
139
140impl ScriptedNliClassifier {
141    pub fn new(verdicts: Vec<Result<NliResult, ProviderError>>) -> Self {
142        Self::Fifo {
143            verdicts: Mutex::new(verdicts),
144            calls: Mutex::new(Vec::new()),
145        }
146    }
147
148    pub fn matching<F>(resolver: F) -> Self
149    where
150        F: Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync + 'static,
151    {
152        Self::Match {
153            resolver: Box::new(resolver),
154            calls: Mutex::new(Vec::new()),
155        }
156    }
157
158    pub fn calls(&self) -> Vec<(String, String)> {
159        match self {
160            Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
161        }
162    }
163}
164
165impl NliClassifier for ScriptedNliClassifier {
166    async fn classify(&self, premise: &str, hypothesis: &str) -> Result<NliResult, ProviderError> {
167        match self {
168            Self::Fifo { verdicts, calls } => {
169                calls
170                    .lock()
171                    .unwrap()
172                    .push((premise.to_string(), hypothesis.to_string()));
173                let mut queue = verdicts.lock().unwrap();
174                if queue.is_empty() {
175                    Err(ProviderError::Unavailable("scripted queue empty".into()))
176                } else {
177                    queue.remove(0)
178                }
179            }
180            Self::Match { resolver, calls } => {
181                calls
182                    .lock()
183                    .unwrap()
184                    .push((premise.to_string(), hypothesis.to_string()));
185                resolver(premise, hypothesis)
186            }
187        }
188    }
189}
190
191/// Closure type used by the matcher variant of [`ScriptedReranker`].
192type RerankResolver =
193    Box<dyn Fn(&str, &[String], usize) -> Result<Vec<RerankResult>, ProviderError> + Send + Sync>;
194
195/// Scripted reranker, parity-shaped with [`ScriptedNliClassifier`]:
196/// - [`ScriptedReranker::new`] (FIFO): each call pops the next scripted
197///   result set in order. When the script is exhausted the reranker returns
198///   `Ok(vec![])` (the legitimate "provider found nothing" shape) so the
199///   fail-closed contract of the rerank stage is exercisable without an
200///   explicit error — mirroring a real provider that responded with zero
201///   results rather than a transport failure.
202/// - [`ScriptedReranker::matching`] (Match): each call dispatches to the
203///   supplied closure. Use when survivor ordering is not deterministic
204///   (`HashMap` order) and the test keys scores on the document text, or to
205///   honour the `top_k` argument for truncation assertions.
206///
207/// Both modes record every `(query, document_count, top_k)` triple so tests
208/// can assert on the exact calls the use case made.
209pub enum ScriptedReranker {
210    Fifo {
211        results: Mutex<Vec<Result<Vec<RerankResult>, ProviderError>>>,
212        calls: Mutex<Vec<(String, usize, usize)>>,
213    },
214    Match {
215        resolver: RerankResolver,
216        calls: Mutex<Vec<(String, usize, usize)>>,
217    },
218}
219
220impl ScriptedReranker {
221    pub fn new(results: Vec<Result<Vec<RerankResult>, ProviderError>>) -> Self {
222        Self::Fifo {
223            results: Mutex::new(results),
224            calls: Mutex::new(Vec::new()),
225        }
226    }
227
228    pub fn matching<F>(resolver: F) -> Self
229    where
230        F: Fn(&str, &[String], usize) -> Result<Vec<RerankResult>, ProviderError>
231            + Send
232            + Sync
233            + 'static,
234    {
235        Self::Match {
236            resolver: Box::new(resolver),
237            calls: Mutex::new(Vec::new()),
238        }
239    }
240
241    /// Recorded `(query, document_count, top_k)` triples in invocation order.
242    pub fn calls(&self) -> Vec<(String, usize, usize)> {
243        match self {
244            Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
245        }
246    }
247}
248
249impl RerankProvider for ScriptedReranker {
250    async fn rerank(
251        &self,
252        query: &str,
253        documents: &[String],
254        top_k: usize,
255    ) -> Result<Vec<RerankResult>, ProviderError> {
256        match self {
257            Self::Fifo { results, calls } => {
258                calls
259                    .lock()
260                    .unwrap()
261                    .push((query.to_string(), documents.len(), top_k));
262                let mut queue = results.lock().unwrap();
263                if queue.is_empty() {
264                    Ok(Vec::new())
265                } else {
266                    queue.remove(0)
267                }
268            }
269            Self::Match { resolver, calls } => {
270                calls
271                    .lock()
272                    .unwrap()
273                    .push((query.to_string(), documents.len(), top_k));
274                resolver(query, documents, top_k)
275            }
276        }
277    }
278}