1use crate::models::{
54 integrations::{anthropic::AnthropicModel, gemini::GeminiModel, openai::OpenAIModel},
55 model_pricing::{ContextAware, ModelContextInfo},
56};
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59use std::fmt::Display;
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ProviderConfig {
95 OpenAI {
97 #[serde(skip_serializing_if = "Option::is_none")]
98 api_key: Option<String>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 api_endpoint: Option<String>,
101 },
102 Anthropic {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 api_key: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 api_endpoint: Option<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 access_token: Option<String>,
111 },
112 Gemini {
114 #[serde(skip_serializing_if = "Option::is_none")]
115 api_key: Option<String>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 api_endpoint: Option<String>,
118 },
119 Custom {
137 #[serde(skip_serializing_if = "Option::is_none")]
138 api_key: Option<String>,
139 api_endpoint: String,
142 },
143 Stakpak {
161 api_key: String,
163 #[serde(skip_serializing_if = "Option::is_none")]
165 api_endpoint: Option<String>,
166 },
167}
168
169impl ProviderConfig {
170 pub fn provider_type(&self) -> &'static str {
172 match self {
173 ProviderConfig::OpenAI { .. } => "openai",
174 ProviderConfig::Anthropic { .. } => "anthropic",
175 ProviderConfig::Gemini { .. } => "gemini",
176 ProviderConfig::Custom { .. } => "custom",
177 ProviderConfig::Stakpak { .. } => "stakpak",
178 }
179 }
180
181 pub fn api_key(&self) -> Option<&str> {
183 match self {
184 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
185 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
186 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
187 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
188 ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
189 }
190 }
191
192 pub fn api_endpoint(&self) -> Option<&str> {
194 match self {
195 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
196 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
197 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
198 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
199 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
200 }
201 }
202
203 pub fn access_token(&self) -> Option<&str> {
205 match self {
206 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
207 _ => None,
208 }
209 }
210
211 pub fn openai(api_key: Option<String>) -> Self {
213 ProviderConfig::OpenAI {
214 api_key,
215 api_endpoint: None,
216 }
217 }
218
219 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
221 ProviderConfig::Anthropic {
222 api_key,
223 api_endpoint: None,
224 access_token,
225 }
226 }
227
228 pub fn gemini(api_key: Option<String>) -> Self {
230 ProviderConfig::Gemini {
231 api_key,
232 api_endpoint: None,
233 }
234 }
235
236 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
238 ProviderConfig::Custom {
239 api_key,
240 api_endpoint,
241 }
242 }
243
244 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
246 ProviderConfig::Stakpak {
247 api_key,
248 api_endpoint,
249 }
250 }
251}
252
253#[derive(Clone, Debug, PartialEq, Serialize)]
254pub enum LLMModel {
255 Anthropic(AnthropicModel),
256 Gemini(GeminiModel),
257 OpenAI(OpenAIModel),
258 Custom {
268 provider: String,
270 model: String,
272 name: Option<String>,
274 },
275}
276
277impl ContextAware for LLMModel {
278 fn context_info(&self) -> ModelContextInfo {
279 match self {
280 LLMModel::Anthropic(model) => model.context_info(),
281 LLMModel::Gemini(model) => model.context_info(),
282 LLMModel::OpenAI(model) => model.context_info(),
283 LLMModel::Custom { .. } => ModelContextInfo::default(),
284 }
285 }
286
287 fn model_name(&self) -> String {
288 match self {
289 LLMModel::Anthropic(model) => model.model_name(),
290 LLMModel::Gemini(model) => model.model_name(),
291 LLMModel::OpenAI(model) => model.model_name(),
292 LLMModel::Custom {
293 provider,
294 model,
295 name,
296 } => name
297 .clone()
298 .unwrap_or_else(|| format!("{}/{}", provider, model)),
299 }
300 }
301}
302
303#[derive(Debug, Clone, Default)]
307pub struct LLMProviderConfig {
308 pub providers: HashMap<String, ProviderConfig>,
310}
311
312impl LLMProviderConfig {
313 pub fn new() -> Self {
315 Self {
316 providers: HashMap::new(),
317 }
318 }
319
320 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
322 self.providers.insert(name.into(), config);
323 }
324
325 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
327 self.providers.get(name)
328 }
329
330 pub fn is_empty(&self) -> bool {
332 self.providers.is_empty()
333 }
334}
335
336impl From<String> for LLMModel {
337 fn from(value: String) -> Self {
350 if let Some((provider, model)) = value.split_once('/') {
353 match provider {
355 "anthropic" => return Self::from_model_name(model),
356 "openai" => return Self::from_model_name(model),
357 "google" | "gemini" => return Self::from_model_name(model),
358 _ => {
360 let display_name = model.rsplit('/').next().unwrap_or(model).to_string();
362 return LLMModel::Custom {
363 provider: provider.to_string(),
364 model: model.to_string(), name: Some(display_name),
366 };
367 }
368 }
369 }
370
371 Self::from_model_name(&value)
373 }
374}
375
376impl LLMModel {
377 fn from_model_name(model: &str) -> Self {
379 if model.starts_with("claude-haiku-4-5") {
380 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
381 } else if model.starts_with("claude-sonnet-4-5") {
382 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
383 } else if model.starts_with("claude-opus-4-5") {
384 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
385 } else if model == "gemini-2.5-flash-lite" {
386 LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
387 } else if model.starts_with("gemini-2.5-flash") {
388 LLMModel::Gemini(GeminiModel::Gemini25Flash)
389 } else if model.starts_with("gemini-2.5-pro") {
390 LLMModel::Gemini(GeminiModel::Gemini25Pro)
391 } else if model.starts_with("gemini-3-pro-preview") {
392 LLMModel::Gemini(GeminiModel::Gemini3Pro)
393 } else if model.starts_with("gemini-3-flash-preview") {
394 LLMModel::Gemini(GeminiModel::Gemini3Flash)
395 } else if model.starts_with("gpt-5-mini") {
396 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
397 } else if model.starts_with("gpt-5") {
398 LLMModel::OpenAI(OpenAIModel::GPT5)
399 } else {
400 LLMModel::Custom {
402 provider: "custom".to_string(),
403 model: model.to_string(),
404 name: Some(model.to_string()), }
406 }
407 }
408
409 pub fn provider_name(&self) -> &str {
411 match self {
412 LLMModel::Anthropic(_) => "anthropic",
413 LLMModel::Gemini(_) => "google",
414 LLMModel::OpenAI(_) => "openai",
415 LLMModel::Custom { provider, .. } => provider,
416 }
417 }
418
419 pub fn model_id(&self) -> String {
421 match self {
422 LLMModel::Anthropic(m) => m.to_string(),
423 LLMModel::Gemini(m) => m.to_string(),
424 LLMModel::OpenAI(m) => m.to_string(),
425 LLMModel::Custom { model, .. } => model.clone(),
426 }
427 }
428
429 pub fn with_name(self, name: impl Into<String>) -> Self {
431 match self {
432 LLMModel::Custom {
433 provider, model, ..
434 } => LLMModel::Custom {
435 provider,
436 model,
437 name: Some(name.into()),
438 },
439 other => other, }
441 }
442
443 pub fn display_name(&self) -> Option<&str> {
445 match self {
446 LLMModel::Custom { name, .. } => name.as_deref(),
447 _ => None,
448 }
449 }
450}
451
452impl Display for LLMModel {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 match self {
455 LLMModel::Anthropic(model) => write!(f, "{}", model),
456 LLMModel::Gemini(model) => write!(f, "{}", model),
457 LLMModel::OpenAI(model) => write!(f, "{}", model),
458 LLMModel::Custom {
459 provider,
460 model,
461 name,
462 } => {
463 if let Some(name) = name {
464 write!(f, "{}", name)
465 } else {
466 write!(f, "{}/{}", provider, model)
467 }
468 }
469 }
470 }
471}
472
473#[derive(Clone, Debug, Serialize, Deserialize, Default)]
475pub struct LLMProviderOptions {
476 #[serde(skip_serializing_if = "Option::is_none")]
478 pub anthropic: Option<LLMAnthropicOptions>,
479
480 #[serde(skip_serializing_if = "Option::is_none")]
482 pub openai: Option<LLMOpenAIOptions>,
483
484 #[serde(skip_serializing_if = "Option::is_none")]
486 pub google: Option<LLMGoogleOptions>,
487}
488
489#[derive(Clone, Debug, Serialize, Deserialize, Default)]
491pub struct LLMAnthropicOptions {
492 #[serde(skip_serializing_if = "Option::is_none")]
494 pub thinking: Option<LLMThinkingOptions>,
495}
496
497#[derive(Clone, Debug, Serialize, Deserialize)]
499pub struct LLMThinkingOptions {
500 pub budget_tokens: u32,
502}
503
504impl LLMThinkingOptions {
505 pub fn new(budget_tokens: u32) -> Self {
506 Self {
507 budget_tokens: budget_tokens.max(1024),
508 }
509 }
510}
511
512#[derive(Clone, Debug, Serialize, Deserialize, Default)]
514pub struct LLMOpenAIOptions {
515 #[serde(skip_serializing_if = "Option::is_none")]
517 pub reasoning_effort: Option<String>,
518}
519
520#[derive(Clone, Debug, Serialize, Deserialize, Default)]
522pub struct LLMGoogleOptions {
523 #[serde(skip_serializing_if = "Option::is_none")]
525 pub thinking_budget: Option<u32>,
526}
527
528#[derive(Clone, Debug, Serialize)]
529pub struct LLMInput {
530 pub model: LLMModel,
531 pub messages: Vec<LLMMessage>,
532 pub max_tokens: u32,
533 pub tools: Option<Vec<LLMTool>>,
534 #[serde(skip_serializing_if = "Option::is_none")]
535 pub provider_options: Option<LLMProviderOptions>,
536 #[serde(skip_serializing_if = "Option::is_none")]
538 pub headers: Option<std::collections::HashMap<String, String>>,
539}
540
541#[derive(Debug)]
542pub struct LLMStreamInput {
543 pub model: LLMModel,
544 pub messages: Vec<LLMMessage>,
545 pub max_tokens: u32,
546 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
547 pub tools: Option<Vec<LLMTool>>,
548 pub provider_options: Option<LLMProviderOptions>,
549 pub headers: Option<std::collections::HashMap<String, String>>,
551}
552
553impl From<&LLMStreamInput> for LLMInput {
554 fn from(value: &LLMStreamInput) -> Self {
555 LLMInput {
556 model: value.model.clone(),
557 messages: value.messages.clone(),
558 max_tokens: value.max_tokens,
559 tools: value.tools.clone(),
560 provider_options: value.provider_options.clone(),
561 headers: value.headers.clone(),
562 }
563 }
564}
565
566#[derive(Serialize, Deserialize, Debug, Clone, Default)]
567pub struct LLMMessage {
568 pub role: String,
569 pub content: LLMMessageContent,
570}
571
572#[derive(Serialize, Deserialize, Debug, Clone)]
573pub struct SimpleLLMMessage {
574 #[serde(rename = "role")]
575 pub role: SimpleLLMRole,
576 pub content: String,
577}
578
579#[derive(Serialize, Deserialize, Debug, Clone)]
580#[serde(rename_all = "lowercase")]
581pub enum SimpleLLMRole {
582 User,
583 Assistant,
584}
585
586impl std::fmt::Display for SimpleLLMRole {
587 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588 match self {
589 SimpleLLMRole::User => write!(f, "user"),
590 SimpleLLMRole::Assistant => write!(f, "assistant"),
591 }
592 }
593}
594
595#[derive(Serialize, Deserialize, Debug, Clone)]
596#[serde(untagged)]
597pub enum LLMMessageContent {
598 String(String),
599 List(Vec<LLMMessageTypedContent>),
600}
601
602#[allow(clippy::to_string_trait_impl)]
603impl ToString for LLMMessageContent {
604 fn to_string(&self) -> String {
605 match self {
606 LLMMessageContent::String(s) => s.clone(),
607 LLMMessageContent::List(l) => l
608 .iter()
609 .map(|c| match c {
610 LLMMessageTypedContent::Text { text } => text.clone(),
611 LLMMessageTypedContent::ToolCall { .. } => String::new(),
612 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
613 LLMMessageTypedContent::Image { .. } => String::new(),
614 })
615 .collect::<Vec<_>>()
616 .join("\n"),
617 }
618 }
619}
620
621impl From<String> for LLMMessageContent {
622 fn from(value: String) -> Self {
623 LLMMessageContent::String(value)
624 }
625}
626
627impl Default for LLMMessageContent {
628 fn default() -> Self {
629 LLMMessageContent::String(String::new())
630 }
631}
632
633#[derive(Serialize, Deserialize, Debug, Clone)]
634#[serde(tag = "type")]
635pub enum LLMMessageTypedContent {
636 #[serde(rename = "text")]
637 Text { text: String },
638 #[serde(rename = "tool_use")]
639 ToolCall {
640 id: String,
641 name: String,
642 #[serde(alias = "input")]
643 args: serde_json::Value,
644 },
645 #[serde(rename = "tool_result")]
646 ToolResult {
647 tool_use_id: String,
648 content: String,
649 },
650 #[serde(rename = "image")]
651 Image { source: LLMMessageImageSource },
652}
653
654#[derive(Serialize, Deserialize, Debug, Clone)]
655pub struct LLMMessageImageSource {
656 #[serde(rename = "type")]
657 pub r#type: String,
658 pub media_type: String,
659 pub data: String,
660}
661
662impl Default for LLMMessageTypedContent {
663 fn default() -> Self {
664 LLMMessageTypedContent::Text {
665 text: String::new(),
666 }
667 }
668}
669
670#[derive(Serialize, Deserialize, Debug, Clone)]
671pub struct LLMChoice {
672 pub finish_reason: Option<String>,
673 pub index: u32,
674 pub message: LLMMessage,
675}
676
677#[derive(Serialize, Deserialize, Debug, Clone)]
678pub struct LLMCompletionResponse {
679 pub model: String,
680 pub object: String,
681 pub choices: Vec<LLMChoice>,
682 pub created: u64,
683 pub usage: Option<LLMTokenUsage>,
684 pub id: String,
685}
686
687#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct LLMStreamDelta {
689 #[serde(skip_serializing_if = "Option::is_none")]
690 pub content: Option<String>,
691}
692
693#[derive(Serialize, Deserialize, Debug, Clone)]
694pub struct LLMStreamChoice {
695 pub finish_reason: Option<String>,
696 pub index: u32,
697 pub message: Option<LLMMessage>,
698 pub delta: LLMStreamDelta,
699}
700
701#[derive(Serialize, Deserialize, Debug, Clone)]
702pub struct LLMCompletionStreamResponse {
703 pub model: String,
704 pub object: String,
705 pub choices: Vec<LLMStreamChoice>,
706 pub created: u64,
707 #[serde(skip_serializing_if = "Option::is_none")]
708 pub usage: Option<LLMTokenUsage>,
709 pub id: String,
710 pub citations: Option<Vec<String>>,
711}
712
713#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
714pub struct LLMTool {
715 pub name: String,
716 pub description: String,
717 pub input_schema: serde_json::Value,
718}
719
720#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
721pub struct LLMTokenUsage {
722 pub prompt_tokens: u32,
723 pub completion_tokens: u32,
724 pub total_tokens: u32,
725
726 #[serde(skip_serializing_if = "Option::is_none")]
727 pub prompt_tokens_details: Option<PromptTokensDetails>,
728}
729
730#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
731#[serde(rename_all = "snake_case")]
732pub enum TokenType {
733 InputTokens,
734 OutputTokens,
735 CacheReadInputTokens,
736 CacheWriteInputTokens,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
740pub struct PromptTokensDetails {
741 #[serde(skip_serializing_if = "Option::is_none")]
742 pub input_tokens: Option<u32>,
743 #[serde(skip_serializing_if = "Option::is_none")]
744 pub output_tokens: Option<u32>,
745 #[serde(skip_serializing_if = "Option::is_none")]
746 pub cache_read_input_tokens: Option<u32>,
747 #[serde(skip_serializing_if = "Option::is_none")]
748 pub cache_write_input_tokens: Option<u32>,
749}
750
751impl PromptTokensDetails {
752 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
754 [
755 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
756 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
757 (
758 TokenType::CacheReadInputTokens,
759 self.cache_read_input_tokens.unwrap_or(0),
760 ),
761 (
762 TokenType::CacheWriteInputTokens,
763 self.cache_write_input_tokens.unwrap_or(0),
764 ),
765 ]
766 .into_iter()
767 }
768}
769
770impl std::ops::Add for PromptTokensDetails {
771 type Output = Self;
772
773 fn add(self, rhs: Self) -> Self::Output {
774 Self {
775 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
776 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
777 cache_read_input_tokens: Some(
778 self.cache_read_input_tokens.unwrap_or(0)
779 + rhs.cache_read_input_tokens.unwrap_or(0),
780 ),
781 cache_write_input_tokens: Some(
782 self.cache_write_input_tokens.unwrap_or(0)
783 + rhs.cache_write_input_tokens.unwrap_or(0),
784 ),
785 }
786 }
787}
788
789impl std::ops::AddAssign for PromptTokensDetails {
790 fn add_assign(&mut self, rhs: Self) {
791 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
792 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
793 self.cache_read_input_tokens = Some(
794 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
795 );
796 self.cache_write_input_tokens = Some(
797 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
798 );
799 }
800}
801
802#[derive(Serialize, Deserialize, Debug, Clone)]
803#[serde(tag = "type")]
804pub enum GenerationDelta {
805 Content { content: String },
806 Thinking { thinking: String },
807 ToolUse { tool_use: GenerationDeltaToolUse },
808 Usage { usage: LLMTokenUsage },
809 Metadata { metadata: serde_json::Value },
810}
811
812#[derive(Serialize, Deserialize, Debug, Clone)]
813pub struct GenerationDeltaToolUse {
814 pub id: Option<String>,
815 pub name: Option<String>,
816 pub input: Option<String>,
817 pub index: usize,
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_llm_model_from_known_anthropic_model() {
826 let model = LLMModel::from("claude-opus-4-5-20251101".to_string());
827 assert!(matches!(
828 model,
829 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
830 ));
831 }
832
833 #[test]
834 fn test_llm_model_from_known_openai_model() {
835 let model = LLMModel::from("gpt-5".to_string());
836 assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
837 }
838
839 #[test]
840 fn test_llm_model_from_known_gemini_model() {
841 let model = LLMModel::from("gemini-2.5-flash".to_string());
842 assert!(matches!(
843 model,
844 LLMModel::Gemini(GeminiModel::Gemini25Flash)
845 ));
846 }
847
848 #[test]
849 fn test_llm_model_from_custom_provider_with_slash() {
850 let model = LLMModel::from("litellm/claude-opus-4-5".to_string());
851 match model {
852 LLMModel::Custom {
853 provider,
854 model,
855 name,
856 } => {
857 assert_eq!(provider, "litellm");
858 assert_eq!(model, "claude-opus-4-5");
859 assert_eq!(name, Some("claude-opus-4-5".to_string()));
861 }
862 _ => panic!("Expected Custom model"),
863 }
864 }
865
866 #[test]
867 fn test_llm_model_from_ollama_provider() {
868 let model = LLMModel::from("ollama/llama3".to_string());
869 match model {
870 LLMModel::Custom {
871 provider,
872 model,
873 name,
874 } => {
875 assert_eq!(provider, "ollama");
876 assert_eq!(model, "llama3");
877 assert_eq!(name, Some("llama3".to_string()));
879 }
880 _ => panic!("Expected Custom model"),
881 }
882 }
883
884 #[test]
885 fn test_llm_model_from_nested_provider() {
886 let model = LLMModel::from("stakpak/anthropic/claude-sonnet-4-5".to_string());
888 match model {
889 LLMModel::Custom {
890 provider,
891 model,
892 name,
893 } => {
894 assert_eq!(provider, "stakpak");
895 assert_eq!(model, "anthropic/claude-sonnet-4-5");
896 assert_eq!(name, Some("claude-sonnet-4-5".to_string()));
898 }
899 _ => panic!("Expected Custom model"),
900 }
901 }
902
903 #[test]
904 fn test_llm_model_explicit_anthropic_prefix() {
905 let model = LLMModel::from("anthropic/claude-opus-4-5".to_string());
907 assert!(matches!(
908 model,
909 LLMModel::Anthropic(AnthropicModel::Claude45Opus)
910 ));
911 }
912
913 #[test]
914 fn test_llm_model_explicit_openai_prefix() {
915 let model = LLMModel::from("openai/gpt-5".to_string());
916 assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
917 }
918
919 #[test]
920 fn test_llm_model_explicit_google_prefix() {
921 let model = LLMModel::from("google/gemini-2.5-flash".to_string());
922 assert!(matches!(
923 model,
924 LLMModel::Gemini(GeminiModel::Gemini25Flash)
925 ));
926 }
927
928 #[test]
929 fn test_llm_model_explicit_gemini_prefix() {
930 let model = LLMModel::from("gemini/gemini-2.5-flash".to_string());
932 assert!(matches!(
933 model,
934 LLMModel::Gemini(GeminiModel::Gemini25Flash)
935 ));
936 }
937
938 #[test]
939 fn test_llm_model_unknown_model_becomes_custom() {
940 let model = LLMModel::from("some-random-model".to_string());
941 match model {
942 LLMModel::Custom {
943 provider,
944 model,
945 name,
946 } => {
947 assert_eq!(provider, "custom");
948 assert_eq!(model, "some-random-model");
949 assert_eq!(name, Some("some-random-model".to_string()));
951 }
952 _ => panic!("Expected Custom model"),
953 }
954 }
955
956 #[test]
957 fn test_llm_model_display_anthropic() {
958 let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet);
959 let s = model.to_string();
960 assert!(s.contains("claude"));
961 }
962
963 #[test]
964 fn test_llm_model_display_custom() {
965 let model = LLMModel::Custom {
966 provider: "litellm".to_string(),
967 model: "claude-opus".to_string(),
968 name: None,
969 };
970 assert_eq!(model.to_string(), "litellm/claude-opus");
971 }
972
973 #[test]
974 fn test_llm_model_display_custom_with_name() {
975 let model = LLMModel::Custom {
976 provider: "litellm".to_string(),
977 model: "claude-opus".to_string(),
978 name: Some("My Custom Model".to_string()),
979 };
980 assert_eq!(model.to_string(), "My Custom Model");
981 }
982
983 #[test]
984 fn test_llm_model_with_name() {
985 let model = LLMModel::from("ollama/llama3".to_string()).with_name("Local Llama");
986 assert_eq!(model.to_string(), "Local Llama");
987 assert_eq!(model.display_name(), Some("Local Llama"));
988 assert_eq!(model.model_id(), "llama3");
990 }
991
992 #[test]
993 fn test_llm_model_provider_name() {
994 assert_eq!(
995 LLMModel::Anthropic(AnthropicModel::Claude45Sonnet).provider_name(),
996 "anthropic"
997 );
998 assert_eq!(
999 LLMModel::OpenAI(OpenAIModel::GPT5).provider_name(),
1000 "openai"
1001 );
1002 assert_eq!(
1003 LLMModel::Gemini(GeminiModel::Gemini25Flash).provider_name(),
1004 "google"
1005 );
1006 assert_eq!(
1007 LLMModel::Custom {
1008 provider: "litellm".to_string(),
1009 model: "test".to_string(),
1010 name: None,
1011 }
1012 .provider_name(),
1013 "litellm"
1014 );
1015 }
1016
1017 #[test]
1018 fn test_llm_model_model_id() {
1019 let model = LLMModel::Custom {
1020 provider: "litellm".to_string(),
1021 model: "claude-opus-4-5".to_string(),
1022 name: None,
1023 };
1024 assert_eq!(model.model_id(), "claude-opus-4-5");
1025 }
1026
1027 #[test]
1032 fn test_provider_config_openai_serialization() {
1033 let config = ProviderConfig::OpenAI {
1034 api_key: Some("sk-test".to_string()),
1035 api_endpoint: None,
1036 };
1037 let json = serde_json::to_string(&config).unwrap();
1038 assert!(json.contains("\"type\":\"openai\""));
1039 assert!(json.contains("\"api_key\":\"sk-test\""));
1040 assert!(!json.contains("api_endpoint")); }
1042
1043 #[test]
1044 fn test_provider_config_openai_with_endpoint() {
1045 let config = ProviderConfig::OpenAI {
1046 api_key: Some("sk-test".to_string()),
1047 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1048 };
1049 let json = serde_json::to_string(&config).unwrap();
1050 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1051 }
1052
1053 #[test]
1054 fn test_provider_config_anthropic_serialization() {
1055 let config = ProviderConfig::Anthropic {
1056 api_key: Some("sk-ant-test".to_string()),
1057 api_endpoint: None,
1058 access_token: Some("oauth-token".to_string()),
1059 };
1060 let json = serde_json::to_string(&config).unwrap();
1061 assert!(json.contains("\"type\":\"anthropic\""));
1062 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1063 assert!(json.contains("\"access_token\":\"oauth-token\""));
1064 }
1065
1066 #[test]
1067 fn test_provider_config_gemini_serialization() {
1068 let config = ProviderConfig::Gemini {
1069 api_key: Some("gemini-key".to_string()),
1070 api_endpoint: None,
1071 };
1072 let json = serde_json::to_string(&config).unwrap();
1073 assert!(json.contains("\"type\":\"gemini\""));
1074 assert!(json.contains("\"api_key\":\"gemini-key\""));
1075 }
1076
1077 #[test]
1078 fn test_provider_config_custom_serialization() {
1079 let config = ProviderConfig::Custom {
1080 api_key: Some("sk-custom".to_string()),
1081 api_endpoint: "http://localhost:4000".to_string(),
1082 };
1083 let json = serde_json::to_string(&config).unwrap();
1084 assert!(json.contains("\"type\":\"custom\""));
1085 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1086 assert!(json.contains("\"api_key\":\"sk-custom\""));
1087 }
1088
1089 #[test]
1090 fn test_provider_config_custom_without_key() {
1091 let config = ProviderConfig::Custom {
1092 api_key: None,
1093 api_endpoint: "http://localhost:11434/v1".to_string(),
1094 };
1095 let json = serde_json::to_string(&config).unwrap();
1096 assert!(json.contains("\"type\":\"custom\""));
1097 assert!(json.contains("\"api_endpoint\""));
1098 assert!(!json.contains("api_key")); }
1100
1101 #[test]
1102 fn test_provider_config_deserialization_openai() {
1103 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1104 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1105 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1106 assert_eq!(config.api_key(), Some("sk-test"));
1107 }
1108
1109 #[test]
1110 fn test_provider_config_deserialization_anthropic() {
1111 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1112 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1113 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1114 assert_eq!(config.api_key(), Some("sk-ant"));
1115 assert_eq!(config.access_token(), Some("oauth"));
1116 }
1117
1118 #[test]
1119 fn test_provider_config_deserialization_gemini() {
1120 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1121 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1122 assert!(matches!(config, ProviderConfig::Gemini { .. }));
1123 assert_eq!(config.api_key(), Some("gemini-key"));
1124 }
1125
1126 #[test]
1127 fn test_provider_config_deserialization_custom() {
1128 let json =
1129 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1130 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1131 assert!(matches!(config, ProviderConfig::Custom { .. }));
1132 assert_eq!(config.api_key(), Some("sk-custom"));
1133 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1134 }
1135
1136 #[test]
1137 fn test_provider_config_helper_methods() {
1138 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1139 assert_eq!(openai.provider_type(), "openai");
1140 assert_eq!(openai.api_key(), Some("sk-openai"));
1141
1142 let anthropic =
1143 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1144 assert_eq!(anthropic.provider_type(), "anthropic");
1145 assert_eq!(anthropic.access_token(), Some("oauth"));
1146
1147 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1148 assert_eq!(gemini.provider_type(), "gemini");
1149
1150 let custom = ProviderConfig::custom(
1151 "http://localhost:4000".to_string(),
1152 Some("sk-custom".to_string()),
1153 );
1154 assert_eq!(custom.provider_type(), "custom");
1155 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1156 }
1157
1158 #[test]
1159 fn test_llm_provider_config_new() {
1160 let config = LLMProviderConfig::new();
1161 assert!(config.is_empty());
1162 }
1163
1164 #[test]
1165 fn test_llm_provider_config_add_and_get() {
1166 let mut config = LLMProviderConfig::new();
1167 config.add_provider(
1168 "openai",
1169 ProviderConfig::openai(Some("sk-test".to_string())),
1170 );
1171 config.add_provider(
1172 "anthropic",
1173 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1174 );
1175
1176 assert!(!config.is_empty());
1177 assert!(config.get_provider("openai").is_some());
1178 assert!(config.get_provider("anthropic").is_some());
1179 assert!(config.get_provider("unknown").is_none());
1180 }
1181
1182 #[test]
1183 fn test_provider_config_toml_parsing() {
1184 let json = r#"{
1186 "openai": {"type": "openai", "api_key": "sk-openai"},
1187 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1188 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1189 }"#;
1190
1191 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1192 assert_eq!(providers.len(), 3);
1193
1194 assert!(matches!(
1195 providers.get("openai"),
1196 Some(ProviderConfig::OpenAI { .. })
1197 ));
1198 assert!(matches!(
1199 providers.get("anthropic"),
1200 Some(ProviderConfig::Anthropic { .. })
1201 ));
1202 assert!(matches!(
1203 providers.get("litellm"),
1204 Some(ProviderConfig::Custom { .. })
1205 ));
1206 }
1207}