1use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102 OpenAI {
104 #[serde(skip_serializing_if = "Option::is_none")]
106 api_key: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 api_endpoint: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 auth: Option<ProviderAuth>,
112 },
113 Anthropic {
115 #[serde(skip_serializing_if = "Option::is_none")]
117 api_key: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 api_endpoint: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 access_token: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 auth: Option<ProviderAuth>,
126 },
127 Gemini {
129 #[serde(skip_serializing_if = "Option::is_none")]
131 api_key: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 api_endpoint: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 auth: Option<ProviderAuth>,
137 },
138 Custom {
155 #[serde(skip_serializing_if = "Option::is_none")]
157 api_key: Option<String>,
158 api_endpoint: String,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 auth: Option<ProviderAuth>,
164 },
165 Stakpak {
186 #[serde(skip_serializing_if = "Option::is_none")]
189 api_key: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 api_endpoint: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 auth: Option<ProviderAuth>,
196 },
197 #[serde(rename = "amazon-bedrock")]
213 Bedrock {
214 region: String,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 profile_name: Option<String>,
219 },
220}
221
222impl ProviderConfig {
223 pub fn provider_type(&self) -> &'static str {
225 match self {
226 ProviderConfig::OpenAI { .. } => "openai",
227 ProviderConfig::Anthropic { .. } => "anthropic",
228 ProviderConfig::Gemini { .. } => "gemini",
229 ProviderConfig::Custom { .. } => "custom",
230 ProviderConfig::Stakpak { .. } => "stakpak",
231 ProviderConfig::Bedrock { .. } => "amazon-bedrock",
232 }
233 }
234
235 pub fn api_key(&self) -> Option<&str> {
237 if let Some(auth) = self.get_auth_ref()
239 && let Some(key) = auth.api_key_value()
240 {
241 return Some(key);
242 }
243 match self {
245 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
246 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
247 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
248 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
249 ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
250 ProviderConfig::Bedrock { .. } => None, }
252 }
253
254 fn get_auth_ref(&self) -> Option<&ProviderAuth> {
256 match self {
257 ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
258 ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
259 ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
260 ProviderConfig::Custom { auth, .. } => auth.as_ref(),
261 ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
262 ProviderConfig::Bedrock { .. } => None,
263 }
264 }
265
266 pub fn get_auth(&self) -> Option<ProviderAuth> {
273 if let Some(auth) = self.get_auth_ref() {
275 return Some(auth.clone());
276 }
277
278 match self {
280 ProviderConfig::OpenAI { api_key, .. }
281 | ProviderConfig::Gemini { api_key, .. }
282 | ProviderConfig::Custom { api_key, .. }
283 | ProviderConfig::Stakpak { api_key, .. } => {
284 api_key.as_ref().map(ProviderAuth::api_key)
285 }
286 ProviderConfig::Anthropic {
287 api_key,
288 access_token,
289 ..
290 } => {
291 if let Some(key) = api_key {
293 Some(ProviderAuth::api_key(key))
294 } else {
295 access_token
299 .as_ref()
300 .map(|token| ProviderAuth::oauth(token, "", 0))
301 }
302 }
303 ProviderConfig::Bedrock { .. } => None,
304 }
305 }
306
307 pub fn set_auth(&mut self, auth: ProviderAuth) {
312 match self {
313 ProviderConfig::OpenAI {
314 auth: auth_field,
315 api_key,
316 ..
317 }
318 | ProviderConfig::Gemini {
319 auth: auth_field,
320 api_key,
321 ..
322 }
323 | ProviderConfig::Custom {
324 auth: auth_field,
325 api_key,
326 ..
327 }
328 | ProviderConfig::Stakpak {
329 auth: auth_field,
330 api_key,
331 ..
332 } => {
333 *auth_field = Some(auth);
334 *api_key = None;
335 }
336 ProviderConfig::Anthropic {
337 auth: auth_field,
338 api_key,
339 access_token,
340 ..
341 } => {
342 *auth_field = Some(auth);
343 *api_key = None;
344 *access_token = None;
345 }
346 ProviderConfig::Bedrock { .. } => {
347 }
349 }
350 }
351
352 pub fn clear_auth(&mut self) {
357 match self {
358 ProviderConfig::OpenAI {
359 auth: auth_field,
360 api_key,
361 ..
362 }
363 | ProviderConfig::Gemini {
364 auth: auth_field,
365 api_key,
366 ..
367 }
368 | ProviderConfig::Custom {
369 auth: auth_field,
370 api_key,
371 ..
372 }
373 | ProviderConfig::Stakpak {
374 auth: auth_field,
375 api_key,
376 ..
377 } => {
378 *auth_field = None;
379 *api_key = None;
380 }
381 ProviderConfig::Anthropic {
382 auth: auth_field,
383 api_key,
384 access_token,
385 ..
386 } => {
387 *auth_field = None;
388 *api_key = None;
389 *access_token = None;
390 }
391 ProviderConfig::Bedrock { .. } => {
392 }
394 }
395 }
396
397 pub fn api_endpoint(&self) -> Option<&str> {
399 match self {
400 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
401 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
402 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
403 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
404 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
405 ProviderConfig::Bedrock { .. } => None, }
407 }
408
409 pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
414 match self {
415 ProviderConfig::OpenAI { api_endpoint, .. }
416 | ProviderConfig::Anthropic { api_endpoint, .. }
417 | ProviderConfig::Gemini { api_endpoint, .. }
418 | ProviderConfig::Stakpak { api_endpoint, .. } => {
419 *api_endpoint = endpoint;
420 }
421 ProviderConfig::Custom { api_endpoint, .. } => {
422 if let Some(custom_endpoint) = endpoint {
423 *api_endpoint = custom_endpoint;
424 }
425 }
426 ProviderConfig::Bedrock { .. } => {}
427 }
428 }
429
430 pub fn access_token(&self) -> Option<&str> {
435 if let Some(auth) = self.get_auth_ref()
437 && let Some(token) = auth.access_token()
438 {
439 return Some(token);
440 }
441 match self {
443 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
444 _ => None,
445 }
446 }
447
448 pub fn openai(api_key: Option<String>) -> Self {
450 ProviderConfig::OpenAI {
451 api_key,
452 api_endpoint: None,
453 auth: None,
454 }
455 }
456
457 pub fn openai_with_auth(auth: ProviderAuth) -> Self {
459 ProviderConfig::OpenAI {
460 api_key: None,
461 api_endpoint: None,
462 auth: Some(auth),
463 }
464 }
465
466 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
468 ProviderConfig::Anthropic {
469 api_key,
470 api_endpoint: None,
471 access_token,
472 auth: None,
473 }
474 }
475
476 pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
478 ProviderConfig::Anthropic {
479 api_key: None,
480 api_endpoint: None,
481 access_token: None,
482 auth: Some(auth),
483 }
484 }
485
486 pub fn gemini(api_key: Option<String>) -> Self {
488 ProviderConfig::Gemini {
489 api_key,
490 api_endpoint: None,
491 auth: None,
492 }
493 }
494
495 pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
497 ProviderConfig::Gemini {
498 api_key: None,
499 api_endpoint: None,
500 auth: Some(auth),
501 }
502 }
503
504 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
506 ProviderConfig::Custom {
507 api_key,
508 api_endpoint,
509 auth: None,
510 }
511 }
512
513 pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
515 ProviderConfig::Custom {
516 api_key: None,
517 api_endpoint,
518 auth: Some(auth),
519 }
520 }
521
522 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
524 ProviderConfig::Stakpak {
525 api_key: Some(api_key),
526 api_endpoint,
527 auth: None,
528 }
529 }
530
531 pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
533 ProviderConfig::Stakpak {
534 api_key: None,
535 api_endpoint,
536 auth: Some(auth),
537 }
538 }
539
540 pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
542 ProviderConfig::Bedrock {
543 region,
544 profile_name,
545 }
546 }
547
548 pub fn region(&self) -> Option<&str> {
550 match self {
551 ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
552 _ => None,
553 }
554 }
555
556 pub fn profile_name(&self) -> Option<&str> {
558 match self {
559 ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
560 _ => None,
561 }
562 }
563
564 pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
569 match provider_name {
570 "openai" => Some(ProviderConfig::OpenAI {
571 api_key: None,
572 api_endpoint: None,
573 auth: None,
574 }),
575 "anthropic" => Some(ProviderConfig::Anthropic {
576 api_key: None,
577 api_endpoint: None,
578 access_token: None,
579 auth: None,
580 }),
581 "gemini" => Some(ProviderConfig::Gemini {
582 api_key: None,
583 api_endpoint: None,
584 auth: None,
585 }),
586 "stakpak" => Some(ProviderConfig::Stakpak {
587 api_key: None,
588 api_endpoint: None,
589 auth: None,
590 }),
591 _ => None,
593 }
594 }
595}
596
597#[derive(Debug, Clone, Default)]
601pub struct LLMProviderConfig {
602 pub providers: HashMap<String, ProviderConfig>,
604}
605
606impl LLMProviderConfig {
607 pub fn new() -> Self {
609 Self {
610 providers: HashMap::new(),
611 }
612 }
613
614 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
616 self.providers.insert(name.into(), config);
617 }
618
619 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
621 self.providers.get(name)
622 }
623
624 pub fn is_empty(&self) -> bool {
626 self.providers.is_empty()
627 }
628}
629
630#[derive(Clone, Debug, Serialize, Deserialize, Default)]
632pub struct LLMProviderOptions {
633 #[serde(skip_serializing_if = "Option::is_none")]
635 pub anthropic: Option<LLMAnthropicOptions>,
636
637 #[serde(skip_serializing_if = "Option::is_none")]
639 pub openai: Option<LLMOpenAIOptions>,
640
641 #[serde(skip_serializing_if = "Option::is_none")]
643 pub google: Option<LLMGoogleOptions>,
644}
645
646#[derive(Clone, Debug, Serialize, Deserialize, Default)]
648pub struct LLMAnthropicOptions {
649 #[serde(skip_serializing_if = "Option::is_none")]
651 pub thinking: Option<LLMThinkingOptions>,
652}
653
654#[derive(Clone, Debug, Serialize, Deserialize)]
656pub struct LLMThinkingOptions {
657 pub budget_tokens: u32,
659}
660
661impl LLMThinkingOptions {
662 pub fn new(budget_tokens: u32) -> Self {
663 Self {
664 budget_tokens: budget_tokens.max(1024),
665 }
666 }
667}
668
669#[derive(Clone, Debug, Serialize, Deserialize, Default)]
671pub struct LLMOpenAIOptions {
672 #[serde(skip_serializing_if = "Option::is_none")]
674 pub reasoning_effort: Option<String>,
675}
676
677#[derive(Clone, Debug, Serialize, Deserialize, Default)]
679pub struct LLMGoogleOptions {
680 #[serde(skip_serializing_if = "Option::is_none")]
682 pub thinking_budget: Option<u32>,
683}
684
685#[derive(Clone, Debug, Serialize)]
686pub struct LLMInput {
687 pub model: Model,
688 pub messages: Vec<LLMMessage>,
689 pub max_tokens: u32,
690 pub tools: Option<Vec<LLMTool>>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 pub provider_options: Option<LLMProviderOptions>,
693 #[serde(skip_serializing_if = "Option::is_none")]
695 pub headers: Option<std::collections::HashMap<String, String>>,
696}
697
698#[derive(Debug)]
699pub struct LLMStreamInput {
700 pub model: Model,
701 pub messages: Vec<LLMMessage>,
702 pub max_tokens: u32,
703 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
704 pub tools: Option<Vec<LLMTool>>,
705 pub provider_options: Option<LLMProviderOptions>,
706 pub headers: Option<std::collections::HashMap<String, String>>,
708}
709
710impl From<&LLMStreamInput> for LLMInput {
711 fn from(value: &LLMStreamInput) -> Self {
712 LLMInput {
713 model: value.model.clone(),
714 messages: value.messages.clone(),
715 max_tokens: value.max_tokens,
716 tools: value.tools.clone(),
717 provider_options: value.provider_options.clone(),
718 headers: value.headers.clone(),
719 }
720 }
721}
722
723#[derive(Serialize, Deserialize, Debug, Clone, Default)]
724pub struct LLMMessage {
725 pub role: String,
726 pub content: LLMMessageContent,
727}
728
729#[derive(Serialize, Deserialize, Debug, Clone)]
730pub struct SimpleLLMMessage {
731 #[serde(rename = "role")]
732 pub role: SimpleLLMRole,
733 pub content: String,
734}
735
736#[derive(Serialize, Deserialize, Debug, Clone)]
737#[serde(rename_all = "lowercase")]
738pub enum SimpleLLMRole {
739 User,
740 Assistant,
741}
742
743impl std::fmt::Display for SimpleLLMRole {
744 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
745 match self {
746 SimpleLLMRole::User => write!(f, "user"),
747 SimpleLLMRole::Assistant => write!(f, "assistant"),
748 }
749 }
750}
751
752#[derive(Serialize, Deserialize, Debug, Clone)]
753#[serde(untagged)]
754pub enum LLMMessageContent {
755 String(String),
756 List(Vec<LLMMessageTypedContent>),
757}
758
759#[allow(clippy::to_string_trait_impl)]
760impl ToString for LLMMessageContent {
761 fn to_string(&self) -> String {
762 match self {
763 LLMMessageContent::String(s) => s.clone(),
764 LLMMessageContent::List(l) => l
765 .iter()
766 .map(|c| match c {
767 LLMMessageTypedContent::Text { text } => text.clone(),
768 LLMMessageTypedContent::ToolCall { .. } => String::new(),
769 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
770 LLMMessageTypedContent::Image { .. } => String::new(),
771 })
772 .collect::<Vec<_>>()
773 .join("\n"),
774 }
775 }
776}
777
778impl From<String> for LLMMessageContent {
779 fn from(value: String) -> Self {
780 LLMMessageContent::String(value)
781 }
782}
783
784impl Default for LLMMessageContent {
785 fn default() -> Self {
786 LLMMessageContent::String(String::new())
787 }
788}
789
790impl LLMMessageContent {
791 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
794 match self {
795 LLMMessageContent::List(parts) => parts,
796 LLMMessageContent::String(s) if s.is_empty() => vec![],
797 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
798 }
799 }
800}
801
802#[derive(Serialize, Deserialize, Debug, Clone)]
803#[serde(tag = "type")]
804pub enum LLMMessageTypedContent {
805 #[serde(rename = "text")]
806 Text { text: String },
807 #[serde(rename = "tool_use")]
808 ToolCall {
809 id: String,
810 name: String,
811 #[serde(alias = "input")]
812 args: serde_json::Value,
813 #[serde(skip_serializing_if = "Option::is_none")]
815 metadata: Option<serde_json::Value>,
816 },
817 #[serde(rename = "tool_result")]
818 ToolResult {
819 tool_use_id: String,
820 content: String,
821 },
822 #[serde(rename = "image")]
823 Image { source: LLMMessageImageSource },
824}
825
826#[derive(Serialize, Deserialize, Debug, Clone)]
827pub struct LLMMessageImageSource {
828 #[serde(rename = "type")]
829 pub r#type: String,
830 pub media_type: String,
831 pub data: String,
832}
833
834impl Default for LLMMessageTypedContent {
835 fn default() -> Self {
836 LLMMessageTypedContent::Text {
837 text: String::new(),
838 }
839 }
840}
841
842#[derive(Serialize, Deserialize, Debug, Clone)]
843pub struct LLMChoice {
844 pub finish_reason: Option<String>,
845 pub index: u32,
846 pub message: LLMMessage,
847}
848
849#[derive(Serialize, Deserialize, Debug, Clone)]
850pub struct LLMCompletionResponse {
851 pub model: String,
852 pub object: String,
853 pub choices: Vec<LLMChoice>,
854 pub created: u64,
855 pub usage: Option<LLMTokenUsage>,
856 pub id: String,
857}
858
859#[derive(Serialize, Deserialize, Debug, Clone)]
860pub struct LLMStreamDelta {
861 #[serde(skip_serializing_if = "Option::is_none")]
862 pub content: Option<String>,
863}
864
865#[derive(Serialize, Deserialize, Debug, Clone)]
866pub struct LLMStreamChoice {
867 pub finish_reason: Option<String>,
868 pub index: u32,
869 pub message: Option<LLMMessage>,
870 pub delta: LLMStreamDelta,
871}
872
873#[derive(Serialize, Deserialize, Debug, Clone)]
874pub struct LLMCompletionStreamResponse {
875 pub model: String,
876 pub object: String,
877 pub choices: Vec<LLMStreamChoice>,
878 pub created: u64,
879 #[serde(skip_serializing_if = "Option::is_none")]
880 pub usage: Option<LLMTokenUsage>,
881 pub id: String,
882 pub citations: Option<Vec<String>>,
883}
884
885#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
886pub struct LLMTool {
887 pub name: String,
888 pub description: String,
889 pub input_schema: serde_json::Value,
890}
891
892#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
893pub struct LLMTokenUsage {
894 pub prompt_tokens: u32,
895 pub completion_tokens: u32,
896 pub total_tokens: u32,
897
898 #[serde(skip_serializing_if = "Option::is_none")]
899 pub prompt_tokens_details: Option<PromptTokensDetails>,
900}
901
902#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
903#[serde(rename_all = "snake_case")]
904pub enum TokenType {
905 InputTokens,
906 OutputTokens,
907 CacheReadInputTokens,
908 CacheWriteInputTokens,
909}
910
911#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
912pub struct PromptTokensDetails {
913 #[serde(skip_serializing_if = "Option::is_none")]
914 pub input_tokens: Option<u32>,
915 #[serde(skip_serializing_if = "Option::is_none")]
916 pub output_tokens: Option<u32>,
917 #[serde(skip_serializing_if = "Option::is_none")]
918 pub cache_read_input_tokens: Option<u32>,
919 #[serde(skip_serializing_if = "Option::is_none")]
920 pub cache_write_input_tokens: Option<u32>,
921}
922
923impl PromptTokensDetails {
924 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
926 [
927 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
928 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
929 (
930 TokenType::CacheReadInputTokens,
931 self.cache_read_input_tokens.unwrap_or(0),
932 ),
933 (
934 TokenType::CacheWriteInputTokens,
935 self.cache_write_input_tokens.unwrap_or(0),
936 ),
937 ]
938 .into_iter()
939 }
940}
941
942impl std::ops::Add for PromptTokensDetails {
943 type Output = Self;
944
945 fn add(self, rhs: Self) -> Self::Output {
946 Self {
947 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
948 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
949 cache_read_input_tokens: Some(
950 self.cache_read_input_tokens.unwrap_or(0)
951 + rhs.cache_read_input_tokens.unwrap_or(0),
952 ),
953 cache_write_input_tokens: Some(
954 self.cache_write_input_tokens.unwrap_or(0)
955 + rhs.cache_write_input_tokens.unwrap_or(0),
956 ),
957 }
958 }
959}
960
961impl std::ops::AddAssign for PromptTokensDetails {
962 fn add_assign(&mut self, rhs: Self) {
963 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
964 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
965 self.cache_read_input_tokens = Some(
966 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
967 );
968 self.cache_write_input_tokens = Some(
969 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
970 );
971 }
972}
973
974#[derive(Serialize, Deserialize, Debug, Clone)]
975#[serde(tag = "type")]
976pub enum GenerationDelta {
977 Content { content: String },
978 Thinking { thinking: String },
979 ToolUse { tool_use: GenerationDeltaToolUse },
980 Usage { usage: LLMTokenUsage },
981 Metadata { metadata: serde_json::Value },
982}
983
984#[derive(Serialize, Deserialize, Debug, Clone)]
985pub struct GenerationDeltaToolUse {
986 pub id: Option<String>,
987 pub name: Option<String>,
988 pub input: Option<String>,
989 pub index: usize,
990 #[serde(skip_serializing_if = "Option::is_none")]
992 pub metadata: Option<serde_json::Value>,
993}
994
995#[cfg(test)]
996mod tests {
997 use super::*;
998
999 #[test]
1004 fn test_provider_config_openai_serialization() {
1005 let config = ProviderConfig::OpenAI {
1006 api_key: Some("sk-test".to_string()),
1007 api_endpoint: None,
1008 auth: None,
1009 };
1010 let json = serde_json::to_string(&config).unwrap();
1011 assert!(json.contains("\"type\":\"openai\""));
1012 assert!(json.contains("\"api_key\":\"sk-test\""));
1013 assert!(!json.contains("api_endpoint")); }
1015
1016 #[test]
1017 fn test_provider_config_openai_with_endpoint() {
1018 let config = ProviderConfig::OpenAI {
1019 api_key: Some("sk-test".to_string()),
1020 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1021 auth: None,
1022 };
1023 let json = serde_json::to_string(&config).unwrap();
1024 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1025 }
1026
1027 #[test]
1028 fn test_provider_config_anthropic_serialization() {
1029 let config = ProviderConfig::Anthropic {
1030 api_key: Some("sk-ant-test".to_string()),
1031 api_endpoint: None,
1032 access_token: Some("oauth-token".to_string()),
1033 auth: None,
1034 };
1035 let json = serde_json::to_string(&config).unwrap();
1036 assert!(json.contains("\"type\":\"anthropic\""));
1037 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1038 assert!(json.contains("\"access_token\":\"oauth-token\""));
1039 }
1040
1041 #[test]
1042 fn test_provider_config_gemini_serialization() {
1043 let config = ProviderConfig::Gemini {
1044 api_key: Some("gemini-key".to_string()),
1045 api_endpoint: None,
1046 auth: None,
1047 };
1048 let json = serde_json::to_string(&config).unwrap();
1049 assert!(json.contains("\"type\":\"gemini\""));
1050 assert!(json.contains("\"api_key\":\"gemini-key\""));
1051 }
1052
1053 #[test]
1054 fn test_provider_config_custom_serialization() {
1055 let config = ProviderConfig::Custom {
1056 api_key: Some("sk-custom".to_string()),
1057 api_endpoint: "http://localhost:4000".to_string(),
1058 auth: None,
1059 };
1060 let json = serde_json::to_string(&config).unwrap();
1061 assert!(json.contains("\"type\":\"custom\""));
1062 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1063 assert!(json.contains("\"api_key\":\"sk-custom\""));
1064 }
1065
1066 #[test]
1067 fn test_provider_config_custom_without_key() {
1068 let config = ProviderConfig::Custom {
1069 api_key: None,
1070 api_endpoint: "http://localhost:11434/v1".to_string(),
1071 auth: None,
1072 };
1073 let json = serde_json::to_string(&config).unwrap();
1074 assert!(json.contains("\"type\":\"custom\""));
1075 assert!(json.contains("\"api_endpoint\""));
1076 assert!(!json.contains("api_key")); }
1078
1079 #[test]
1080 fn test_provider_config_deserialization_openai() {
1081 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1082 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1083 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1084 assert_eq!(config.api_key(), Some("sk-test"));
1085 }
1086
1087 #[test]
1088 fn test_provider_config_deserialization_anthropic() {
1089 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1090 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1091 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1092 assert_eq!(config.api_key(), Some("sk-ant"));
1093 assert_eq!(config.access_token(), Some("oauth"));
1094 }
1095
1096 #[test]
1097 fn test_provider_config_deserialization_gemini() {
1098 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1099 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1100 assert!(matches!(config, ProviderConfig::Gemini { .. }));
1101 assert_eq!(config.api_key(), Some("gemini-key"));
1102 }
1103
1104 #[test]
1105 fn test_provider_config_deserialization_custom() {
1106 let json =
1107 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1108 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1109 assert!(matches!(config, ProviderConfig::Custom { .. }));
1110 assert_eq!(config.api_key(), Some("sk-custom"));
1111 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1112 }
1113
1114 #[test]
1115 fn test_provider_config_helper_methods() {
1116 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1117 assert_eq!(openai.provider_type(), "openai");
1118 assert_eq!(openai.api_key(), Some("sk-openai"));
1119
1120 let anthropic =
1121 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1122 assert_eq!(anthropic.provider_type(), "anthropic");
1123 assert_eq!(anthropic.access_token(), Some("oauth"));
1124
1125 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1126 assert_eq!(gemini.provider_type(), "gemini");
1127
1128 let custom = ProviderConfig::custom(
1129 "http://localhost:4000".to_string(),
1130 Some("sk-custom".to_string()),
1131 );
1132 assert_eq!(custom.provider_type(), "custom");
1133 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1134 }
1135
1136 #[test]
1137 fn test_set_api_endpoint_updates_supported_providers() {
1138 let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1139 openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1140 assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1141
1142 let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1143 bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1144 assert_eq!(bedrock.api_endpoint(), None);
1145 }
1146
1147 #[test]
1148 fn test_llm_provider_config_new() {
1149 let config = LLMProviderConfig::new();
1150 assert!(config.is_empty());
1151 }
1152
1153 #[test]
1154 fn test_llm_provider_config_add_and_get() {
1155 let mut config = LLMProviderConfig::new();
1156 config.add_provider(
1157 "openai",
1158 ProviderConfig::openai(Some("sk-test".to_string())),
1159 );
1160 config.add_provider(
1161 "anthropic",
1162 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1163 );
1164
1165 assert!(!config.is_empty());
1166 assert!(config.get_provider("openai").is_some());
1167 assert!(config.get_provider("anthropic").is_some());
1168 assert!(config.get_provider("unknown").is_none());
1169 }
1170
1171 #[test]
1172 fn test_provider_config_toml_parsing() {
1173 let json = r#"{
1175 "openai": {"type": "openai", "api_key": "sk-openai"},
1176 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1177 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1178 }"#;
1179
1180 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1181 assert_eq!(providers.len(), 3);
1182
1183 assert!(matches!(
1184 providers.get("openai"),
1185 Some(ProviderConfig::OpenAI { .. })
1186 ));
1187 assert!(matches!(
1188 providers.get("anthropic"),
1189 Some(ProviderConfig::Anthropic { .. })
1190 ));
1191 assert!(matches!(
1192 providers.get("litellm"),
1193 Some(ProviderConfig::Custom { .. })
1194 ));
1195 }
1196
1197 #[test]
1202 fn test_provider_config_bedrock_serialization() {
1203 let config = ProviderConfig::Bedrock {
1204 region: "us-east-1".to_string(),
1205 profile_name: Some("my-profile".to_string()),
1206 };
1207 let json = serde_json::to_string(&config).unwrap();
1208 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1209 assert!(json.contains("\"region\":\"us-east-1\""));
1210 assert!(json.contains("\"profile_name\":\"my-profile\""));
1211 }
1212
1213 #[test]
1214 fn test_provider_config_bedrock_serialization_without_profile() {
1215 let config = ProviderConfig::Bedrock {
1216 region: "us-west-2".to_string(),
1217 profile_name: None,
1218 };
1219 let json = serde_json::to_string(&config).unwrap();
1220 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1221 assert!(json.contains("\"region\":\"us-west-2\""));
1222 assert!(!json.contains("profile_name")); }
1224
1225 #[test]
1226 fn test_provider_config_bedrock_deserialization() {
1227 let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1228 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1229 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1230 assert_eq!(config.region(), Some("us-east-1"));
1231 assert_eq!(config.profile_name(), Some("prod"));
1232 }
1233
1234 #[test]
1235 fn test_provider_config_bedrock_deserialization_minimal() {
1236 let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1237 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1238 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1239 assert_eq!(config.region(), Some("eu-west-1"));
1240 assert_eq!(config.profile_name(), None);
1241 }
1242
1243 #[test]
1244 fn test_provider_config_bedrock_no_api_key() {
1245 let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1246 assert_eq!(config.api_key(), None); assert_eq!(config.api_endpoint(), None); }
1249
1250 #[test]
1251 fn test_provider_config_bedrock_helper_methods() {
1252 let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1253 assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1254 assert_eq!(bedrock.region(), Some("us-east-1"));
1255 assert_eq!(bedrock.profile_name(), Some("prod"));
1256 assert_eq!(bedrock.api_key(), None);
1257 assert_eq!(bedrock.api_endpoint(), None);
1258 assert_eq!(bedrock.access_token(), None);
1259 }
1260
1261 #[test]
1262 fn test_provider_config_bedrock_toml_roundtrip() {
1263 let config = ProviderConfig::Bedrock {
1264 region: "us-east-1".to_string(),
1265 profile_name: Some("my-profile".to_string()),
1266 };
1267 let toml_str = toml::to_string(&config).unwrap();
1268 let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1269 assert_eq!(config, parsed);
1270 }
1271
1272 #[test]
1273 fn test_provider_config_bedrock_toml_parsing() {
1274 let toml_str = r#"
1275 type = "amazon-bedrock"
1276 region = "us-east-1"
1277 profile_name = "production"
1278 "#;
1279 let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1280 assert!(matches!(
1281 config,
1282 ProviderConfig::Bedrock {
1283 ref region,
1284 ref profile_name,
1285 } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1286 ));
1287 }
1288
1289 #[test]
1290 fn test_provider_config_bedrock_missing_region_fails() {
1291 let json = r#"{"type":"amazon-bedrock"}"#;
1292 let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1293 assert!(result.is_err()); }
1295
1296 #[test]
1297 fn test_provider_config_bedrock_in_providers_map() {
1298 let json = r#"{
1299 "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1300 "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1301 }"#;
1302 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1303 assert_eq!(providers.len(), 2);
1304 assert!(matches!(
1305 providers.get("amazon-bedrock"),
1306 Some(ProviderConfig::Bedrock { .. })
1307 ));
1308 }
1309
1310 #[test]
1311 fn test_region_returns_none_for_non_bedrock() {
1312 let openai = ProviderConfig::openai(Some("key".to_string()));
1313 assert_eq!(openai.region(), None);
1314
1315 let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1316 assert_eq!(anthropic.region(), None);
1317 }
1318
1319 #[test]
1320 fn test_profile_name_returns_none_for_non_bedrock() {
1321 let openai = ProviderConfig::openai(Some("key".to_string()));
1322 assert_eq!(openai.profile_name(), None);
1323 }
1324}