1use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
90#[serde(tag = "type", rename_all = "lowercase")]
91pub enum ProviderConfig {
92 OpenAI {
94 #[serde(skip_serializing_if = "Option::is_none")]
95 api_key: Option<String>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 api_endpoint: Option<String>,
98 },
99 Anthropic {
101 #[serde(skip_serializing_if = "Option::is_none")]
102 api_key: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 api_endpoint: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 access_token: Option<String>,
108 },
109 Gemini {
111 #[serde(skip_serializing_if = "Option::is_none")]
112 api_key: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 api_endpoint: Option<String>,
115 },
116 Custom {
134 #[serde(skip_serializing_if = "Option::is_none")]
135 api_key: Option<String>,
136 api_endpoint: String,
139 },
140 Stakpak {
158 api_key: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 api_endpoint: Option<String>,
163 },
164 #[serde(rename = "amazon-bedrock")]
180 Bedrock {
181 region: String,
183 #[serde(skip_serializing_if = "Option::is_none")]
185 profile_name: Option<String>,
186 },
187}
188
189impl ProviderConfig {
190 pub fn provider_type(&self) -> &'static str {
192 match self {
193 ProviderConfig::OpenAI { .. } => "openai",
194 ProviderConfig::Anthropic { .. } => "anthropic",
195 ProviderConfig::Gemini { .. } => "gemini",
196 ProviderConfig::Custom { .. } => "custom",
197 ProviderConfig::Stakpak { .. } => "stakpak",
198 ProviderConfig::Bedrock { .. } => "amazon-bedrock",
199 }
200 }
201
202 pub fn api_key(&self) -> Option<&str> {
204 match self {
205 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
206 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
207 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
208 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
209 ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
210 ProviderConfig::Bedrock { .. } => None, }
212 }
213
214 pub fn api_endpoint(&self) -> Option<&str> {
216 match self {
217 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
218 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
219 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
220 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
221 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
222 ProviderConfig::Bedrock { .. } => None, }
224 }
225
226 pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
231 match self {
232 ProviderConfig::OpenAI { api_endpoint, .. }
233 | ProviderConfig::Anthropic { api_endpoint, .. }
234 | ProviderConfig::Gemini { api_endpoint, .. }
235 | ProviderConfig::Stakpak { api_endpoint, .. } => {
236 *api_endpoint = endpoint;
237 }
238 ProviderConfig::Custom { api_endpoint, .. } => {
239 if let Some(custom_endpoint) = endpoint {
240 *api_endpoint = custom_endpoint;
241 }
242 }
243 ProviderConfig::Bedrock { .. } => {}
244 }
245 }
246
247 pub fn access_token(&self) -> Option<&str> {
249 match self {
250 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
251 _ => None,
252 }
253 }
254
255 pub fn openai(api_key: Option<String>) -> Self {
257 ProviderConfig::OpenAI {
258 api_key,
259 api_endpoint: None,
260 }
261 }
262
263 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
265 ProviderConfig::Anthropic {
266 api_key,
267 api_endpoint: None,
268 access_token,
269 }
270 }
271
272 pub fn gemini(api_key: Option<String>) -> Self {
274 ProviderConfig::Gemini {
275 api_key,
276 api_endpoint: None,
277 }
278 }
279
280 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
282 ProviderConfig::Custom {
283 api_key,
284 api_endpoint,
285 }
286 }
287
288 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
290 ProviderConfig::Stakpak {
291 api_key,
292 api_endpoint,
293 }
294 }
295
296 pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
298 ProviderConfig::Bedrock {
299 region,
300 profile_name,
301 }
302 }
303
304 pub fn region(&self) -> Option<&str> {
306 match self {
307 ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
308 _ => None,
309 }
310 }
311
312 pub fn profile_name(&self) -> Option<&str> {
314 match self {
315 ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
316 _ => None,
317 }
318 }
319}
320
321#[derive(Debug, Clone, Default)]
325pub struct LLMProviderConfig {
326 pub providers: HashMap<String, ProviderConfig>,
328}
329
330impl LLMProviderConfig {
331 pub fn new() -> Self {
333 Self {
334 providers: HashMap::new(),
335 }
336 }
337
338 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
340 self.providers.insert(name.into(), config);
341 }
342
343 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
345 self.providers.get(name)
346 }
347
348 pub fn is_empty(&self) -> bool {
350 self.providers.is_empty()
351 }
352}
353
354#[derive(Clone, Debug, Serialize, Deserialize, Default)]
356pub struct LLMProviderOptions {
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub anthropic: Option<LLMAnthropicOptions>,
360
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub openai: Option<LLMOpenAIOptions>,
364
365 #[serde(skip_serializing_if = "Option::is_none")]
367 pub google: Option<LLMGoogleOptions>,
368}
369
370#[derive(Clone, Debug, Serialize, Deserialize, Default)]
372pub struct LLMAnthropicOptions {
373 #[serde(skip_serializing_if = "Option::is_none")]
375 pub thinking: Option<LLMThinkingOptions>,
376}
377
378#[derive(Clone, Debug, Serialize, Deserialize)]
380pub struct LLMThinkingOptions {
381 pub budget_tokens: u32,
383}
384
385impl LLMThinkingOptions {
386 pub fn new(budget_tokens: u32) -> Self {
387 Self {
388 budget_tokens: budget_tokens.max(1024),
389 }
390 }
391}
392
393#[derive(Clone, Debug, Serialize, Deserialize, Default)]
395pub struct LLMOpenAIOptions {
396 #[serde(skip_serializing_if = "Option::is_none")]
398 pub reasoning_effort: Option<String>,
399}
400
401#[derive(Clone, Debug, Serialize, Deserialize, Default)]
403pub struct LLMGoogleOptions {
404 #[serde(skip_serializing_if = "Option::is_none")]
406 pub thinking_budget: Option<u32>,
407}
408
409#[derive(Clone, Debug, Serialize)]
410pub struct LLMInput {
411 pub model: Model,
412 pub messages: Vec<LLMMessage>,
413 pub max_tokens: u32,
414 pub tools: Option<Vec<LLMTool>>,
415 #[serde(skip_serializing_if = "Option::is_none")]
416 pub provider_options: Option<LLMProviderOptions>,
417 #[serde(skip_serializing_if = "Option::is_none")]
419 pub headers: Option<std::collections::HashMap<String, String>>,
420}
421
422#[derive(Debug)]
423pub struct LLMStreamInput {
424 pub model: Model,
425 pub messages: Vec<LLMMessage>,
426 pub max_tokens: u32,
427 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
428 pub tools: Option<Vec<LLMTool>>,
429 pub provider_options: Option<LLMProviderOptions>,
430 pub headers: Option<std::collections::HashMap<String, String>>,
432}
433
434impl From<&LLMStreamInput> for LLMInput {
435 fn from(value: &LLMStreamInput) -> Self {
436 LLMInput {
437 model: value.model.clone(),
438 messages: value.messages.clone(),
439 max_tokens: value.max_tokens,
440 tools: value.tools.clone(),
441 provider_options: value.provider_options.clone(),
442 headers: value.headers.clone(),
443 }
444 }
445}
446
447#[derive(Serialize, Deserialize, Debug, Clone, Default)]
448pub struct LLMMessage {
449 pub role: String,
450 pub content: LLMMessageContent,
451}
452
453#[derive(Serialize, Deserialize, Debug, Clone)]
454pub struct SimpleLLMMessage {
455 #[serde(rename = "role")]
456 pub role: SimpleLLMRole,
457 pub content: String,
458}
459
460#[derive(Serialize, Deserialize, Debug, Clone)]
461#[serde(rename_all = "lowercase")]
462pub enum SimpleLLMRole {
463 User,
464 Assistant,
465}
466
467impl std::fmt::Display for SimpleLLMRole {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 match self {
470 SimpleLLMRole::User => write!(f, "user"),
471 SimpleLLMRole::Assistant => write!(f, "assistant"),
472 }
473 }
474}
475
476#[derive(Serialize, Deserialize, Debug, Clone)]
477#[serde(untagged)]
478pub enum LLMMessageContent {
479 String(String),
480 List(Vec<LLMMessageTypedContent>),
481}
482
483#[allow(clippy::to_string_trait_impl)]
484impl ToString for LLMMessageContent {
485 fn to_string(&self) -> String {
486 match self {
487 LLMMessageContent::String(s) => s.clone(),
488 LLMMessageContent::List(l) => l
489 .iter()
490 .map(|c| match c {
491 LLMMessageTypedContent::Text { text } => text.clone(),
492 LLMMessageTypedContent::ToolCall { .. } => String::new(),
493 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
494 LLMMessageTypedContent::Image { .. } => String::new(),
495 })
496 .collect::<Vec<_>>()
497 .join("\n"),
498 }
499 }
500}
501
502impl From<String> for LLMMessageContent {
503 fn from(value: String) -> Self {
504 LLMMessageContent::String(value)
505 }
506}
507
508impl Default for LLMMessageContent {
509 fn default() -> Self {
510 LLMMessageContent::String(String::new())
511 }
512}
513
514impl LLMMessageContent {
515 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
518 match self {
519 LLMMessageContent::List(parts) => parts,
520 LLMMessageContent::String(s) if s.is_empty() => vec![],
521 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
522 }
523 }
524}
525
526#[derive(Serialize, Deserialize, Debug, Clone)]
527#[serde(tag = "type")]
528pub enum LLMMessageTypedContent {
529 #[serde(rename = "text")]
530 Text { text: String },
531 #[serde(rename = "tool_use")]
532 ToolCall {
533 id: String,
534 name: String,
535 #[serde(alias = "input")]
536 args: serde_json::Value,
537 #[serde(skip_serializing_if = "Option::is_none")]
539 metadata: Option<serde_json::Value>,
540 },
541 #[serde(rename = "tool_result")]
542 ToolResult {
543 tool_use_id: String,
544 content: String,
545 },
546 #[serde(rename = "image")]
547 Image { source: LLMMessageImageSource },
548}
549
550#[derive(Serialize, Deserialize, Debug, Clone)]
551pub struct LLMMessageImageSource {
552 #[serde(rename = "type")]
553 pub r#type: String,
554 pub media_type: String,
555 pub data: String,
556}
557
558impl Default for LLMMessageTypedContent {
559 fn default() -> Self {
560 LLMMessageTypedContent::Text {
561 text: String::new(),
562 }
563 }
564}
565
566#[derive(Serialize, Deserialize, Debug, Clone)]
567pub struct LLMChoice {
568 pub finish_reason: Option<String>,
569 pub index: u32,
570 pub message: LLMMessage,
571}
572
573#[derive(Serialize, Deserialize, Debug, Clone)]
574pub struct LLMCompletionResponse {
575 pub model: String,
576 pub object: String,
577 pub choices: Vec<LLMChoice>,
578 pub created: u64,
579 pub usage: Option<LLMTokenUsage>,
580 pub id: String,
581}
582
583#[derive(Serialize, Deserialize, Debug, Clone)]
584pub struct LLMStreamDelta {
585 #[serde(skip_serializing_if = "Option::is_none")]
586 pub content: Option<String>,
587}
588
589#[derive(Serialize, Deserialize, Debug, Clone)]
590pub struct LLMStreamChoice {
591 pub finish_reason: Option<String>,
592 pub index: u32,
593 pub message: Option<LLMMessage>,
594 pub delta: LLMStreamDelta,
595}
596
597#[derive(Serialize, Deserialize, Debug, Clone)]
598pub struct LLMCompletionStreamResponse {
599 pub model: String,
600 pub object: String,
601 pub choices: Vec<LLMStreamChoice>,
602 pub created: u64,
603 #[serde(skip_serializing_if = "Option::is_none")]
604 pub usage: Option<LLMTokenUsage>,
605 pub id: String,
606 pub citations: Option<Vec<String>>,
607}
608
609#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
610pub struct LLMTool {
611 pub name: String,
612 pub description: String,
613 pub input_schema: serde_json::Value,
614}
615
616#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
617pub struct LLMTokenUsage {
618 pub prompt_tokens: u32,
619 pub completion_tokens: u32,
620 pub total_tokens: u32,
621
622 #[serde(skip_serializing_if = "Option::is_none")]
623 pub prompt_tokens_details: Option<PromptTokensDetails>,
624}
625
626#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
627#[serde(rename_all = "snake_case")]
628pub enum TokenType {
629 InputTokens,
630 OutputTokens,
631 CacheReadInputTokens,
632 CacheWriteInputTokens,
633}
634
635#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
636pub struct PromptTokensDetails {
637 #[serde(skip_serializing_if = "Option::is_none")]
638 pub input_tokens: Option<u32>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub output_tokens: Option<u32>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub cache_read_input_tokens: Option<u32>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 pub cache_write_input_tokens: Option<u32>,
645}
646
647impl PromptTokensDetails {
648 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
650 [
651 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
652 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
653 (
654 TokenType::CacheReadInputTokens,
655 self.cache_read_input_tokens.unwrap_or(0),
656 ),
657 (
658 TokenType::CacheWriteInputTokens,
659 self.cache_write_input_tokens.unwrap_or(0),
660 ),
661 ]
662 .into_iter()
663 }
664}
665
666impl std::ops::Add for PromptTokensDetails {
667 type Output = Self;
668
669 fn add(self, rhs: Self) -> Self::Output {
670 Self {
671 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
672 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
673 cache_read_input_tokens: Some(
674 self.cache_read_input_tokens.unwrap_or(0)
675 + rhs.cache_read_input_tokens.unwrap_or(0),
676 ),
677 cache_write_input_tokens: Some(
678 self.cache_write_input_tokens.unwrap_or(0)
679 + rhs.cache_write_input_tokens.unwrap_or(0),
680 ),
681 }
682 }
683}
684
685impl std::ops::AddAssign for PromptTokensDetails {
686 fn add_assign(&mut self, rhs: Self) {
687 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
688 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
689 self.cache_read_input_tokens = Some(
690 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
691 );
692 self.cache_write_input_tokens = Some(
693 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
694 );
695 }
696}
697
698#[derive(Serialize, Deserialize, Debug, Clone)]
699#[serde(tag = "type")]
700pub enum GenerationDelta {
701 Content { content: String },
702 Thinking { thinking: String },
703 ToolUse { tool_use: GenerationDeltaToolUse },
704 Usage { usage: LLMTokenUsage },
705 Metadata { metadata: serde_json::Value },
706}
707
708#[derive(Serialize, Deserialize, Debug, Clone)]
709pub struct GenerationDeltaToolUse {
710 pub id: Option<String>,
711 pub name: Option<String>,
712 pub input: Option<String>,
713 pub index: usize,
714 #[serde(skip_serializing_if = "Option::is_none")]
716 pub metadata: Option<serde_json::Value>,
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
728 fn test_provider_config_openai_serialization() {
729 let config = ProviderConfig::OpenAI {
730 api_key: Some("sk-test".to_string()),
731 api_endpoint: None,
732 };
733 let json = serde_json::to_string(&config).unwrap();
734 assert!(json.contains("\"type\":\"openai\""));
735 assert!(json.contains("\"api_key\":\"sk-test\""));
736 assert!(!json.contains("api_endpoint")); }
738
739 #[test]
740 fn test_provider_config_openai_with_endpoint() {
741 let config = ProviderConfig::OpenAI {
742 api_key: Some("sk-test".to_string()),
743 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
744 };
745 let json = serde_json::to_string(&config).unwrap();
746 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
747 }
748
749 #[test]
750 fn test_provider_config_anthropic_serialization() {
751 let config = ProviderConfig::Anthropic {
752 api_key: Some("sk-ant-test".to_string()),
753 api_endpoint: None,
754 access_token: Some("oauth-token".to_string()),
755 };
756 let json = serde_json::to_string(&config).unwrap();
757 assert!(json.contains("\"type\":\"anthropic\""));
758 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
759 assert!(json.contains("\"access_token\":\"oauth-token\""));
760 }
761
762 #[test]
763 fn test_provider_config_gemini_serialization() {
764 let config = ProviderConfig::Gemini {
765 api_key: Some("gemini-key".to_string()),
766 api_endpoint: None,
767 };
768 let json = serde_json::to_string(&config).unwrap();
769 assert!(json.contains("\"type\":\"gemini\""));
770 assert!(json.contains("\"api_key\":\"gemini-key\""));
771 }
772
773 #[test]
774 fn test_provider_config_custom_serialization() {
775 let config = ProviderConfig::Custom {
776 api_key: Some("sk-custom".to_string()),
777 api_endpoint: "http://localhost:4000".to_string(),
778 };
779 let json = serde_json::to_string(&config).unwrap();
780 assert!(json.contains("\"type\":\"custom\""));
781 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
782 assert!(json.contains("\"api_key\":\"sk-custom\""));
783 }
784
785 #[test]
786 fn test_provider_config_custom_without_key() {
787 let config = ProviderConfig::Custom {
788 api_key: None,
789 api_endpoint: "http://localhost:11434/v1".to_string(),
790 };
791 let json = serde_json::to_string(&config).unwrap();
792 assert!(json.contains("\"type\":\"custom\""));
793 assert!(json.contains("\"api_endpoint\""));
794 assert!(!json.contains("api_key")); }
796
797 #[test]
798 fn test_provider_config_deserialization_openai() {
799 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
800 let config: ProviderConfig = serde_json::from_str(json).unwrap();
801 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
802 assert_eq!(config.api_key(), Some("sk-test"));
803 }
804
805 #[test]
806 fn test_provider_config_deserialization_anthropic() {
807 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
808 let config: ProviderConfig = serde_json::from_str(json).unwrap();
809 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
810 assert_eq!(config.api_key(), Some("sk-ant"));
811 assert_eq!(config.access_token(), Some("oauth"));
812 }
813
814 #[test]
815 fn test_provider_config_deserialization_gemini() {
816 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
817 let config: ProviderConfig = serde_json::from_str(json).unwrap();
818 assert!(matches!(config, ProviderConfig::Gemini { .. }));
819 assert_eq!(config.api_key(), Some("gemini-key"));
820 }
821
822 #[test]
823 fn test_provider_config_deserialization_custom() {
824 let json =
825 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
826 let config: ProviderConfig = serde_json::from_str(json).unwrap();
827 assert!(matches!(config, ProviderConfig::Custom { .. }));
828 assert_eq!(config.api_key(), Some("sk-custom"));
829 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
830 }
831
832 #[test]
833 fn test_provider_config_helper_methods() {
834 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
835 assert_eq!(openai.provider_type(), "openai");
836 assert_eq!(openai.api_key(), Some("sk-openai"));
837
838 let anthropic =
839 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
840 assert_eq!(anthropic.provider_type(), "anthropic");
841 assert_eq!(anthropic.access_token(), Some("oauth"));
842
843 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
844 assert_eq!(gemini.provider_type(), "gemini");
845
846 let custom = ProviderConfig::custom(
847 "http://localhost:4000".to_string(),
848 Some("sk-custom".to_string()),
849 );
850 assert_eq!(custom.provider_type(), "custom");
851 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
852 }
853
854 #[test]
855 fn test_set_api_endpoint_updates_supported_providers() {
856 let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
857 openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
858 assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
859
860 let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
861 bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
862 assert_eq!(bedrock.api_endpoint(), None);
863 }
864
865 #[test]
866 fn test_llm_provider_config_new() {
867 let config = LLMProviderConfig::new();
868 assert!(config.is_empty());
869 }
870
871 #[test]
872 fn test_llm_provider_config_add_and_get() {
873 let mut config = LLMProviderConfig::new();
874 config.add_provider(
875 "openai",
876 ProviderConfig::openai(Some("sk-test".to_string())),
877 );
878 config.add_provider(
879 "anthropic",
880 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
881 );
882
883 assert!(!config.is_empty());
884 assert!(config.get_provider("openai").is_some());
885 assert!(config.get_provider("anthropic").is_some());
886 assert!(config.get_provider("unknown").is_none());
887 }
888
889 #[test]
890 fn test_provider_config_toml_parsing() {
891 let json = r#"{
893 "openai": {"type": "openai", "api_key": "sk-openai"},
894 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
895 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
896 }"#;
897
898 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
899 assert_eq!(providers.len(), 3);
900
901 assert!(matches!(
902 providers.get("openai"),
903 Some(ProviderConfig::OpenAI { .. })
904 ));
905 assert!(matches!(
906 providers.get("anthropic"),
907 Some(ProviderConfig::Anthropic { .. })
908 ));
909 assert!(matches!(
910 providers.get("litellm"),
911 Some(ProviderConfig::Custom { .. })
912 ));
913 }
914
915 #[test]
920 fn test_provider_config_bedrock_serialization() {
921 let config = ProviderConfig::Bedrock {
922 region: "us-east-1".to_string(),
923 profile_name: Some("my-profile".to_string()),
924 };
925 let json = serde_json::to_string(&config).unwrap();
926 assert!(json.contains("\"type\":\"amazon-bedrock\""));
927 assert!(json.contains("\"region\":\"us-east-1\""));
928 assert!(json.contains("\"profile_name\":\"my-profile\""));
929 }
930
931 #[test]
932 fn test_provider_config_bedrock_serialization_without_profile() {
933 let config = ProviderConfig::Bedrock {
934 region: "us-west-2".to_string(),
935 profile_name: None,
936 };
937 let json = serde_json::to_string(&config).unwrap();
938 assert!(json.contains("\"type\":\"amazon-bedrock\""));
939 assert!(json.contains("\"region\":\"us-west-2\""));
940 assert!(!json.contains("profile_name")); }
942
943 #[test]
944 fn test_provider_config_bedrock_deserialization() {
945 let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
946 let config: ProviderConfig = serde_json::from_str(json).unwrap();
947 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
948 assert_eq!(config.region(), Some("us-east-1"));
949 assert_eq!(config.profile_name(), Some("prod"));
950 }
951
952 #[test]
953 fn test_provider_config_bedrock_deserialization_minimal() {
954 let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
955 let config: ProviderConfig = serde_json::from_str(json).unwrap();
956 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
957 assert_eq!(config.region(), Some("eu-west-1"));
958 assert_eq!(config.profile_name(), None);
959 }
960
961 #[test]
962 fn test_provider_config_bedrock_no_api_key() {
963 let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
964 assert_eq!(config.api_key(), None); assert_eq!(config.api_endpoint(), None); }
967
968 #[test]
969 fn test_provider_config_bedrock_helper_methods() {
970 let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
971 assert_eq!(bedrock.provider_type(), "amazon-bedrock");
972 assert_eq!(bedrock.region(), Some("us-east-1"));
973 assert_eq!(bedrock.profile_name(), Some("prod"));
974 assert_eq!(bedrock.api_key(), None);
975 assert_eq!(bedrock.api_endpoint(), None);
976 assert_eq!(bedrock.access_token(), None);
977 }
978
979 #[test]
980 fn test_provider_config_bedrock_toml_roundtrip() {
981 let config = ProviderConfig::Bedrock {
982 region: "us-east-1".to_string(),
983 profile_name: Some("my-profile".to_string()),
984 };
985 let toml_str = toml::to_string(&config).unwrap();
986 let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
987 assert_eq!(config, parsed);
988 }
989
990 #[test]
991 fn test_provider_config_bedrock_toml_parsing() {
992 let toml_str = r#"
993 type = "amazon-bedrock"
994 region = "us-east-1"
995 profile_name = "production"
996 "#;
997 let config: ProviderConfig = toml::from_str(toml_str).unwrap();
998 assert!(matches!(
999 config,
1000 ProviderConfig::Bedrock {
1001 ref region,
1002 ref profile_name,
1003 } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1004 ));
1005 }
1006
1007 #[test]
1008 fn test_provider_config_bedrock_missing_region_fails() {
1009 let json = r#"{"type":"amazon-bedrock"}"#;
1010 let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1011 assert!(result.is_err()); }
1013
1014 #[test]
1015 fn test_provider_config_bedrock_in_providers_map() {
1016 let json = r#"{
1017 "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1018 "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1019 }"#;
1020 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1021 assert_eq!(providers.len(), 2);
1022 assert!(matches!(
1023 providers.get("amazon-bedrock"),
1024 Some(ProviderConfig::Bedrock { .. })
1025 ));
1026 }
1027
1028 #[test]
1029 fn test_region_returns_none_for_non_bedrock() {
1030 let openai = ProviderConfig::openai(Some("key".to_string()));
1031 assert_eq!(openai.region(), None);
1032
1033 let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1034 assert_eq!(anthropic.region(), None);
1035 }
1036
1037 #[test]
1038 fn test_profile_name_returns_none_for_non_bedrock() {
1039 let openai = ProviderConfig::openai(Some("key".to_string()));
1040 assert_eq!(openai.profile_name(), None);
1041 }
1042}