1use crate::core::error::RustChainError;
2use async_trait::async_trait;
3use futures::stream::Stream;
4use std::pin::Pin;
5
6#[async_trait]
7pub trait LLMBackend: Send + Sync {
8 async fn generate(&self, prompt: &str) -> Result<String, RustChainError> {
9 let mut stream = self.stream(prompt).await?;
10 let mut output = String::new();
11 use futures::StreamExt;
12 while let Some(chunk) = stream.next().await {
13 output.push_str(&chunk?);
14 }
15 Ok(output)
16 }
17
18 async fn stream(
19 &self,
20 prompt: &str,
21 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError>;
22
23 fn name(&self) -> &'static str;
24
25 async fn health_check(&self) -> Result<bool, RustChainError>;
26}
27
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use crate::core::error::RustChainError;
32 use async_trait::async_trait;
33 use futures::stream;
34 use std::pin::Pin;
35
36 struct MockLLMBackend {
38 responses: Vec<String>,
39 should_fail: bool,
40 stream_chunks: Vec<String>,
41 health_status: bool,
42 }
43
44 impl MockLLMBackend {
45 fn new() -> Self {
46 Self {
47 responses: vec!["Default mock response".to_string()],
48 should_fail: false,
49 stream_chunks: vec!["Hello".to_string(), " world!".to_string()],
50 health_status: true,
51 }
52 }
53
54 fn with_responses(mut self, responses: Vec<String>) -> Self {
55 self.responses = responses;
56 self
57 }
58
59 fn with_failure(mut self, should_fail: bool) -> Self {
60 self.should_fail = should_fail;
61 self
62 }
63
64 fn with_stream_chunks(mut self, chunks: Vec<String>) -> Self {
65 self.stream_chunks = chunks;
66 self
67 }
68
69 fn with_health_status(mut self, healthy: bool) -> Self {
70 self.health_status = healthy;
71 self
72 }
73 }
74
75 #[async_trait]
76 impl LLMBackend for MockLLMBackend {
77 async fn generate(&self, prompt: &str) -> Result<String, RustChainError> {
78 if self.should_fail {
79 return Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
80 "Mock LLM failure".to_string()
81 )));
82 }
83
84 if prompt.contains("error") {
86 Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
87 "Prompt contained error".to_string()
88 )))
89 } else if prompt.contains("hello") {
90 Ok("Hello! How can I help you today?".to_string())
91 } else if prompt.contains("translate") {
92 Ok("Translated text: Bonjour le monde!".to_string())
93 } else {
94 Ok(self.responses.get(0).unwrap_or(&"Default response".to_string()).clone())
95 }
96 }
97
98 async fn stream(
99 &self,
100 prompt: &str,
101 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
102 if self.should_fail {
103 return Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
104 "Mock stream failure".to_string()
105 )));
106 }
107
108 if prompt.contains("stream_error") {
109 let error_stream = stream::iter(vec![
111 Ok("Starting...".to_string()),
112 Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
113 "Stream error during generation".to_string()
114 )))
115 ]);
116 return Ok(Box::pin(error_stream));
117 }
118
119 let chunks = self.stream_chunks.clone();
121 let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
122 Ok(Box::pin(chunk_stream))
123 }
124
125 fn name(&self) -> &'static str {
126 "MockLLM"
127 }
128
129 async fn health_check(&self) -> Result<bool, RustChainError> {
130 if self.should_fail && !self.health_status {
131 Err(RustChainError::Llm(crate::core::error::LlmError::service_unavailable(
132 "MockLLM"
133 )))
134 } else {
135 Ok(self.health_status)
136 }
137 }
138 }
139
140 struct DirectGenerateMock {
142 response: String,
143 should_fail: bool,
144 }
145
146 impl DirectGenerateMock {
147 fn new(response: String) -> Self {
148 Self {
149 response,
150 should_fail: false,
151 }
152 }
153
154 fn with_failure(mut self) -> Self {
155 self.should_fail = true;
156 self
157 }
158 }
159
160 #[async_trait]
161 impl LLMBackend for DirectGenerateMock {
162 async fn generate(&self, _prompt: &str) -> Result<String, RustChainError> {
164 if self.should_fail {
165 Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
166 "Direct generate failure".to_string()
167 )))
168 } else {
169 Ok(self.response.clone())
170 }
171 }
172
173 async fn stream(
174 &self,
175 _prompt: &str,
176 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
177 let chunks = vec![self.response.clone()];
178 let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
179 Ok(Box::pin(chunk_stream))
180 }
181
182 fn name(&self) -> &'static str {
183 "DirectGenerateMock"
184 }
185
186 async fn health_check(&self) -> Result<bool, RustChainError> {
187 Ok(!self.should_fail)
188 }
189 }
190
191 #[tokio::test]
192 async fn test_mock_llm_backend_basic() {
193 let mock = MockLLMBackend::new();
194
195 assert_eq!(mock.name(), "MockLLM");
196
197 let health = mock.health_check().await.unwrap();
198 assert!(health);
199 }
200
201 #[tokio::test]
202 async fn test_mock_llm_generate_success() {
203 let mock = MockLLMBackend::new()
204 .with_responses(vec!["Test response".to_string()]);
205
206 let result = mock.generate("test prompt").await.unwrap();
207 assert_eq!(result, "Test response");
208 }
209
210 #[tokio::test]
211 async fn test_mock_llm_generate_context_aware() {
212 let mock = MockLLMBackend::new();
213
214 let hello_result = mock.generate("hello world").await.unwrap();
216 assert_eq!(hello_result, "Hello! How can I help you today?");
217
218 let translate_result = mock.generate("translate this text").await.unwrap();
219 assert_eq!(translate_result, "Translated text: Bonjour le monde!");
220
221 let generic_result = mock.generate("generic prompt").await.unwrap();
222 assert_eq!(generic_result, "Default mock response");
223 }
224
225 #[tokio::test]
226 async fn test_mock_llm_generate_failure() {
227 let mock = MockLLMBackend::new().with_failure(true);
228
229 let result = mock.generate("test prompt").await;
230 assert!(result.is_err());
231
232 match result {
233 Err(RustChainError::Llm(e)) => {
234 assert!(e.to_string().contains("Mock LLM failure"));
235 }
236 _ => panic!("Expected LLM error"),
237 }
238 }
239
240 #[tokio::test]
241 async fn test_mock_llm_generate_prompt_error() {
242 let mock = MockLLMBackend::new();
243
244 let result = mock.generate("this prompt contains error").await;
245 assert!(result.is_err());
246
247 match result {
248 Err(RustChainError::Llm(e)) => {
249 assert!(e.to_string().contains("Prompt contained error"));
250 }
251 _ => panic!("Expected LLM error"),
252 }
253 }
254
255 #[tokio::test]
256 async fn test_mock_llm_stream_success() {
257 let mock = MockLLMBackend::new()
258 .with_stream_chunks(vec!["Hello".to_string(), " world!".to_string()]);
259
260 let mut stream = mock.stream("test prompt").await.unwrap();
261
262 use futures::StreamExt;
263 let mut chunks = Vec::new();
264 while let Some(chunk_result) = stream.next().await {
265 chunks.push(chunk_result.unwrap());
266 }
267
268 assert_eq!(chunks, vec!["Hello", " world!"]);
269 }
270
271 #[tokio::test]
272 async fn test_mock_llm_stream_failure() {
273 let mock = MockLLMBackend::new().with_failure(true);
274
275 let result = mock.stream("test prompt").await;
276 assert!(result.is_err());
277
278 match result {
279 Err(RustChainError::Llm(e)) => {
280 assert!(e.to_string().contains("Mock stream failure"));
281 }
282 _ => panic!("Expected LLM error"),
283 }
284 }
285
286 #[tokio::test]
287 async fn test_mock_llm_stream_error_during_generation() {
288 let mock = MockLLMBackend::new();
289
290 let mut stream = mock.stream("stream_error prompt").await.unwrap();
291
292 use futures::StreamExt;
293 let first_chunk = stream.next().await.unwrap().unwrap();
294 assert_eq!(first_chunk, "Starting...");
295
296 let second_chunk = stream.next().await.unwrap();
297 assert!(second_chunk.is_err());
298
299 match second_chunk {
300 Err(RustChainError::Llm(e)) => {
301 assert!(e.to_string().contains("Stream error during generation"));
302 }
303 _ => panic!("Expected LLM error"),
304 }
305 }
306
307 #[tokio::test]
308 async fn test_mock_llm_health_check_success() {
309 let mock = MockLLMBackend::new().with_health_status(true);
310
311 let health = mock.health_check().await.unwrap();
312 assert!(health);
313 }
314
315 #[tokio::test]
316 async fn test_mock_llm_health_check_unhealthy() {
317 let mock = MockLLMBackend::new().with_health_status(false);
318
319 let health = mock.health_check().await.unwrap();
320 assert!(!health);
321 }
322
323 #[tokio::test]
324 async fn test_mock_llm_health_check_failure() {
325 let mock = MockLLMBackend::new()
326 .with_failure(true)
327 .with_health_status(false);
328
329 let result = mock.health_check().await;
330 assert!(result.is_err());
331
332 match result {
333 Err(RustChainError::Llm(e)) => {
334 assert!(e.to_string().contains("service unavailable"));
335 }
336 _ => panic!("Expected LLM error"),
337 }
338 }
339
340 #[tokio::test]
341 async fn test_default_generate_implementation() {
342 let mock = MockLLMBackend::new()
344 .with_stream_chunks(vec!["Chunk 1".to_string(), " Chunk 2".to_string()]);
345
346 let result = mock.generate("test prompt").await.unwrap();
347 assert_eq!(result, "Default mock response");
349 }
350
351 #[tokio::test]
352 async fn test_direct_generate_mock() {
353 let mock = DirectGenerateMock::new("Direct response".to_string());
354
355 assert_eq!(mock.name(), "DirectGenerateMock");
356
357 let result = mock.generate("any prompt").await.unwrap();
358 assert_eq!(result, "Direct response");
359
360 let health = mock.health_check().await.unwrap();
361 assert!(health);
362 }
363
364 #[tokio::test]
365 async fn test_direct_generate_mock_failure() {
366 let mock = DirectGenerateMock::new("Response".to_string()).with_failure();
367
368 let result = mock.generate("any prompt").await;
369 assert!(result.is_err());
370
371 match result {
372 Err(RustChainError::Llm(e)) => {
373 assert!(e.to_string().contains("Direct generate failure"));
374 }
375 _ => panic!("Expected LLM error"),
376 }
377
378 let health = mock.health_check().await.unwrap();
379 assert!(!health);
380 }
381
382 #[tokio::test]
383 async fn test_direct_generate_mock_stream() {
384 let mock = DirectGenerateMock::new("Stream response".to_string());
385
386 let mut stream = mock.stream("test prompt").await.unwrap();
387
388 use futures::StreamExt;
389 let chunk = stream.next().await.unwrap().unwrap();
390 assert_eq!(chunk, "Stream response");
391
392 let next_chunk = stream.next().await;
394 assert!(next_chunk.is_none());
395 }
396
397 #[tokio::test]
398 async fn test_llm_backend_trait_object() {
399 let mock: Box<dyn LLMBackend> = Box::new(MockLLMBackend::new());
401
402 let result = mock.generate("trait object test").await.unwrap();
403 assert_eq!(result, "Default mock response");
404
405 assert_eq!(mock.name(), "MockLLM");
406
407 let health = mock.health_check().await.unwrap();
408 assert!(health);
409 }
410
411 #[tokio::test]
412 async fn test_multiple_llm_backends() {
413 let mock1: Box<dyn LLMBackend> = Box::new(MockLLMBackend::new()
414 .with_responses(vec!["Mock1 response".to_string()]));
415 let mock2: Box<dyn LLMBackend> = Box::new(DirectGenerateMock::new("Mock2 response".to_string()));
416
417 let backends = vec![mock1, mock2];
418
419 for (i, backend) in backends.iter().enumerate() {
420 let result = backend.generate("test prompt").await.unwrap();
421 if i == 0 {
422 assert_eq!(result, "Mock1 response"); } else {
424 assert_eq!(result, "Mock2 response");
425 }
426
427 let health = backend.health_check().await.unwrap();
428 assert!(health);
429 }
430 }
431
432 #[tokio::test]
433 async fn test_stream_collection_integration() {
434 struct StreamOnlyMock;
436
437 #[async_trait]
438 impl LLMBackend for StreamOnlyMock {
439 async fn stream(
442 &self,
443 _prompt: &str,
444 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
445 let chunks = vec!["Stream".to_string(), " collected".to_string(), " response".to_string()];
446 let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
447 Ok(Box::pin(chunk_stream))
448 }
449
450 fn name(&self) -> &'static str {
451 "StreamOnlyMock"
452 }
453
454 async fn health_check(&self) -> Result<bool, RustChainError> {
455 Ok(true)
456 }
457 }
458
459 let mock = StreamOnlyMock;
460
461 let result = mock.generate("test prompt").await.unwrap();
462 assert_eq!(result, "Stream collected response");
463 }
464
465 #[tokio::test]
466 async fn test_stream_collection_with_error() {
467 struct ErrorStreamMock;
469
470 #[async_trait]
471 impl LLMBackend for ErrorStreamMock {
472 async fn stream(
475 &self,
476 _prompt: &str,
477 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
478 let items = vec![
479 Ok("Start".to_string()),
480 Err(RustChainError::Llm(crate::core::error::LlmError::response_error("Mid-stream error".to_string())))
481 ];
482 let error_stream = stream::iter(items);
483 Ok(Box::pin(error_stream))
484 }
485
486 fn name(&self) -> &'static str {
487 "ErrorStreamMock"
488 }
489
490 async fn health_check(&self) -> Result<bool, RustChainError> {
491 Ok(true)
492 }
493 }
494
495 let mock = ErrorStreamMock;
496
497 let result = mock.generate("test prompt").await;
498 assert!(result.is_err());
499
500 match result {
501 Err(RustChainError::Llm(e)) => {
502 assert!(e.to_string().contains("Mid-stream error"));
503 }
504 _ => panic!("Expected LLM error"),
505 }
506 }
507
508 #[tokio::test]
509 async fn test_empty_stream_collection() {
510 struct EmptyStreamMock;
512
513 #[async_trait]
514 impl LLMBackend for EmptyStreamMock {
515 async fn stream(
516 &self,
517 _prompt: &str,
518 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
519 let empty_stream = stream::iter(vec![]);
520 Ok(Box::pin(empty_stream))
521 }
522
523 fn name(&self) -> &'static str {
524 "EmptyStreamMock"
525 }
526
527 async fn health_check(&self) -> Result<bool, RustChainError> {
528 Ok(true)
529 }
530 }
531
532 let mock = EmptyStreamMock;
533
534 let result = mock.generate("test prompt").await.unwrap();
535 assert_eq!(result, ""); }
537
538 #[tokio::test]
539 async fn test_large_stream_collection() {
540 struct LargeStreamMock;
542
543 #[async_trait]
544 impl LLMBackend for LargeStreamMock {
545 async fn stream(
546 &self,
547 _prompt: &str,
548 ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
549 let chunks: Vec<_> = (0..100).map(|i| Ok(format!("chunk{} ", i))).collect();
550 let chunk_stream = stream::iter(chunks);
551 Ok(Box::pin(chunk_stream))
552 }
553
554 fn name(&self) -> &'static str {
555 "LargeStreamMock"
556 }
557
558 async fn health_check(&self) -> Result<bool, RustChainError> {
559 Ok(true)
560 }
561 }
562
563 let mock = LargeStreamMock;
564
565 let result = mock.generate("test prompt").await.unwrap();
566
567 assert!(result.starts_with("chunk0 chunk1 chunk2"));
569 assert!(result.contains("chunk50"));
570 assert!(result.ends_with("chunk99 "));
571
572 let chunk_count = result.matches("chunk").count();
574 assert_eq!(chunk_count, 100);
575 }
576
577 #[test]
578 fn test_llm_backend_trait_bounds() {
579 fn require_send_sync<T: Send + Sync>() {}
581 require_send_sync::<Box<dyn LLMBackend>>();
582 }
583}