1use super::{
6 CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7 CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8 ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9 StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, knowledge::AccountCacheState,
10 models::*,
11};
12use crate::models::{
13 CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
14 ListRulebooksResponse, RuleBook,
15};
16use reqwest::{Response, header};
17use rmcp::model::Content;
18use serde::de::DeserializeOwned;
19use serde_json::{Value, json};
20use stakpak_shared::models::billing::BillingResponse;
21use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use uuid::Uuid;
25
26#[derive(Clone, Debug)]
28pub struct StakpakApiClient {
29 pub(super) client: reqwest::Client,
30 pub(super) base_url: String,
31 pub(super) account_name: Arc<Mutex<AccountCacheState>>,
32}
33
34#[derive(Debug, serde::Deserialize)]
36pub(super) struct ApiError {
37 pub(super) error: ApiErrorDetail,
38}
39
40#[derive(Debug, serde::Deserialize)]
41pub(super) struct ApiErrorDetail {
42 pub(super) key: String,
43 pub(super) message: String,
44}
45
46impl StakpakApiClient {
47 pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
49 if config.api_key.is_empty() {
50 return Err("Stakpak API key is required".to_string());
51 }
52
53 let mut headers = header::HeaderMap::new();
54 headers.insert(
55 header::AUTHORIZATION,
56 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
57 .map_err(|e| e.to_string())?,
58 );
59 headers.insert(
60 header::USER_AGENT,
61 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
62 .map_err(|e| e.to_string())?,
63 );
64
65 let client = create_tls_client(
66 TlsClientConfig::default()
67 .with_headers(headers)
68 .with_timeout(std::time::Duration::from_secs(300)),
69 )?;
70
71 Ok(Self {
72 client,
73 base_url: config.api_endpoint.clone(),
74 account_name: Arc::new(Mutex::new(AccountCacheState::Unknown)),
75 })
76 }
77
78 pub async fn create_session(
84 &self,
85 req: &CreateSessionRequest,
86 ) -> Result<CreateSessionResponse, String> {
87 let url = format!("{}/v1/sessions", self.base_url);
88 let response = self
89 .client
90 .post(&url)
91 .json(req)
92 .send()
93 .await
94 .map_err(|e| e.to_string())?;
95 self.handle_response(response).await
96 }
97
98 pub async fn create_checkpoint(
100 &self,
101 session_id: Uuid,
102 req: &CreateCheckpointRequest,
103 ) -> Result<CreateCheckpointResponse, String> {
104 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
105 let response = self
106 .client
107 .post(&url)
108 .json(req)
109 .send()
110 .await
111 .map_err(|e| e.to_string())?;
112 self.handle_response(response).await
113 }
114
115 pub async fn list_sessions(
117 &self,
118 query: &ListSessionsQuery,
119 ) -> Result<ListSessionsResponse, String> {
120 let url = format!("{}/v1/sessions", self.base_url);
121 let response = self
122 .client
123 .get(&url)
124 .query(query)
125 .send()
126 .await
127 .map_err(|e| e.to_string())?;
128 self.handle_response(response).await
129 }
130
131 pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
133 let url = format!("{}/v1/sessions/{}", self.base_url, id);
134 let response = self
135 .client
136 .get(&url)
137 .send()
138 .await
139 .map_err(|e| e.to_string())?;
140 self.handle_response(response).await
141 }
142
143 pub async fn update_session(
145 &self,
146 id: Uuid,
147 req: &UpdateSessionRequest,
148 ) -> Result<UpdateSessionResponse, String> {
149 let url = format!("{}/v1/sessions/{}", self.base_url, id);
150 let response = self
151 .client
152 .patch(&url)
153 .json(req)
154 .send()
155 .await
156 .map_err(|e| e.to_string())?;
157 self.handle_response(response).await
158 }
159
160 pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
162 let url = format!("{}/v1/sessions/{}", self.base_url, id);
163 let response = self
164 .client
165 .delete(&url)
166 .send()
167 .await
168 .map_err(|e| e.to_string())?;
169 self.handle_response_no_body(response).await
170 }
171
172 pub async fn list_checkpoints(
174 &self,
175 session_id: Uuid,
176 query: &ListCheckpointsQuery,
177 ) -> Result<ListCheckpointsResponse, String> {
178 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
179 let response = self
180 .client
181 .get(&url)
182 .query(query)
183 .send()
184 .await
185 .map_err(|e| e.to_string())?;
186 self.handle_response(response).await
187 }
188
189 pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
191 let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
192 let response = self
193 .client
194 .get(&url)
195 .send()
196 .await
197 .map_err(|e| e.to_string())?;
198 self.handle_response(response).await
199 }
200
201 pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
207 let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
208 let response = self
209 .client
210 .post(&url)
211 .send()
212 .await
213 .map_err(|e| e.to_string())?;
214 self.handle_response_no_body(response).await
215 }
216
217 pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
223 let url = format!("{}/v1/account", self.base_url);
224 let response = self
225 .client
226 .get(&url)
227 .send()
228 .await
229 .map_err(|e| e.to_string())?;
230 self.handle_response(response).await
231 }
232
233 pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
235 let url = format!("{}/v2/{}/billing", self.base_url, username);
236 let response = self
237 .client
238 .get(&url)
239 .send()
240 .await
241 .map_err(|e| e.to_string())?;
242 self.handle_response(response).await
243 }
244
245 pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
251 let url = format!("{}/v1/rules", self.base_url);
252 let response = self
253 .client
254 .get(&url)
255 .send()
256 .await
257 .map_err(|e| e.to_string())?;
258
259 let response = self.handle_response_error(response).await?;
260 let value: Value = response.json().await.map_err(|e| e.to_string())?;
261
262 match serde_json::from_value::<ListRulebooksResponse>(value) {
263 Ok(response) => Ok(response.results),
264 Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
265 }
266 }
267
268 pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
270 let encoded_uri = urlencoding::encode(uri);
271 let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
272 let response = self
273 .client
274 .get(&url)
275 .send()
276 .await
277 .map_err(|e| e.to_string())?;
278 self.handle_response(response).await
279 }
280
281 pub async fn create_rulebook(
283 &self,
284 input: &CreateRuleBookInput,
285 ) -> Result<CreateRuleBookResponse, String> {
286 let url = format!("{}/v1/rules", self.base_url);
287 let response = self
288 .client
289 .post(&url)
290 .json(input)
291 .send()
292 .await
293 .map_err(|e| e.to_string())?;
294 self.handle_response(response).await
295 }
296
297 pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
299 let encoded_uri = urlencoding::encode(uri);
300 let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
301 let response = self
302 .client
303 .delete(&url)
304 .send()
305 .await
306 .map_err(|e| e.to_string())?;
307 self.handle_response_no_body(response).await
308 }
309
310 pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
316 self.call_mcp_tool(&ToolsCallParams {
317 name: "search_docs".to_string(),
318 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
319 })
320 .await
321 }
322
323 pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
325 self.call_mcp_tool(&ToolsCallParams {
326 name: "search_memory".to_string(),
327 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
328 })
329 .await
330 }
331
332 pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
334 let url = format!(
335 "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
336 self.base_url, checkpoint_id
337 );
338 let response = self
339 .client
340 .post(&url)
341 .send()
342 .await
343 .map_err(|e| e.to_string())?;
344 self.handle_response_no_body(response).await
345 }
346
347 pub async fn slack_read_messages(
349 &self,
350 req: &SlackReadMessagesRequest,
351 ) -> Result<Vec<Content>, String> {
352 self.call_mcp_tool(&ToolsCallParams {
353 name: "slack_read_messages".to_string(),
354 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
355 })
356 .await
357 }
358
359 pub async fn slack_read_replies(
361 &self,
362 req: &SlackReadRepliesRequest,
363 ) -> Result<Vec<Content>, String> {
364 self.call_mcp_tool(&ToolsCallParams {
365 name: "slack_read_replies".to_string(),
366 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
367 })
368 .await
369 }
370
371 pub async fn slack_send_message(
373 &self,
374 req: &SlackSendMessageRequest,
375 ) -> Result<Vec<Content>, String> {
376 self.call_mcp_tool(&ToolsCallParams {
377 name: "slack_send_message".to_string(),
378 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
379 })
380 .await
381 }
382
383 async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
389 let url = format!("{}/v1/mcp", self.base_url);
390 let body = json!({
391 "jsonrpc": "2.0",
392 "id": 1,
393 "method": "tools/call",
394 "params": params
395 });
396
397 let response = self
398 .client
399 .post(&url)
400 .json(&body)
401 .send()
402 .await
403 .map_err(|e| e.to_string())?;
404
405 let resp: Value = self.handle_response(response).await?;
406
407 if let Some(result) = resp.get("result")
409 && let Some(content) = result.get("content")
410 {
411 let content: Vec<Content> =
412 serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
413 return Ok(content);
414 }
415
416 if let Some(error) = resp.get("error") {
418 let msg = error
419 .get("message")
420 .and_then(|m| m.as_str())
421 .unwrap_or("Unknown error");
422 return Err(msg.to_string());
423 }
424
425 Err("Invalid MCP response format".to_string())
426 }
427
428 async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
430 let response = self.handle_response_error(response).await?;
431 let url = response.url().to_string();
432 let status = response.status();
433 let body = response.text().await.map_err(|e| {
434 format!(
435 "Failed to read response body from {} (status {}): {}",
436 url, status, e
437 )
438 })?;
439 serde_json::from_str(&body).map_err(|e| {
440 let truncated_body: String = body.chars().take(500).collect();
442 format!(
443 "Failed to decode response from {} (status {}): {} | body: {}",
444 url, status, e, truncated_body
445 )
446 })
447 }
448
449 async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
451 self.handle_response_error(response).await?;
452 Ok(())
453 }
454
455 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
457 if response.status().is_success() {
458 return Ok(response);
459 }
460
461 let status = response.status();
462 let error_body = response.text().await.unwrap_or_default();
463
464 if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
466 if api_error.error.key == "EXCEEDED_API_LIMIT" {
468 return Err(format!(
469 "{}. You can top up your billing at https://stakpak.dev/settings/billing",
470 api_error.error.message
471 ));
472 }
473 return Err(api_error.error.message);
474 }
475
476 Err(format!("API error {}: {}", status, error_body))
477 }
478}
479
480impl CreateSessionRequest {
485 pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
487 Self {
488 title: title.into(),
489 visibility: Some(SessionVisibility::Private),
490 cwd: None,
491 state,
492 }
493 }
494
495 pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
497 self.cwd = Some(cwd.into());
498 self
499 }
500
501 pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
503 self.visibility = Some(visibility);
504 self
505 }
506}
507
508impl CreateCheckpointRequest {
509 pub fn new(state: CheckpointState) -> Self {
511 Self {
512 state,
513 parent_id: None,
514 }
515 }
516
517 pub fn with_parent(mut self, parent_id: Uuid) -> Self {
519 self.parent_id = Some(parent_id);
520 self
521 }
522}