1use super::{
6 CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7 CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8 ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9 StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, models::*,
10};
11use crate::models::{
12 CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
13 ListRulebooksResponse, RuleBook,
14};
15use reqwest::{Response, header};
16use rmcp::model::Content;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19use stakpak_shared::models::billing::BillingResponse;
20use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
21use uuid::Uuid;
22
23#[derive(Clone, Debug)]
25pub struct StakpakApiClient {
26 client: reqwest::Client,
27 base_url: String,
28}
29
30#[derive(Debug, serde::Deserialize)]
32struct ApiError {
33 error: ApiErrorDetail,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ApiErrorDetail {
38 key: String,
39 message: String,
40}
41
42impl StakpakApiClient {
43 pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
45 if config.api_key.is_empty() {
46 return Err("Stakpak API key is required".to_string());
47 }
48
49 let mut headers = header::HeaderMap::new();
50 headers.insert(
51 header::AUTHORIZATION,
52 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
53 .map_err(|e| e.to_string())?,
54 );
55 headers.insert(
56 header::USER_AGENT,
57 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
58 .map_err(|e| e.to_string())?,
59 );
60
61 let client = create_tls_client(
62 TlsClientConfig::default()
63 .with_headers(headers)
64 .with_timeout(std::time::Duration::from_secs(300)),
65 )?;
66
67 Ok(Self {
68 client,
69 base_url: config.api_endpoint.clone(),
70 })
71 }
72
73 pub async fn create_session(
79 &self,
80 req: &CreateSessionRequest,
81 ) -> Result<CreateSessionResponse, String> {
82 let url = format!("{}/v1/sessions", self.base_url);
83 let response = self
84 .client
85 .post(&url)
86 .json(req)
87 .send()
88 .await
89 .map_err(|e| e.to_string())?;
90 self.handle_response(response).await
91 }
92
93 pub async fn create_checkpoint(
95 &self,
96 session_id: Uuid,
97 req: &CreateCheckpointRequest,
98 ) -> Result<CreateCheckpointResponse, String> {
99 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
100 let response = self
101 .client
102 .post(&url)
103 .json(req)
104 .send()
105 .await
106 .map_err(|e| e.to_string())?;
107 self.handle_response(response).await
108 }
109
110 pub async fn list_sessions(
112 &self,
113 query: &ListSessionsQuery,
114 ) -> Result<ListSessionsResponse, String> {
115 let url = format!("{}/v1/sessions", self.base_url);
116 let response = self
117 .client
118 .get(&url)
119 .query(query)
120 .send()
121 .await
122 .map_err(|e| e.to_string())?;
123 self.handle_response(response).await
124 }
125
126 pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
128 let url = format!("{}/v1/sessions/{}", self.base_url, id);
129 let response = self
130 .client
131 .get(&url)
132 .send()
133 .await
134 .map_err(|e| e.to_string())?;
135 self.handle_response(response).await
136 }
137
138 pub async fn update_session(
140 &self,
141 id: Uuid,
142 req: &UpdateSessionRequest,
143 ) -> Result<UpdateSessionResponse, String> {
144 let url = format!("{}/v1/sessions/{}", self.base_url, id);
145 let response = self
146 .client
147 .patch(&url)
148 .json(req)
149 .send()
150 .await
151 .map_err(|e| e.to_string())?;
152 self.handle_response(response).await
153 }
154
155 pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
157 let url = format!("{}/v1/sessions/{}", self.base_url, id);
158 let response = self
159 .client
160 .delete(&url)
161 .send()
162 .await
163 .map_err(|e| e.to_string())?;
164 self.handle_response_no_body(response).await
165 }
166
167 pub async fn list_checkpoints(
169 &self,
170 session_id: Uuid,
171 query: &ListCheckpointsQuery,
172 ) -> Result<ListCheckpointsResponse, String> {
173 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
174 let response = self
175 .client
176 .get(&url)
177 .query(query)
178 .send()
179 .await
180 .map_err(|e| e.to_string())?;
181 self.handle_response(response).await
182 }
183
184 pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
186 let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
187 let response = self
188 .client
189 .get(&url)
190 .send()
191 .await
192 .map_err(|e| e.to_string())?;
193 self.handle_response(response).await
194 }
195
196 pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
202 let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
203 let response = self
204 .client
205 .post(&url)
206 .send()
207 .await
208 .map_err(|e| e.to_string())?;
209 self.handle_response_no_body(response).await
210 }
211
212 pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
218 let url = format!("{}/v1/account", self.base_url);
219 let response = self
220 .client
221 .get(&url)
222 .send()
223 .await
224 .map_err(|e| e.to_string())?;
225 self.handle_response(response).await
226 }
227
228 pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
230 let url = format!("{}/v2/{}/billing", self.base_url, username);
231 let response = self
232 .client
233 .get(&url)
234 .send()
235 .await
236 .map_err(|e| e.to_string())?;
237 self.handle_response(response).await
238 }
239
240 pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
246 let url = format!("{}/v1/rules", self.base_url);
247 let response = self
248 .client
249 .get(&url)
250 .send()
251 .await
252 .map_err(|e| e.to_string())?;
253
254 let response = self.handle_response_error(response).await?;
255 let value: Value = response.json().await.map_err(|e| e.to_string())?;
256
257 match serde_json::from_value::<ListRulebooksResponse>(value) {
258 Ok(response) => Ok(response.results),
259 Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
260 }
261 }
262
263 pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
265 let encoded_uri = urlencoding::encode(uri);
266 let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
267 let response = self
268 .client
269 .get(&url)
270 .send()
271 .await
272 .map_err(|e| e.to_string())?;
273 self.handle_response(response).await
274 }
275
276 pub async fn create_rulebook(
278 &self,
279 input: &CreateRuleBookInput,
280 ) -> Result<CreateRuleBookResponse, String> {
281 let url = format!("{}/v1/rules", self.base_url);
282 let response = self
283 .client
284 .post(&url)
285 .json(input)
286 .send()
287 .await
288 .map_err(|e| e.to_string())?;
289 self.handle_response(response).await
290 }
291
292 pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
294 let encoded_uri = urlencoding::encode(uri);
295 let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
296 let response = self
297 .client
298 .delete(&url)
299 .send()
300 .await
301 .map_err(|e| e.to_string())?;
302 self.handle_response_no_body(response).await
303 }
304
305 pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
311 self.call_mcp_tool(&ToolsCallParams {
312 name: "search_docs".to_string(),
313 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
314 })
315 .await
316 }
317
318 pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
320 self.call_mcp_tool(&ToolsCallParams {
321 name: "search_memory".to_string(),
322 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
323 })
324 .await
325 }
326
327 pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
329 let url = format!(
330 "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
331 self.base_url, checkpoint_id
332 );
333 let response = self
334 .client
335 .post(&url)
336 .send()
337 .await
338 .map_err(|e| e.to_string())?;
339 self.handle_response_no_body(response).await
340 }
341
342 pub async fn slack_read_messages(
344 &self,
345 req: &SlackReadMessagesRequest,
346 ) -> Result<Vec<Content>, String> {
347 self.call_mcp_tool(&ToolsCallParams {
348 name: "slack_read_messages".to_string(),
349 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
350 })
351 .await
352 }
353
354 pub async fn slack_read_replies(
356 &self,
357 req: &SlackReadRepliesRequest,
358 ) -> Result<Vec<Content>, String> {
359 self.call_mcp_tool(&ToolsCallParams {
360 name: "slack_read_replies".to_string(),
361 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
362 })
363 .await
364 }
365
366 pub async fn slack_send_message(
368 &self,
369 req: &SlackSendMessageRequest,
370 ) -> Result<Vec<Content>, String> {
371 self.call_mcp_tool(&ToolsCallParams {
372 name: "slack_send_message".to_string(),
373 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
374 })
375 .await
376 }
377
378 async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
384 let url = format!("{}/v1/mcp", self.base_url);
385 let body = json!({
386 "jsonrpc": "2.0",
387 "id": 1,
388 "method": "tools/call",
389 "params": params
390 });
391
392 let response = self
393 .client
394 .post(&url)
395 .json(&body)
396 .send()
397 .await
398 .map_err(|e| e.to_string())?;
399
400 let resp: Value = self.handle_response(response).await?;
401
402 if let Some(result) = resp.get("result")
404 && let Some(content) = result.get("content")
405 {
406 let content: Vec<Content> =
407 serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
408 return Ok(content);
409 }
410
411 if let Some(error) = resp.get("error") {
413 let msg = error
414 .get("message")
415 .and_then(|m| m.as_str())
416 .unwrap_or("Unknown error");
417 return Err(msg.to_string());
418 }
419
420 Err("Invalid MCP response format".to_string())
421 }
422
423 async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
425 let response = self.handle_response_error(response).await?;
426 response.json().await.map_err(|e| e.to_string())
427 }
428
429 async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
431 self.handle_response_error(response).await?;
432 Ok(())
433 }
434
435 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
437 if response.status().is_success() {
438 return Ok(response);
439 }
440
441 let status = response.status();
442 let error_body = response.text().await.unwrap_or_default();
443
444 if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
446 if api_error.error.key == "EXCEEDED_API_LIMIT" {
448 return Err(format!(
449 "{}. You can top up your billing at https://stakpak.dev/settings/billing",
450 api_error.error.message
451 ));
452 }
453 return Err(api_error.error.message);
454 }
455
456 Err(format!("API error {}: {}", status, error_body))
457 }
458}
459
460impl CreateSessionRequest {
465 pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
467 Self {
468 title: title.into(),
469 visibility: Some(SessionVisibility::Private),
470 cwd: None,
471 state,
472 }
473 }
474
475 pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
477 self.cwd = Some(cwd.into());
478 self
479 }
480
481 pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
483 self.visibility = Some(visibility);
484 self
485 }
486}
487
488impl CreateCheckpointRequest {
489 pub fn new(state: CheckpointState) -> Self {
491 Self {
492 state,
493 parent_id: None,
494 }
495 }
496
497 pub fn with_parent(mut self, parent_id: Uuid) -> Self {
499 self.parent_id = Some(parent_id);
500 self
501 }
502}