Skip to main content

synwire_core/language_models/
fake.rs

1//! Fake chat model for testing.
2
3use crate::BoxFuture;
4use crate::BoxStream;
5use crate::error::{ModelError, SynwireError};
6use crate::language_models::traits::BaseChatModel;
7use crate::language_models::types::{ChatChunk, ChatResult};
8use crate::messages::Message;
9use crate::runnables::RunnableConfig;
10use crate::tools::ToolSchema;
11use std::sync::Mutex;
12use std::sync::atomic::{AtomicUsize, Ordering};
13
14/// A fake chat model for testing without API calls.
15///
16/// Returns pre-configured responses in order. Tracks call count
17/// and can inject errors at specified positions.
18///
19/// # Examples
20///
21/// ```
22/// use synwire_core::language_models::fake::FakeChatModel;
23/// use synwire_core::language_models::traits::BaseChatModel;
24/// use synwire_core::messages::Message;
25///
26/// # tokio_test::block_on(async {
27/// let model = FakeChatModel::new(vec!["Hello!".into()]);
28/// let messages = vec![Message::human("Hi")];
29/// let result = model.invoke(&messages, None).await.unwrap();
30/// assert_eq!(result.message.content().as_text(), "Hello!");
31/// # });
32/// ```
33pub struct FakeChatModel {
34    responses: Vec<String>,
35    call_count: AtomicUsize,
36    error_at: Option<usize>,
37    calls: Mutex<Vec<Vec<Message>>>,
38    /// When `Some(n)`, the stream method splits each response into chunks
39    /// of `n` characters instead of yielding the entire text as one chunk.
40    chunk_size: Option<usize>,
41    /// When `Some(n)`, the stream method injects an error after yielding
42    /// `n` chunks, causing all subsequent chunks to be errors.
43    stream_error_after: Option<usize>,
44}
45
46impl FakeChatModel {
47    /// Creates a new fake chat model with the given responses.
48    ///
49    /// Responses are returned in order, cycling back to the start
50    /// when all responses have been used.
51    pub const fn new(responses: Vec<String>) -> Self {
52        Self {
53            responses,
54            call_count: AtomicUsize::new(0),
55            error_at: None,
56            calls: Mutex::new(Vec::new()),
57            chunk_size: None,
58            stream_error_after: None,
59        }
60    }
61
62    /// Sets the call index at which to return an error.
63    ///
64    /// When the zero-based call count matches `index`, the model
65    /// returns a [`ModelError::Other`] instead of a response.
66    #[must_use]
67    pub const fn with_error_at(mut self, index: usize) -> Self {
68        self.error_at = Some(index);
69        self
70    }
71
72    /// Sets the chunk size for streaming responses.
73    ///
74    /// When set, the [`stream`](BaseChatModel::stream) method splits each
75    /// response into chunks of at most `size` characters.
76    #[must_use]
77    pub const fn with_chunk_size(mut self, size: usize) -> Self {
78        self.chunk_size = Some(size);
79        self
80    }
81
82    /// Configures an error to be injected after `n` chunks during streaming.
83    ///
84    /// The stream yields the first `n` chunks successfully, then returns
85    /// an error for every subsequent chunk position.
86    #[must_use]
87    pub const fn with_stream_error_after(mut self, n: usize) -> Self {
88        self.stream_error_after = Some(n);
89        self
90    }
91
92    /// Returns the number of times invoke has been called.
93    pub fn call_count(&self) -> usize {
94        self.call_count.load(Ordering::Relaxed)
95    }
96
97    /// Returns a clone of all recorded input message lists.
98    pub fn calls(&self) -> Vec<Vec<Message>> {
99        self.calls.lock().map_or_else(|_| Vec::new(), |g| g.clone())
100    }
101}
102
103impl std::fmt::Debug for FakeChatModel {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("FakeChatModel")
106            .field("responses", &self.responses)
107            .field("call_count", &self.call_count.load(Ordering::Relaxed))
108            .field("error_at", &self.error_at)
109            .field("calls", &self.calls)
110            .field("chunk_size", &self.chunk_size)
111            .field("stream_error_after", &self.stream_error_after)
112            .finish()
113    }
114}
115
116impl BaseChatModel for FakeChatModel {
117    fn invoke<'a>(
118        &'a self,
119        messages: &'a [Message],
120        _config: Option<&'a RunnableConfig>,
121    ) -> BoxFuture<'a, Result<ChatResult, SynwireError>> {
122        Box::pin(async move {
123            let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
124
125            // Record call
126            if let Ok(mut calls) = self.calls.lock() {
127                calls.push(messages.to_vec());
128            }
129
130            // Check for error injection
131            if self.error_at == Some(idx) {
132                return Err(SynwireError::from(ModelError::Other {
133                    message: format!("injected error at call {idx}"),
134                }));
135            }
136
137            let response_text = self
138                .responses
139                .get(idx % self.responses.len())
140                .cloned()
141                .unwrap_or_default();
142
143            Ok(ChatResult {
144                message: Message::ai(response_text),
145                generation_info: None,
146                cost: None,
147            })
148        })
149    }
150
151    fn stream<'a>(
152        &'a self,
153        messages: &'a [Message],
154        config: Option<&'a RunnableConfig>,
155    ) -> BoxFuture<'a, Result<BoxStream<'a, Result<ChatChunk, SynwireError>>, SynwireError>> {
156        Box::pin(async move {
157            let result = self.invoke(messages, config).await?;
158            let full_text = result.message.content().as_text();
159
160            let chunk_size = self.chunk_size.unwrap_or(full_text.len()).max(1);
161            let error_after = self.stream_error_after;
162
163            let chunks: Vec<String> = full_text
164                .chars()
165                .collect::<Vec<_>>()
166                .chunks(chunk_size)
167                .map(|c| c.iter().collect())
168                .collect();
169
170            let total = chunks.len();
171            let stream =
172                futures_util::stream::iter(chunks.into_iter().enumerate().map(move |(i, text)| {
173                    if let Some(error_at) = error_after
174                        && i >= error_at
175                    {
176                        return Err(SynwireError::from(ModelError::Other {
177                            message: "stream error injected".into(),
178                        }));
179                    }
180                    let finish_reason = if i + 1 == total {
181                        Some("stop".into())
182                    } else {
183                        None
184                    };
185                    Ok(ChatChunk {
186                        delta_content: Some(text),
187                        delta_tool_calls: Vec::new(),
188                        finish_reason,
189                        usage: None,
190                    })
191                }));
192
193            Ok(Box::pin(stream) as BoxStream<'_, Result<ChatChunk, SynwireError>>)
194        })
195    }
196
197    fn model_type(&self) -> &'static str {
198        "fake"
199    }
200
201    fn bind_tools(&self, _tools: &[ToolSchema]) -> Result<Box<dyn BaseChatModel>, SynwireError> {
202        let mut model = Self::new(self.responses.clone());
203        model.chunk_size = self.chunk_size;
204        model.stream_error_after = self.stream_error_after;
205        Ok(Box::new(model))
206    }
207}
208
209#[cfg(test)]
210#[allow(clippy::unwrap_used)]
211mod tests {
212    use super::*;
213
214    #[tokio::test]
215    async fn test_fake_chat_model_invoke_returns_chat_result() {
216        let model = FakeChatModel::new(vec!["Hello!".into()]);
217        let messages = vec![Message::human("Hi")];
218        let result = model.invoke(&messages, None).await.unwrap();
219        assert_eq!(result.message.content().as_text(), "Hello!");
220        assert_eq!(result.message.message_type(), "ai");
221    }
222
223    #[tokio::test]
224    async fn test_fake_chat_model_invoke_with_error() {
225        let model = FakeChatModel::new(vec!["ok".into()]).with_error_at(0);
226        let messages = vec![Message::human("Hi")];
227        let result = model.invoke(&messages, None).await;
228        assert!(result.is_err());
229    }
230
231    #[tokio::test]
232    async fn test_fake_chat_model_swap_compiles() {
233        let model_a: Box<dyn BaseChatModel> = Box::new(FakeChatModel::new(vec!["A".into()]));
234        let model_b: Box<dyn BaseChatModel> = Box::new(FakeChatModel::new(vec!["B".into()]));
235        let messages = vec![Message::human("test")];
236
237        let result_a = model_a.invoke(&messages, None).await.unwrap();
238        let result_b = model_b.invoke(&messages, None).await.unwrap();
239        assert_eq!(result_a.message.content().as_text(), "A");
240        assert_eq!(result_b.message.content().as_text(), "B");
241    }
242
243    #[tokio::test]
244    async fn test_fake_chat_model_batch() {
245        let model = FakeChatModel::new(vec!["R1".into(), "R2".into()]);
246        let inputs = vec![vec![Message::human("Q1")], vec![Message::human("Q2")]];
247        let results = model.batch(&inputs, None).await.unwrap();
248        assert_eq!(results.len(), 2);
249        assert_eq!(results[0].message.content().as_text(), "R1");
250        assert_eq!(results[1].message.content().as_text(), "R2");
251    }
252
253    #[tokio::test]
254    async fn test_invoke_empty_messages_returns_result() {
255        let model = FakeChatModel::new(vec!["response".into()]);
256        let result = model.invoke(&[], None).await.unwrap();
257        assert_eq!(result.message.content().as_text(), "response");
258    }
259
260    #[tokio::test]
261    async fn test_bind_tools_returns_model() {
262        let model = FakeChatModel::new(vec!["ok".into()]);
263        let tools = vec![crate::tools::ToolSchema {
264            name: "search".into(),
265            description: "Search".into(),
266            parameters: serde_json::json!({}),
267        }];
268        let bound = model.bind_tools(&tools).unwrap();
269        assert_eq!(bound.model_type(), "fake");
270    }
271
272    #[tokio::test]
273    async fn test_call_tracking() {
274        let model = FakeChatModel::new(vec!["A".into(), "B".into()]);
275        let _r1 = model.invoke(&[Message::human("Q1")], None).await.unwrap();
276        let _r2 = model.invoke(&[Message::human("Q2")], None).await.unwrap();
277        assert_eq!(model.call_count(), 2);
278        let calls = model.calls();
279        assert_eq!(calls.len(), 2);
280    }
281
282    #[tokio::test]
283    async fn test_fake_stream_yields_chunks_in_order() {
284        use futures_util::StreamExt as _;
285
286        let model = FakeChatModel::new(vec!["abcdefgh".into()]).with_chunk_size(3);
287        let messages = vec![Message::human("Hi")];
288        let mut stream = model.stream(&messages, None).await.unwrap();
289
290        let mut chunks = Vec::new();
291        while let Some(result) = stream.next().await {
292            let chunk = result.unwrap();
293            if let Some(text) = &chunk.delta_content {
294                chunks.push(text.clone());
295            }
296        }
297
298        assert_eq!(chunks, vec!["abc", "def", "gh"]);
299    }
300
301    #[tokio::test]
302    async fn test_concatenated_stream_equals_invoke() {
303        use futures_util::StreamExt as _;
304
305        let response = "Hello, this is a test response!";
306        let model = FakeChatModel::new(vec![response.into()]).with_chunk_size(5);
307        let messages = vec![Message::human("Hi")];
308
309        // Stream and concatenate
310        let mut stream = model.stream(&messages, None).await.unwrap();
311        let mut streamed = String::new();
312        while let Some(result) = stream.next().await {
313            let chunk = result.unwrap();
314            if let Some(text) = &chunk.delta_content {
315                streamed.push_str(text);
316            }
317        }
318
319        // Invoke (call_count is now 1 from stream's internal invoke, so
320        // the second call cycles to index 1 % 1 == 0, returning the same response)
321        let invoke_result = model.invoke(&messages, None).await.unwrap();
322        let invoked = invoke_result.message.content().as_text();
323
324        assert_eq!(streamed, invoked);
325    }
326
327    #[tokio::test]
328    async fn test_stream_mid_error() {
329        use futures_util::StreamExt as _;
330
331        let model = FakeChatModel::new(vec!["abcdefghij".into()])
332            .with_chunk_size(2)
333            .with_stream_error_after(2);
334
335        let messages = vec![Message::human("Hi")];
336        let mut stream = model.stream(&messages, None).await.unwrap();
337
338        let mut ok_chunks = Vec::new();
339        let mut saw_error = false;
340
341        while let Some(result) = stream.next().await {
342            if let Ok(chunk) = result {
343                if let Some(text) = &chunk.delta_content {
344                    ok_chunks.push(text.clone());
345                }
346            } else {
347                saw_error = true;
348                break;
349            }
350        }
351
352        assert_eq!(ok_chunks, vec!["ab", "cd"]);
353        assert!(saw_error, "expected an error after 2 chunks");
354    }
355
356    #[tokio::test]
357    async fn test_stream_drop_no_leak() {
358        use futures_util::StreamExt as _;
359
360        let model = FakeChatModel::new(vec!["abcdefghij".into()]).with_chunk_size(2);
361        let messages = vec![Message::human("Hi")];
362        let mut stream = model.stream(&messages, None).await.unwrap();
363
364        // Consume only the first chunk, then drop the stream
365        let first = stream.next().await;
366        assert!(first.is_some());
367        drop(stream);
368        // No panic or resource leak -- test passes by completing successfully
369    }
370
371    #[tokio::test]
372    async fn test_runnable_core_default_stream() {
373        use crate::runnables::core::RunnableCore;
374        use futures_util::StreamExt as _;
375
376        struct EchoRunnable;
377
378        impl RunnableCore for EchoRunnable {
379            fn invoke<'a>(
380                &'a self,
381                input: serde_json::Value,
382                _config: Option<&'a crate::runnables::RunnableConfig>,
383            ) -> crate::BoxFuture<'a, Result<serde_json::Value, crate::error::SynwireError>>
384            {
385                Box::pin(async move { Ok(input) })
386            }
387        }
388
389        let runnable = EchoRunnable;
390        let input = serde_json::json!({"greeting": "hello"});
391        let mut stream = runnable.stream(input.clone(), None).await.unwrap();
392
393        let first = stream.next().await;
394        assert!(first.is_some());
395        let value = first.unwrap().unwrap();
396        assert_eq!(value, input);
397
398        // Should have no more items
399        let second = stream.next().await;
400        assert!(second.is_none());
401    }
402}