1use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102 OpenAI {
104 #[serde(skip_serializing_if = "Option::is_none")]
106 api_key: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 api_endpoint: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 auth: Option<ProviderAuth>,
112 },
113 Anthropic {
115 #[serde(skip_serializing_if = "Option::is_none")]
117 api_key: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 api_endpoint: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 access_token: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 auth: Option<ProviderAuth>,
126 },
127 Gemini {
129 #[serde(skip_serializing_if = "Option::is_none")]
131 api_key: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 api_endpoint: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 auth: Option<ProviderAuth>,
137 },
138 Custom {
155 #[serde(skip_serializing_if = "Option::is_none")]
157 api_key: Option<String>,
158 api_endpoint: String,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 auth: Option<ProviderAuth>,
164 },
165 Stakpak {
186 #[serde(skip_serializing_if = "Option::is_none")]
189 api_key: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 api_endpoint: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 auth: Option<ProviderAuth>,
196 },
197 #[serde(rename = "amazon-bedrock")]
213 Bedrock {
214 region: String,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 profile_name: Option<String>,
219 },
220 #[serde(rename = "github-copilot")]
241 GitHubCopilot {
242 #[serde(skip_serializing_if = "Option::is_none")]
244 api_endpoint: Option<String>,
245 #[serde(skip_serializing_if = "Option::is_none")]
247 auth: Option<ProviderAuth>,
248 },
249}
250
251impl ProviderConfig {
252 pub fn provider_type(&self) -> &'static str {
254 match self {
255 ProviderConfig::OpenAI { .. } => "openai",
256 ProviderConfig::Anthropic { .. } => "anthropic",
257 ProviderConfig::Gemini { .. } => "gemini",
258 ProviderConfig::Custom { .. } => "custom",
259 ProviderConfig::Stakpak { .. } => "stakpak",
260 ProviderConfig::Bedrock { .. } => "amazon-bedrock",
261 ProviderConfig::GitHubCopilot { .. } => "github-copilot",
262 }
263 }
264
265 pub fn api_key(&self) -> Option<&str> {
267 if let Some(auth) = self.get_auth_ref()
269 && let Some(key) = auth.api_key_value()
270 {
271 return Some(key);
272 }
273 match self {
275 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
276 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
277 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
278 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
279 ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
280 ProviderConfig::Bedrock { .. } => None, ProviderConfig::GitHubCopilot { .. } => None, }
283 }
284
285 fn get_auth_ref(&self) -> Option<&ProviderAuth> {
287 match self {
288 ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
289 ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
290 ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
291 ProviderConfig::Custom { auth, .. } => auth.as_ref(),
292 ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
293 ProviderConfig::Bedrock { .. } => None,
294 ProviderConfig::GitHubCopilot { auth, .. } => auth.as_ref(),
295 }
296 }
297
298 pub fn get_auth(&self) -> Option<ProviderAuth> {
305 if let Some(auth) = self.get_auth_ref() {
307 return Some(auth.clone());
308 }
309
310 match self {
312 ProviderConfig::OpenAI { api_key, .. }
313 | ProviderConfig::Gemini { api_key, .. }
314 | ProviderConfig::Custom { api_key, .. }
315 | ProviderConfig::Stakpak { api_key, .. } => {
316 api_key.as_ref().map(ProviderAuth::api_key)
317 }
318 ProviderConfig::Anthropic {
319 api_key,
320 access_token,
321 ..
322 } => {
323 if let Some(key) = api_key {
325 Some(ProviderAuth::api_key(key))
326 } else {
327 access_token
331 .as_ref()
332 .map(|token| ProviderAuth::oauth(token, "", 0))
333 }
334 }
335 ProviderConfig::Bedrock { .. } => None,
336 ProviderConfig::GitHubCopilot { .. } => None,
338 }
339 }
340
341 pub fn set_auth(&mut self, auth: ProviderAuth) {
346 match self {
347 ProviderConfig::OpenAI {
348 auth: auth_field,
349 api_key,
350 ..
351 }
352 | ProviderConfig::Gemini {
353 auth: auth_field,
354 api_key,
355 ..
356 }
357 | ProviderConfig::Custom {
358 auth: auth_field,
359 api_key,
360 ..
361 }
362 | ProviderConfig::Stakpak {
363 auth: auth_field,
364 api_key,
365 ..
366 } => {
367 *auth_field = Some(auth);
368 *api_key = None;
369 }
370 ProviderConfig::Anthropic {
371 auth: auth_field,
372 api_key,
373 access_token,
374 ..
375 } => {
376 *auth_field = Some(auth);
377 *api_key = None;
378 *access_token = None;
379 }
380 ProviderConfig::GitHubCopilot {
381 auth: auth_field, ..
382 } => {
383 *auth_field = Some(auth);
384 }
385 ProviderConfig::Bedrock { .. } => {
386 }
388 }
389 }
390
391 pub fn clear_auth(&mut self) {
396 match self {
397 ProviderConfig::OpenAI {
398 auth: auth_field,
399 api_key,
400 ..
401 }
402 | ProviderConfig::Gemini {
403 auth: auth_field,
404 api_key,
405 ..
406 }
407 | ProviderConfig::Custom {
408 auth: auth_field,
409 api_key,
410 ..
411 }
412 | ProviderConfig::Stakpak {
413 auth: auth_field,
414 api_key,
415 ..
416 } => {
417 *auth_field = None;
418 *api_key = None;
419 }
420 ProviderConfig::Anthropic {
421 auth: auth_field,
422 api_key,
423 access_token,
424 ..
425 } => {
426 *auth_field = None;
427 *api_key = None;
428 *access_token = None;
429 }
430 ProviderConfig::GitHubCopilot {
431 auth: auth_field, ..
432 } => {
433 *auth_field = None;
434 }
435 ProviderConfig::Bedrock { .. } => {
436 }
438 }
439 }
440
441 pub fn api_endpoint(&self) -> Option<&str> {
443 match self {
444 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
445 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
446 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
447 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
448 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
449 ProviderConfig::Bedrock { .. } => None, ProviderConfig::GitHubCopilot { api_endpoint, .. } => api_endpoint.as_deref(),
451 }
452 }
453
454 pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
459 match self {
460 ProviderConfig::OpenAI { api_endpoint, .. }
461 | ProviderConfig::Anthropic { api_endpoint, .. }
462 | ProviderConfig::Gemini { api_endpoint, .. }
463 | ProviderConfig::Stakpak { api_endpoint, .. }
464 | ProviderConfig::GitHubCopilot { api_endpoint, .. } => {
465 *api_endpoint = endpoint;
466 }
467 ProviderConfig::Custom { api_endpoint, .. } => {
468 if let Some(custom_endpoint) = endpoint {
469 *api_endpoint = custom_endpoint;
470 }
471 }
472 ProviderConfig::Bedrock { .. } => {}
473 }
474 }
475
476 pub fn access_token(&self) -> Option<&str> {
481 if let Some(auth) = self.get_auth_ref()
483 && let Some(token) = auth.access_token()
484 {
485 return Some(token);
486 }
487 match self {
489 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
490 _ => None,
491 }
492 }
493
494 pub fn openai(api_key: Option<String>) -> Self {
496 ProviderConfig::OpenAI {
497 api_key,
498 api_endpoint: None,
499 auth: None,
500 }
501 }
502
503 pub fn openai_with_auth(auth: ProviderAuth) -> Self {
505 ProviderConfig::OpenAI {
506 api_key: None,
507 api_endpoint: None,
508 auth: Some(auth),
509 }
510 }
511
512 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
514 ProviderConfig::Anthropic {
515 api_key,
516 api_endpoint: None,
517 access_token,
518 auth: None,
519 }
520 }
521
522 pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
524 ProviderConfig::Anthropic {
525 api_key: None,
526 api_endpoint: None,
527 access_token: None,
528 auth: Some(auth),
529 }
530 }
531
532 pub fn gemini(api_key: Option<String>) -> Self {
534 ProviderConfig::Gemini {
535 api_key,
536 api_endpoint: None,
537 auth: None,
538 }
539 }
540
541 pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
543 ProviderConfig::Gemini {
544 api_key: None,
545 api_endpoint: None,
546 auth: Some(auth),
547 }
548 }
549
550 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
552 ProviderConfig::Custom {
553 api_key,
554 api_endpoint,
555 auth: None,
556 }
557 }
558
559 pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
561 ProviderConfig::Custom {
562 api_key: None,
563 api_endpoint,
564 auth: Some(auth),
565 }
566 }
567
568 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
570 ProviderConfig::Stakpak {
571 api_key: Some(api_key),
572 api_endpoint,
573 auth: None,
574 }
575 }
576
577 pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
579 ProviderConfig::Stakpak {
580 api_key: None,
581 api_endpoint,
582 auth: Some(auth),
583 }
584 }
585
586 pub fn github_copilot_with_auth(auth: ProviderAuth) -> Self {
588 ProviderConfig::GitHubCopilot {
589 api_endpoint: None,
590 auth: Some(auth),
591 }
592 }
593
594 pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
596 ProviderConfig::Bedrock {
597 region,
598 profile_name,
599 }
600 }
601
602 pub fn region(&self) -> Option<&str> {
604 match self {
605 ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
606 _ => None,
607 }
608 }
609
610 pub fn profile_name(&self) -> Option<&str> {
612 match self {
613 ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
614 _ => None,
615 }
616 }
617
618 pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
623 match provider_name {
624 "openai" => Some(ProviderConfig::OpenAI {
625 api_key: None,
626 api_endpoint: None,
627 auth: None,
628 }),
629 "anthropic" => Some(ProviderConfig::Anthropic {
630 api_key: None,
631 api_endpoint: None,
632 access_token: None,
633 auth: None,
634 }),
635 "gemini" => Some(ProviderConfig::Gemini {
636 api_key: None,
637 api_endpoint: None,
638 auth: None,
639 }),
640 "stakpak" => Some(ProviderConfig::Stakpak {
641 api_key: None,
642 api_endpoint: None,
643 auth: None,
644 }),
645 "github-copilot" => Some(ProviderConfig::GitHubCopilot {
646 api_endpoint: None,
647 auth: None,
648 }),
649 _ => None,
651 }
652 }
653}
654
655#[derive(Debug, Clone, Default)]
659pub struct LLMProviderConfig {
660 pub providers: HashMap<String, ProviderConfig>,
662}
663
664impl LLMProviderConfig {
665 pub fn new() -> Self {
667 Self {
668 providers: HashMap::new(),
669 }
670 }
671
672 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
674 self.providers.insert(name.into(), config);
675 }
676
677 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
679 self.providers.get(name)
680 }
681
682 pub fn is_empty(&self) -> bool {
684 self.providers.is_empty()
685 }
686}
687
688#[derive(Clone, Debug, Serialize, Deserialize, Default)]
690pub struct LLMProviderOptions {
691 #[serde(skip_serializing_if = "Option::is_none")]
693 pub anthropic: Option<LLMAnthropicOptions>,
694
695 #[serde(skip_serializing_if = "Option::is_none")]
697 pub openai: Option<LLMOpenAIOptions>,
698
699 #[serde(skip_serializing_if = "Option::is_none")]
701 pub google: Option<LLMGoogleOptions>,
702}
703
704#[derive(Clone, Debug, Serialize, Deserialize, Default)]
706pub struct LLMAnthropicOptions {
707 #[serde(skip_serializing_if = "Option::is_none")]
709 pub thinking: Option<LLMThinkingOptions>,
710}
711
712#[derive(Clone, Debug, Serialize, Deserialize)]
714pub struct LLMThinkingOptions {
715 pub budget_tokens: u32,
717}
718
719impl LLMThinkingOptions {
720 pub fn new(budget_tokens: u32) -> Self {
721 Self {
722 budget_tokens: budget_tokens.max(1024),
723 }
724 }
725}
726
727#[derive(Clone, Debug, Serialize, Deserialize, Default)]
729pub struct LLMOpenAIOptions {
730 #[serde(skip_serializing_if = "Option::is_none")]
732 pub reasoning_effort: Option<String>,
733}
734
735#[derive(Clone, Debug, Serialize, Deserialize, Default)]
737pub struct LLMGoogleOptions {
738 #[serde(skip_serializing_if = "Option::is_none")]
740 pub thinking_budget: Option<u32>,
741}
742
743#[derive(Clone, Debug, Serialize)]
744pub struct LLMInput {
745 pub model: Model,
746 pub messages: Vec<LLMMessage>,
747 pub max_tokens: u32,
748 pub tools: Option<Vec<LLMTool>>,
749 #[serde(skip_serializing_if = "Option::is_none")]
750 pub provider_options: Option<LLMProviderOptions>,
751 #[serde(skip_serializing_if = "Option::is_none")]
753 pub headers: Option<std::collections::HashMap<String, String>>,
754}
755
756#[derive(Debug)]
757pub struct LLMStreamInput {
758 pub model: Model,
759 pub messages: Vec<LLMMessage>,
760 pub max_tokens: u32,
761 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
762 pub tools: Option<Vec<LLMTool>>,
763 pub provider_options: Option<LLMProviderOptions>,
764 pub headers: Option<std::collections::HashMap<String, String>>,
766}
767
768impl From<&LLMStreamInput> for LLMInput {
769 fn from(value: &LLMStreamInput) -> Self {
770 LLMInput {
771 model: value.model.clone(),
772 messages: value.messages.clone(),
773 max_tokens: value.max_tokens,
774 tools: value.tools.clone(),
775 provider_options: value.provider_options.clone(),
776 headers: value.headers.clone(),
777 }
778 }
779}
780
781#[derive(Serialize, Deserialize, Debug, Clone, Default)]
782pub struct LLMMessage {
783 pub role: String,
784 pub content: LLMMessageContent,
785}
786
787#[derive(Serialize, Deserialize, Debug, Clone)]
788pub struct SimpleLLMMessage {
789 #[serde(rename = "role")]
790 pub role: SimpleLLMRole,
791 pub content: String,
792}
793
794#[derive(Serialize, Deserialize, Debug, Clone)]
795#[serde(rename_all = "lowercase")]
796pub enum SimpleLLMRole {
797 User,
798 Assistant,
799}
800
801impl std::fmt::Display for SimpleLLMRole {
802 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803 match self {
804 SimpleLLMRole::User => write!(f, "user"),
805 SimpleLLMRole::Assistant => write!(f, "assistant"),
806 }
807 }
808}
809
810#[derive(Serialize, Deserialize, Debug, Clone)]
811#[serde(untagged)]
812pub enum LLMMessageContent {
813 String(String),
814 List(Vec<LLMMessageTypedContent>),
815}
816
817#[allow(clippy::to_string_trait_impl)]
818impl ToString for LLMMessageContent {
819 fn to_string(&self) -> String {
820 match self {
821 LLMMessageContent::String(s) => s.clone(),
822 LLMMessageContent::List(l) => l
823 .iter()
824 .map(|c| match c {
825 LLMMessageTypedContent::Text { text } => text.clone(),
826 LLMMessageTypedContent::ToolCall { .. } => String::new(),
827 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
828 LLMMessageTypedContent::Image { .. } => String::new(),
829 })
830 .collect::<Vec<_>>()
831 .join("\n"),
832 }
833 }
834}
835
836impl From<String> for LLMMessageContent {
837 fn from(value: String) -> Self {
838 LLMMessageContent::String(value)
839 }
840}
841
842impl Default for LLMMessageContent {
843 fn default() -> Self {
844 LLMMessageContent::String(String::new())
845 }
846}
847
848impl LLMMessageContent {
849 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
852 match self {
853 LLMMessageContent::List(parts) => parts,
854 LLMMessageContent::String(s) if s.is_empty() => vec![],
855 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
856 }
857 }
858}
859
860#[derive(Serialize, Deserialize, Debug, Clone)]
861#[serde(tag = "type")]
862pub enum LLMMessageTypedContent {
863 #[serde(rename = "text")]
864 Text { text: String },
865 #[serde(rename = "tool_use")]
866 ToolCall {
867 id: String,
868 name: String,
869 #[serde(alias = "input")]
870 args: serde_json::Value,
871 #[serde(skip_serializing_if = "Option::is_none")]
873 metadata: Option<serde_json::Value>,
874 },
875 #[serde(rename = "tool_result")]
876 ToolResult {
877 tool_use_id: String,
878 content: String,
879 },
880 #[serde(rename = "image")]
881 Image { source: LLMMessageImageSource },
882}
883
884#[derive(Serialize, Deserialize, Debug, Clone)]
885pub struct LLMMessageImageSource {
886 #[serde(rename = "type")]
887 pub r#type: String,
888 pub media_type: String,
889 pub data: String,
890}
891
892impl Default for LLMMessageTypedContent {
893 fn default() -> Self {
894 LLMMessageTypedContent::Text {
895 text: String::new(),
896 }
897 }
898}
899
900#[derive(Serialize, Deserialize, Debug, Clone)]
901pub struct LLMChoice {
902 pub finish_reason: Option<String>,
903 pub index: u32,
904 pub message: LLMMessage,
905}
906
907#[derive(Serialize, Deserialize, Debug, Clone)]
908pub struct LLMCompletionResponse {
909 pub model: String,
910 pub object: String,
911 pub choices: Vec<LLMChoice>,
912 pub created: u64,
913 pub usage: Option<LLMTokenUsage>,
914 pub id: String,
915}
916
917#[derive(Serialize, Deserialize, Debug, Clone)]
918pub struct LLMStreamDelta {
919 #[serde(skip_serializing_if = "Option::is_none")]
920 pub content: Option<String>,
921}
922
923#[derive(Serialize, Deserialize, Debug, Clone)]
924pub struct LLMStreamChoice {
925 pub finish_reason: Option<String>,
926 pub index: u32,
927 pub message: Option<LLMMessage>,
928 pub delta: LLMStreamDelta,
929}
930
931#[derive(Serialize, Deserialize, Debug, Clone)]
932pub struct LLMCompletionStreamResponse {
933 pub model: String,
934 pub object: String,
935 pub choices: Vec<LLMStreamChoice>,
936 pub created: u64,
937 #[serde(skip_serializing_if = "Option::is_none")]
938 pub usage: Option<LLMTokenUsage>,
939 pub id: String,
940 pub citations: Option<Vec<String>>,
941}
942
943#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
944pub struct LLMTool {
945 pub name: String,
946 pub description: String,
947 pub input_schema: serde_json::Value,
948}
949
950#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
951pub struct LLMTokenUsage {
952 pub prompt_tokens: u32,
953 pub completion_tokens: u32,
954 pub total_tokens: u32,
955
956 #[serde(skip_serializing_if = "Option::is_none")]
957 pub prompt_tokens_details: Option<PromptTokensDetails>,
958}
959
960#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
961#[serde(rename_all = "snake_case")]
962pub enum TokenType {
963 InputTokens,
964 OutputTokens,
965 CacheReadInputTokens,
966 CacheWriteInputTokens,
967}
968
969#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
970pub struct PromptTokensDetails {
971 #[serde(skip_serializing_if = "Option::is_none")]
972 pub input_tokens: Option<u32>,
973 #[serde(skip_serializing_if = "Option::is_none")]
974 pub output_tokens: Option<u32>,
975 #[serde(skip_serializing_if = "Option::is_none")]
976 pub cache_read_input_tokens: Option<u32>,
977 #[serde(skip_serializing_if = "Option::is_none")]
978 pub cache_write_input_tokens: Option<u32>,
979}
980
981impl PromptTokensDetails {
982 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
984 [
985 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
986 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
987 (
988 TokenType::CacheReadInputTokens,
989 self.cache_read_input_tokens.unwrap_or(0),
990 ),
991 (
992 TokenType::CacheWriteInputTokens,
993 self.cache_write_input_tokens.unwrap_or(0),
994 ),
995 ]
996 .into_iter()
997 }
998}
999
1000impl std::ops::Add for PromptTokensDetails {
1001 type Output = Self;
1002
1003 fn add(self, rhs: Self) -> Self::Output {
1004 Self {
1005 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
1006 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
1007 cache_read_input_tokens: Some(
1008 self.cache_read_input_tokens.unwrap_or(0)
1009 + rhs.cache_read_input_tokens.unwrap_or(0),
1010 ),
1011 cache_write_input_tokens: Some(
1012 self.cache_write_input_tokens.unwrap_or(0)
1013 + rhs.cache_write_input_tokens.unwrap_or(0),
1014 ),
1015 }
1016 }
1017}
1018
1019impl std::ops::AddAssign for PromptTokensDetails {
1020 fn add_assign(&mut self, rhs: Self) {
1021 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
1022 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
1023 self.cache_read_input_tokens = Some(
1024 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
1025 );
1026 self.cache_write_input_tokens = Some(
1027 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
1028 );
1029 }
1030}
1031
1032#[derive(Serialize, Deserialize, Debug, Clone)]
1033#[serde(tag = "type")]
1034pub enum GenerationDelta {
1035 Content { content: String },
1036 Thinking { thinking: String },
1037 ToolUse { tool_use: GenerationDeltaToolUse },
1038 Usage { usage: LLMTokenUsage },
1039 Metadata { metadata: serde_json::Value },
1040}
1041
1042#[derive(Serialize, Deserialize, Debug, Clone)]
1043pub struct GenerationDeltaToolUse {
1044 pub id: Option<String>,
1045 pub name: Option<String>,
1046 pub input: Option<String>,
1047 pub index: usize,
1048 #[serde(skip_serializing_if = "Option::is_none")]
1050 pub metadata: Option<serde_json::Value>,
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056
1057 #[test]
1062 fn test_provider_config_openai_serialization() {
1063 let config = ProviderConfig::OpenAI {
1064 api_key: Some("sk-test".to_string()),
1065 api_endpoint: None,
1066 auth: None,
1067 };
1068 let json = serde_json::to_string(&config).unwrap();
1069 assert!(json.contains("\"type\":\"openai\""));
1070 assert!(json.contains("\"api_key\":\"sk-test\""));
1071 assert!(!json.contains("api_endpoint")); }
1073
1074 #[test]
1075 fn test_provider_config_openai_with_endpoint() {
1076 let config = ProviderConfig::OpenAI {
1077 api_key: Some("sk-test".to_string()),
1078 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1079 auth: None,
1080 };
1081 let json = serde_json::to_string(&config).unwrap();
1082 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1083 }
1084
1085 #[test]
1086 fn test_provider_config_anthropic_serialization() {
1087 let config = ProviderConfig::Anthropic {
1088 api_key: Some("sk-ant-test".to_string()),
1089 api_endpoint: None,
1090 access_token: Some("oauth-token".to_string()),
1091 auth: None,
1092 };
1093 let json = serde_json::to_string(&config).unwrap();
1094 assert!(json.contains("\"type\":\"anthropic\""));
1095 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1096 assert!(json.contains("\"access_token\":\"oauth-token\""));
1097 }
1098
1099 #[test]
1100 fn test_provider_config_gemini_serialization() {
1101 let config = ProviderConfig::Gemini {
1102 api_key: Some("gemini-key".to_string()),
1103 api_endpoint: None,
1104 auth: None,
1105 };
1106 let json = serde_json::to_string(&config).unwrap();
1107 assert!(json.contains("\"type\":\"gemini\""));
1108 assert!(json.contains("\"api_key\":\"gemini-key\""));
1109 }
1110
1111 #[test]
1112 fn test_provider_config_custom_serialization() {
1113 let config = ProviderConfig::Custom {
1114 api_key: Some("sk-custom".to_string()),
1115 api_endpoint: "http://localhost:4000".to_string(),
1116 auth: None,
1117 };
1118 let json = serde_json::to_string(&config).unwrap();
1119 assert!(json.contains("\"type\":\"custom\""));
1120 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1121 assert!(json.contains("\"api_key\":\"sk-custom\""));
1122 }
1123
1124 #[test]
1125 fn test_provider_config_custom_without_key() {
1126 let config = ProviderConfig::Custom {
1127 api_key: None,
1128 api_endpoint: "http://localhost:11434/v1".to_string(),
1129 auth: None,
1130 };
1131 let json = serde_json::to_string(&config).unwrap();
1132 assert!(json.contains("\"type\":\"custom\""));
1133 assert!(json.contains("\"api_endpoint\""));
1134 assert!(!json.contains("api_key")); }
1136
1137 #[test]
1138 fn test_provider_config_deserialization_openai() {
1139 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1140 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1141 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1142 assert_eq!(config.api_key(), Some("sk-test"));
1143 }
1144
1145 #[test]
1146 fn test_provider_config_deserialization_anthropic() {
1147 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1148 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1149 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1150 assert_eq!(config.api_key(), Some("sk-ant"));
1151 assert_eq!(config.access_token(), Some("oauth"));
1152 }
1153
1154 #[test]
1155 fn test_provider_config_deserialization_gemini() {
1156 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1157 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1158 assert!(matches!(config, ProviderConfig::Gemini { .. }));
1159 assert_eq!(config.api_key(), Some("gemini-key"));
1160 }
1161
1162 #[test]
1163 fn test_provider_config_deserialization_custom() {
1164 let json =
1165 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1166 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1167 assert!(matches!(config, ProviderConfig::Custom { .. }));
1168 assert_eq!(config.api_key(), Some("sk-custom"));
1169 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1170 }
1171
1172 #[test]
1173 fn test_provider_config_helper_methods() {
1174 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1175 assert_eq!(openai.provider_type(), "openai");
1176 assert_eq!(openai.api_key(), Some("sk-openai"));
1177
1178 let anthropic =
1179 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1180 assert_eq!(anthropic.provider_type(), "anthropic");
1181 assert_eq!(anthropic.access_token(), Some("oauth"));
1182
1183 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1184 assert_eq!(gemini.provider_type(), "gemini");
1185
1186 let custom = ProviderConfig::custom(
1187 "http://localhost:4000".to_string(),
1188 Some("sk-custom".to_string()),
1189 );
1190 assert_eq!(custom.provider_type(), "custom");
1191 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1192 }
1193
1194 #[test]
1195 fn test_set_api_endpoint_updates_supported_providers() {
1196 let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1197 openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1198 assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1199
1200 let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1201 bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1202 assert_eq!(bedrock.api_endpoint(), None);
1203 }
1204
1205 #[test]
1206 fn test_llm_provider_config_new() {
1207 let config = LLMProviderConfig::new();
1208 assert!(config.is_empty());
1209 }
1210
1211 #[test]
1212 fn test_llm_provider_config_add_and_get() {
1213 let mut config = LLMProviderConfig::new();
1214 config.add_provider(
1215 "openai",
1216 ProviderConfig::openai(Some("sk-test".to_string())),
1217 );
1218 config.add_provider(
1219 "anthropic",
1220 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1221 );
1222
1223 assert!(!config.is_empty());
1224 assert!(config.get_provider("openai").is_some());
1225 assert!(config.get_provider("anthropic").is_some());
1226 assert!(config.get_provider("unknown").is_none());
1227 }
1228
1229 #[test]
1230 fn test_provider_config_toml_parsing() {
1231 let json = r#"{
1233 "openai": {"type": "openai", "api_key": "sk-openai"},
1234 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1235 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1236 }"#;
1237
1238 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1239 assert_eq!(providers.len(), 3);
1240
1241 assert!(matches!(
1242 providers.get("openai"),
1243 Some(ProviderConfig::OpenAI { .. })
1244 ));
1245 assert!(matches!(
1246 providers.get("anthropic"),
1247 Some(ProviderConfig::Anthropic { .. })
1248 ));
1249 assert!(matches!(
1250 providers.get("litellm"),
1251 Some(ProviderConfig::Custom { .. })
1252 ));
1253 }
1254
1255 #[test]
1260 fn test_provider_config_bedrock_serialization() {
1261 let config = ProviderConfig::Bedrock {
1262 region: "us-east-1".to_string(),
1263 profile_name: Some("my-profile".to_string()),
1264 };
1265 let json = serde_json::to_string(&config).unwrap();
1266 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1267 assert!(json.contains("\"region\":\"us-east-1\""));
1268 assert!(json.contains("\"profile_name\":\"my-profile\""));
1269 }
1270
1271 #[test]
1272 fn test_provider_config_bedrock_serialization_without_profile() {
1273 let config = ProviderConfig::Bedrock {
1274 region: "us-west-2".to_string(),
1275 profile_name: None,
1276 };
1277 let json = serde_json::to_string(&config).unwrap();
1278 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1279 assert!(json.contains("\"region\":\"us-west-2\""));
1280 assert!(!json.contains("profile_name")); }
1282
1283 #[test]
1284 fn test_provider_config_bedrock_deserialization() {
1285 let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1286 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1287 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1288 assert_eq!(config.region(), Some("us-east-1"));
1289 assert_eq!(config.profile_name(), Some("prod"));
1290 }
1291
1292 #[test]
1293 fn test_provider_config_bedrock_deserialization_minimal() {
1294 let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1295 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1296 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1297 assert_eq!(config.region(), Some("eu-west-1"));
1298 assert_eq!(config.profile_name(), None);
1299 }
1300
1301 #[test]
1302 fn test_provider_config_bedrock_no_api_key() {
1303 let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1304 assert_eq!(config.api_key(), None); assert_eq!(config.api_endpoint(), None); }
1307
1308 #[test]
1309 fn test_provider_config_bedrock_helper_methods() {
1310 let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1311 assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1312 assert_eq!(bedrock.region(), Some("us-east-1"));
1313 assert_eq!(bedrock.profile_name(), Some("prod"));
1314 assert_eq!(bedrock.api_key(), None);
1315 assert_eq!(bedrock.api_endpoint(), None);
1316 assert_eq!(bedrock.access_token(), None);
1317 }
1318
1319 #[test]
1320 fn test_provider_config_bedrock_toml_roundtrip() {
1321 let config = ProviderConfig::Bedrock {
1322 region: "us-east-1".to_string(),
1323 profile_name: Some("my-profile".to_string()),
1324 };
1325 let toml_str = toml::to_string(&config).unwrap();
1326 let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1327 assert_eq!(config, parsed);
1328 }
1329
1330 #[test]
1331 fn test_provider_config_bedrock_toml_parsing() {
1332 let toml_str = r#"
1333 type = "amazon-bedrock"
1334 region = "us-east-1"
1335 profile_name = "production"
1336 "#;
1337 let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1338 assert!(matches!(
1339 config,
1340 ProviderConfig::Bedrock {
1341 ref region,
1342 ref profile_name,
1343 } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1344 ));
1345 }
1346
1347 #[test]
1348 fn test_provider_config_bedrock_missing_region_fails() {
1349 let json = r#"{"type":"amazon-bedrock"}"#;
1350 let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1351 assert!(result.is_err()); }
1353
1354 #[test]
1355 fn test_provider_config_bedrock_in_providers_map() {
1356 let json = r#"{
1357 "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1358 "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1359 }"#;
1360 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1361 assert_eq!(providers.len(), 2);
1362 assert!(matches!(
1363 providers.get("amazon-bedrock"),
1364 Some(ProviderConfig::Bedrock { .. })
1365 ));
1366 }
1367
1368 #[test]
1369 fn test_region_returns_none_for_non_bedrock() {
1370 let openai = ProviderConfig::openai(Some("key".to_string()));
1371 assert_eq!(openai.region(), None);
1372
1373 let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1374 assert_eq!(anthropic.region(), None);
1375 }
1376
1377 #[test]
1378 fn test_profile_name_returns_none_for_non_bedrock() {
1379 let openai = ProviderConfig::openai(Some("key".to_string()));
1380 assert_eq!(openai.profile_name(), None);
1381 }
1382}