synwire_core/language_models/
fake.rs1use crate::BoxFuture;
4use crate::BoxStream;
5use crate::error::{ModelError, SynwireError};
6use crate::language_models::traits::BaseChatModel;
7use crate::language_models::types::{ChatChunk, ChatResult};
8use crate::messages::Message;
9use crate::runnables::RunnableConfig;
10use crate::tools::ToolSchema;
11use std::sync::Mutex;
12use std::sync::atomic::{AtomicUsize, Ordering};
13
14pub struct FakeChatModel {
34 responses: Vec<String>,
35 call_count: AtomicUsize,
36 error_at: Option<usize>,
37 calls: Mutex<Vec<Vec<Message>>>,
38 chunk_size: Option<usize>,
41 stream_error_after: Option<usize>,
44}
45
46impl FakeChatModel {
47 pub const fn new(responses: Vec<String>) -> Self {
52 Self {
53 responses,
54 call_count: AtomicUsize::new(0),
55 error_at: None,
56 calls: Mutex::new(Vec::new()),
57 chunk_size: None,
58 stream_error_after: None,
59 }
60 }
61
62 #[must_use]
67 pub const fn with_error_at(mut self, index: usize) -> Self {
68 self.error_at = Some(index);
69 self
70 }
71
72 #[must_use]
77 pub const fn with_chunk_size(mut self, size: usize) -> Self {
78 self.chunk_size = Some(size);
79 self
80 }
81
82 #[must_use]
87 pub const fn with_stream_error_after(mut self, n: usize) -> Self {
88 self.stream_error_after = Some(n);
89 self
90 }
91
92 pub fn call_count(&self) -> usize {
94 self.call_count.load(Ordering::Relaxed)
95 }
96
97 pub fn calls(&self) -> Vec<Vec<Message>> {
99 self.calls.lock().map_or_else(|_| Vec::new(), |g| g.clone())
100 }
101}
102
103impl std::fmt::Debug for FakeChatModel {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("FakeChatModel")
106 .field("responses", &self.responses)
107 .field("call_count", &self.call_count.load(Ordering::Relaxed))
108 .field("error_at", &self.error_at)
109 .field("calls", &self.calls)
110 .field("chunk_size", &self.chunk_size)
111 .field("stream_error_after", &self.stream_error_after)
112 .finish()
113 }
114}
115
116impl BaseChatModel for FakeChatModel {
117 fn invoke<'a>(
118 &'a self,
119 messages: &'a [Message],
120 _config: Option<&'a RunnableConfig>,
121 ) -> BoxFuture<'a, Result<ChatResult, SynwireError>> {
122 Box::pin(async move {
123 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
124
125 if let Ok(mut calls) = self.calls.lock() {
127 calls.push(messages.to_vec());
128 }
129
130 if self.error_at == Some(idx) {
132 return Err(SynwireError::from(ModelError::Other {
133 message: format!("injected error at call {idx}"),
134 }));
135 }
136
137 let response_text = self
138 .responses
139 .get(idx % self.responses.len())
140 .cloned()
141 .unwrap_or_default();
142
143 Ok(ChatResult {
144 message: Message::ai(response_text),
145 generation_info: None,
146 cost: None,
147 })
148 })
149 }
150
151 fn stream<'a>(
152 &'a self,
153 messages: &'a [Message],
154 config: Option<&'a RunnableConfig>,
155 ) -> BoxFuture<'a, Result<BoxStream<'a, Result<ChatChunk, SynwireError>>, SynwireError>> {
156 Box::pin(async move {
157 let result = self.invoke(messages, config).await?;
158 let full_text = result.message.content().as_text();
159
160 let chunk_size = self.chunk_size.unwrap_or(full_text.len()).max(1);
161 let error_after = self.stream_error_after;
162
163 let chunks: Vec<String> = full_text
164 .chars()
165 .collect::<Vec<_>>()
166 .chunks(chunk_size)
167 .map(|c| c.iter().collect())
168 .collect();
169
170 let total = chunks.len();
171 let stream =
172 futures_util::stream::iter(chunks.into_iter().enumerate().map(move |(i, text)| {
173 if let Some(error_at) = error_after
174 && i >= error_at
175 {
176 return Err(SynwireError::from(ModelError::Other {
177 message: "stream error injected".into(),
178 }));
179 }
180 let finish_reason = if i + 1 == total {
181 Some("stop".into())
182 } else {
183 None
184 };
185 Ok(ChatChunk {
186 delta_content: Some(text),
187 delta_tool_calls: Vec::new(),
188 finish_reason,
189 usage: None,
190 })
191 }));
192
193 Ok(Box::pin(stream) as BoxStream<'_, Result<ChatChunk, SynwireError>>)
194 })
195 }
196
197 fn model_type(&self) -> &'static str {
198 "fake"
199 }
200
201 fn bind_tools(&self, _tools: &[ToolSchema]) -> Result<Box<dyn BaseChatModel>, SynwireError> {
202 let mut model = Self::new(self.responses.clone());
203 model.chunk_size = self.chunk_size;
204 model.stream_error_after = self.stream_error_after;
205 Ok(Box::new(model))
206 }
207}
208
209#[cfg(test)]
210#[allow(clippy::unwrap_used)]
211mod tests {
212 use super::*;
213
214 #[tokio::test]
215 async fn test_fake_chat_model_invoke_returns_chat_result() {
216 let model = FakeChatModel::new(vec!["Hello!".into()]);
217 let messages = vec![Message::human("Hi")];
218 let result = model.invoke(&messages, None).await.unwrap();
219 assert_eq!(result.message.content().as_text(), "Hello!");
220 assert_eq!(result.message.message_type(), "ai");
221 }
222
223 #[tokio::test]
224 async fn test_fake_chat_model_invoke_with_error() {
225 let model = FakeChatModel::new(vec!["ok".into()]).with_error_at(0);
226 let messages = vec![Message::human("Hi")];
227 let result = model.invoke(&messages, None).await;
228 assert!(result.is_err());
229 }
230
231 #[tokio::test]
232 async fn test_fake_chat_model_swap_compiles() {
233 let model_a: Box<dyn BaseChatModel> = Box::new(FakeChatModel::new(vec!["A".into()]));
234 let model_b: Box<dyn BaseChatModel> = Box::new(FakeChatModel::new(vec!["B".into()]));
235 let messages = vec![Message::human("test")];
236
237 let result_a = model_a.invoke(&messages, None).await.unwrap();
238 let result_b = model_b.invoke(&messages, None).await.unwrap();
239 assert_eq!(result_a.message.content().as_text(), "A");
240 assert_eq!(result_b.message.content().as_text(), "B");
241 }
242
243 #[tokio::test]
244 async fn test_fake_chat_model_batch() {
245 let model = FakeChatModel::new(vec!["R1".into(), "R2".into()]);
246 let inputs = vec![vec![Message::human("Q1")], vec![Message::human("Q2")]];
247 let results = model.batch(&inputs, None).await.unwrap();
248 assert_eq!(results.len(), 2);
249 assert_eq!(results[0].message.content().as_text(), "R1");
250 assert_eq!(results[1].message.content().as_text(), "R2");
251 }
252
253 #[tokio::test]
254 async fn test_invoke_empty_messages_returns_result() {
255 let model = FakeChatModel::new(vec!["response".into()]);
256 let result = model.invoke(&[], None).await.unwrap();
257 assert_eq!(result.message.content().as_text(), "response");
258 }
259
260 #[tokio::test]
261 async fn test_bind_tools_returns_model() {
262 let model = FakeChatModel::new(vec!["ok".into()]);
263 let tools = vec![crate::tools::ToolSchema {
264 name: "search".into(),
265 description: "Search".into(),
266 parameters: serde_json::json!({}),
267 }];
268 let bound = model.bind_tools(&tools).unwrap();
269 assert_eq!(bound.model_type(), "fake");
270 }
271
272 #[tokio::test]
273 async fn test_call_tracking() {
274 let model = FakeChatModel::new(vec!["A".into(), "B".into()]);
275 let _r1 = model.invoke(&[Message::human("Q1")], None).await.unwrap();
276 let _r2 = model.invoke(&[Message::human("Q2")], None).await.unwrap();
277 assert_eq!(model.call_count(), 2);
278 let calls = model.calls();
279 assert_eq!(calls.len(), 2);
280 }
281
282 #[tokio::test]
283 async fn test_fake_stream_yields_chunks_in_order() {
284 use futures_util::StreamExt as _;
285
286 let model = FakeChatModel::new(vec!["abcdefgh".into()]).with_chunk_size(3);
287 let messages = vec![Message::human("Hi")];
288 let mut stream = model.stream(&messages, None).await.unwrap();
289
290 let mut chunks = Vec::new();
291 while let Some(result) = stream.next().await {
292 let chunk = result.unwrap();
293 if let Some(text) = &chunk.delta_content {
294 chunks.push(text.clone());
295 }
296 }
297
298 assert_eq!(chunks, vec!["abc", "def", "gh"]);
299 }
300
301 #[tokio::test]
302 async fn test_concatenated_stream_equals_invoke() {
303 use futures_util::StreamExt as _;
304
305 let response = "Hello, this is a test response!";
306 let model = FakeChatModel::new(vec![response.into()]).with_chunk_size(5);
307 let messages = vec![Message::human("Hi")];
308
309 let mut stream = model.stream(&messages, None).await.unwrap();
311 let mut streamed = String::new();
312 while let Some(result) = stream.next().await {
313 let chunk = result.unwrap();
314 if let Some(text) = &chunk.delta_content {
315 streamed.push_str(text);
316 }
317 }
318
319 let invoke_result = model.invoke(&messages, None).await.unwrap();
322 let invoked = invoke_result.message.content().as_text();
323
324 assert_eq!(streamed, invoked);
325 }
326
327 #[tokio::test]
328 async fn test_stream_mid_error() {
329 use futures_util::StreamExt as _;
330
331 let model = FakeChatModel::new(vec!["abcdefghij".into()])
332 .with_chunk_size(2)
333 .with_stream_error_after(2);
334
335 let messages = vec![Message::human("Hi")];
336 let mut stream = model.stream(&messages, None).await.unwrap();
337
338 let mut ok_chunks = Vec::new();
339 let mut saw_error = false;
340
341 while let Some(result) = stream.next().await {
342 if let Ok(chunk) = result {
343 if let Some(text) = &chunk.delta_content {
344 ok_chunks.push(text.clone());
345 }
346 } else {
347 saw_error = true;
348 break;
349 }
350 }
351
352 assert_eq!(ok_chunks, vec!["ab", "cd"]);
353 assert!(saw_error, "expected an error after 2 chunks");
354 }
355
356 #[tokio::test]
357 async fn test_stream_drop_no_leak() {
358 use futures_util::StreamExt as _;
359
360 let model = FakeChatModel::new(vec!["abcdefghij".into()]).with_chunk_size(2);
361 let messages = vec![Message::human("Hi")];
362 let mut stream = model.stream(&messages, None).await.unwrap();
363
364 let first = stream.next().await;
366 assert!(first.is_some());
367 drop(stream);
368 }
370
371 #[tokio::test]
372 async fn test_runnable_core_default_stream() {
373 use crate::runnables::core::RunnableCore;
374 use futures_util::StreamExt as _;
375
376 struct EchoRunnable;
377
378 impl RunnableCore for EchoRunnable {
379 fn invoke<'a>(
380 &'a self,
381 input: serde_json::Value,
382 _config: Option<&'a crate::runnables::RunnableConfig>,
383 ) -> crate::BoxFuture<'a, Result<serde_json::Value, crate::error::SynwireError>>
384 {
385 Box::pin(async move { Ok(input) })
386 }
387 }
388
389 let runnable = EchoRunnable;
390 let input = serde_json::json!({"greeting": "hello"});
391 let mut stream = runnable.stream(input.clone(), None).await.unwrap();
392
393 let first = stream.next().await;
394 assert!(first.is_some());
395 let value = first.unwrap().unwrap();
396 assert_eq!(value, input);
397
398 let second = stream.next().await;
400 assert!(second.is_none());
401 }
402}