1use async_trait::async_trait;
2use futures_util::Stream;
3use models::*;
4use reqwest::header::HeaderMap;
5use rmcp::model::Content;
6use stakpak_shared::models::integrations::openai::{
7 ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, Tool,
8};
9use uuid::Uuid;
10
11pub mod client;
12pub mod error;
13pub mod local;
14pub mod models;
15pub mod stakpak;
16pub mod storage;
17
18pub use client::{
20 AgentClient, AgentClientConfig, DEFAULT_STAKPAK_ENDPOINT, ModelOptions, StakpakConfig,
21};
22
23pub use stakai::{Model, ModelCost, ModelLimit};
25
26pub use storage::{
28 BoxedSessionStorage, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
29 CreateSessionRequest as StorageCreateSessionRequest, CreateSessionResult, ListCheckpointsQuery,
30 ListCheckpointsResult, ListSessionsQuery, ListSessionsResult, LocalStorage, Session,
31 SessionStats, SessionStatus, SessionStorage, SessionSummary, SessionVisibility, StakpakStorage,
32 StorageError, UpdateSessionRequest as StorageUpdateSessionRequest,
33};
34
35pub fn find_model(model_str: &str, use_stakpak: bool) -> Option<Model> {
43 const PROVIDERS: &[&str] = &["anthropic", "openai", "google"];
44
45 let (provider_hint, model_id) = parse_model_string(model_str);
46
47 let model = provider_hint
49 .and_then(|p| find_in_provider(p, model_id))
50 .or_else(|| {
51 PROVIDERS
52 .iter()
53 .find_map(|&p| find_in_provider(p, model_id))
54 })?;
55
56 Some(if use_stakpak {
57 transform_for_stakpak(model)
58 } else {
59 model
60 })
61}
62
63#[allow(clippy::string_slice)] fn parse_model_string(s: &str) -> (Option<&str>, &str) {
66 match s.find('/') {
67 Some(idx) => {
68 let provider = &s[..idx];
69 let model_id = &s[idx + 1..];
70 let normalized = match provider {
71 "gemini" => "google",
72 p => p,
73 };
74 (Some(normalized), model_id)
75 }
76 None => (None, s),
77 }
78}
79
80fn find_in_provider(provider_id: &str, model_id: &str) -> Option<Model> {
82 let models = stakai::load_models_for_provider(provider_id).ok()?;
83
84 if let Some(model) = models.iter().find(|m| m.id == model_id) {
86 return Some(model.clone());
87 }
88
89 let mut best_match: Option<&Model> = None;
92 let mut best_len = 0;
93
94 for model in &models {
95 if model_id.starts_with(&model.id) && model.id.len() > best_len {
96 best_match = Some(model);
97 best_len = model.id.len();
98 }
99 }
100
101 best_match.cloned()
102}
103
104pub fn transform_for_stakpak(model: Model) -> Model {
109 Model {
110 id: format!("{}/{}", model.provider, model.id),
111 provider: "stakpak".into(),
112 name: model.name,
113 reasoning: model.reasoning,
114 cost: model.cost,
115 limit: model.limit,
116 release_date: model.release_date,
117 }
118}
119
120#[async_trait]
126pub trait AgentProvider: SessionStorage + Send + Sync {
127 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String>;
129 async fn get_billing_info(
130 &self,
131 account_username: &str,
132 ) -> Result<stakpak_shared::models::billing::BillingResponse, String>;
133
134 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String>;
136 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String>;
137 async fn create_rulebook(
138 &self,
139 uri: &str,
140 description: &str,
141 content: &str,
142 tags: Vec<String>,
143 visibility: Option<RuleBookVisibility>,
144 ) -> Result<CreateRuleBookResponse, String>;
145 async fn delete_rulebook(&self, uri: &str) -> Result<(), String>;
146
147 async fn chat_completion(
149 &self,
150 model: Model,
151 messages: Vec<ChatMessage>,
152 tools: Option<Vec<Tool>>,
153 session_id: Option<Uuid>,
154 metadata: Option<serde_json::Value>,
155 ) -> Result<ChatCompletionResponse, String>;
156 async fn chat_completion_stream(
157 &self,
158 model: Model,
159 messages: Vec<ChatMessage>,
160 tools: Option<Vec<Tool>>,
161 headers: Option<HeaderMap>,
162 session_id: Option<Uuid>,
163 metadata: Option<serde_json::Value>,
164 ) -> Result<
165 (
166 std::pin::Pin<
167 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
168 >,
169 Option<String>,
170 ),
171 String,
172 >;
173 async fn cancel_stream(&self, request_id: String) -> Result<(), String>;
174
175 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String>;
177
178 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String>;
180 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String>;
181
182 async fn slack_read_messages(
184 &self,
185 input: &SlackReadMessagesRequest,
186 ) -> Result<Vec<Content>, String>;
187 async fn slack_read_replies(
188 &self,
189 input: &SlackReadRepliesRequest,
190 ) -> Result<Vec<Content>, String>;
191 async fn slack_send_message(
192 &self,
193 input: &SlackSendMessageRequest,
194 ) -> Result<Vec<Content>, String>;
195
196 async fn list_models(&self) -> Vec<Model>;
198}