rustchain/core/
llm.rs

1use crate::core::error::RustChainError;
2use async_trait::async_trait;
3use futures::stream::Stream;
4use std::pin::Pin;
5
6#[async_trait]
7pub trait LLMBackend: Send + Sync {
8    async fn generate(&self, prompt: &str) -> Result<String, RustChainError> {
9        let mut stream = self.stream(prompt).await?;
10        let mut output = String::new();
11        use futures::StreamExt;
12        while let Some(chunk) = stream.next().await {
13            output.push_str(&chunk?);
14        }
15        Ok(output)
16    }
17
18    async fn stream(
19        &self,
20        prompt: &str,
21    ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError>;
22
23    fn name(&self) -> &'static str;
24
25    async fn health_check(&self) -> Result<bool, RustChainError>;
26}
27
28#[cfg(test)]
29mod tests {
30    use super::*;
31    use crate::core::error::RustChainError;
32    use async_trait::async_trait;
33    use futures::stream;
34    use std::pin::Pin;
35
36    // Mock LLM Backend for testing
37    struct MockLLMBackend {
38        responses: Vec<String>,
39        should_fail: bool,
40        stream_chunks: Vec<String>,
41        health_status: bool,
42    }
43
44    impl MockLLMBackend {
45        fn new() -> Self {
46            Self {
47                responses: vec!["Default mock response".to_string()],
48                should_fail: false,
49                stream_chunks: vec!["Hello".to_string(), " world!".to_string()],
50                health_status: true,
51            }
52        }
53
54        fn with_responses(mut self, responses: Vec<String>) -> Self {
55            self.responses = responses;
56            self
57        }
58
59        fn with_failure(mut self, should_fail: bool) -> Self {
60            self.should_fail = should_fail;
61            self
62        }
63
64        fn with_stream_chunks(mut self, chunks: Vec<String>) -> Self {
65            self.stream_chunks = chunks;
66            self
67        }
68
69        fn with_health_status(mut self, healthy: bool) -> Self {
70            self.health_status = healthy;
71            self
72        }
73    }
74
75    #[async_trait]
76    impl LLMBackend for MockLLMBackend {
77        async fn generate(&self, prompt: &str) -> Result<String, RustChainError> {
78            if self.should_fail {
79                return Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
80                    "Mock LLM failure".to_string()
81                )));
82            }
83
84            // Select response based on prompt content for more realistic testing
85            if prompt.contains("error") {
86                Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
87                    "Prompt contained error".to_string()
88                )))
89            } else if prompt.contains("hello") {
90                Ok("Hello! How can I help you today?".to_string())
91            } else if prompt.contains("translate") {
92                Ok("Translated text: Bonjour le monde!".to_string())
93            } else {
94                Ok(self.responses.get(0).unwrap_or(&"Default response".to_string()).clone())
95            }
96        }
97
98        async fn stream(
99            &self,
100            prompt: &str,
101        ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
102            if self.should_fail {
103                return Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
104                    "Mock stream failure".to_string()
105                )));
106            }
107
108            if prompt.contains("stream_error") {
109                // Return a stream that fails partway through
110                let error_stream = stream::iter(vec![
111                    Ok("Starting...".to_string()),
112                    Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
113                        "Stream error during generation".to_string()
114                    )))
115                ]);
116                return Ok(Box::pin(error_stream));
117            }
118
119            // Create a stream from the chunks
120            let chunks = self.stream_chunks.clone();
121            let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
122            Ok(Box::pin(chunk_stream))
123        }
124
125        fn name(&self) -> &'static str {
126            "MockLLM"
127        }
128
129        async fn health_check(&self) -> Result<bool, RustChainError> {
130            if self.should_fail && !self.health_status {
131                Err(RustChainError::Llm(crate::core::error::LlmError::service_unavailable(
132                    "MockLLM"
133                )))
134            } else {
135                Ok(self.health_status)
136            }
137        }
138    }
139
140    // Alternative mock that implements generate directly
141    struct DirectGenerateMock {
142        response: String,
143        should_fail: bool,
144    }
145
146    impl DirectGenerateMock {
147        fn new(response: String) -> Self {
148            Self {
149                response,
150                should_fail: false,
151            }
152        }
153
154        fn with_failure(mut self) -> Self {
155            self.should_fail = true;
156            self
157        }
158    }
159
160    #[async_trait]
161    impl LLMBackend for DirectGenerateMock {
162        // Override the default generate implementation 
163        async fn generate(&self, _prompt: &str) -> Result<String, RustChainError> {
164            if self.should_fail {
165                Err(RustChainError::Llm(crate::core::error::LlmError::response_error(
166                    "Direct generate failure".to_string()
167                )))
168            } else {
169                Ok(self.response.clone())
170            }
171        }
172
173        async fn stream(
174            &self,
175            _prompt: &str,
176        ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
177            let chunks = vec![self.response.clone()];
178            let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
179            Ok(Box::pin(chunk_stream))
180        }
181
182        fn name(&self) -> &'static str {
183            "DirectGenerateMock"
184        }
185
186        async fn health_check(&self) -> Result<bool, RustChainError> {
187            Ok(!self.should_fail)
188        }
189    }
190
191    #[tokio::test]
192    async fn test_mock_llm_backend_basic() {
193        let mock = MockLLMBackend::new();
194        
195        assert_eq!(mock.name(), "MockLLM");
196        
197        let health = mock.health_check().await.unwrap();
198        assert!(health);
199    }
200
201    #[tokio::test]
202    async fn test_mock_llm_generate_success() {
203        let mock = MockLLMBackend::new()
204            .with_responses(vec!["Test response".to_string()]);
205        
206        let result = mock.generate("test prompt").await.unwrap();
207        assert_eq!(result, "Test response");
208    }
209
210    #[tokio::test]
211    async fn test_mock_llm_generate_context_aware() {
212        let mock = MockLLMBackend::new();
213        
214        // Test context-aware responses
215        let hello_result = mock.generate("hello world").await.unwrap();
216        assert_eq!(hello_result, "Hello! How can I help you today?");
217        
218        let translate_result = mock.generate("translate this text").await.unwrap();
219        assert_eq!(translate_result, "Translated text: Bonjour le monde!");
220        
221        let generic_result = mock.generate("generic prompt").await.unwrap();
222        assert_eq!(generic_result, "Default mock response");
223    }
224
225    #[tokio::test]
226    async fn test_mock_llm_generate_failure() {
227        let mock = MockLLMBackend::new().with_failure(true);
228        
229        let result = mock.generate("test prompt").await;
230        assert!(result.is_err());
231        
232        match result {
233            Err(RustChainError::Llm(e)) => {
234                assert!(e.to_string().contains("Mock LLM failure"));
235            }
236            _ => panic!("Expected LLM error"),
237        }
238    }
239
240    #[tokio::test]
241    async fn test_mock_llm_generate_prompt_error() {
242        let mock = MockLLMBackend::new();
243        
244        let result = mock.generate("this prompt contains error").await;
245        assert!(result.is_err());
246        
247        match result {
248            Err(RustChainError::Llm(e)) => {
249                assert!(e.to_string().contains("Prompt contained error"));
250            }
251            _ => panic!("Expected LLM error"),
252        }
253    }
254
255    #[tokio::test]
256    async fn test_mock_llm_stream_success() {
257        let mock = MockLLMBackend::new()
258            .with_stream_chunks(vec!["Hello".to_string(), " world!".to_string()]);
259        
260        let mut stream = mock.stream("test prompt").await.unwrap();
261        
262        use futures::StreamExt;
263        let mut chunks = Vec::new();
264        while let Some(chunk_result) = stream.next().await {
265            chunks.push(chunk_result.unwrap());
266        }
267        
268        assert_eq!(chunks, vec!["Hello", " world!"]);
269    }
270
271    #[tokio::test]
272    async fn test_mock_llm_stream_failure() {
273        let mock = MockLLMBackend::new().with_failure(true);
274        
275        let result = mock.stream("test prompt").await;
276        assert!(result.is_err());
277        
278        match result {
279            Err(RustChainError::Llm(e)) => {
280                assert!(e.to_string().contains("Mock stream failure"));
281            }
282            _ => panic!("Expected LLM error"),
283        }
284    }
285
286    #[tokio::test]
287    async fn test_mock_llm_stream_error_during_generation() {
288        let mock = MockLLMBackend::new();
289        
290        let mut stream = mock.stream("stream_error prompt").await.unwrap();
291        
292        use futures::StreamExt;
293        let first_chunk = stream.next().await.unwrap().unwrap();
294        assert_eq!(first_chunk, "Starting...");
295        
296        let second_chunk = stream.next().await.unwrap();
297        assert!(second_chunk.is_err());
298        
299        match second_chunk {
300            Err(RustChainError::Llm(e)) => {
301                assert!(e.to_string().contains("Stream error during generation"));
302            }
303            _ => panic!("Expected LLM error"),
304        }
305    }
306
307    #[tokio::test]
308    async fn test_mock_llm_health_check_success() {
309        let mock = MockLLMBackend::new().with_health_status(true);
310        
311        let health = mock.health_check().await.unwrap();
312        assert!(health);
313    }
314
315    #[tokio::test]
316    async fn test_mock_llm_health_check_unhealthy() {
317        let mock = MockLLMBackend::new().with_health_status(false);
318        
319        let health = mock.health_check().await.unwrap();
320        assert!(!health);
321    }
322
323    #[tokio::test]
324    async fn test_mock_llm_health_check_failure() {
325        let mock = MockLLMBackend::new()
326            .with_failure(true)
327            .with_health_status(false);
328        
329        let result = mock.health_check().await;
330        assert!(result.is_err());
331        
332        match result {
333            Err(RustChainError::Llm(e)) => {
334                assert!(e.to_string().contains("service unavailable"));
335            }
336            _ => panic!("Expected LLM error"),
337        }
338    }
339
340    #[tokio::test]
341    async fn test_default_generate_implementation() {
342        // Test that the default generate implementation works by collecting stream chunks
343        let mock = MockLLMBackend::new()
344            .with_stream_chunks(vec!["Chunk 1".to_string(), " Chunk 2".to_string()]);
345        
346        let result = mock.generate("test prompt").await.unwrap();
347        // The default implementation should use context-aware logic, not stream collection
348        assert_eq!(result, "Default mock response");
349    }
350
351    #[tokio::test]
352    async fn test_direct_generate_mock() {
353        let mock = DirectGenerateMock::new("Direct response".to_string());
354        
355        assert_eq!(mock.name(), "DirectGenerateMock");
356        
357        let result = mock.generate("any prompt").await.unwrap();
358        assert_eq!(result, "Direct response");
359        
360        let health = mock.health_check().await.unwrap();
361        assert!(health);
362    }
363
364    #[tokio::test]
365    async fn test_direct_generate_mock_failure() {
366        let mock = DirectGenerateMock::new("Response".to_string()).with_failure();
367        
368        let result = mock.generate("any prompt").await;
369        assert!(result.is_err());
370        
371        match result {
372            Err(RustChainError::Llm(e)) => {
373                assert!(e.to_string().contains("Direct generate failure"));
374            }
375            _ => panic!("Expected LLM error"),
376        }
377        
378        let health = mock.health_check().await.unwrap();
379        assert!(!health);
380    }
381
382    #[tokio::test]
383    async fn test_direct_generate_mock_stream() {
384        let mock = DirectGenerateMock::new("Stream response".to_string());
385        
386        let mut stream = mock.stream("test prompt").await.unwrap();
387        
388        use futures::StreamExt;
389        let chunk = stream.next().await.unwrap().unwrap();
390        assert_eq!(chunk, "Stream response");
391        
392        // Stream should be exhausted
393        let next_chunk = stream.next().await;
394        assert!(next_chunk.is_none());
395    }
396
397    #[tokio::test]
398    async fn test_llm_backend_trait_object() {
399        // Test that we can use LLMBackend as a trait object
400        let mock: Box<dyn LLMBackend> = Box::new(MockLLMBackend::new());
401        
402        let result = mock.generate("trait object test").await.unwrap();
403        assert_eq!(result, "Default mock response");
404        
405        assert_eq!(mock.name(), "MockLLM");
406        
407        let health = mock.health_check().await.unwrap();
408        assert!(health);
409    }
410
411    #[tokio::test]
412    async fn test_multiple_llm_backends() {
413        let mock1: Box<dyn LLMBackend> = Box::new(MockLLMBackend::new()
414            .with_responses(vec!["Mock1 response".to_string()]));
415        let mock2: Box<dyn LLMBackend> = Box::new(DirectGenerateMock::new("Mock2 response".to_string()));
416        
417        let backends = vec![mock1, mock2];
418        
419        for (i, backend) in backends.iter().enumerate() {
420            let result = backend.generate("test prompt").await.unwrap();
421            if i == 0 {
422                assert_eq!(result, "Mock1 response"); // MockLLMBackend uses custom responses
423            } else {
424                assert_eq!(result, "Mock2 response");
425            }
426            
427            let health = backend.health_check().await.unwrap();
428            assert!(health);
429        }
430    }
431
432    #[tokio::test]
433    async fn test_stream_collection_integration() {
434        // Test the default generate implementation that collects from stream
435        struct StreamOnlyMock;
436        
437        #[async_trait]
438        impl LLMBackend for StreamOnlyMock {
439            // Don't override generate - use the default implementation
440            
441            async fn stream(
442                &self,
443                _prompt: &str,
444            ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
445                let chunks = vec!["Stream".to_string(), " collected".to_string(), " response".to_string()];
446                let chunk_stream = stream::iter(chunks.into_iter().map(Ok));
447                Ok(Box::pin(chunk_stream))
448            }
449
450            fn name(&self) -> &'static str {
451                "StreamOnlyMock"
452            }
453
454            async fn health_check(&self) -> Result<bool, RustChainError> {
455                Ok(true)
456            }
457        }
458        
459        let mock = StreamOnlyMock;
460        
461        let result = mock.generate("test prompt").await.unwrap();
462        assert_eq!(result, "Stream collected response");
463    }
464
465    #[tokio::test]
466    async fn test_stream_collection_with_error() {
467        // Test the default generate implementation when stream has errors
468        struct ErrorStreamMock;
469        
470        #[async_trait]
471        impl LLMBackend for ErrorStreamMock {
472            // Don't override generate - use the default implementation
473            
474            async fn stream(
475                &self,
476                _prompt: &str,
477            ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
478                let items = vec![
479                    Ok("Start".to_string()),
480                    Err(RustChainError::Llm(crate::core::error::LlmError::response_error("Mid-stream error".to_string())))
481                ];
482                let error_stream = stream::iter(items);
483                Ok(Box::pin(error_stream))
484            }
485
486            fn name(&self) -> &'static str {
487                "ErrorStreamMock"
488            }
489
490            async fn health_check(&self) -> Result<bool, RustChainError> {
491                Ok(true)
492            }
493        }
494        
495        let mock = ErrorStreamMock;
496        
497        let result = mock.generate("test prompt").await;
498        assert!(result.is_err());
499        
500        match result {
501            Err(RustChainError::Llm(e)) => {
502                assert!(e.to_string().contains("Mid-stream error"));
503            }
504            _ => panic!("Expected LLM error"),
505        }
506    }
507
508    #[tokio::test]
509    async fn test_empty_stream_collection() {
510        // Test the default generate implementation with empty stream
511        struct EmptyStreamMock;
512        
513        #[async_trait]
514        impl LLMBackend for EmptyStreamMock {
515            async fn stream(
516                &self,
517                _prompt: &str,
518            ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
519                let empty_stream = stream::iter(vec![]);
520                Ok(Box::pin(empty_stream))
521            }
522
523            fn name(&self) -> &'static str {
524                "EmptyStreamMock"
525            }
526
527            async fn health_check(&self) -> Result<bool, RustChainError> {
528                Ok(true)
529            }
530        }
531        
532        let mock = EmptyStreamMock;
533        
534        let result = mock.generate("test prompt").await.unwrap();
535        assert_eq!(result, ""); // Should return empty string
536    }
537
538    #[tokio::test]
539    async fn test_large_stream_collection() {
540        // Test the default generate implementation with many chunks
541        struct LargeStreamMock;
542        
543        #[async_trait]
544        impl LLMBackend for LargeStreamMock {
545            async fn stream(
546                &self,
547                _prompt: &str,
548            ) -> Result<Pin<Box<dyn Stream<Item = Result<String, RustChainError>> + Send>>, RustChainError> {
549                let chunks: Vec<_> = (0..100).map(|i| Ok(format!("chunk{} ", i))).collect();
550                let chunk_stream = stream::iter(chunks);
551                Ok(Box::pin(chunk_stream))
552            }
553
554            fn name(&self) -> &'static str {
555                "LargeStreamMock"
556            }
557
558            async fn health_check(&self) -> Result<bool, RustChainError> {
559                Ok(true)
560            }
561        }
562        
563        let mock = LargeStreamMock;
564        
565        let result = mock.generate("test prompt").await.unwrap();
566        
567        // Should contain all chunks
568        assert!(result.starts_with("chunk0 chunk1 chunk2"));
569        assert!(result.contains("chunk50"));
570        assert!(result.ends_with("chunk99 "));
571        
572        // Verify it collected all 100 chunks
573        let chunk_count = result.matches("chunk").count();
574        assert_eq!(chunk_count, 100);
575    }
576
577    #[test]
578    fn test_llm_backend_trait_bounds() {
579        // Test that LLMBackend has the correct trait bounds
580        fn require_send_sync<T: Send + Sync>() {}
581        require_send_sync::<Box<dyn LLMBackend>>();
582    }
583}