turbomcp_client/llm/
core.rs

1//! Core LLM abstractions and types
2//!
3//! Defines the fundamental traits and types for the LLM system, providing a
4//! provider-agnostic interface for different LLM backends.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11use thiserror::Error;
12
13// ============================================================================
14// ERROR TYPES
15// ============================================================================
16
17/// Errors that can occur during LLM operations
18#[derive(Error, Debug)]
19pub enum LLMError {
20    /// Configuration errors
21    #[error("Configuration error: {message}")]
22    Configuration { message: String },
23
24    /// Authentication failures
25    #[error("Authentication failed: {details}")]
26    Authentication { details: String },
27
28    /// Network and connectivity issues
29    #[error("Network error: {message}")]
30    Network { message: String },
31
32    /// API rate limiting
33    #[error("Rate limited. Retry after: {retry_after_seconds}s")]
34    RateLimit { retry_after_seconds: u64 },
35
36    /// Model or parameter validation errors
37    #[error("Invalid parameters: {details}")]
38    InvalidParameters { details: String },
39
40    /// Provider-specific errors
41    #[error("Provider error [{code}]: {message}")]
42    ProviderError { code: i32, message: String },
43
44    /// Request timeout
45    #[error("Request timed out after {seconds}s")]
46    Timeout { seconds: u64 },
47
48    /// Content processing errors
49    #[error("Content processing error: {details}")]
50    ContentProcessing { details: String },
51
52    /// Token limit exceeded
53    #[error("Token limit exceeded: {used} > {limit}")]
54    TokenLimitExceeded { used: usize, limit: usize },
55
56    /// Model not available
57    #[error("Model '{model}' not available")]
58    ModelNotAvailable { model: String },
59
60    /// Provider not found
61    #[error("Provider '{provider}' not found")]
62    ProviderNotFound { provider: String },
63
64    /// Session management errors
65    #[error("Session error: {message}")]
66    Session { message: String },
67
68    /// Generic errors
69    #[error("LLM error: {message}")]
70    Generic { message: String },
71}
72
73impl LLMError {
74    /// Create a configuration error
75    pub fn configuration(message: impl Into<String>) -> Self {
76        Self::Configuration {
77            message: message.into(),
78        }
79    }
80
81    /// Create an authentication error
82    pub fn authentication(details: impl Into<String>) -> Self {
83        Self::Authentication {
84            details: details.into(),
85        }
86    }
87
88    /// Create a network error
89    pub fn network(message: impl Into<String>) -> Self {
90        Self::Network {
91            message: message.into(),
92        }
93    }
94
95    /// Create an invalid parameters error
96    pub fn invalid_parameters(details: impl Into<String>) -> Self {
97        Self::InvalidParameters {
98            details: details.into(),
99        }
100    }
101
102    /// Create a provider error
103    pub fn provider_error(code: i32, message: impl Into<String>) -> Self {
104        Self::ProviderError {
105            code,
106            message: message.into(),
107        }
108    }
109
110    /// Create a timeout error
111    pub fn timeout(seconds: u64) -> Self {
112        Self::Timeout { seconds }
113    }
114
115    /// Create a content processing error
116    pub fn content_processing(details: impl Into<String>) -> Self {
117        Self::ContentProcessing {
118            details: details.into(),
119        }
120    }
121
122    /// Create a token limit error
123    pub fn token_limit_exceeded(used: usize, limit: usize) -> Self {
124        Self::TokenLimitExceeded { used, limit }
125    }
126
127    /// Create a model not available error
128    pub fn model_not_available(model: impl Into<String>) -> Self {
129        Self::ModelNotAvailable {
130            model: model.into(),
131        }
132    }
133
134    /// Create a provider not found error
135    pub fn provider_not_found(provider: impl Into<String>) -> Self {
136        Self::ProviderNotFound {
137            provider: provider.into(),
138        }
139    }
140
141    /// Create a session error
142    pub fn session(message: impl Into<String>) -> Self {
143        Self::Session {
144            message: message.into(),
145        }
146    }
147
148    /// Create a generic error
149    pub fn generic(message: impl Into<String>) -> Self {
150        Self::Generic {
151            message: message.into(),
152        }
153    }
154}
155
156pub type LLMResult<T> = Result<T, LLMError>;
157
158// ============================================================================
159// MESSAGE TYPES
160// ============================================================================
161
162/// Role of a message in a conversation
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164#[serde(rename_all = "lowercase")]
165pub enum MessageRole {
166    /// User message
167    User,
168    /// Assistant/AI response
169    Assistant,
170    /// System instruction
171    System,
172    /// Function/tool call
173    Function,
174}
175
176/// Content type for messages
177#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(tag = "type")]
179pub enum MessageContent {
180    /// Plain text content
181    #[serde(rename = "text")]
182    Text { text: String },
183
184    /// Image content
185    #[serde(rename = "image")]
186    Image { url: String, detail: Option<String> },
187
188    /// Tool call content
189    #[serde(rename = "tool_call")]
190    ToolCall {
191        id: String,
192        function: String,
193        arguments: serde_json::Value,
194    },
195
196    /// Tool result content
197    #[serde(rename = "tool_result")]
198    ToolResult {
199        tool_call_id: String,
200        result: serde_json::Value,
201        is_error: bool,
202    },
203}
204
205impl MessageContent {
206    /// Create text content
207    pub fn text(text: impl Into<String>) -> Self {
208        Self::Text { text: text.into() }
209    }
210
211    /// Create image content
212    pub fn image(url: impl Into<String>, detail: Option<String>) -> Self {
213        Self::Image {
214            url: url.into(),
215            detail,
216        }
217    }
218
219    /// Extract text content if available
220    pub fn as_text(&self) -> Option<&str> {
221        match self {
222            Self::Text { text } => Some(text),
223            _ => None,
224        }
225    }
226
227    /// Check if content is text
228    pub fn is_text(&self) -> bool {
229        matches!(self, Self::Text { .. })
230    }
231
232    /// Check if content is an image
233    pub fn is_image(&self) -> bool {
234        matches!(self, Self::Image { .. })
235    }
236}
237
238/// A message in an LLM conversation
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct LLMMessage {
241    /// Message role
242    pub role: MessageRole,
243
244    /// Message content
245    pub content: MessageContent,
246
247    /// Optional message metadata
248    pub metadata: HashMap<String, serde_json::Value>,
249
250    /// Timestamp
251    pub timestamp: DateTime<Utc>,
252}
253
254impl LLMMessage {
255    /// Create a user message
256    pub fn user(content: impl Into<String>) -> Self {
257        Self {
258            role: MessageRole::User,
259            content: MessageContent::text(content),
260            metadata: HashMap::new(),
261            timestamp: Utc::now(),
262        }
263    }
264
265    /// Create an assistant message
266    pub fn assistant(content: impl Into<String>) -> Self {
267        Self {
268            role: MessageRole::Assistant,
269            content: MessageContent::text(content),
270            metadata: HashMap::new(),
271            timestamp: Utc::now(),
272        }
273    }
274
275    /// Create a system message
276    pub fn system(content: impl Into<String>) -> Self {
277        Self {
278            role: MessageRole::System,
279            content: MessageContent::text(content),
280            metadata: HashMap::new(),
281            timestamp: Utc::now(),
282        }
283    }
284
285    /// Add metadata to the message
286    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
287        self.metadata.insert(key, value);
288        self
289    }
290
291    /// Get metadata value
292    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
293        self.metadata.get(key)
294    }
295}
296
297// ============================================================================
298// REQUEST AND RESPONSE TYPES
299// ============================================================================
300
301/// LLM generation parameters
302#[derive(Debug, Clone, Serialize, Deserialize, Default)]
303pub struct GenerationParams {
304    /// Temperature (0.0 to 2.0)
305    pub temperature: Option<f32>,
306
307    /// Top-p sampling (0.0 to 1.0)
308    pub top_p: Option<f32>,
309
310    /// Top-k sampling
311    pub top_k: Option<i32>,
312
313    /// Maximum tokens to generate
314    pub max_tokens: Option<i32>,
315
316    /// Stop sequences
317    pub stop_sequences: Option<Vec<String>>,
318
319    /// Frequency penalty
320    pub frequency_penalty: Option<f32>,
321
322    /// Presence penalty
323    pub presence_penalty: Option<f32>,
324
325    /// Random seed for reproducibility
326    pub seed: Option<i64>,
327}
328
329// Default implementation is now derived
330
331/// Request to an LLM provider
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct LLMRequest {
334    /// Model to use for generation
335    pub model: String,
336
337    /// Conversation messages
338    pub messages: Vec<LLMMessage>,
339
340    /// Generation parameters
341    pub params: GenerationParams,
342
343    /// Enable streaming response
344    pub stream: bool,
345
346    /// Request metadata
347    pub metadata: HashMap<String, serde_json::Value>,
348
349    /// Timeout for the request
350    pub timeout: Option<Duration>,
351}
352
353impl LLMRequest {
354    /// Create a new LLM request
355    pub fn new(model: impl Into<String>, messages: Vec<LLMMessage>) -> Self {
356        Self {
357            model: model.into(),
358            messages,
359            params: GenerationParams::default(),
360            stream: false,
361            metadata: HashMap::new(),
362            timeout: None,
363        }
364    }
365
366    /// Set generation parameters
367    pub fn with_params(mut self, params: GenerationParams) -> Self {
368        self.params = params;
369        self
370    }
371
372    /// Enable streaming
373    pub fn with_streaming(mut self, stream: bool) -> Self {
374        self.stream = stream;
375        self
376    }
377
378    /// Set timeout
379    pub fn with_timeout(mut self, timeout: Duration) -> Self {
380        self.timeout = Some(timeout);
381        self
382    }
383
384    /// Add metadata
385    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
386        self.metadata.insert(key, value);
387        self
388    }
389
390    /// Get the first user message
391    pub fn get_user_message(&self) -> Option<&str> {
392        self.messages
393            .iter()
394            .find(|msg| msg.role == MessageRole::User)
395            .and_then(|msg| msg.content.as_text())
396    }
397
398    /// Count total messages
399    pub fn message_count(&self) -> usize {
400        self.messages.len()
401    }
402}
403
404/// Token usage information
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct TokenUsage {
407    /// Input tokens consumed
408    pub prompt_tokens: usize,
409
410    /// Output tokens generated
411    pub completion_tokens: usize,
412
413    /// Total tokens used
414    pub total_tokens: usize,
415}
416
417impl TokenUsage {
418    /// Create new token usage
419    pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
420        Self {
421            prompt_tokens,
422            completion_tokens,
423            total_tokens: prompt_tokens + completion_tokens,
424        }
425    }
426
427    /// Create empty token usage
428    pub fn empty() -> Self {
429        Self::new(0, 0)
430    }
431}
432
433/// Response from an LLM provider
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct LLMResponse {
436    /// Generated message
437    pub message: LLMMessage,
438
439    /// Model used for generation
440    pub model: String,
441
442    /// Token usage information
443    pub usage: TokenUsage,
444
445    /// Stop reason
446    pub stop_reason: Option<String>,
447
448    /// Response metadata
449    pub metadata: HashMap<String, serde_json::Value>,
450
451    /// Generation timestamp
452    pub timestamp: DateTime<Utc>,
453}
454
455impl LLMResponse {
456    /// Create a new LLM response
457    pub fn new(message: LLMMessage, model: impl Into<String>, usage: TokenUsage) -> Self {
458        Self {
459            message,
460            model: model.into(),
461            usage,
462            stop_reason: None,
463            metadata: HashMap::new(),
464            timestamp: Utc::now(),
465        }
466    }
467
468    /// Set stop reason
469    pub fn with_stop_reason(mut self, stop_reason: impl Into<String>) -> Self {
470        self.stop_reason = Some(stop_reason.into());
471        self
472    }
473
474    /// Add metadata
475    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
476        self.metadata.insert(key, value);
477        self
478    }
479
480    /// Get response text
481    pub fn text(&self) -> Option<&str> {
482        self.message.content.as_text()
483    }
484
485    /// Check if response is complete
486    pub fn is_complete(&self) -> bool {
487        self.stop_reason.is_some()
488    }
489}
490
491// ============================================================================
492// PROVIDER CONFIGURATION
493// ============================================================================
494
495/// Configuration for an LLM provider
496#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct LLMProviderConfig {
498    /// API key or credentials
499    pub api_key: String,
500
501    /// Base URL for API requests
502    pub base_url: Option<String>,
503
504    /// Default model to use
505    pub model: String,
506
507    /// Request timeout in seconds
508    pub timeout_seconds: u64,
509
510    /// Maximum retry attempts
511    pub max_retries: u32,
512
513    /// Custom headers
514    pub headers: HashMap<String, String>,
515
516    /// Provider-specific options
517    pub options: HashMap<String, serde_json::Value>,
518}
519
520impl Default for LLMProviderConfig {
521    fn default() -> Self {
522        Self {
523            api_key: String::new(),
524            base_url: None,
525            model: String::new(),
526            timeout_seconds: 30,
527            max_retries: 3,
528            headers: HashMap::new(),
529            options: HashMap::new(),
530        }
531    }
532}
533
534impl LLMProviderConfig {
535    /// Create a new provider config
536    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
537        Self {
538            api_key: api_key.into(),
539            model: model.into(),
540            ..Default::default()
541        }
542    }
543
544    /// Set base URL
545    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
546        self.base_url = Some(base_url.into());
547        self
548    }
549
550    /// Set timeout
551    pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
552        self.timeout_seconds = timeout_seconds;
553        self
554    }
555
556    /// Set max retries
557    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
558        self.max_retries = max_retries;
559        self
560    }
561
562    /// Add custom header
563    pub fn with_header(mut self, key: String, value: String) -> Self {
564        self.headers.insert(key, value);
565        self
566    }
567
568    /// Add custom option
569    pub fn with_option(mut self, key: String, value: serde_json::Value) -> Self {
570        self.options.insert(key, value);
571        self
572    }
573
574    /// Validate the configuration
575    pub fn validate(&self) -> LLMResult<()> {
576        if self.api_key.trim().is_empty() {
577            return Err(LLMError::configuration("API key cannot be empty"));
578        }
579
580        if self.model.trim().is_empty() {
581            return Err(LLMError::configuration("Model cannot be empty"));
582        }
583
584        if self.timeout_seconds == 0 {
585            return Err(LLMError::configuration("Timeout must be greater than 0"));
586        }
587
588        Ok(())
589    }
590}
591
592// ============================================================================
593// PROVIDER CAPABILITIES AND MODEL INFO
594// ============================================================================
595
596/// Capabilities of an LLM provider
597#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct LLMCapabilities {
599    /// Supports streaming responses
600    pub streaming: bool,
601
602    /// Supports image inputs
603    pub vision: bool,
604
605    /// Supports function/tool calling
606    pub function_calling: bool,
607
608    /// Supports JSON mode
609    pub json_mode: bool,
610
611    /// Maximum context window size
612    pub max_context_tokens: Option<usize>,
613
614    /// Maximum output tokens
615    pub max_output_tokens: Option<usize>,
616
617    /// Supported content types
618    pub content_types: Vec<String>,
619}
620
621impl Default for LLMCapabilities {
622    fn default() -> Self {
623        Self {
624            streaming: false,
625            vision: false,
626            function_calling: false,
627            json_mode: false,
628            max_context_tokens: None,
629            max_output_tokens: None,
630            content_types: vec!["text".to_string()],
631        }
632    }
633}
634
635/// Information about a specific model
636#[derive(Debug, Clone, Serialize, Deserialize)]
637pub struct ModelInfo {
638    /// Model name/identifier
639    pub name: String,
640
641    /// Human-readable display name
642    pub display_name: String,
643
644    /// Model description
645    pub description: Option<String>,
646
647    /// Model capabilities
648    pub capabilities: LLMCapabilities,
649
650    /// Model version
651    pub version: Option<String>,
652
653    /// Model pricing (tokens per dollar)
654    pub pricing: Option<ModelPricing>,
655
656    /// Whether the model is available
657    pub available: bool,
658}
659
660/// Pricing information for a model
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ModelPricing {
663    /// Input token cost (per 1K tokens)
664    pub input_cost_per_1k: Option<f64>,
665
666    /// Output token cost (per 1K tokens)
667    pub output_cost_per_1k: Option<f64>,
668}
669
670// ============================================================================
671// CORE PROVIDER TRAIT
672// ============================================================================
673
674/// Core trait for LLM providers
675///
676/// Implement this trait to add support for new LLM providers. The trait provides
677/// a standardized interface for generating text, managing models, and handling
678/// provider-specific functionality.
679///
680/// # Examples
681///
682/// ```rust,no_run
683/// use turbomcp_client::llm::{LLMProvider, LLMRequest, LLMResponse, LLMResult, ModelInfo, LLMCapabilities};
684/// use async_trait::async_trait;
685///
686/// #[derive(Debug)]
687/// struct CustomProvider;
688///
689/// #[async_trait]
690/// impl LLMProvider for CustomProvider {
691///     fn name(&self) -> &str {
692///         "custom"
693///     }
694///
695///     async fn generate(&self, request: &LLMRequest) -> LLMResult<LLMResponse> {
696///         // Implementation here
697///         todo!()
698///     }
699///
700///     async fn list_models(&self) -> LLMResult<Vec<ModelInfo>> {
701///         // Return available models
702///         todo!()
703///     }
704///
705///     fn capabilities(&self) -> &LLMCapabilities {
706///         // Return provider capabilities
707///         todo!()
708///     }
709/// }
710/// ```
711#[async_trait]
712pub trait LLMProvider: Send + Sync + std::fmt::Debug {
713    /// Provider name (e.g., "openai", "anthropic")
714    fn name(&self) -> &str;
715
716    /// Provider version
717    fn version(&self) -> &str {
718        "1.0.0"
719    }
720
721    /// Generate a response for the given request
722    ///
723    /// This is the core method that handles text generation. Implementations
724    /// should convert the request to the provider's format, make the API call,
725    /// and return a standardized response.
726    async fn generate(&self, request: &LLMRequest) -> LLMResult<LLMResponse>;
727
728    /// List available models for this provider
729    async fn list_models(&self) -> LLMResult<Vec<ModelInfo>>;
730
731    /// Get provider capabilities
732    fn capabilities(&self) -> &LLMCapabilities;
733
734    /// Get information about a specific model
735    async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
736        let models = self.list_models().await?;
737        models
738            .into_iter()
739            .find(|m| m.name == model)
740            .ok_or_else(|| LLMError::model_not_available(model))
741    }
742
743    /// Check if a model is supported
744    async fn supports_model(&self, model: &str) -> bool {
745        self.get_model_info(model).await.is_ok()
746    }
747
748    /// Estimate token count for text (optional override)
749    fn estimate_tokens(&self, text: &str) -> usize {
750        // Simple estimation: ~1 token per 4 characters
751        text.len().div_ceil(4)
752    }
753
754    /// Validate a request before sending
755    async fn validate_request(&self, request: &LLMRequest) -> LLMResult<()> {
756        // Check if model is supported
757        if !self.supports_model(&request.model).await {
758            return Err(LLMError::model_not_available(&request.model));
759        }
760
761        // Validate messages
762        if request.messages.is_empty() {
763            return Err(LLMError::invalid_parameters(
764                "At least one message is required",
765            ));
766        }
767
768        // Check token limits if available
769        let model_info = self.get_model_info(&request.model).await?;
770        if let Some(max_tokens) = model_info.capabilities.max_context_tokens {
771            let estimated_tokens: usize = request
772                .messages
773                .iter()
774                .filter_map(|msg| msg.content.as_text())
775                .map(|text| self.estimate_tokens(text))
776                .sum();
777
778            if estimated_tokens > max_tokens {
779                return Err(LLMError::token_limit_exceeded(estimated_tokens, max_tokens));
780            }
781        }
782
783        Ok(())
784    }
785
786    /// Health check for the provider
787    async fn health_check(&self) -> LLMResult<()> {
788        // Try to list models as a basic health check
789        self.list_models().await?;
790        Ok(())
791    }
792
793    /// Handle MCP CreateMessageRequest (adapts to LLM types)
794    ///
795    /// This method provides a bridge between MCP protocol types and the LLM system.
796    /// It converts CreateMessageRequest to LLMRequest, calls generate(), and converts
797    /// the response back to CreateMessageResult.
798    async fn handle_create_message(
799        &self,
800        request: turbomcp_protocol::types::CreateMessageRequest,
801    ) -> LLMResult<turbomcp_protocol::types::CreateMessageResult> {
802        use turbomcp_protocol::types::{Content, CreateMessageResult, Role, TextContent};
803
804        // Convert MCP messages to LLM messages
805        let llm_messages: Vec<LLMMessage> = request
806            .messages
807            .iter()
808            .map(|msg| {
809                let role = match msg.role {
810                    Role::User => MessageRole::User,
811                    Role::Assistant => MessageRole::Assistant,
812                };
813
814                let content = match &msg.content {
815                    Content::Text(text) => MessageContent::text(&text.text),
816                    _ => MessageContent::text("Non-text content not yet supported"),
817                };
818
819                LLMMessage {
820                    role,
821                    content,
822                    metadata: std::collections::HashMap::new(),
823                    timestamp: chrono::Utc::now(),
824                }
825            })
826            .collect();
827
828        // Add system message if provided
829        let mut all_messages = Vec::new();
830        if let Some(system_prompt) = &request.system_prompt {
831            all_messages.push(LLMMessage::system(system_prompt));
832        }
833        all_messages.extend(llm_messages);
834
835        // Build generation parameters
836        let params = GenerationParams {
837            max_tokens: Some(request.max_tokens as i32),
838            temperature: request.temperature.map(|t| t as f32),
839            stop_sequences: request.stop_sequences.clone(),
840            ..Default::default()
841        };
842
843        // Determine model to use
844        let model = if let Some(prefs) = &request.model_preferences
845            && let Some(hints) = &prefs.hints
846            && let Some(model_hint) = hints.first()
847        {
848            model_hint.name.clone()
849        } else {
850            // Use first available model
851            let models = self.list_models().await.unwrap_or_default();
852            models
853                .first()
854                .map(|m| m.name.clone())
855                .unwrap_or_else(|| "default".to_string())
856        };
857
858        // Create LLM request
859        let llm_request = LLMRequest::new(model, all_messages).with_params(params);
860
861        // Generate response
862        let llm_response = self.generate(&llm_request).await?;
863
864        // Convert back to MCP format
865        let text = llm_response.text().unwrap_or("").to_string();
866        let result = CreateMessageResult {
867            role: Role::Assistant,
868            content: Content::Text(TextContent {
869                text,
870                annotations: None,
871                meta: None,
872            }),
873            model: Some(llm_response.model),
874            stop_reason: llm_response.stop_reason,
875            _meta: None,
876        };
877
878        Ok(result)
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885    use serde_json::json;
886
887    #[test]
888    fn test_llm_error_creation() {
889        let config_error = LLMError::configuration("Test error");
890        assert!(config_error.to_string().contains("Configuration error"));
891
892        let auth_error = LLMError::authentication("Invalid key");
893        assert!(auth_error.to_string().contains("Authentication failed"));
894
895        let token_error = LLMError::token_limit_exceeded(1000, 800);
896        assert!(token_error.to_string().contains("1000 > 800"));
897    }
898
899    #[test]
900    fn test_message_creation() {
901        let user_msg = LLMMessage::user("Hello, world!");
902        assert_eq!(user_msg.role, MessageRole::User);
903        assert_eq!(user_msg.content.as_text(), Some("Hello, world!"));
904
905        let assistant_msg = LLMMessage::assistant("Hi there!");
906        assert_eq!(assistant_msg.role, MessageRole::Assistant);
907
908        let system_msg = LLMMessage::system("You are a helpful assistant");
909        assert_eq!(system_msg.role, MessageRole::System);
910    }
911
912    #[test]
913    fn test_message_content() {
914        let text_content = MessageContent::text("Hello");
915        assert!(text_content.is_text());
916        assert!(!text_content.is_image());
917        assert_eq!(text_content.as_text(), Some("Hello"));
918
919        let image_content = MessageContent::image("https://example.com/image.jpg", None);
920        assert!(!image_content.is_text());
921        assert!(image_content.is_image());
922        assert_eq!(image_content.as_text(), None);
923    }
924
925    #[test]
926    fn test_generation_params() {
927        let params = GenerationParams {
928            temperature: Some(0.7),
929            max_tokens: Some(100),
930            ..Default::default()
931        };
932
933        assert_eq!(params.temperature, Some(0.7));
934        assert_eq!(params.max_tokens, Some(100));
935        assert_eq!(params.top_p, None);
936    }
937
938    #[test]
939    fn test_llm_request() {
940        let messages = vec![
941            LLMMessage::user("What's 2+2?"),
942            LLMMessage::assistant("2+2 equals 4."),
943        ];
944
945        let request = LLMRequest::new("gpt-4", messages.clone())
946            .with_streaming(true)
947            .with_metadata("session_id".to_string(), json!("session123"));
948
949        assert_eq!(request.model, "gpt-4");
950        assert_eq!(request.messages.len(), 2);
951        assert!(request.stream);
952        assert_eq!(request.get_user_message(), Some("What's 2+2?"));
953        assert_eq!(request.message_count(), 2);
954    }
955
956    #[test]
957    fn test_token_usage() {
958        let usage = TokenUsage::new(100, 50);
959        assert_eq!(usage.prompt_tokens, 100);
960        assert_eq!(usage.completion_tokens, 50);
961        assert_eq!(usage.total_tokens, 150);
962
963        let empty_usage = TokenUsage::empty();
964        assert_eq!(empty_usage.total_tokens, 0);
965    }
966
967    #[test]
968    fn test_llm_response() {
969        let message = LLMMessage::assistant("The answer is 4.");
970        let usage = TokenUsage::new(20, 10);
971
972        let response = LLMResponse::new(message, "gpt-4", usage)
973            .with_stop_reason("complete")
974            .with_metadata("finish_reason".to_string(), json!("stop"));
975
976        assert_eq!(response.model, "gpt-4");
977        assert_eq!(response.text(), Some("The answer is 4."));
978        assert_eq!(response.stop_reason, Some("complete".to_string()));
979        assert!(response.is_complete());
980        assert_eq!(response.usage.total_tokens, 30);
981    }
982
983    #[test]
984    fn test_provider_config() {
985        let config = LLMProviderConfig::new("test-key", "gpt-4")
986            .with_base_url("https://custom.api.com")
987            .with_timeout(60)
988            .with_max_retries(5)
989            .with_header("Custom-Header".to_string(), "value".to_string())
990            .with_option("custom_option".to_string(), json!(true));
991
992        assert_eq!(config.api_key, "test-key");
993        assert_eq!(config.model, "gpt-4");
994        assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
995        assert_eq!(config.timeout_seconds, 60);
996        assert_eq!(config.max_retries, 5);
997        assert_eq!(
998            config.headers.get("Custom-Header"),
999            Some(&"value".to_string())
1000        );
1001
1002        assert!(config.validate().is_ok());
1003
1004        let invalid_config = LLMProviderConfig {
1005            api_key: "".to_string(),
1006            ..config
1007        };
1008        assert!(invalid_config.validate().is_err());
1009    }
1010
1011    #[test]
1012    fn test_capabilities() {
1013        let mut capabilities = LLMCapabilities::default();
1014        assert!(!capabilities.streaming);
1015        assert!(!capabilities.vision);
1016        assert!(!capabilities.function_calling);
1017        assert_eq!(capabilities.content_types, vec!["text".to_string()]);
1018
1019        capabilities.streaming = true;
1020        capabilities.vision = true;
1021        capabilities.max_context_tokens = Some(128000);
1022
1023        assert!(capabilities.streaming);
1024        assert!(capabilities.vision);
1025        assert_eq!(capabilities.max_context_tokens, Some(128000));
1026    }
1027
1028    #[test]
1029    fn test_model_info() {
1030        let capabilities = LLMCapabilities {
1031            streaming: true,
1032            vision: false,
1033            max_context_tokens: Some(128000),
1034            ..Default::default()
1035        };
1036
1037        let model = ModelInfo {
1038            name: "gpt-4".to_string(),
1039            display_name: "GPT-4".to_string(),
1040            description: Some("Advanced language model".to_string()),
1041            capabilities,
1042            version: Some("2024-01".to_string()),
1043            pricing: None,
1044            available: true,
1045        };
1046
1047        assert_eq!(model.name, "gpt-4");
1048        assert_eq!(model.display_name, "GPT-4");
1049        assert!(model.available);
1050        assert!(model.capabilities.streaming);
1051    }
1052}