smos_application/testkit/
providers.rs1use std::sync::{Arc, Mutex};
5
6use smos_domain::NliResult;
7use smos_domain::chat::ToolCall;
8
9use crate::errors::ProviderError;
10use crate::ports::{EmbeddingProvider, LlmExtractor, NliClassifier, RerankProvider};
11use crate::types::RerankResult;
12
13pub struct ScriptedExtractor {
23 results: Mutex<Vec<Result<Vec<String>, ProviderError>>>,
24 calls: Mutex<u32>,
25 inputs: Mutex<Vec<(String, Vec<ToolCall>)>>,
26}
27
28impl ScriptedExtractor {
29 pub fn new(results: Vec<Result<Vec<String>, ProviderError>>) -> Self {
30 Self {
31 results: Mutex::new(results),
32 calls: Mutex::new(0),
33 inputs: Mutex::new(Vec::new()),
34 }
35 }
36
37 pub fn call_count(&self) -> u32 {
38 *self.calls.lock().unwrap()
39 }
40
41 pub fn inputs(&self) -> Vec<(String, Vec<ToolCall>)> {
43 self.inputs.lock().unwrap().clone()
44 }
45}
46
47impl LlmExtractor for ScriptedExtractor {
48 async fn extract_facts(
49 &self,
50 content: &str,
51 tool_calls: &[ToolCall],
52 ) -> Result<Vec<String>, ProviderError> {
53 *self.calls.lock().unwrap() += 1;
54 self.inputs
55 .lock()
56 .unwrap()
57 .push((content.to_string(), tool_calls.to_vec()));
58 let mut guard = self.results.lock().unwrap();
59 if guard.is_empty() {
60 Ok(Vec::new())
61 } else {
62 guard.remove(0)
63 }
64 }
65}
66
67pub struct ConstantEmbedder(pub Vec<f32>);
69
70impl EmbeddingProvider for ConstantEmbedder {
71 async fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
72 Ok(Some(self.0.clone()))
73 }
74}
75
76pub struct RecordingEmbedder {
82 calls: Arc<Mutex<Vec<String>>>,
83}
84
85impl RecordingEmbedder {
86 pub fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
87 let calls = Arc::new(Mutex::new(Vec::new()));
88 (
89 Self {
90 calls: calls.clone(),
91 },
92 calls,
93 )
94 }
95
96 fn vector_for(text: &str) -> Vec<f32> {
97 let hash = text
102 .bytes()
103 .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
104 let mut vec = vec![0.0; 1024];
105 vec[(hash as usize) % 1024] = 1.0;
106 vec
107 }
108}
109
110impl EmbeddingProvider for RecordingEmbedder {
111 async fn embed(&self, text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
112 self.calls.lock().unwrap().push(text.to_string());
113 Ok(Some(Self::vector_for(text)))
114 }
115}
116
117type NliResolver = Box<dyn Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync>;
119
120pub enum ScriptedNliClassifier {
130 Fifo {
131 verdicts: Mutex<Vec<Result<NliResult, ProviderError>>>,
132 calls: Mutex<Vec<(String, String)>>,
133 },
134 Match {
135 resolver: NliResolver,
136 calls: Mutex<Vec<(String, String)>>,
137 },
138}
139
140impl ScriptedNliClassifier {
141 pub fn new(verdicts: Vec<Result<NliResult, ProviderError>>) -> Self {
142 Self::Fifo {
143 verdicts: Mutex::new(verdicts),
144 calls: Mutex::new(Vec::new()),
145 }
146 }
147
148 pub fn matching<F>(resolver: F) -> Self
149 where
150 F: Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync + 'static,
151 {
152 Self::Match {
153 resolver: Box::new(resolver),
154 calls: Mutex::new(Vec::new()),
155 }
156 }
157
158 pub fn calls(&self) -> Vec<(String, String)> {
159 match self {
160 Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
161 }
162 }
163}
164
165impl NliClassifier for ScriptedNliClassifier {
166 async fn classify(&self, premise: &str, hypothesis: &str) -> Result<NliResult, ProviderError> {
167 match self {
168 Self::Fifo { verdicts, calls } => {
169 calls
170 .lock()
171 .unwrap()
172 .push((premise.to_string(), hypothesis.to_string()));
173 let mut queue = verdicts.lock().unwrap();
174 if queue.is_empty() {
175 Err(ProviderError::Unavailable("scripted queue empty".into()))
176 } else {
177 queue.remove(0)
178 }
179 }
180 Self::Match { resolver, calls } => {
181 calls
182 .lock()
183 .unwrap()
184 .push((premise.to_string(), hypothesis.to_string()));
185 resolver(premise, hypothesis)
186 }
187 }
188 }
189}
190
191type RerankResolver =
193 Box<dyn Fn(&str, &[String], usize) -> Result<Vec<RerankResult>, ProviderError> + Send + Sync>;
194
195pub enum ScriptedReranker {
210 Fifo {
211 results: Mutex<Vec<Result<Vec<RerankResult>, ProviderError>>>,
212 calls: Mutex<Vec<(String, usize, usize)>>,
213 },
214 Match {
215 resolver: RerankResolver,
216 calls: Mutex<Vec<(String, usize, usize)>>,
217 },
218}
219
220impl ScriptedReranker {
221 pub fn new(results: Vec<Result<Vec<RerankResult>, ProviderError>>) -> Self {
222 Self::Fifo {
223 results: Mutex::new(results),
224 calls: Mutex::new(Vec::new()),
225 }
226 }
227
228 pub fn matching<F>(resolver: F) -> Self
229 where
230 F: Fn(&str, &[String], usize) -> Result<Vec<RerankResult>, ProviderError>
231 + Send
232 + Sync
233 + 'static,
234 {
235 Self::Match {
236 resolver: Box::new(resolver),
237 calls: Mutex::new(Vec::new()),
238 }
239 }
240
241 pub fn calls(&self) -> Vec<(String, usize, usize)> {
243 match self {
244 Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
245 }
246 }
247}
248
249impl RerankProvider for ScriptedReranker {
250 async fn rerank(
251 &self,
252 query: &str,
253 documents: &[String],
254 top_k: usize,
255 ) -> Result<Vec<RerankResult>, ProviderError> {
256 match self {
257 Self::Fifo { results, calls } => {
258 calls
259 .lock()
260 .unwrap()
261 .push((query.to_string(), documents.len(), top_k));
262 let mut queue = results.lock().unwrap();
263 if queue.is_empty() {
264 Ok(Vec::new())
265 } else {
266 queue.remove(0)
267 }
268 }
269 Self::Match { resolver, calls } => {
270 calls
271 .lock()
272 .unwrap()
273 .push((query.to_string(), documents.len(), top_k));
274 resolver(query, documents, top_k)
275 }
276 }
277 }
278}