1use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
90#[serde(tag = "type", rename_all = "lowercase")]
91pub enum ProviderConfig {
92 OpenAI {
94 #[serde(skip_serializing_if = "Option::is_none")]
95 api_key: Option<String>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 api_endpoint: Option<String>,
98 },
99 Anthropic {
101 #[serde(skip_serializing_if = "Option::is_none")]
102 api_key: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 api_endpoint: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 access_token: Option<String>,
108 },
109 Gemini {
111 #[serde(skip_serializing_if = "Option::is_none")]
112 api_key: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 api_endpoint: Option<String>,
115 },
116 Custom {
134 #[serde(skip_serializing_if = "Option::is_none")]
135 api_key: Option<String>,
136 api_endpoint: String,
139 },
140 Stakpak {
158 api_key: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 api_endpoint: Option<String>,
163 },
164 #[serde(rename = "amazon-bedrock")]
180 Bedrock {
181 region: String,
183 #[serde(skip_serializing_if = "Option::is_none")]
185 profile_name: Option<String>,
186 },
187}
188
189impl ProviderConfig {
190 pub fn provider_type(&self) -> &'static str {
192 match self {
193 ProviderConfig::OpenAI { .. } => "openai",
194 ProviderConfig::Anthropic { .. } => "anthropic",
195 ProviderConfig::Gemini { .. } => "gemini",
196 ProviderConfig::Custom { .. } => "custom",
197 ProviderConfig::Stakpak { .. } => "stakpak",
198 ProviderConfig::Bedrock { .. } => "amazon-bedrock",
199 }
200 }
201
202 pub fn api_key(&self) -> Option<&str> {
204 match self {
205 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
206 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
207 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
208 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
209 ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
210 ProviderConfig::Bedrock { .. } => None, }
212 }
213
214 pub fn api_endpoint(&self) -> Option<&str> {
216 match self {
217 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
218 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
219 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
220 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
221 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
222 ProviderConfig::Bedrock { .. } => None, }
224 }
225
226 pub fn access_token(&self) -> Option<&str> {
228 match self {
229 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
230 _ => None,
231 }
232 }
233
234 pub fn openai(api_key: Option<String>) -> Self {
236 ProviderConfig::OpenAI {
237 api_key,
238 api_endpoint: None,
239 }
240 }
241
242 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
244 ProviderConfig::Anthropic {
245 api_key,
246 api_endpoint: None,
247 access_token,
248 }
249 }
250
251 pub fn gemini(api_key: Option<String>) -> Self {
253 ProviderConfig::Gemini {
254 api_key,
255 api_endpoint: None,
256 }
257 }
258
259 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
261 ProviderConfig::Custom {
262 api_key,
263 api_endpoint,
264 }
265 }
266
267 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
269 ProviderConfig::Stakpak {
270 api_key,
271 api_endpoint,
272 }
273 }
274
275 pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
277 ProviderConfig::Bedrock {
278 region,
279 profile_name,
280 }
281 }
282
283 pub fn region(&self) -> Option<&str> {
285 match self {
286 ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
287 _ => None,
288 }
289 }
290
291 pub fn profile_name(&self) -> Option<&str> {
293 match self {
294 ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
295 _ => None,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Default)]
304pub struct LLMProviderConfig {
305 pub providers: HashMap<String, ProviderConfig>,
307}
308
309impl LLMProviderConfig {
310 pub fn new() -> Self {
312 Self {
313 providers: HashMap::new(),
314 }
315 }
316
317 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
319 self.providers.insert(name.into(), config);
320 }
321
322 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
324 self.providers.get(name)
325 }
326
327 pub fn is_empty(&self) -> bool {
329 self.providers.is_empty()
330 }
331}
332
333#[derive(Clone, Debug, Serialize, Deserialize, Default)]
335pub struct LLMProviderOptions {
336 #[serde(skip_serializing_if = "Option::is_none")]
338 pub anthropic: Option<LLMAnthropicOptions>,
339
340 #[serde(skip_serializing_if = "Option::is_none")]
342 pub openai: Option<LLMOpenAIOptions>,
343
344 #[serde(skip_serializing_if = "Option::is_none")]
346 pub google: Option<LLMGoogleOptions>,
347}
348
349#[derive(Clone, Debug, Serialize, Deserialize, Default)]
351pub struct LLMAnthropicOptions {
352 #[serde(skip_serializing_if = "Option::is_none")]
354 pub thinking: Option<LLMThinkingOptions>,
355}
356
357#[derive(Clone, Debug, Serialize, Deserialize)]
359pub struct LLMThinkingOptions {
360 pub budget_tokens: u32,
362}
363
364impl LLMThinkingOptions {
365 pub fn new(budget_tokens: u32) -> Self {
366 Self {
367 budget_tokens: budget_tokens.max(1024),
368 }
369 }
370}
371
372#[derive(Clone, Debug, Serialize, Deserialize, Default)]
374pub struct LLMOpenAIOptions {
375 #[serde(skip_serializing_if = "Option::is_none")]
377 pub reasoning_effort: Option<String>,
378}
379
380#[derive(Clone, Debug, Serialize, Deserialize, Default)]
382pub struct LLMGoogleOptions {
383 #[serde(skip_serializing_if = "Option::is_none")]
385 pub thinking_budget: Option<u32>,
386}
387
388#[derive(Clone, Debug, Serialize)]
389pub struct LLMInput {
390 pub model: Model,
391 pub messages: Vec<LLMMessage>,
392 pub max_tokens: u32,
393 pub tools: Option<Vec<LLMTool>>,
394 #[serde(skip_serializing_if = "Option::is_none")]
395 pub provider_options: Option<LLMProviderOptions>,
396 #[serde(skip_serializing_if = "Option::is_none")]
398 pub headers: Option<std::collections::HashMap<String, String>>,
399}
400
401#[derive(Debug)]
402pub struct LLMStreamInput {
403 pub model: Model,
404 pub messages: Vec<LLMMessage>,
405 pub max_tokens: u32,
406 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
407 pub tools: Option<Vec<LLMTool>>,
408 pub provider_options: Option<LLMProviderOptions>,
409 pub headers: Option<std::collections::HashMap<String, String>>,
411}
412
413impl From<&LLMStreamInput> for LLMInput {
414 fn from(value: &LLMStreamInput) -> Self {
415 LLMInput {
416 model: value.model.clone(),
417 messages: value.messages.clone(),
418 max_tokens: value.max_tokens,
419 tools: value.tools.clone(),
420 provider_options: value.provider_options.clone(),
421 headers: value.headers.clone(),
422 }
423 }
424}
425
426#[derive(Serialize, Deserialize, Debug, Clone, Default)]
427pub struct LLMMessage {
428 pub role: String,
429 pub content: LLMMessageContent,
430}
431
432#[derive(Serialize, Deserialize, Debug, Clone)]
433pub struct SimpleLLMMessage {
434 #[serde(rename = "role")]
435 pub role: SimpleLLMRole,
436 pub content: String,
437}
438
439#[derive(Serialize, Deserialize, Debug, Clone)]
440#[serde(rename_all = "lowercase")]
441pub enum SimpleLLMRole {
442 User,
443 Assistant,
444}
445
446impl std::fmt::Display for SimpleLLMRole {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 match self {
449 SimpleLLMRole::User => write!(f, "user"),
450 SimpleLLMRole::Assistant => write!(f, "assistant"),
451 }
452 }
453}
454
455#[derive(Serialize, Deserialize, Debug, Clone)]
456#[serde(untagged)]
457pub enum LLMMessageContent {
458 String(String),
459 List(Vec<LLMMessageTypedContent>),
460}
461
462#[allow(clippy::to_string_trait_impl)]
463impl ToString for LLMMessageContent {
464 fn to_string(&self) -> String {
465 match self {
466 LLMMessageContent::String(s) => s.clone(),
467 LLMMessageContent::List(l) => l
468 .iter()
469 .map(|c| match c {
470 LLMMessageTypedContent::Text { text } => text.clone(),
471 LLMMessageTypedContent::ToolCall { .. } => String::new(),
472 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
473 LLMMessageTypedContent::Image { .. } => String::new(),
474 })
475 .collect::<Vec<_>>()
476 .join("\n"),
477 }
478 }
479}
480
481impl From<String> for LLMMessageContent {
482 fn from(value: String) -> Self {
483 LLMMessageContent::String(value)
484 }
485}
486
487impl Default for LLMMessageContent {
488 fn default() -> Self {
489 LLMMessageContent::String(String::new())
490 }
491}
492
493impl LLMMessageContent {
494 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
497 match self {
498 LLMMessageContent::List(parts) => parts,
499 LLMMessageContent::String(s) if s.is_empty() => vec![],
500 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
501 }
502 }
503}
504
505#[derive(Serialize, Deserialize, Debug, Clone)]
506#[serde(tag = "type")]
507pub enum LLMMessageTypedContent {
508 #[serde(rename = "text")]
509 Text { text: String },
510 #[serde(rename = "tool_use")]
511 ToolCall {
512 id: String,
513 name: String,
514 #[serde(alias = "input")]
515 args: serde_json::Value,
516 #[serde(skip_serializing_if = "Option::is_none")]
518 metadata: Option<serde_json::Value>,
519 },
520 #[serde(rename = "tool_result")]
521 ToolResult {
522 tool_use_id: String,
523 content: String,
524 },
525 #[serde(rename = "image")]
526 Image { source: LLMMessageImageSource },
527}
528
529#[derive(Serialize, Deserialize, Debug, Clone)]
530pub struct LLMMessageImageSource {
531 #[serde(rename = "type")]
532 pub r#type: String,
533 pub media_type: String,
534 pub data: String,
535}
536
537impl Default for LLMMessageTypedContent {
538 fn default() -> Self {
539 LLMMessageTypedContent::Text {
540 text: String::new(),
541 }
542 }
543}
544
545#[derive(Serialize, Deserialize, Debug, Clone)]
546pub struct LLMChoice {
547 pub finish_reason: Option<String>,
548 pub index: u32,
549 pub message: LLMMessage,
550}
551
552#[derive(Serialize, Deserialize, Debug, Clone)]
553pub struct LLMCompletionResponse {
554 pub model: String,
555 pub object: String,
556 pub choices: Vec<LLMChoice>,
557 pub created: u64,
558 pub usage: Option<LLMTokenUsage>,
559 pub id: String,
560}
561
562#[derive(Serialize, Deserialize, Debug, Clone)]
563pub struct LLMStreamDelta {
564 #[serde(skip_serializing_if = "Option::is_none")]
565 pub content: Option<String>,
566}
567
568#[derive(Serialize, Deserialize, Debug, Clone)]
569pub struct LLMStreamChoice {
570 pub finish_reason: Option<String>,
571 pub index: u32,
572 pub message: Option<LLMMessage>,
573 pub delta: LLMStreamDelta,
574}
575
576#[derive(Serialize, Deserialize, Debug, Clone)]
577pub struct LLMCompletionStreamResponse {
578 pub model: String,
579 pub object: String,
580 pub choices: Vec<LLMStreamChoice>,
581 pub created: u64,
582 #[serde(skip_serializing_if = "Option::is_none")]
583 pub usage: Option<LLMTokenUsage>,
584 pub id: String,
585 pub citations: Option<Vec<String>>,
586}
587
588#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
589pub struct LLMTool {
590 pub name: String,
591 pub description: String,
592 pub input_schema: serde_json::Value,
593}
594
595#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
596pub struct LLMTokenUsage {
597 pub prompt_tokens: u32,
598 pub completion_tokens: u32,
599 pub total_tokens: u32,
600
601 #[serde(skip_serializing_if = "Option::is_none")]
602 pub prompt_tokens_details: Option<PromptTokensDetails>,
603}
604
605#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
606#[serde(rename_all = "snake_case")]
607pub enum TokenType {
608 InputTokens,
609 OutputTokens,
610 CacheReadInputTokens,
611 CacheWriteInputTokens,
612}
613
614#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
615pub struct PromptTokensDetails {
616 #[serde(skip_serializing_if = "Option::is_none")]
617 pub input_tokens: Option<u32>,
618 #[serde(skip_serializing_if = "Option::is_none")]
619 pub output_tokens: Option<u32>,
620 #[serde(skip_serializing_if = "Option::is_none")]
621 pub cache_read_input_tokens: Option<u32>,
622 #[serde(skip_serializing_if = "Option::is_none")]
623 pub cache_write_input_tokens: Option<u32>,
624}
625
626impl PromptTokensDetails {
627 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
629 [
630 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
631 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
632 (
633 TokenType::CacheReadInputTokens,
634 self.cache_read_input_tokens.unwrap_or(0),
635 ),
636 (
637 TokenType::CacheWriteInputTokens,
638 self.cache_write_input_tokens.unwrap_or(0),
639 ),
640 ]
641 .into_iter()
642 }
643}
644
645impl std::ops::Add for PromptTokensDetails {
646 type Output = Self;
647
648 fn add(self, rhs: Self) -> Self::Output {
649 Self {
650 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
651 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
652 cache_read_input_tokens: Some(
653 self.cache_read_input_tokens.unwrap_or(0)
654 + rhs.cache_read_input_tokens.unwrap_or(0),
655 ),
656 cache_write_input_tokens: Some(
657 self.cache_write_input_tokens.unwrap_or(0)
658 + rhs.cache_write_input_tokens.unwrap_or(0),
659 ),
660 }
661 }
662}
663
664impl std::ops::AddAssign for PromptTokensDetails {
665 fn add_assign(&mut self, rhs: Self) {
666 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
667 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
668 self.cache_read_input_tokens = Some(
669 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
670 );
671 self.cache_write_input_tokens = Some(
672 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
673 );
674 }
675}
676
677#[derive(Serialize, Deserialize, Debug, Clone)]
678#[serde(tag = "type")]
679pub enum GenerationDelta {
680 Content { content: String },
681 Thinking { thinking: String },
682 ToolUse { tool_use: GenerationDeltaToolUse },
683 Usage { usage: LLMTokenUsage },
684 Metadata { metadata: serde_json::Value },
685}
686
687#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct GenerationDeltaToolUse {
689 pub id: Option<String>,
690 pub name: Option<String>,
691 pub input: Option<String>,
692 pub index: usize,
693 #[serde(skip_serializing_if = "Option::is_none")]
695 pub metadata: Option<serde_json::Value>,
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
707 fn test_provider_config_openai_serialization() {
708 let config = ProviderConfig::OpenAI {
709 api_key: Some("sk-test".to_string()),
710 api_endpoint: None,
711 };
712 let json = serde_json::to_string(&config).unwrap();
713 assert!(json.contains("\"type\":\"openai\""));
714 assert!(json.contains("\"api_key\":\"sk-test\""));
715 assert!(!json.contains("api_endpoint")); }
717
718 #[test]
719 fn test_provider_config_openai_with_endpoint() {
720 let config = ProviderConfig::OpenAI {
721 api_key: Some("sk-test".to_string()),
722 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
723 };
724 let json = serde_json::to_string(&config).unwrap();
725 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
726 }
727
728 #[test]
729 fn test_provider_config_anthropic_serialization() {
730 let config = ProviderConfig::Anthropic {
731 api_key: Some("sk-ant-test".to_string()),
732 api_endpoint: None,
733 access_token: Some("oauth-token".to_string()),
734 };
735 let json = serde_json::to_string(&config).unwrap();
736 assert!(json.contains("\"type\":\"anthropic\""));
737 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
738 assert!(json.contains("\"access_token\":\"oauth-token\""));
739 }
740
741 #[test]
742 fn test_provider_config_gemini_serialization() {
743 let config = ProviderConfig::Gemini {
744 api_key: Some("gemini-key".to_string()),
745 api_endpoint: None,
746 };
747 let json = serde_json::to_string(&config).unwrap();
748 assert!(json.contains("\"type\":\"gemini\""));
749 assert!(json.contains("\"api_key\":\"gemini-key\""));
750 }
751
752 #[test]
753 fn test_provider_config_custom_serialization() {
754 let config = ProviderConfig::Custom {
755 api_key: Some("sk-custom".to_string()),
756 api_endpoint: "http://localhost:4000".to_string(),
757 };
758 let json = serde_json::to_string(&config).unwrap();
759 assert!(json.contains("\"type\":\"custom\""));
760 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
761 assert!(json.contains("\"api_key\":\"sk-custom\""));
762 }
763
764 #[test]
765 fn test_provider_config_custom_without_key() {
766 let config = ProviderConfig::Custom {
767 api_key: None,
768 api_endpoint: "http://localhost:11434/v1".to_string(),
769 };
770 let json = serde_json::to_string(&config).unwrap();
771 assert!(json.contains("\"type\":\"custom\""));
772 assert!(json.contains("\"api_endpoint\""));
773 assert!(!json.contains("api_key")); }
775
776 #[test]
777 fn test_provider_config_deserialization_openai() {
778 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
779 let config: ProviderConfig = serde_json::from_str(json).unwrap();
780 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
781 assert_eq!(config.api_key(), Some("sk-test"));
782 }
783
784 #[test]
785 fn test_provider_config_deserialization_anthropic() {
786 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
787 let config: ProviderConfig = serde_json::from_str(json).unwrap();
788 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
789 assert_eq!(config.api_key(), Some("sk-ant"));
790 assert_eq!(config.access_token(), Some("oauth"));
791 }
792
793 #[test]
794 fn test_provider_config_deserialization_gemini() {
795 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
796 let config: ProviderConfig = serde_json::from_str(json).unwrap();
797 assert!(matches!(config, ProviderConfig::Gemini { .. }));
798 assert_eq!(config.api_key(), Some("gemini-key"));
799 }
800
801 #[test]
802 fn test_provider_config_deserialization_custom() {
803 let json =
804 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
805 let config: ProviderConfig = serde_json::from_str(json).unwrap();
806 assert!(matches!(config, ProviderConfig::Custom { .. }));
807 assert_eq!(config.api_key(), Some("sk-custom"));
808 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
809 }
810
811 #[test]
812 fn test_provider_config_helper_methods() {
813 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
814 assert_eq!(openai.provider_type(), "openai");
815 assert_eq!(openai.api_key(), Some("sk-openai"));
816
817 let anthropic =
818 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
819 assert_eq!(anthropic.provider_type(), "anthropic");
820 assert_eq!(anthropic.access_token(), Some("oauth"));
821
822 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
823 assert_eq!(gemini.provider_type(), "gemini");
824
825 let custom = ProviderConfig::custom(
826 "http://localhost:4000".to_string(),
827 Some("sk-custom".to_string()),
828 );
829 assert_eq!(custom.provider_type(), "custom");
830 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
831 }
832
833 #[test]
834 fn test_llm_provider_config_new() {
835 let config = LLMProviderConfig::new();
836 assert!(config.is_empty());
837 }
838
839 #[test]
840 fn test_llm_provider_config_add_and_get() {
841 let mut config = LLMProviderConfig::new();
842 config.add_provider(
843 "openai",
844 ProviderConfig::openai(Some("sk-test".to_string())),
845 );
846 config.add_provider(
847 "anthropic",
848 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
849 );
850
851 assert!(!config.is_empty());
852 assert!(config.get_provider("openai").is_some());
853 assert!(config.get_provider("anthropic").is_some());
854 assert!(config.get_provider("unknown").is_none());
855 }
856
857 #[test]
858 fn test_provider_config_toml_parsing() {
859 let json = r#"{
861 "openai": {"type": "openai", "api_key": "sk-openai"},
862 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
863 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
864 }"#;
865
866 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
867 assert_eq!(providers.len(), 3);
868
869 assert!(matches!(
870 providers.get("openai"),
871 Some(ProviderConfig::OpenAI { .. })
872 ));
873 assert!(matches!(
874 providers.get("anthropic"),
875 Some(ProviderConfig::Anthropic { .. })
876 ));
877 assert!(matches!(
878 providers.get("litellm"),
879 Some(ProviderConfig::Custom { .. })
880 ));
881 }
882
883 #[test]
888 fn test_provider_config_bedrock_serialization() {
889 let config = ProviderConfig::Bedrock {
890 region: "us-east-1".to_string(),
891 profile_name: Some("my-profile".to_string()),
892 };
893 let json = serde_json::to_string(&config).unwrap();
894 assert!(json.contains("\"type\":\"amazon-bedrock\""));
895 assert!(json.contains("\"region\":\"us-east-1\""));
896 assert!(json.contains("\"profile_name\":\"my-profile\""));
897 }
898
899 #[test]
900 fn test_provider_config_bedrock_serialization_without_profile() {
901 let config = ProviderConfig::Bedrock {
902 region: "us-west-2".to_string(),
903 profile_name: None,
904 };
905 let json = serde_json::to_string(&config).unwrap();
906 assert!(json.contains("\"type\":\"amazon-bedrock\""));
907 assert!(json.contains("\"region\":\"us-west-2\""));
908 assert!(!json.contains("profile_name")); }
910
911 #[test]
912 fn test_provider_config_bedrock_deserialization() {
913 let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
914 let config: ProviderConfig = serde_json::from_str(json).unwrap();
915 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
916 assert_eq!(config.region(), Some("us-east-1"));
917 assert_eq!(config.profile_name(), Some("prod"));
918 }
919
920 #[test]
921 fn test_provider_config_bedrock_deserialization_minimal() {
922 let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
923 let config: ProviderConfig = serde_json::from_str(json).unwrap();
924 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
925 assert_eq!(config.region(), Some("eu-west-1"));
926 assert_eq!(config.profile_name(), None);
927 }
928
929 #[test]
930 fn test_provider_config_bedrock_no_api_key() {
931 let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
932 assert_eq!(config.api_key(), None); assert_eq!(config.api_endpoint(), None); }
935
936 #[test]
937 fn test_provider_config_bedrock_helper_methods() {
938 let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
939 assert_eq!(bedrock.provider_type(), "amazon-bedrock");
940 assert_eq!(bedrock.region(), Some("us-east-1"));
941 assert_eq!(bedrock.profile_name(), Some("prod"));
942 assert_eq!(bedrock.api_key(), None);
943 assert_eq!(bedrock.api_endpoint(), None);
944 assert_eq!(bedrock.access_token(), None);
945 }
946
947 #[test]
948 fn test_provider_config_bedrock_toml_roundtrip() {
949 let config = ProviderConfig::Bedrock {
950 region: "us-east-1".to_string(),
951 profile_name: Some("my-profile".to_string()),
952 };
953 let toml_str = toml::to_string(&config).unwrap();
954 let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
955 assert_eq!(config, parsed);
956 }
957
958 #[test]
959 fn test_provider_config_bedrock_toml_parsing() {
960 let toml_str = r#"
961 type = "amazon-bedrock"
962 region = "us-east-1"
963 profile_name = "production"
964 "#;
965 let config: ProviderConfig = toml::from_str(toml_str).unwrap();
966 assert!(matches!(
967 config,
968 ProviderConfig::Bedrock {
969 ref region,
970 ref profile_name,
971 } if region == "us-east-1" && profile_name.as_deref() == Some("production")
972 ));
973 }
974
975 #[test]
976 fn test_provider_config_bedrock_missing_region_fails() {
977 let json = r#"{"type":"amazon-bedrock"}"#;
978 let result: Result<ProviderConfig, _> = serde_json::from_str(json);
979 assert!(result.is_err()); }
981
982 #[test]
983 fn test_provider_config_bedrock_in_providers_map() {
984 let json = r#"{
985 "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
986 "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
987 }"#;
988 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
989 assert_eq!(providers.len(), 2);
990 assert!(matches!(
991 providers.get("amazon-bedrock"),
992 Some(ProviderConfig::Bedrock { .. })
993 ));
994 }
995
996 #[test]
997 fn test_region_returns_none_for_non_bedrock() {
998 let openai = ProviderConfig::openai(Some("key".to_string()));
999 assert_eq!(openai.region(), None);
1000
1001 let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1002 assert_eq!(anthropic.region(), None);
1003 }
1004
1005 #[test]
1006 fn test_profile_name_returns_none_for_non_bedrock() {
1007 let openai = ProviderConfig::openai(Some("key".to_string()));
1008 assert_eq!(openai.profile_name(), None);
1009 }
1010}