1use serde::{Deserialize, Serialize};
54use stakai::Model;
55use std::collections::HashMap;
56
57#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
89#[serde(tag = "type", rename_all = "lowercase")]
90pub enum ProviderConfig {
91 OpenAI {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 api_key: Option<String>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 api_endpoint: Option<String>,
97 },
98 Anthropic {
100 #[serde(skip_serializing_if = "Option::is_none")]
101 api_key: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 api_endpoint: Option<String>,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 access_token: Option<String>,
107 },
108 Gemini {
110 #[serde(skip_serializing_if = "Option::is_none")]
111 api_key: Option<String>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 api_endpoint: Option<String>,
114 },
115 Custom {
133 #[serde(skip_serializing_if = "Option::is_none")]
134 api_key: Option<String>,
135 api_endpoint: String,
138 },
139 Stakpak {
157 api_key: String,
159 #[serde(skip_serializing_if = "Option::is_none")]
161 api_endpoint: Option<String>,
162 },
163}
164
165impl ProviderConfig {
166 pub fn provider_type(&self) -> &'static str {
168 match self {
169 ProviderConfig::OpenAI { .. } => "openai",
170 ProviderConfig::Anthropic { .. } => "anthropic",
171 ProviderConfig::Gemini { .. } => "gemini",
172 ProviderConfig::Custom { .. } => "custom",
173 ProviderConfig::Stakpak { .. } => "stakpak",
174 }
175 }
176
177 pub fn api_key(&self) -> Option<&str> {
179 match self {
180 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
181 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
182 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
183 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
184 ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
185 }
186 }
187
188 pub fn api_endpoint(&self) -> Option<&str> {
190 match self {
191 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
192 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
193 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
194 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
195 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
196 }
197 }
198
199 pub fn access_token(&self) -> Option<&str> {
201 match self {
202 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
203 _ => None,
204 }
205 }
206
207 pub fn openai(api_key: Option<String>) -> Self {
209 ProviderConfig::OpenAI {
210 api_key,
211 api_endpoint: None,
212 }
213 }
214
215 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
217 ProviderConfig::Anthropic {
218 api_key,
219 api_endpoint: None,
220 access_token,
221 }
222 }
223
224 pub fn gemini(api_key: Option<String>) -> Self {
226 ProviderConfig::Gemini {
227 api_key,
228 api_endpoint: None,
229 }
230 }
231
232 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
234 ProviderConfig::Custom {
235 api_key,
236 api_endpoint,
237 }
238 }
239
240 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
242 ProviderConfig::Stakpak {
243 api_key,
244 api_endpoint,
245 }
246 }
247}
248
249#[derive(Debug, Clone, Default)]
253pub struct LLMProviderConfig {
254 pub providers: HashMap<String, ProviderConfig>,
256}
257
258impl LLMProviderConfig {
259 pub fn new() -> Self {
261 Self {
262 providers: HashMap::new(),
263 }
264 }
265
266 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
268 self.providers.insert(name.into(), config);
269 }
270
271 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
273 self.providers.get(name)
274 }
275
276 pub fn is_empty(&self) -> bool {
278 self.providers.is_empty()
279 }
280}
281
282#[derive(Clone, Debug, Serialize, Deserialize, Default)]
284pub struct LLMProviderOptions {
285 #[serde(skip_serializing_if = "Option::is_none")]
287 pub anthropic: Option<LLMAnthropicOptions>,
288
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub openai: Option<LLMOpenAIOptions>,
292
293 #[serde(skip_serializing_if = "Option::is_none")]
295 pub google: Option<LLMGoogleOptions>,
296}
297
298#[derive(Clone, Debug, Serialize, Deserialize, Default)]
300pub struct LLMAnthropicOptions {
301 #[serde(skip_serializing_if = "Option::is_none")]
303 pub thinking: Option<LLMThinkingOptions>,
304}
305
306#[derive(Clone, Debug, Serialize, Deserialize)]
308pub struct LLMThinkingOptions {
309 pub budget_tokens: u32,
311}
312
313impl LLMThinkingOptions {
314 pub fn new(budget_tokens: u32) -> Self {
315 Self {
316 budget_tokens: budget_tokens.max(1024),
317 }
318 }
319}
320
321#[derive(Clone, Debug, Serialize, Deserialize, Default)]
323pub struct LLMOpenAIOptions {
324 #[serde(skip_serializing_if = "Option::is_none")]
326 pub reasoning_effort: Option<String>,
327}
328
329#[derive(Clone, Debug, Serialize, Deserialize, Default)]
331pub struct LLMGoogleOptions {
332 #[serde(skip_serializing_if = "Option::is_none")]
334 pub thinking_budget: Option<u32>,
335}
336
337#[derive(Clone, Debug, Serialize)]
338pub struct LLMInput {
339 pub model: Model,
340 pub messages: Vec<LLMMessage>,
341 pub max_tokens: u32,
342 pub tools: Option<Vec<LLMTool>>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub provider_options: Option<LLMProviderOptions>,
345 #[serde(skip_serializing_if = "Option::is_none")]
347 pub headers: Option<std::collections::HashMap<String, String>>,
348}
349
350#[derive(Debug)]
351pub struct LLMStreamInput {
352 pub model: Model,
353 pub messages: Vec<LLMMessage>,
354 pub max_tokens: u32,
355 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
356 pub tools: Option<Vec<LLMTool>>,
357 pub provider_options: Option<LLMProviderOptions>,
358 pub headers: Option<std::collections::HashMap<String, String>>,
360}
361
362impl From<&LLMStreamInput> for LLMInput {
363 fn from(value: &LLMStreamInput) -> Self {
364 LLMInput {
365 model: value.model.clone(),
366 messages: value.messages.clone(),
367 max_tokens: value.max_tokens,
368 tools: value.tools.clone(),
369 provider_options: value.provider_options.clone(),
370 headers: value.headers.clone(),
371 }
372 }
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone, Default)]
376pub struct LLMMessage {
377 pub role: String,
378 pub content: LLMMessageContent,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct SimpleLLMMessage {
383 #[serde(rename = "role")]
384 pub role: SimpleLLMRole,
385 pub content: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "lowercase")]
390pub enum SimpleLLMRole {
391 User,
392 Assistant,
393}
394
395impl std::fmt::Display for SimpleLLMRole {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 match self {
398 SimpleLLMRole::User => write!(f, "user"),
399 SimpleLLMRole::Assistant => write!(f, "assistant"),
400 }
401 }
402}
403
404#[derive(Serialize, Deserialize, Debug, Clone)]
405#[serde(untagged)]
406pub enum LLMMessageContent {
407 String(String),
408 List(Vec<LLMMessageTypedContent>),
409}
410
411#[allow(clippy::to_string_trait_impl)]
412impl ToString for LLMMessageContent {
413 fn to_string(&self) -> String {
414 match self {
415 LLMMessageContent::String(s) => s.clone(),
416 LLMMessageContent::List(l) => l
417 .iter()
418 .map(|c| match c {
419 LLMMessageTypedContent::Text { text } => text.clone(),
420 LLMMessageTypedContent::ToolCall { .. } => String::new(),
421 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
422 LLMMessageTypedContent::Image { .. } => String::new(),
423 })
424 .collect::<Vec<_>>()
425 .join("\n"),
426 }
427 }
428}
429
430impl From<String> for LLMMessageContent {
431 fn from(value: String) -> Self {
432 LLMMessageContent::String(value)
433 }
434}
435
436impl Default for LLMMessageContent {
437 fn default() -> Self {
438 LLMMessageContent::String(String::new())
439 }
440}
441
442#[derive(Serialize, Deserialize, Debug, Clone)]
443#[serde(tag = "type")]
444pub enum LLMMessageTypedContent {
445 #[serde(rename = "text")]
446 Text { text: String },
447 #[serde(rename = "tool_use")]
448 ToolCall {
449 id: String,
450 name: String,
451 #[serde(alias = "input")]
452 args: serde_json::Value,
453 #[serde(skip_serializing_if = "Option::is_none")]
455 metadata: Option<serde_json::Value>,
456 },
457 #[serde(rename = "tool_result")]
458 ToolResult {
459 tool_use_id: String,
460 content: String,
461 },
462 #[serde(rename = "image")]
463 Image { source: LLMMessageImageSource },
464}
465
466#[derive(Serialize, Deserialize, Debug, Clone)]
467pub struct LLMMessageImageSource {
468 #[serde(rename = "type")]
469 pub r#type: String,
470 pub media_type: String,
471 pub data: String,
472}
473
474impl Default for LLMMessageTypedContent {
475 fn default() -> Self {
476 LLMMessageTypedContent::Text {
477 text: String::new(),
478 }
479 }
480}
481
482#[derive(Serialize, Deserialize, Debug, Clone)]
483pub struct LLMChoice {
484 pub finish_reason: Option<String>,
485 pub index: u32,
486 pub message: LLMMessage,
487}
488
489#[derive(Serialize, Deserialize, Debug, Clone)]
490pub struct LLMCompletionResponse {
491 pub model: String,
492 pub object: String,
493 pub choices: Vec<LLMChoice>,
494 pub created: u64,
495 pub usage: Option<LLMTokenUsage>,
496 pub id: String,
497}
498
499#[derive(Serialize, Deserialize, Debug, Clone)]
500pub struct LLMStreamDelta {
501 #[serde(skip_serializing_if = "Option::is_none")]
502 pub content: Option<String>,
503}
504
505#[derive(Serialize, Deserialize, Debug, Clone)]
506pub struct LLMStreamChoice {
507 pub finish_reason: Option<String>,
508 pub index: u32,
509 pub message: Option<LLMMessage>,
510 pub delta: LLMStreamDelta,
511}
512
513#[derive(Serialize, Deserialize, Debug, Clone)]
514pub struct LLMCompletionStreamResponse {
515 pub model: String,
516 pub object: String,
517 pub choices: Vec<LLMStreamChoice>,
518 pub created: u64,
519 #[serde(skip_serializing_if = "Option::is_none")]
520 pub usage: Option<LLMTokenUsage>,
521 pub id: String,
522 pub citations: Option<Vec<String>>,
523}
524
525#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
526pub struct LLMTool {
527 pub name: String,
528 pub description: String,
529 pub input_schema: serde_json::Value,
530}
531
532#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
533pub struct LLMTokenUsage {
534 pub prompt_tokens: u32,
535 pub completion_tokens: u32,
536 pub total_tokens: u32,
537
538 #[serde(skip_serializing_if = "Option::is_none")]
539 pub prompt_tokens_details: Option<PromptTokensDetails>,
540}
541
542#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
543#[serde(rename_all = "snake_case")]
544pub enum TokenType {
545 InputTokens,
546 OutputTokens,
547 CacheReadInputTokens,
548 CacheWriteInputTokens,
549}
550
551#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
552pub struct PromptTokensDetails {
553 #[serde(skip_serializing_if = "Option::is_none")]
554 pub input_tokens: Option<u32>,
555 #[serde(skip_serializing_if = "Option::is_none")]
556 pub output_tokens: Option<u32>,
557 #[serde(skip_serializing_if = "Option::is_none")]
558 pub cache_read_input_tokens: Option<u32>,
559 #[serde(skip_serializing_if = "Option::is_none")]
560 pub cache_write_input_tokens: Option<u32>,
561}
562
563impl PromptTokensDetails {
564 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
566 [
567 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
568 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
569 (
570 TokenType::CacheReadInputTokens,
571 self.cache_read_input_tokens.unwrap_or(0),
572 ),
573 (
574 TokenType::CacheWriteInputTokens,
575 self.cache_write_input_tokens.unwrap_or(0),
576 ),
577 ]
578 .into_iter()
579 }
580}
581
582impl std::ops::Add for PromptTokensDetails {
583 type Output = Self;
584
585 fn add(self, rhs: Self) -> Self::Output {
586 Self {
587 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
588 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
589 cache_read_input_tokens: Some(
590 self.cache_read_input_tokens.unwrap_or(0)
591 + rhs.cache_read_input_tokens.unwrap_or(0),
592 ),
593 cache_write_input_tokens: Some(
594 self.cache_write_input_tokens.unwrap_or(0)
595 + rhs.cache_write_input_tokens.unwrap_or(0),
596 ),
597 }
598 }
599}
600
601impl std::ops::AddAssign for PromptTokensDetails {
602 fn add_assign(&mut self, rhs: Self) {
603 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
604 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
605 self.cache_read_input_tokens = Some(
606 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
607 );
608 self.cache_write_input_tokens = Some(
609 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
610 );
611 }
612}
613
614#[derive(Serialize, Deserialize, Debug, Clone)]
615#[serde(tag = "type")]
616pub enum GenerationDelta {
617 Content { content: String },
618 Thinking { thinking: String },
619 ToolUse { tool_use: GenerationDeltaToolUse },
620 Usage { usage: LLMTokenUsage },
621 Metadata { metadata: serde_json::Value },
622}
623
624#[derive(Serialize, Deserialize, Debug, Clone)]
625pub struct GenerationDeltaToolUse {
626 pub id: Option<String>,
627 pub name: Option<String>,
628 pub input: Option<String>,
629 pub index: usize,
630 #[serde(skip_serializing_if = "Option::is_none")]
632 pub metadata: Option<serde_json::Value>,
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
644 fn test_provider_config_openai_serialization() {
645 let config = ProviderConfig::OpenAI {
646 api_key: Some("sk-test".to_string()),
647 api_endpoint: None,
648 };
649 let json = serde_json::to_string(&config).unwrap();
650 assert!(json.contains("\"type\":\"openai\""));
651 assert!(json.contains("\"api_key\":\"sk-test\""));
652 assert!(!json.contains("api_endpoint")); }
654
655 #[test]
656 fn test_provider_config_openai_with_endpoint() {
657 let config = ProviderConfig::OpenAI {
658 api_key: Some("sk-test".to_string()),
659 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
660 };
661 let json = serde_json::to_string(&config).unwrap();
662 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
663 }
664
665 #[test]
666 fn test_provider_config_anthropic_serialization() {
667 let config = ProviderConfig::Anthropic {
668 api_key: Some("sk-ant-test".to_string()),
669 api_endpoint: None,
670 access_token: Some("oauth-token".to_string()),
671 };
672 let json = serde_json::to_string(&config).unwrap();
673 assert!(json.contains("\"type\":\"anthropic\""));
674 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
675 assert!(json.contains("\"access_token\":\"oauth-token\""));
676 }
677
678 #[test]
679 fn test_provider_config_gemini_serialization() {
680 let config = ProviderConfig::Gemini {
681 api_key: Some("gemini-key".to_string()),
682 api_endpoint: None,
683 };
684 let json = serde_json::to_string(&config).unwrap();
685 assert!(json.contains("\"type\":\"gemini\""));
686 assert!(json.contains("\"api_key\":\"gemini-key\""));
687 }
688
689 #[test]
690 fn test_provider_config_custom_serialization() {
691 let config = ProviderConfig::Custom {
692 api_key: Some("sk-custom".to_string()),
693 api_endpoint: "http://localhost:4000".to_string(),
694 };
695 let json = serde_json::to_string(&config).unwrap();
696 assert!(json.contains("\"type\":\"custom\""));
697 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
698 assert!(json.contains("\"api_key\":\"sk-custom\""));
699 }
700
701 #[test]
702 fn test_provider_config_custom_without_key() {
703 let config = ProviderConfig::Custom {
704 api_key: None,
705 api_endpoint: "http://localhost:11434/v1".to_string(),
706 };
707 let json = serde_json::to_string(&config).unwrap();
708 assert!(json.contains("\"type\":\"custom\""));
709 assert!(json.contains("\"api_endpoint\""));
710 assert!(!json.contains("api_key")); }
712
713 #[test]
714 fn test_provider_config_deserialization_openai() {
715 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
716 let config: ProviderConfig = serde_json::from_str(json).unwrap();
717 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
718 assert_eq!(config.api_key(), Some("sk-test"));
719 }
720
721 #[test]
722 fn test_provider_config_deserialization_anthropic() {
723 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
724 let config: ProviderConfig = serde_json::from_str(json).unwrap();
725 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
726 assert_eq!(config.api_key(), Some("sk-ant"));
727 assert_eq!(config.access_token(), Some("oauth"));
728 }
729
730 #[test]
731 fn test_provider_config_deserialization_gemini() {
732 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
733 let config: ProviderConfig = serde_json::from_str(json).unwrap();
734 assert!(matches!(config, ProviderConfig::Gemini { .. }));
735 assert_eq!(config.api_key(), Some("gemini-key"));
736 }
737
738 #[test]
739 fn test_provider_config_deserialization_custom() {
740 let json =
741 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
742 let config: ProviderConfig = serde_json::from_str(json).unwrap();
743 assert!(matches!(config, ProviderConfig::Custom { .. }));
744 assert_eq!(config.api_key(), Some("sk-custom"));
745 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
746 }
747
748 #[test]
749 fn test_provider_config_helper_methods() {
750 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
751 assert_eq!(openai.provider_type(), "openai");
752 assert_eq!(openai.api_key(), Some("sk-openai"));
753
754 let anthropic =
755 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
756 assert_eq!(anthropic.provider_type(), "anthropic");
757 assert_eq!(anthropic.access_token(), Some("oauth"));
758
759 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
760 assert_eq!(gemini.provider_type(), "gemini");
761
762 let custom = ProviderConfig::custom(
763 "http://localhost:4000".to_string(),
764 Some("sk-custom".to_string()),
765 );
766 assert_eq!(custom.provider_type(), "custom");
767 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
768 }
769
770 #[test]
771 fn test_llm_provider_config_new() {
772 let config = LLMProviderConfig::new();
773 assert!(config.is_empty());
774 }
775
776 #[test]
777 fn test_llm_provider_config_add_and_get() {
778 let mut config = LLMProviderConfig::new();
779 config.add_provider(
780 "openai",
781 ProviderConfig::openai(Some("sk-test".to_string())),
782 );
783 config.add_provider(
784 "anthropic",
785 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
786 );
787
788 assert!(!config.is_empty());
789 assert!(config.get_provider("openai").is_some());
790 assert!(config.get_provider("anthropic").is_some());
791 assert!(config.get_provider("unknown").is_none());
792 }
793
794 #[test]
795 fn test_provider_config_toml_parsing() {
796 let json = r#"{
798 "openai": {"type": "openai", "api_key": "sk-openai"},
799 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
800 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
801 }"#;
802
803 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
804 assert_eq!(providers.len(), 3);
805
806 assert!(matches!(
807 providers.get("openai"),
808 Some(ProviderConfig::OpenAI { .. })
809 ));
810 assert!(matches!(
811 providers.get("anthropic"),
812 Some(ProviderConfig::Anthropic { .. })
813 ));
814 assert!(matches!(
815 providers.get("litellm"),
816 Some(ProviderConfig::Custom { .. })
817 ));
818 }
819}