1use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102 OpenAI {
104 #[serde(skip_serializing_if = "Option::is_none")]
106 api_key: Option<String>,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 api_endpoint: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 auth: Option<ProviderAuth>,
112 },
113 Anthropic {
115 #[serde(skip_serializing_if = "Option::is_none")]
117 api_key: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 api_endpoint: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 access_token: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
125 auth: Option<ProviderAuth>,
126 },
127 Gemini {
129 #[serde(skip_serializing_if = "Option::is_none")]
131 api_key: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 api_endpoint: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 auth: Option<ProviderAuth>,
137 },
138 Custom {
155 #[serde(skip_serializing_if = "Option::is_none")]
157 api_key: Option<String>,
158 api_endpoint: String,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 auth: Option<ProviderAuth>,
164 },
165 Stakpak {
186 #[serde(skip_serializing_if = "Option::is_none")]
189 api_key: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 api_endpoint: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 auth: Option<ProviderAuth>,
196 },
197 #[serde(rename = "amazon-bedrock")]
213 Bedrock {
214 region: String,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 profile_name: Option<String>,
219 },
220 #[serde(rename = "github-copilot")]
241 GitHubCopilot {
242 #[serde(skip_serializing_if = "Option::is_none")]
244 api_endpoint: Option<String>,
245 #[serde(skip_serializing_if = "Option::is_none")]
247 auth: Option<ProviderAuth>,
248 },
249
250 #[serde(rename = "openrouter")]
251 OpenRouter {
252 #[serde(skip_serializing_if = "Option::is_none")]
254 api_key: Option<String>,
255 #[serde(skip_serializing_if = "Option::is_none")]
257 api_endpoint: Option<String>,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 auth: Option<ProviderAuth>,
261 },
262}
263
264impl ProviderConfig {
265 pub fn provider_type(&self) -> &'static str {
267 match self {
268 ProviderConfig::OpenAI { .. } => "openai",
269 ProviderConfig::Anthropic { .. } => "anthropic",
270 ProviderConfig::Gemini { .. } => "gemini",
271 ProviderConfig::Custom { .. } => "custom",
272 ProviderConfig::Stakpak { .. } => "stakpak",
273 ProviderConfig::Bedrock { .. } => "amazon-bedrock",
274 ProviderConfig::GitHubCopilot { .. } => "github-copilot",
275 ProviderConfig::OpenRouter { .. } => "openrouter",
276 }
277 }
278
279 pub fn api_key(&self) -> Option<&str> {
281 if let Some(auth) = self.get_auth_ref()
283 && let Some(key) = auth.api_key_value()
284 {
285 return Some(key);
286 }
287 match self {
289 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
290 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
291 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
292 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
293 ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
294 ProviderConfig::OpenRouter { api_key, .. } => api_key.as_deref(),
295 ProviderConfig::Bedrock { .. } => None, ProviderConfig::GitHubCopilot { .. } => None, }
298 }
299
300 fn get_auth_ref(&self) -> Option<&ProviderAuth> {
302 match self {
303 ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
304 ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
305 ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
306 ProviderConfig::Custom { auth, .. } => auth.as_ref(),
307 ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
308 ProviderConfig::Bedrock { .. } => None,
309 ProviderConfig::GitHubCopilot { auth, .. } => auth.as_ref(),
310 ProviderConfig::OpenRouter { auth, .. } => auth.as_ref(),
311 }
312 }
313
314 pub fn get_auth(&self) -> Option<ProviderAuth> {
321 if let Some(auth) = self.get_auth_ref() {
323 return Some(auth.clone());
324 }
325
326 match self {
328 ProviderConfig::OpenAI { api_key, .. }
329 | ProviderConfig::Gemini { api_key, .. }
330 | ProviderConfig::Custom { api_key, .. }
331 | ProviderConfig::Stakpak { api_key, .. }
332 | ProviderConfig::OpenRouter { api_key, .. } => {
333 api_key.as_ref().map(ProviderAuth::api_key)
334 }
335 ProviderConfig::Anthropic {
336 api_key,
337 access_token,
338 ..
339 } => {
340 if let Some(key) = api_key {
342 Some(ProviderAuth::api_key(key))
343 } else {
344 access_token
348 .as_ref()
349 .map(|token| ProviderAuth::oauth(token, "", 0))
350 }
351 }
352 ProviderConfig::Bedrock { .. } => None,
353 ProviderConfig::GitHubCopilot { .. } => None,
355 }
356 }
357
358 pub fn set_auth(&mut self, auth: ProviderAuth) {
363 match self {
364 ProviderConfig::OpenAI {
365 auth: auth_field,
366 api_key,
367 ..
368 }
369 | ProviderConfig::Gemini {
370 auth: auth_field,
371 api_key,
372 ..
373 }
374 | ProviderConfig::Custom {
375 auth: auth_field,
376 api_key,
377 ..
378 }
379 | ProviderConfig::Stakpak {
380 auth: auth_field,
381 api_key,
382 ..
383 }
384 | ProviderConfig::OpenRouter {
385 auth: auth_field,
386 api_key,
387 ..
388 } => {
389 *auth_field = Some(auth);
390 *api_key = None;
391 }
392 ProviderConfig::Anthropic {
393 auth: auth_field,
394 api_key,
395 access_token,
396 ..
397 } => {
398 *auth_field = Some(auth);
399 *api_key = None;
400 *access_token = None;
401 }
402 ProviderConfig::GitHubCopilot {
403 auth: auth_field, ..
404 } => {
405 *auth_field = Some(auth);
406 }
407 ProviderConfig::Bedrock { .. } => {
408 }
410 }
411 }
412
413 pub fn clear_auth(&mut self) {
418 match self {
419 ProviderConfig::OpenAI {
420 auth: auth_field,
421 api_key,
422 ..
423 }
424 | ProviderConfig::Gemini {
425 auth: auth_field,
426 api_key,
427 ..
428 }
429 | ProviderConfig::Custom {
430 auth: auth_field,
431 api_key,
432 ..
433 }
434 | ProviderConfig::Stakpak {
435 auth: auth_field,
436 api_key,
437 ..
438 }
439 | ProviderConfig::OpenRouter {
440 auth: auth_field,
441 api_key,
442 ..
443 } => {
444 *auth_field = None;
445 *api_key = None;
446 }
447 ProviderConfig::Anthropic {
448 auth: auth_field,
449 api_key,
450 access_token,
451 ..
452 } => {
453 *auth_field = None;
454 *api_key = None;
455 *access_token = None;
456 }
457 ProviderConfig::GitHubCopilot {
458 auth: auth_field, ..
459 } => {
460 *auth_field = None;
461 }
462 ProviderConfig::Bedrock { .. } => {
463 }
465 }
466 }
467
468 pub fn api_endpoint(&self) -> Option<&str> {
470 match self {
471 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
472 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
473 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
474 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
475 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
476 ProviderConfig::OpenRouter { api_endpoint, .. } => api_endpoint.as_deref(),
477 ProviderConfig::Bedrock { .. } => None, ProviderConfig::GitHubCopilot { api_endpoint, .. } => api_endpoint.as_deref(),
479 }
480 }
481
482 pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
487 match self {
488 ProviderConfig::OpenAI { api_endpoint, .. }
489 | ProviderConfig::Anthropic { api_endpoint, .. }
490 | ProviderConfig::Gemini { api_endpoint, .. }
491 | ProviderConfig::Stakpak { api_endpoint, .. }
492 | ProviderConfig::GitHubCopilot { api_endpoint, .. }
493 | ProviderConfig::OpenRouter { api_endpoint, .. } => {
494 *api_endpoint = endpoint;
495 }
496 ProviderConfig::Custom { api_endpoint, .. } => {
497 if let Some(custom_endpoint) = endpoint {
498 *api_endpoint = custom_endpoint;
499 }
500 }
501 ProviderConfig::Bedrock { .. } => {}
502 }
503 }
504
505 pub fn access_token(&self) -> Option<&str> {
510 if let Some(auth) = self.get_auth_ref()
512 && let Some(token) = auth.access_token()
513 {
514 return Some(token);
515 }
516 match self {
518 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
519 _ => None,
520 }
521 }
522
523 pub fn openai(api_key: Option<String>) -> Self {
525 ProviderConfig::OpenAI {
526 api_key,
527 api_endpoint: None,
528 auth: None,
529 }
530 }
531
532 pub fn openai_with_auth(auth: ProviderAuth) -> Self {
534 ProviderConfig::OpenAI {
535 api_key: None,
536 api_endpoint: None,
537 auth: Some(auth),
538 }
539 }
540
541 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
543 ProviderConfig::Anthropic {
544 api_key,
545 api_endpoint: None,
546 access_token,
547 auth: None,
548 }
549 }
550
551 pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
553 ProviderConfig::Anthropic {
554 api_key: None,
555 api_endpoint: None,
556 access_token: None,
557 auth: Some(auth),
558 }
559 }
560
561 pub fn gemini(api_key: Option<String>) -> Self {
563 ProviderConfig::Gemini {
564 api_key,
565 api_endpoint: None,
566 auth: None,
567 }
568 }
569
570 pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
572 ProviderConfig::Gemini {
573 api_key: None,
574 api_endpoint: None,
575 auth: Some(auth),
576 }
577 }
578
579 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
581 ProviderConfig::Custom {
582 api_key,
583 api_endpoint,
584 auth: None,
585 }
586 }
587
588 pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
590 ProviderConfig::Custom {
591 api_key: None,
592 api_endpoint,
593 auth: Some(auth),
594 }
595 }
596
597 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
599 ProviderConfig::Stakpak {
600 api_key: Some(api_key),
601 api_endpoint,
602 auth: None,
603 }
604 }
605
606 pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
608 ProviderConfig::Stakpak {
609 api_key: None,
610 api_endpoint,
611 auth: Some(auth),
612 }
613 }
614
615 pub fn openrouter(api_key: Option<String>, api_endpoint: Option<String>) -> Self {
617 ProviderConfig::OpenRouter {
618 api_key,
619 api_endpoint,
620 auth: None,
621 }
622 }
623
624 pub fn openrouter_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
626 ProviderConfig::OpenRouter {
627 api_key: None,
628 api_endpoint,
629 auth: Some(auth),
630 }
631 }
632
633 pub fn github_copilot_with_auth(auth: ProviderAuth) -> Self {
635 ProviderConfig::GitHubCopilot {
636 api_endpoint: None,
637 auth: Some(auth),
638 }
639 }
640
641 pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
643 ProviderConfig::Bedrock {
644 region,
645 profile_name,
646 }
647 }
648
649 pub fn region(&self) -> Option<&str> {
651 match self {
652 ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
653 _ => None,
654 }
655 }
656
657 pub fn profile_name(&self) -> Option<&str> {
659 match self {
660 ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
661 _ => None,
662 }
663 }
664
665 pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
670 match provider_name {
671 "openai" => Some(ProviderConfig::OpenAI {
672 api_key: None,
673 api_endpoint: None,
674 auth: None,
675 }),
676 "anthropic" => Some(ProviderConfig::Anthropic {
677 api_key: None,
678 api_endpoint: None,
679 access_token: None,
680 auth: None,
681 }),
682 "gemini" => Some(ProviderConfig::Gemini {
683 api_key: None,
684 api_endpoint: None,
685 auth: None,
686 }),
687 "stakpak" => Some(ProviderConfig::Stakpak {
688 api_key: None,
689 api_endpoint: None,
690 auth: None,
691 }),
692 "github-copilot" => Some(ProviderConfig::GitHubCopilot {
693 api_endpoint: None,
694 auth: None,
695 }),
696 "openrouter" => Some(ProviderConfig::OpenRouter {
697 api_key: None,
698 api_endpoint: None,
699 auth: None,
700 }),
701 _ => None,
703 }
704 }
705}
706
707#[derive(Debug, Clone, Default)]
711pub struct LLMProviderConfig {
712 pub providers: HashMap<String, ProviderConfig>,
714}
715
716impl LLMProviderConfig {
717 pub fn new() -> Self {
719 Self {
720 providers: HashMap::new(),
721 }
722 }
723
724 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
726 self.providers.insert(name.into(), config);
727 }
728
729 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
731 self.providers.get(name)
732 }
733
734 pub fn is_empty(&self) -> bool {
736 self.providers.is_empty()
737 }
738}
739
740#[derive(Clone, Debug, Serialize, Deserialize, Default)]
742pub struct LLMProviderOptions {
743 #[serde(skip_serializing_if = "Option::is_none")]
745 pub anthropic: Option<LLMAnthropicOptions>,
746
747 #[serde(skip_serializing_if = "Option::is_none")]
749 pub openai: Option<LLMOpenAIOptions>,
750
751 #[serde(skip_serializing_if = "Option::is_none")]
753 pub google: Option<LLMGoogleOptions>,
754}
755
756#[derive(Clone, Debug, Serialize, Deserialize, Default)]
758pub struct LLMAnthropicOptions {
759 #[serde(skip_serializing_if = "Option::is_none")]
761 pub thinking: Option<LLMThinkingOptions>,
762}
763
764#[derive(Clone, Debug, Serialize, Deserialize)]
766pub struct LLMThinkingOptions {
767 pub budget_tokens: u32,
769}
770
771impl LLMThinkingOptions {
772 pub fn new(budget_tokens: u32) -> Self {
773 Self {
774 budget_tokens: budget_tokens.max(1024),
775 }
776 }
777}
778
779#[derive(Clone, Debug, Serialize, Deserialize, Default)]
781pub struct LLMOpenAIOptions {
782 #[serde(skip_serializing_if = "Option::is_none")]
784 pub reasoning_effort: Option<String>,
785}
786
787#[derive(Clone, Debug, Serialize, Deserialize, Default)]
789pub struct LLMGoogleOptions {
790 #[serde(skip_serializing_if = "Option::is_none")]
792 pub thinking_budget: Option<u32>,
793}
794
795#[derive(Clone, Debug, Serialize)]
796pub struct LLMInput {
797 pub model: Model,
798 pub messages: Vec<LLMMessage>,
799 pub max_tokens: u32,
800 pub tools: Option<Vec<LLMTool>>,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 pub provider_options: Option<LLMProviderOptions>,
803 #[serde(skip_serializing_if = "Option::is_none")]
805 pub headers: Option<std::collections::HashMap<String, String>>,
806}
807
808#[derive(Debug)]
809pub struct LLMStreamInput {
810 pub model: Model,
811 pub messages: Vec<LLMMessage>,
812 pub max_tokens: u32,
813 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
814 pub tools: Option<Vec<LLMTool>>,
815 pub provider_options: Option<LLMProviderOptions>,
816 pub headers: Option<std::collections::HashMap<String, String>>,
818}
819
820impl From<&LLMStreamInput> for LLMInput {
821 fn from(value: &LLMStreamInput) -> Self {
822 LLMInput {
823 model: value.model.clone(),
824 messages: value.messages.clone(),
825 max_tokens: value.max_tokens,
826 tools: value.tools.clone(),
827 provider_options: value.provider_options.clone(),
828 headers: value.headers.clone(),
829 }
830 }
831}
832
833#[derive(Serialize, Deserialize, Debug, Clone, Default)]
834pub struct LLMMessage {
835 pub role: String,
836 pub content: LLMMessageContent,
837}
838
839#[derive(Serialize, Deserialize, Debug, Clone)]
840pub struct SimpleLLMMessage {
841 #[serde(rename = "role")]
842 pub role: SimpleLLMRole,
843 pub content: String,
844}
845
846#[derive(Serialize, Deserialize, Debug, Clone)]
847#[serde(rename_all = "lowercase")]
848pub enum SimpleLLMRole {
849 User,
850 Assistant,
851}
852
853impl std::fmt::Display for SimpleLLMRole {
854 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855 match self {
856 SimpleLLMRole::User => write!(f, "user"),
857 SimpleLLMRole::Assistant => write!(f, "assistant"),
858 }
859 }
860}
861
862#[derive(Serialize, Deserialize, Debug, Clone)]
863#[serde(untagged)]
864pub enum LLMMessageContent {
865 String(String),
866 List(Vec<LLMMessageTypedContent>),
867}
868
869#[allow(clippy::to_string_trait_impl)]
870impl ToString for LLMMessageContent {
871 fn to_string(&self) -> String {
872 match self {
873 LLMMessageContent::String(s) => s.clone(),
874 LLMMessageContent::List(l) => l
875 .iter()
876 .map(|c| match c {
877 LLMMessageTypedContent::Text { text } => text.clone(),
878 LLMMessageTypedContent::ToolCall { .. } => String::new(),
879 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
880 LLMMessageTypedContent::Image { .. } => String::new(),
881 })
882 .collect::<Vec<_>>()
883 .join("\n"),
884 }
885 }
886}
887
888impl From<String> for LLMMessageContent {
889 fn from(value: String) -> Self {
890 LLMMessageContent::String(value)
891 }
892}
893
894impl Default for LLMMessageContent {
895 fn default() -> Self {
896 LLMMessageContent::String(String::new())
897 }
898}
899
900impl LLMMessageContent {
901 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
904 match self {
905 LLMMessageContent::List(parts) => parts,
906 LLMMessageContent::String(s) if s.is_empty() => vec![],
907 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
908 }
909 }
910}
911
912#[derive(Serialize, Deserialize, Debug, Clone)]
913#[serde(tag = "type")]
914pub enum LLMMessageTypedContent {
915 #[serde(rename = "text")]
916 Text { text: String },
917 #[serde(rename = "tool_use")]
918 ToolCall {
919 id: String,
920 name: String,
921 #[serde(alias = "input")]
922 args: serde_json::Value,
923 #[serde(skip_serializing_if = "Option::is_none")]
925 metadata: Option<serde_json::Value>,
926 },
927 #[serde(rename = "tool_result")]
928 ToolResult {
929 tool_use_id: String,
930 content: String,
931 },
932 #[serde(rename = "image")]
933 Image { source: LLMMessageImageSource },
934}
935
936#[derive(Serialize, Deserialize, Debug, Clone)]
937pub struct LLMMessageImageSource {
938 #[serde(rename = "type")]
939 pub r#type: String,
940 pub media_type: String,
941 pub data: String,
942}
943
944impl Default for LLMMessageTypedContent {
945 fn default() -> Self {
946 LLMMessageTypedContent::Text {
947 text: String::new(),
948 }
949 }
950}
951
952#[derive(Serialize, Deserialize, Debug, Clone)]
953pub struct LLMChoice {
954 pub finish_reason: Option<String>,
955 pub index: u32,
956 pub message: LLMMessage,
957}
958
959#[derive(Serialize, Deserialize, Debug, Clone)]
960pub struct LLMCompletionResponse {
961 pub model: String,
962 pub object: String,
963 pub choices: Vec<LLMChoice>,
964 pub created: u64,
965 pub usage: Option<LLMTokenUsage>,
966 pub id: String,
967}
968
969#[derive(Serialize, Deserialize, Debug, Clone)]
970pub struct LLMStreamDelta {
971 #[serde(skip_serializing_if = "Option::is_none")]
972 pub content: Option<String>,
973}
974
975#[derive(Serialize, Deserialize, Debug, Clone)]
976pub struct LLMStreamChoice {
977 pub finish_reason: Option<String>,
978 pub index: u32,
979 pub message: Option<LLMMessage>,
980 pub delta: LLMStreamDelta,
981}
982
983#[derive(Serialize, Deserialize, Debug, Clone)]
984pub struct LLMCompletionStreamResponse {
985 pub model: String,
986 pub object: String,
987 pub choices: Vec<LLMStreamChoice>,
988 pub created: u64,
989 #[serde(skip_serializing_if = "Option::is_none")]
990 pub usage: Option<LLMTokenUsage>,
991 pub id: String,
992 pub citations: Option<Vec<String>>,
993}
994
995#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
996pub struct LLMTool {
997 pub name: String,
998 pub description: String,
999 pub input_schema: serde_json::Value,
1000}
1001
1002#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
1003pub struct LLMTokenUsage {
1004 pub prompt_tokens: u32,
1005 pub completion_tokens: u32,
1006 pub total_tokens: u32,
1007
1008 #[serde(skip_serializing_if = "Option::is_none")]
1009 pub prompt_tokens_details: Option<PromptTokensDetails>,
1010}
1011
1012#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
1013#[serde(rename_all = "snake_case")]
1014pub enum TokenType {
1015 InputTokens,
1016 OutputTokens,
1017 CacheReadInputTokens,
1018 CacheWriteInputTokens,
1019}
1020
1021#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
1022pub struct PromptTokensDetails {
1023 #[serde(skip_serializing_if = "Option::is_none")]
1024 pub input_tokens: Option<u32>,
1025 #[serde(skip_serializing_if = "Option::is_none")]
1026 pub output_tokens: Option<u32>,
1027 #[serde(skip_serializing_if = "Option::is_none")]
1028 pub cache_read_input_tokens: Option<u32>,
1029 #[serde(skip_serializing_if = "Option::is_none")]
1030 pub cache_write_input_tokens: Option<u32>,
1031}
1032
1033impl PromptTokensDetails {
1034 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
1036 [
1037 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
1038 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
1039 (
1040 TokenType::CacheReadInputTokens,
1041 self.cache_read_input_tokens.unwrap_or(0),
1042 ),
1043 (
1044 TokenType::CacheWriteInputTokens,
1045 self.cache_write_input_tokens.unwrap_or(0),
1046 ),
1047 ]
1048 .into_iter()
1049 }
1050}
1051
1052impl std::ops::Add for PromptTokensDetails {
1053 type Output = Self;
1054
1055 fn add(self, rhs: Self) -> Self::Output {
1056 Self {
1057 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
1058 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
1059 cache_read_input_tokens: Some(
1060 self.cache_read_input_tokens.unwrap_or(0)
1061 + rhs.cache_read_input_tokens.unwrap_or(0),
1062 ),
1063 cache_write_input_tokens: Some(
1064 self.cache_write_input_tokens.unwrap_or(0)
1065 + rhs.cache_write_input_tokens.unwrap_or(0),
1066 ),
1067 }
1068 }
1069}
1070
1071impl std::ops::AddAssign for PromptTokensDetails {
1072 fn add_assign(&mut self, rhs: Self) {
1073 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
1074 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
1075 self.cache_read_input_tokens = Some(
1076 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
1077 );
1078 self.cache_write_input_tokens = Some(
1079 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
1080 );
1081 }
1082}
1083
1084#[derive(Serialize, Deserialize, Debug, Clone)]
1085#[serde(tag = "type")]
1086pub enum GenerationDelta {
1087 Content { content: String },
1088 Thinking { thinking: String },
1089 ToolUse { tool_use: GenerationDeltaToolUse },
1090 Usage { usage: LLMTokenUsage },
1091 Metadata { metadata: serde_json::Value },
1092}
1093
1094#[derive(Serialize, Deserialize, Debug, Clone)]
1095pub struct GenerationDeltaToolUse {
1096 pub id: Option<String>,
1097 pub name: Option<String>,
1098 pub input: Option<String>,
1099 pub index: usize,
1100 #[serde(skip_serializing_if = "Option::is_none")]
1102 pub metadata: Option<serde_json::Value>,
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108
1109 #[test]
1114 fn test_provider_config_openai_serialization() {
1115 let config = ProviderConfig::OpenAI {
1116 api_key: Some("sk-test".to_string()),
1117 api_endpoint: None,
1118 auth: None,
1119 };
1120 let json = serde_json::to_string(&config).unwrap();
1121 assert!(json.contains("\"type\":\"openai\""));
1122 assert!(json.contains("\"api_key\":\"sk-test\""));
1123 assert!(!json.contains("api_endpoint")); }
1125
1126 #[test]
1127 fn test_provider_config_openai_with_endpoint() {
1128 let config = ProviderConfig::OpenAI {
1129 api_key: Some("sk-test".to_string()),
1130 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1131 auth: None,
1132 };
1133 let json = serde_json::to_string(&config).unwrap();
1134 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1135 }
1136
1137 #[test]
1138 fn test_provider_config_anthropic_serialization() {
1139 let config = ProviderConfig::Anthropic {
1140 api_key: Some("sk-ant-test".to_string()),
1141 api_endpoint: None,
1142 access_token: Some("oauth-token".to_string()),
1143 auth: None,
1144 };
1145 let json = serde_json::to_string(&config).unwrap();
1146 assert!(json.contains("\"type\":\"anthropic\""));
1147 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1148 assert!(json.contains("\"access_token\":\"oauth-token\""));
1149 }
1150
1151 #[test]
1152 fn test_provider_config_gemini_serialization() {
1153 let config = ProviderConfig::Gemini {
1154 api_key: Some("gemini-key".to_string()),
1155 api_endpoint: None,
1156 auth: None,
1157 };
1158 let json = serde_json::to_string(&config).unwrap();
1159 assert!(json.contains("\"type\":\"gemini\""));
1160 assert!(json.contains("\"api_key\":\"gemini-key\""));
1161 }
1162
1163 #[test]
1164 fn test_provider_config_custom_serialization() {
1165 let config = ProviderConfig::Custom {
1166 api_key: Some("sk-custom".to_string()),
1167 api_endpoint: "http://localhost:4000".to_string(),
1168 auth: None,
1169 };
1170 let json = serde_json::to_string(&config).unwrap();
1171 assert!(json.contains("\"type\":\"custom\""));
1172 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1173 assert!(json.contains("\"api_key\":\"sk-custom\""));
1174 }
1175
1176 #[test]
1177 fn test_provider_config_custom_without_key() {
1178 let config = ProviderConfig::Custom {
1179 api_key: None,
1180 api_endpoint: "http://localhost:11434/v1".to_string(),
1181 auth: None,
1182 };
1183 let json = serde_json::to_string(&config).unwrap();
1184 assert!(json.contains("\"type\":\"custom\""));
1185 assert!(json.contains("\"api_endpoint\""));
1186 assert!(!json.contains("api_key")); }
1188
1189 #[test]
1190 fn test_provider_config_deserialization_openai() {
1191 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1192 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1193 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1194 assert_eq!(config.api_key(), Some("sk-test"));
1195 }
1196
1197 #[test]
1198 fn test_provider_config_deserialization_anthropic() {
1199 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1200 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1201 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1202 assert_eq!(config.api_key(), Some("sk-ant"));
1203 assert_eq!(config.access_token(), Some("oauth"));
1204 }
1205
1206 #[test]
1207 fn test_provider_config_deserialization_gemini() {
1208 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1209 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1210 assert!(matches!(config, ProviderConfig::Gemini { .. }));
1211 assert_eq!(config.api_key(), Some("gemini-key"));
1212 }
1213
1214 #[test]
1215 fn test_provider_config_deserialization_custom() {
1216 let json =
1217 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1218 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1219 assert!(matches!(config, ProviderConfig::Custom { .. }));
1220 assert_eq!(config.api_key(), Some("sk-custom"));
1221 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1222 }
1223
1224 #[test]
1225 fn test_provider_config_helper_methods() {
1226 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1227 assert_eq!(openai.provider_type(), "openai");
1228 assert_eq!(openai.api_key(), Some("sk-openai"));
1229
1230 let anthropic =
1231 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1232 assert_eq!(anthropic.provider_type(), "anthropic");
1233 assert_eq!(anthropic.access_token(), Some("oauth"));
1234
1235 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1236 assert_eq!(gemini.provider_type(), "gemini");
1237
1238 let custom = ProviderConfig::custom(
1239 "http://localhost:4000".to_string(),
1240 Some("sk-custom".to_string()),
1241 );
1242 assert_eq!(custom.provider_type(), "custom");
1243 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1244 }
1245
1246 #[test]
1247 fn test_set_api_endpoint_updates_supported_providers() {
1248 let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1249 openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1250 assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1251
1252 let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1253 bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1254 assert_eq!(bedrock.api_endpoint(), None);
1255 }
1256
1257 #[test]
1258 fn test_llm_provider_config_new() {
1259 let config = LLMProviderConfig::new();
1260 assert!(config.is_empty());
1261 }
1262
1263 #[test]
1264 fn test_llm_provider_config_add_and_get() {
1265 let mut config = LLMProviderConfig::new();
1266 config.add_provider(
1267 "openai",
1268 ProviderConfig::openai(Some("sk-test".to_string())),
1269 );
1270 config.add_provider(
1271 "anthropic",
1272 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1273 );
1274
1275 assert!(!config.is_empty());
1276 assert!(config.get_provider("openai").is_some());
1277 assert!(config.get_provider("anthropic").is_some());
1278 assert!(config.get_provider("unknown").is_none());
1279 }
1280
1281 #[test]
1282 fn test_provider_config_toml_parsing() {
1283 let json = r#"{
1285 "openai": {"type": "openai", "api_key": "sk-openai"},
1286 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1287 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1288 }"#;
1289
1290 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1291 assert_eq!(providers.len(), 3);
1292
1293 assert!(matches!(
1294 providers.get("openai"),
1295 Some(ProviderConfig::OpenAI { .. })
1296 ));
1297 assert!(matches!(
1298 providers.get("anthropic"),
1299 Some(ProviderConfig::Anthropic { .. })
1300 ));
1301 assert!(matches!(
1302 providers.get("litellm"),
1303 Some(ProviderConfig::Custom { .. })
1304 ));
1305 }
1306
1307 #[test]
1312 fn test_provider_config_bedrock_serialization() {
1313 let config = ProviderConfig::Bedrock {
1314 region: "us-east-1".to_string(),
1315 profile_name: Some("my-profile".to_string()),
1316 };
1317 let json = serde_json::to_string(&config).unwrap();
1318 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1319 assert!(json.contains("\"region\":\"us-east-1\""));
1320 assert!(json.contains("\"profile_name\":\"my-profile\""));
1321 }
1322
1323 #[test]
1324 fn test_provider_config_bedrock_serialization_without_profile() {
1325 let config = ProviderConfig::Bedrock {
1326 region: "us-west-2".to_string(),
1327 profile_name: None,
1328 };
1329 let json = serde_json::to_string(&config).unwrap();
1330 assert!(json.contains("\"type\":\"amazon-bedrock\""));
1331 assert!(json.contains("\"region\":\"us-west-2\""));
1332 assert!(!json.contains("profile_name")); }
1334
1335 #[test]
1336 fn test_provider_config_bedrock_deserialization() {
1337 let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1338 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1339 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1340 assert_eq!(config.region(), Some("us-east-1"));
1341 assert_eq!(config.profile_name(), Some("prod"));
1342 }
1343
1344 #[test]
1345 fn test_provider_config_bedrock_deserialization_minimal() {
1346 let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1347 let config: ProviderConfig = serde_json::from_str(json).unwrap();
1348 assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1349 assert_eq!(config.region(), Some("eu-west-1"));
1350 assert_eq!(config.profile_name(), None);
1351 }
1352
1353 #[test]
1354 fn test_provider_config_bedrock_no_api_key() {
1355 let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1356 assert_eq!(config.api_key(), None); assert_eq!(config.api_endpoint(), None); }
1359
1360 #[test]
1361 fn test_provider_config_bedrock_helper_methods() {
1362 let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1363 assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1364 assert_eq!(bedrock.region(), Some("us-east-1"));
1365 assert_eq!(bedrock.profile_name(), Some("prod"));
1366 assert_eq!(bedrock.api_key(), None);
1367 assert_eq!(bedrock.api_endpoint(), None);
1368 assert_eq!(bedrock.access_token(), None);
1369 }
1370
1371 #[test]
1372 fn test_provider_config_bedrock_toml_roundtrip() {
1373 let config = ProviderConfig::Bedrock {
1374 region: "us-east-1".to_string(),
1375 profile_name: Some("my-profile".to_string()),
1376 };
1377 let toml_str = toml::to_string(&config).unwrap();
1378 let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1379 assert_eq!(config, parsed);
1380 }
1381
1382 #[test]
1383 fn test_provider_config_bedrock_toml_parsing() {
1384 let toml_str = r#"
1385 type = "amazon-bedrock"
1386 region = "us-east-1"
1387 profile_name = "production"
1388 "#;
1389 let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1390 assert!(matches!(
1391 config,
1392 ProviderConfig::Bedrock {
1393 ref region,
1394 ref profile_name,
1395 } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1396 ));
1397 }
1398
1399 #[test]
1400 fn test_provider_config_bedrock_missing_region_fails() {
1401 let json = r#"{"type":"amazon-bedrock"}"#;
1402 let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1403 assert!(result.is_err()); }
1405
1406 #[test]
1407 fn test_provider_config_bedrock_in_providers_map() {
1408 let json = r#"{
1409 "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1410 "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1411 }"#;
1412 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1413 assert_eq!(providers.len(), 2);
1414 assert!(matches!(
1415 providers.get("amazon-bedrock"),
1416 Some(ProviderConfig::Bedrock { .. })
1417 ));
1418 }
1419
1420 #[test]
1421 fn test_region_returns_none_for_non_bedrock() {
1422 let openai = ProviderConfig::openai(Some("key".to_string()));
1423 assert_eq!(openai.region(), None);
1424
1425 let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1426 assert_eq!(anthropic.region(), None);
1427 }
1428
1429 #[test]
1430 fn test_profile_name_returns_none_for_non_bedrock() {
1431 let openai = ProviderConfig::openai(Some("key".to_string()));
1432 assert_eq!(openai.profile_name(), None);
1433 }
1434}