smos_application/testkit/
providers.rs1use std::sync::{Arc, Mutex};
5
6use smos_domain::NliResult;
7use smos_domain::chat::ToolCall;
8
9use crate::errors::ProviderError;
10use crate::ports::{EmbeddingProvider, LlmExtractor, NliClassifier};
11
12pub struct ScriptedExtractor {
17 results: Mutex<Vec<Result<Vec<String>, ProviderError>>>,
18 calls: Mutex<u32>,
19}
20
21impl ScriptedExtractor {
22 pub fn new(results: Vec<Result<Vec<String>, ProviderError>>) -> Self {
23 Self {
24 results: Mutex::new(results),
25 calls: Mutex::new(0),
26 }
27 }
28
29 pub fn call_count(&self) -> u32 {
30 *self.calls.lock().unwrap()
31 }
32}
33
34impl LlmExtractor for ScriptedExtractor {
35 async fn extract_facts(
36 &self,
37 _content: &str,
38 _tool_calls: &[ToolCall],
39 ) -> Result<Vec<String>, ProviderError> {
40 *self.calls.lock().unwrap() += 1;
41 let mut guard = self.results.lock().unwrap();
42 if guard.is_empty() {
43 Ok(Vec::new())
44 } else {
45 guard.remove(0)
46 }
47 }
48}
49
50pub struct ConstantEmbedder(pub Vec<f32>);
52
53impl EmbeddingProvider for ConstantEmbedder {
54 async fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
55 Ok(Some(self.0.clone()))
56 }
57}
58
59pub struct RecordingEmbedder {
65 calls: Arc<Mutex<Vec<String>>>,
66}
67
68impl RecordingEmbedder {
69 pub fn new() -> (Self, Arc<Mutex<Vec<String>>>) {
70 let calls = Arc::new(Mutex::new(Vec::new()));
71 (
72 Self {
73 calls: calls.clone(),
74 },
75 calls,
76 )
77 }
78
79 fn vector_for(text: &str) -> Vec<f32> {
80 let hash = text
85 .bytes()
86 .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
87 let mut vec = vec![0.0; 1024];
88 vec[(hash as usize) % 1024] = 1.0;
89 vec
90 }
91}
92
93impl EmbeddingProvider for RecordingEmbedder {
94 async fn embed(&self, text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
95 self.calls.lock().unwrap().push(text.to_string());
96 Ok(Some(Self::vector_for(text)))
97 }
98}
99
100type NliResolver = Box<dyn Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync>;
102
103pub enum ScriptedNliClassifier {
113 Fifo {
114 verdicts: Mutex<Vec<Result<NliResult, ProviderError>>>,
115 calls: Mutex<Vec<(String, String)>>,
116 },
117 Match {
118 resolver: NliResolver,
119 calls: Mutex<Vec<(String, String)>>,
120 },
121}
122
123impl ScriptedNliClassifier {
124 pub fn new(verdicts: Vec<Result<NliResult, ProviderError>>) -> Self {
125 Self::Fifo {
126 verdicts: Mutex::new(verdicts),
127 calls: Mutex::new(Vec::new()),
128 }
129 }
130
131 pub fn matching<F>(resolver: F) -> Self
132 where
133 F: Fn(&str, &str) -> Result<NliResult, ProviderError> + Send + Sync + 'static,
134 {
135 Self::Match {
136 resolver: Box::new(resolver),
137 calls: Mutex::new(Vec::new()),
138 }
139 }
140
141 pub fn calls(&self) -> Vec<(String, String)> {
142 match self {
143 Self::Fifo { calls, .. } | Self::Match { calls, .. } => calls.lock().unwrap().clone(),
144 }
145 }
146}
147
148impl NliClassifier for ScriptedNliClassifier {
149 async fn classify(&self, premise: &str, hypothesis: &str) -> Result<NliResult, ProviderError> {
150 match self {
151 Self::Fifo { verdicts, calls } => {
152 calls
153 .lock()
154 .unwrap()
155 .push((premise.to_string(), hypothesis.to_string()));
156 let mut queue = verdicts.lock().unwrap();
157 if queue.is_empty() {
158 Err(ProviderError::Unavailable("scripted queue empty".into()))
159 } else {
160 queue.remove(0)
161 }
162 }
163 Self::Match { resolver, calls } => {
164 calls
165 .lock()
166 .unwrap()
167 .push((premise.to_string(), hypothesis.to_string()));
168 resolver(premise, hypothesis)
169 }
170 }
171 }
172}