Skip to main content

stakpak_api/stakpak/
client.rs

1//! StakpakApiClient implementation
2//!
3//! Provides access to Stakpak's non-inference APIs.
4
5use super::{
6    CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7    CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8    ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9    StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, models::*,
10};
11use crate::models::{
12    CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
13    ListRulebooksResponse, RuleBook,
14};
15use reqwest::{Client as ReqwestClient, Response, header};
16use rmcp::model::Content;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19use stakpak_shared::models::billing::BillingResponse;
20use uuid::Uuid;
21
22/// Client for Stakpak's non-inference APIs
23#[derive(Clone, Debug)]
24pub struct StakpakApiClient {
25    client: ReqwestClient,
26    base_url: String,
27}
28
29/// API error response format
30#[derive(Debug, serde::Deserialize)]
31struct ApiError {
32    error: ApiErrorDetail,
33}
34
35#[derive(Debug, serde::Deserialize)]
36struct ApiErrorDetail {
37    key: String,
38    message: String,
39}
40
41impl StakpakApiClient {
42    /// Create a new StakpakApiClient
43    pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
44        if config.api_key.is_empty() {
45            return Err("Stakpak API key is required".to_string());
46        }
47
48        let mut headers = header::HeaderMap::new();
49        headers.insert(
50            header::AUTHORIZATION,
51            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
52                .map_err(|e| e.to_string())?,
53        );
54        headers.insert(
55            header::USER_AGENT,
56            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
57                .map_err(|e| e.to_string())?,
58        );
59
60        let client = ReqwestClient::builder()
61            .default_headers(headers)
62            .timeout(std::time::Duration::from_secs(300))
63            .build()
64            .map_err(|e| e.to_string())?;
65
66        Ok(Self {
67            client,
68            base_url: config.api_endpoint.clone(),
69        })
70    }
71
72    // =========================================================================
73    // Session APIs - New /v1/sessions endpoints
74    // =========================================================================
75
76    /// Create a new session
77    pub async fn create_session(
78        &self,
79        req: &CreateSessionRequest,
80    ) -> Result<CreateSessionResponse, String> {
81        let url = format!("{}/v1/sessions", self.base_url);
82        let response = self
83            .client
84            .post(&url)
85            .json(req)
86            .send()
87            .await
88            .map_err(|e| e.to_string())?;
89        self.handle_response(response).await
90    }
91
92    /// Create a checkpoint for a session
93    pub async fn create_checkpoint(
94        &self,
95        session_id: Uuid,
96        req: &CreateCheckpointRequest,
97    ) -> Result<CreateCheckpointResponse, String> {
98        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
99        let response = self
100            .client
101            .post(&url)
102            .json(req)
103            .send()
104            .await
105            .map_err(|e| e.to_string())?;
106        self.handle_response(response).await
107    }
108
109    /// List sessions
110    pub async fn list_sessions(
111        &self,
112        query: &ListSessionsQuery,
113    ) -> Result<ListSessionsResponse, String> {
114        let url = format!("{}/v1/sessions", self.base_url);
115        let response = self
116            .client
117            .get(&url)
118            .query(query)
119            .send()
120            .await
121            .map_err(|e| e.to_string())?;
122        self.handle_response(response).await
123    }
124
125    /// Get a session by ID
126    pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
127        let url = format!("{}/v1/sessions/{}", self.base_url, id);
128        let response = self
129            .client
130            .get(&url)
131            .send()
132            .await
133            .map_err(|e| e.to_string())?;
134        self.handle_response(response).await
135    }
136
137    /// Update a session
138    pub async fn update_session(
139        &self,
140        id: Uuid,
141        req: &UpdateSessionRequest,
142    ) -> Result<UpdateSessionResponse, String> {
143        let url = format!("{}/v1/sessions/{}", self.base_url, id);
144        let response = self
145            .client
146            .patch(&url)
147            .json(req)
148            .send()
149            .await
150            .map_err(|e| e.to_string())?;
151        self.handle_response(response).await
152    }
153
154    /// Delete a session
155    pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
156        let url = format!("{}/v1/sessions/{}", self.base_url, id);
157        let response = self
158            .client
159            .delete(&url)
160            .send()
161            .await
162            .map_err(|e| e.to_string())?;
163        self.handle_response_no_body(response).await
164    }
165
166    /// List checkpoints for a session
167    pub async fn list_checkpoints(
168        &self,
169        session_id: Uuid,
170        query: &ListCheckpointsQuery,
171    ) -> Result<ListCheckpointsResponse, String> {
172        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
173        let response = self
174            .client
175            .get(&url)
176            .query(query)
177            .send()
178            .await
179            .map_err(|e| e.to_string())?;
180        self.handle_response(response).await
181    }
182
183    /// Get a checkpoint by ID
184    pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
185        let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
186        let response = self
187            .client
188            .get(&url)
189            .send()
190            .await
191            .map_err(|e| e.to_string())?;
192        self.handle_response(response).await
193    }
194
195    // =========================================================================
196    // Cancel API
197    // =========================================================================
198
199    /// Cancel an active inference request
200    pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
201        let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
202        let response = self
203            .client
204            .post(&url)
205            .send()
206            .await
207            .map_err(|e| e.to_string())?;
208        self.handle_response_no_body(response).await
209    }
210
211    // =========================================================================
212    // Account APIs
213    // =========================================================================
214
215    /// Get the current user's account info
216    pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
217        let url = format!("{}/v1/account", self.base_url);
218        let response = self
219            .client
220            .get(&url)
221            .send()
222            .await
223            .map_err(|e| e.to_string())?;
224        self.handle_response(response).await
225    }
226
227    /// Get billing info for a user
228    pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
229        let url = format!("{}/v2/{}/billing", self.base_url, username);
230        let response = self
231            .client
232            .get(&url)
233            .send()
234            .await
235            .map_err(|e| e.to_string())?;
236        self.handle_response(response).await
237    }
238
239    // =========================================================================
240    // Rulebook APIs
241    // =========================================================================
242
243    /// List all rulebooks
244    pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
245        let url = format!("{}/v1/rules", self.base_url);
246        let response = self
247            .client
248            .get(&url)
249            .send()
250            .await
251            .map_err(|e| e.to_string())?;
252
253        let response = self.handle_response_error(response).await?;
254        let value: Value = response.json().await.map_err(|e| e.to_string())?;
255
256        match serde_json::from_value::<ListRulebooksResponse>(value) {
257            Ok(response) => Ok(response.results),
258            Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
259        }
260    }
261
262    /// Get a rulebook by URI
263    pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
264        let encoded_uri = urlencoding::encode(uri);
265        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
266        let response = self
267            .client
268            .get(&url)
269            .send()
270            .await
271            .map_err(|e| e.to_string())?;
272        self.handle_response(response).await
273    }
274
275    /// Create a new rulebook
276    pub async fn create_rulebook(
277        &self,
278        input: &CreateRuleBookInput,
279    ) -> Result<CreateRuleBookResponse, String> {
280        let url = format!("{}/v1/rules", self.base_url);
281        let response = self
282            .client
283            .post(&url)
284            .json(input)
285            .send()
286            .await
287            .map_err(|e| e.to_string())?;
288        self.handle_response(response).await
289    }
290
291    /// Delete a rulebook
292    pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
293        let encoded_uri = urlencoding::encode(uri);
294        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
295        let response = self
296            .client
297            .delete(&url)
298            .send()
299            .await
300            .map_err(|e| e.to_string())?;
301        self.handle_response_no_body(response).await
302    }
303
304    // =========================================================================
305    // MCP Tool APIs
306    // =========================================================================
307
308    /// Search documentation
309    pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
310        self.call_mcp_tool(&ToolsCallParams {
311            name: "search_docs".to_string(),
312            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
313        })
314        .await
315    }
316
317    /// Search memory
318    pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
319        self.call_mcp_tool(&ToolsCallParams {
320            name: "search_memory".to_string(),
321            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
322        })
323        .await
324    }
325
326    /// Memorize a session checkpoint (extract memory)
327    pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
328        let url = format!(
329            "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
330            self.base_url, checkpoint_id
331        );
332        let response = self
333            .client
334            .post(&url)
335            .send()
336            .await
337            .map_err(|e| e.to_string())?;
338        self.handle_response_no_body(response).await
339    }
340
341    /// Read Slack messages from a channel
342    pub async fn slack_read_messages(
343        &self,
344        req: &SlackReadMessagesRequest,
345    ) -> Result<Vec<Content>, String> {
346        self.call_mcp_tool(&ToolsCallParams {
347            name: "slack_read_messages".to_string(),
348            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
349        })
350        .await
351    }
352
353    /// Read Slack thread replies
354    pub async fn slack_read_replies(
355        &self,
356        req: &SlackReadRepliesRequest,
357    ) -> Result<Vec<Content>, String> {
358        self.call_mcp_tool(&ToolsCallParams {
359            name: "slack_read_replies".to_string(),
360            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
361        })
362        .await
363    }
364
365    /// Send a Slack message
366    pub async fn slack_send_message(
367        &self,
368        req: &SlackSendMessageRequest,
369    ) -> Result<Vec<Content>, String> {
370        self.call_mcp_tool(&ToolsCallParams {
371            name: "slack_send_message".to_string(),
372            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
373        })
374        .await
375    }
376
377    // =========================================================================
378    // Helper Methods
379    // =========================================================================
380
381    /// Call an MCP tool via JSON-RPC
382    async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
383        let url = format!("{}/v1/mcp", self.base_url);
384        let body = json!({
385            "jsonrpc": "2.0",
386            "id": 1,
387            "method": "tools/call",
388            "params": params
389        });
390
391        let response = self
392            .client
393            .post(&url)
394            .json(&body)
395            .send()
396            .await
397            .map_err(|e| e.to_string())?;
398
399        let resp: Value = self.handle_response(response).await?;
400
401        // Extract result.content from JSON-RPC response
402        if let Some(result) = resp.get("result")
403            && let Some(content) = result.get("content")
404        {
405            let content: Vec<Content> =
406                serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
407            return Ok(content);
408        }
409
410        // Check for error
411        if let Some(error) = resp.get("error") {
412            let msg = error
413                .get("message")
414                .and_then(|m| m.as_str())
415                .unwrap_or("Unknown error");
416            return Err(msg.to_string());
417        }
418
419        Err("Invalid MCP response format".to_string())
420    }
421
422    /// Handle response and parse JSON
423    async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
424        let response = self.handle_response_error(response).await?;
425        response.json().await.map_err(|e| e.to_string())
426    }
427
428    /// Handle response without body
429    async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
430        self.handle_response_error(response).await?;
431        Ok(())
432    }
433
434    /// Handle response errors
435    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
436        if response.status().is_success() {
437            return Ok(response);
438        }
439
440        let status = response.status();
441        let error_body = response.text().await.unwrap_or_default();
442
443        // Try to parse as API error
444        if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
445            // Special handling for API limit exceeded
446            if api_error.error.key == "EXCEEDED_API_LIMIT" {
447                return Err(format!(
448                    "{}. You can top up your billing at https://stakpak.dev/settings/billing",
449                    api_error.error.message
450                ));
451            }
452            return Err(api_error.error.message);
453        }
454
455        Err(format!("API error {}: {}", status, error_body))
456    }
457}
458
459// =============================================================================
460// Builder helpers for creating sessions and checkpoints
461// =============================================================================
462
463impl CreateSessionRequest {
464    /// Create a new session request with initial state
465    pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
466        Self {
467            title: title.into(),
468            visibility: Some(SessionVisibility::Private),
469            cwd: None,
470            state,
471        }
472    }
473
474    /// Set the working directory
475    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
476        self.cwd = Some(cwd.into());
477        self
478    }
479
480    /// Set visibility
481    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
482        self.visibility = Some(visibility);
483        self
484    }
485}
486
487impl CreateCheckpointRequest {
488    /// Create a new checkpoint request
489    pub fn new(state: CheckpointState) -> Self {
490        Self {
491            state,
492            parent_id: None,
493        }
494    }
495
496    /// Set the parent checkpoint ID (for branching)
497    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
498        self.parent_id = Some(parent_id);
499        self
500    }
501}