1use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::Duration;
11use thiserror::Error;
12
13#[derive(Error, Debug)]
19pub enum LLMError {
20 #[error("Configuration error: {message}")]
22 Configuration { message: String },
23
24 #[error("Authentication failed: {details}")]
26 Authentication { details: String },
27
28 #[error("Network error: {message}")]
30 Network { message: String },
31
32 #[error("Rate limited. Retry after: {retry_after_seconds}s")]
34 RateLimit { retry_after_seconds: u64 },
35
36 #[error("Invalid parameters: {details}")]
38 InvalidParameters { details: String },
39
40 #[error("Provider error [{code}]: {message}")]
42 ProviderError { code: i32, message: String },
43
44 #[error("Request timed out after {seconds}s")]
46 Timeout { seconds: u64 },
47
48 #[error("Content processing error: {details}")]
50 ContentProcessing { details: String },
51
52 #[error("Token limit exceeded: {used} > {limit}")]
54 TokenLimitExceeded { used: usize, limit: usize },
55
56 #[error("Model '{model}' not available")]
58 ModelNotAvailable { model: String },
59
60 #[error("Provider '{provider}' not found")]
62 ProviderNotFound { provider: String },
63
64 #[error("Session error: {message}")]
66 Session { message: String },
67
68 #[error("LLM error: {message}")]
70 Generic { message: String },
71}
72
73impl LLMError {
74 pub fn configuration(message: impl Into<String>) -> Self {
76 Self::Configuration {
77 message: message.into(),
78 }
79 }
80
81 pub fn authentication(details: impl Into<String>) -> Self {
83 Self::Authentication {
84 details: details.into(),
85 }
86 }
87
88 pub fn network(message: impl Into<String>) -> Self {
90 Self::Network {
91 message: message.into(),
92 }
93 }
94
95 pub fn invalid_parameters(details: impl Into<String>) -> Self {
97 Self::InvalidParameters {
98 details: details.into(),
99 }
100 }
101
102 pub fn provider_error(code: i32, message: impl Into<String>) -> Self {
104 Self::ProviderError {
105 code,
106 message: message.into(),
107 }
108 }
109
110 pub fn timeout(seconds: u64) -> Self {
112 Self::Timeout { seconds }
113 }
114
115 pub fn content_processing(details: impl Into<String>) -> Self {
117 Self::ContentProcessing {
118 details: details.into(),
119 }
120 }
121
122 pub fn token_limit_exceeded(used: usize, limit: usize) -> Self {
124 Self::TokenLimitExceeded { used, limit }
125 }
126
127 pub fn model_not_available(model: impl Into<String>) -> Self {
129 Self::ModelNotAvailable {
130 model: model.into(),
131 }
132 }
133
134 pub fn provider_not_found(provider: impl Into<String>) -> Self {
136 Self::ProviderNotFound {
137 provider: provider.into(),
138 }
139 }
140
141 pub fn session(message: impl Into<String>) -> Self {
143 Self::Session {
144 message: message.into(),
145 }
146 }
147
148 pub fn generic(message: impl Into<String>) -> Self {
150 Self::Generic {
151 message: message.into(),
152 }
153 }
154}
155
156pub type LLMResult<T> = Result<T, LLMError>;
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
164#[serde(rename_all = "lowercase")]
165pub enum MessageRole {
166 User,
168 Assistant,
170 System,
172 Function,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(tag = "type")]
179pub enum MessageContent {
180 #[serde(rename = "text")]
182 Text { text: String },
183
184 #[serde(rename = "image")]
186 Image { url: String, detail: Option<String> },
187
188 #[serde(rename = "tool_call")]
190 ToolCall {
191 id: String,
192 function: String,
193 arguments: serde_json::Value,
194 },
195
196 #[serde(rename = "tool_result")]
198 ToolResult {
199 tool_call_id: String,
200 result: serde_json::Value,
201 is_error: bool,
202 },
203}
204
205impl MessageContent {
206 pub fn text(text: impl Into<String>) -> Self {
208 Self::Text { text: text.into() }
209 }
210
211 pub fn image(url: impl Into<String>, detail: Option<String>) -> Self {
213 Self::Image {
214 url: url.into(),
215 detail,
216 }
217 }
218
219 pub fn as_text(&self) -> Option<&str> {
221 match self {
222 Self::Text { text } => Some(text),
223 _ => None,
224 }
225 }
226
227 pub fn is_text(&self) -> bool {
229 matches!(self, Self::Text { .. })
230 }
231
232 pub fn is_image(&self) -> bool {
234 matches!(self, Self::Image { .. })
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct LLMMessage {
241 pub role: MessageRole,
243
244 pub content: MessageContent,
246
247 pub metadata: HashMap<String, serde_json::Value>,
249
250 pub timestamp: DateTime<Utc>,
252}
253
254impl LLMMessage {
255 pub fn user(content: impl Into<String>) -> Self {
257 Self {
258 role: MessageRole::User,
259 content: MessageContent::text(content),
260 metadata: HashMap::new(),
261 timestamp: Utc::now(),
262 }
263 }
264
265 pub fn assistant(content: impl Into<String>) -> Self {
267 Self {
268 role: MessageRole::Assistant,
269 content: MessageContent::text(content),
270 metadata: HashMap::new(),
271 timestamp: Utc::now(),
272 }
273 }
274
275 pub fn system(content: impl Into<String>) -> Self {
277 Self {
278 role: MessageRole::System,
279 content: MessageContent::text(content),
280 metadata: HashMap::new(),
281 timestamp: Utc::now(),
282 }
283 }
284
285 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
287 self.metadata.insert(key, value);
288 self
289 }
290
291 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
293 self.metadata.get(key)
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize, Default)]
303pub struct GenerationParams {
304 pub temperature: Option<f32>,
306
307 pub top_p: Option<f32>,
309
310 pub top_k: Option<i32>,
312
313 pub max_tokens: Option<i32>,
315
316 pub stop_sequences: Option<Vec<String>>,
318
319 pub frequency_penalty: Option<f32>,
321
322 pub presence_penalty: Option<f32>,
324
325 pub seed: Option<i64>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct LLMRequest {
334 pub model: String,
336
337 pub messages: Vec<LLMMessage>,
339
340 pub params: GenerationParams,
342
343 pub stream: bool,
345
346 pub metadata: HashMap<String, serde_json::Value>,
348
349 pub timeout: Option<Duration>,
351}
352
353impl LLMRequest {
354 pub fn new(model: impl Into<String>, messages: Vec<LLMMessage>) -> Self {
356 Self {
357 model: model.into(),
358 messages,
359 params: GenerationParams::default(),
360 stream: false,
361 metadata: HashMap::new(),
362 timeout: None,
363 }
364 }
365
366 pub fn with_params(mut self, params: GenerationParams) -> Self {
368 self.params = params;
369 self
370 }
371
372 pub fn with_streaming(mut self, stream: bool) -> Self {
374 self.stream = stream;
375 self
376 }
377
378 pub fn with_timeout(mut self, timeout: Duration) -> Self {
380 self.timeout = Some(timeout);
381 self
382 }
383
384 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
386 self.metadata.insert(key, value);
387 self
388 }
389
390 pub fn get_user_message(&self) -> Option<&str> {
392 self.messages
393 .iter()
394 .find(|msg| msg.role == MessageRole::User)
395 .and_then(|msg| msg.content.as_text())
396 }
397
398 pub fn message_count(&self) -> usize {
400 self.messages.len()
401 }
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct TokenUsage {
407 pub prompt_tokens: usize,
409
410 pub completion_tokens: usize,
412
413 pub total_tokens: usize,
415}
416
417impl TokenUsage {
418 pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
420 Self {
421 prompt_tokens,
422 completion_tokens,
423 total_tokens: prompt_tokens + completion_tokens,
424 }
425 }
426
427 pub fn empty() -> Self {
429 Self::new(0, 0)
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct LLMResponse {
436 pub message: LLMMessage,
438
439 pub model: String,
441
442 pub usage: TokenUsage,
444
445 pub stop_reason: Option<String>,
447
448 pub metadata: HashMap<String, serde_json::Value>,
450
451 pub timestamp: DateTime<Utc>,
453}
454
455impl LLMResponse {
456 pub fn new(message: LLMMessage, model: impl Into<String>, usage: TokenUsage) -> Self {
458 Self {
459 message,
460 model: model.into(),
461 usage,
462 stop_reason: None,
463 metadata: HashMap::new(),
464 timestamp: Utc::now(),
465 }
466 }
467
468 pub fn with_stop_reason(mut self, stop_reason: impl Into<String>) -> Self {
470 self.stop_reason = Some(stop_reason.into());
471 self
472 }
473
474 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
476 self.metadata.insert(key, value);
477 self
478 }
479
480 pub fn text(&self) -> Option<&str> {
482 self.message.content.as_text()
483 }
484
485 pub fn is_complete(&self) -> bool {
487 self.stop_reason.is_some()
488 }
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct LLMProviderConfig {
498 pub api_key: String,
500
501 pub base_url: Option<String>,
503
504 pub model: String,
506
507 pub timeout_seconds: u64,
509
510 pub max_retries: u32,
512
513 pub headers: HashMap<String, String>,
515
516 pub options: HashMap<String, serde_json::Value>,
518}
519
520impl Default for LLMProviderConfig {
521 fn default() -> Self {
522 Self {
523 api_key: String::new(),
524 base_url: None,
525 model: String::new(),
526 timeout_seconds: 30,
527 max_retries: 3,
528 headers: HashMap::new(),
529 options: HashMap::new(),
530 }
531 }
532}
533
534impl LLMProviderConfig {
535 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
537 Self {
538 api_key: api_key.into(),
539 model: model.into(),
540 ..Default::default()
541 }
542 }
543
544 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
546 self.base_url = Some(base_url.into());
547 self
548 }
549
550 pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
552 self.timeout_seconds = timeout_seconds;
553 self
554 }
555
556 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
558 self.max_retries = max_retries;
559 self
560 }
561
562 pub fn with_header(mut self, key: String, value: String) -> Self {
564 self.headers.insert(key, value);
565 self
566 }
567
568 pub fn with_option(mut self, key: String, value: serde_json::Value) -> Self {
570 self.options.insert(key, value);
571 self
572 }
573
574 pub fn validate(&self) -> LLMResult<()> {
576 if self.api_key.trim().is_empty() {
577 return Err(LLMError::configuration("API key cannot be empty"));
578 }
579
580 if self.model.trim().is_empty() {
581 return Err(LLMError::configuration("Model cannot be empty"));
582 }
583
584 if self.timeout_seconds == 0 {
585 return Err(LLMError::configuration("Timeout must be greater than 0"));
586 }
587
588 Ok(())
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct LLMCapabilities {
599 pub streaming: bool,
601
602 pub vision: bool,
604
605 pub function_calling: bool,
607
608 pub json_mode: bool,
610
611 pub max_context_tokens: Option<usize>,
613
614 pub max_output_tokens: Option<usize>,
616
617 pub content_types: Vec<String>,
619}
620
621impl Default for LLMCapabilities {
622 fn default() -> Self {
623 Self {
624 streaming: false,
625 vision: false,
626 function_calling: false,
627 json_mode: false,
628 max_context_tokens: None,
629 max_output_tokens: None,
630 content_types: vec!["text".to_string()],
631 }
632 }
633}
634
635#[derive(Debug, Clone, Serialize, Deserialize)]
637pub struct ModelInfo {
638 pub name: String,
640
641 pub display_name: String,
643
644 pub description: Option<String>,
646
647 pub capabilities: LLMCapabilities,
649
650 pub version: Option<String>,
652
653 pub pricing: Option<ModelPricing>,
655
656 pub available: bool,
658}
659
660#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ModelPricing {
663 pub input_cost_per_1k: Option<f64>,
665
666 pub output_cost_per_1k: Option<f64>,
668}
669
670#[async_trait]
712pub trait LLMProvider: Send + Sync + std::fmt::Debug {
713 fn name(&self) -> &str;
715
716 fn version(&self) -> &str {
718 "1.0.0"
719 }
720
721 async fn generate(&self, request: &LLMRequest) -> LLMResult<LLMResponse>;
727
728 async fn list_models(&self) -> LLMResult<Vec<ModelInfo>>;
730
731 fn capabilities(&self) -> &LLMCapabilities;
733
734 async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
736 let models = self.list_models().await?;
737 models
738 .into_iter()
739 .find(|m| m.name == model)
740 .ok_or_else(|| LLMError::model_not_available(model))
741 }
742
743 async fn supports_model(&self, model: &str) -> bool {
745 self.get_model_info(model).await.is_ok()
746 }
747
748 fn estimate_tokens(&self, text: &str) -> usize {
750 text.len().div_ceil(4)
752 }
753
754 async fn validate_request(&self, request: &LLMRequest) -> LLMResult<()> {
756 if !self.supports_model(&request.model).await {
758 return Err(LLMError::model_not_available(&request.model));
759 }
760
761 if request.messages.is_empty() {
763 return Err(LLMError::invalid_parameters(
764 "At least one message is required",
765 ));
766 }
767
768 let model_info = self.get_model_info(&request.model).await?;
770 if let Some(max_tokens) = model_info.capabilities.max_context_tokens {
771 let estimated_tokens: usize = request
772 .messages
773 .iter()
774 .filter_map(|msg| msg.content.as_text())
775 .map(|text| self.estimate_tokens(text))
776 .sum();
777
778 if estimated_tokens > max_tokens {
779 return Err(LLMError::token_limit_exceeded(estimated_tokens, max_tokens));
780 }
781 }
782
783 Ok(())
784 }
785
786 async fn health_check(&self) -> LLMResult<()> {
788 self.list_models().await?;
790 Ok(())
791 }
792
793 async fn handle_create_message(
799 &self,
800 request: turbomcp_protocol::types::CreateMessageRequest,
801 ) -> LLMResult<turbomcp_protocol::types::CreateMessageResult> {
802 use turbomcp_protocol::types::{Content, CreateMessageResult, Role, TextContent};
803
804 let llm_messages: Vec<LLMMessage> = request
806 .messages
807 .iter()
808 .map(|msg| {
809 let role = match msg.role {
810 Role::User => MessageRole::User,
811 Role::Assistant => MessageRole::Assistant,
812 };
813
814 let content = match &msg.content {
815 Content::Text(text) => MessageContent::text(&text.text),
816 _ => MessageContent::text("Non-text content not yet supported"),
817 };
818
819 LLMMessage {
820 role,
821 content,
822 metadata: std::collections::HashMap::new(),
823 timestamp: chrono::Utc::now(),
824 }
825 })
826 .collect();
827
828 let mut all_messages = Vec::new();
830 if let Some(system_prompt) = &request.system_prompt {
831 all_messages.push(LLMMessage::system(system_prompt));
832 }
833 all_messages.extend(llm_messages);
834
835 let params = GenerationParams {
837 max_tokens: Some(request.max_tokens as i32),
838 temperature: request.temperature.map(|t| t as f32),
839 stop_sequences: request.stop_sequences.clone(),
840 ..Default::default()
841 };
842
843 let model = if let Some(prefs) = &request.model_preferences
845 && let Some(hints) = &prefs.hints
846 && let Some(model_hint) = hints.first()
847 {
848 model_hint.name.clone()
849 } else {
850 let models = self.list_models().await.unwrap_or_default();
852 models
853 .first()
854 .map(|m| m.name.clone())
855 .unwrap_or_else(|| "default".to_string())
856 };
857
858 let llm_request = LLMRequest::new(model, all_messages).with_params(params);
860
861 let llm_response = self.generate(&llm_request).await?;
863
864 let text = llm_response.text().unwrap_or("").to_string();
866 let result = CreateMessageResult {
867 role: Role::Assistant,
868 content: Content::Text(TextContent {
869 text,
870 annotations: None,
871 meta: None,
872 }),
873 model: Some(llm_response.model),
874 stop_reason: llm_response.stop_reason,
875 _meta: None,
876 };
877
878 Ok(result)
879 }
880}
881
882#[cfg(test)]
883mod tests {
884 use super::*;
885 use serde_json::json;
886
887 #[test]
888 fn test_llm_error_creation() {
889 let config_error = LLMError::configuration("Test error");
890 assert!(config_error.to_string().contains("Configuration error"));
891
892 let auth_error = LLMError::authentication("Invalid key");
893 assert!(auth_error.to_string().contains("Authentication failed"));
894
895 let token_error = LLMError::token_limit_exceeded(1000, 800);
896 assert!(token_error.to_string().contains("1000 > 800"));
897 }
898
899 #[test]
900 fn test_message_creation() {
901 let user_msg = LLMMessage::user("Hello, world!");
902 assert_eq!(user_msg.role, MessageRole::User);
903 assert_eq!(user_msg.content.as_text(), Some("Hello, world!"));
904
905 let assistant_msg = LLMMessage::assistant("Hi there!");
906 assert_eq!(assistant_msg.role, MessageRole::Assistant);
907
908 let system_msg = LLMMessage::system("You are a helpful assistant");
909 assert_eq!(system_msg.role, MessageRole::System);
910 }
911
912 #[test]
913 fn test_message_content() {
914 let text_content = MessageContent::text("Hello");
915 assert!(text_content.is_text());
916 assert!(!text_content.is_image());
917 assert_eq!(text_content.as_text(), Some("Hello"));
918
919 let image_content = MessageContent::image("https://example.com/image.jpg", None);
920 assert!(!image_content.is_text());
921 assert!(image_content.is_image());
922 assert_eq!(image_content.as_text(), None);
923 }
924
925 #[test]
926 fn test_generation_params() {
927 let params = GenerationParams {
928 temperature: Some(0.7),
929 max_tokens: Some(100),
930 ..Default::default()
931 };
932
933 assert_eq!(params.temperature, Some(0.7));
934 assert_eq!(params.max_tokens, Some(100));
935 assert_eq!(params.top_p, None);
936 }
937
938 #[test]
939 fn test_llm_request() {
940 let messages = vec![
941 LLMMessage::user("What's 2+2?"),
942 LLMMessage::assistant("2+2 equals 4."),
943 ];
944
945 let request = LLMRequest::new("gpt-4", messages.clone())
946 .with_streaming(true)
947 .with_metadata("session_id".to_string(), json!("session123"));
948
949 assert_eq!(request.model, "gpt-4");
950 assert_eq!(request.messages.len(), 2);
951 assert!(request.stream);
952 assert_eq!(request.get_user_message(), Some("What's 2+2?"));
953 assert_eq!(request.message_count(), 2);
954 }
955
956 #[test]
957 fn test_token_usage() {
958 let usage = TokenUsage::new(100, 50);
959 assert_eq!(usage.prompt_tokens, 100);
960 assert_eq!(usage.completion_tokens, 50);
961 assert_eq!(usage.total_tokens, 150);
962
963 let empty_usage = TokenUsage::empty();
964 assert_eq!(empty_usage.total_tokens, 0);
965 }
966
967 #[test]
968 fn test_llm_response() {
969 let message = LLMMessage::assistant("The answer is 4.");
970 let usage = TokenUsage::new(20, 10);
971
972 let response = LLMResponse::new(message, "gpt-4", usage)
973 .with_stop_reason("complete")
974 .with_metadata("finish_reason".to_string(), json!("stop"));
975
976 assert_eq!(response.model, "gpt-4");
977 assert_eq!(response.text(), Some("The answer is 4."));
978 assert_eq!(response.stop_reason, Some("complete".to_string()));
979 assert!(response.is_complete());
980 assert_eq!(response.usage.total_tokens, 30);
981 }
982
983 #[test]
984 fn test_provider_config() {
985 let config = LLMProviderConfig::new("test-key", "gpt-4")
986 .with_base_url("https://custom.api.com")
987 .with_timeout(60)
988 .with_max_retries(5)
989 .with_header("Custom-Header".to_string(), "value".to_string())
990 .with_option("custom_option".to_string(), json!(true));
991
992 assert_eq!(config.api_key, "test-key");
993 assert_eq!(config.model, "gpt-4");
994 assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
995 assert_eq!(config.timeout_seconds, 60);
996 assert_eq!(config.max_retries, 5);
997 assert_eq!(
998 config.headers.get("Custom-Header"),
999 Some(&"value".to_string())
1000 );
1001
1002 assert!(config.validate().is_ok());
1003
1004 let invalid_config = LLMProviderConfig {
1005 api_key: "".to_string(),
1006 ..config
1007 };
1008 assert!(invalid_config.validate().is_err());
1009 }
1010
1011 #[test]
1012 fn test_capabilities() {
1013 let mut capabilities = LLMCapabilities::default();
1014 assert!(!capabilities.streaming);
1015 assert!(!capabilities.vision);
1016 assert!(!capabilities.function_calling);
1017 assert_eq!(capabilities.content_types, vec!["text".to_string()]);
1018
1019 capabilities.streaming = true;
1020 capabilities.vision = true;
1021 capabilities.max_context_tokens = Some(128000);
1022
1023 assert!(capabilities.streaming);
1024 assert!(capabilities.vision);
1025 assert_eq!(capabilities.max_context_tokens, Some(128000));
1026 }
1027
1028 #[test]
1029 fn test_model_info() {
1030 let capabilities = LLMCapabilities {
1031 streaming: true,
1032 vision: false,
1033 max_context_tokens: Some(128000),
1034 ..Default::default()
1035 };
1036
1037 let model = ModelInfo {
1038 name: "gpt-4".to_string(),
1039 display_name: "GPT-4".to_string(),
1040 description: Some("Advanced language model".to_string()),
1041 capabilities,
1042 version: Some("2024-01".to_string()),
1043 pricing: None,
1044 available: true,
1045 };
1046
1047 assert_eq!(model.name, "gpt-4");
1048 assert_eq!(model.display_name, "GPT-4");
1049 assert!(model.available);
1050 assert!(model.capabilities.streaming);
1051 }
1052}