Skip to main content

stakpak_api/stakpak/
client.rs

1//! StakpakApiClient implementation
2//!
3//! Provides access to Stakpak's non-inference APIs.
4
5use super::{
6    CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7    CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8    ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9    StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, models::*,
10};
11use crate::models::{
12    CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
13    ListRulebooksResponse, RuleBook,
14};
15use reqwest::{Client as ReqwestClient, Response, header};
16use rmcp::model::Content;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19use stakpak_shared::models::billing::BillingResponse;
20use uuid::Uuid;
21
22/// Client for Stakpak's non-inference APIs
23#[derive(Clone, Debug)]
24pub struct StakpakApiClient {
25    client: ReqwestClient,
26    base_url: String,
27}
28
29/// API error response format
30#[derive(Debug, serde::Deserialize)]
31struct ApiError {
32    error: ApiErrorDetail,
33}
34
35#[derive(Debug, serde::Deserialize)]
36struct ApiErrorDetail {
37    key: String,
38    message: String,
39}
40
41impl StakpakApiClient {
42    /// Create a new StakpakApiClient
43    pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
44        if config.api_key.is_empty() {
45            return Err("Stakpak API key is required".to_string());
46        }
47
48        let mut headers = header::HeaderMap::new();
49        headers.insert(
50            header::AUTHORIZATION,
51            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
52                .map_err(|e| e.to_string())?,
53        );
54        headers.insert(
55            header::USER_AGENT,
56            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
57                .map_err(|e| e.to_string())?,
58        );
59
60        let client = ReqwestClient::builder()
61            .default_headers(headers)
62            .timeout(std::time::Duration::from_secs(300))
63            .build()
64            .map_err(|e| e.to_string())?;
65
66        Ok(Self {
67            client,
68            base_url: config.api_endpoint.clone(),
69        })
70    }
71
72    // =========================================================================
73    // Session APIs - New /v1/sessions endpoints
74    // =========================================================================
75
76    /// Create a new session
77    pub async fn create_session(
78        &self,
79        req: &CreateSessionRequest,
80    ) -> Result<CreateSessionResponse, String> {
81        let url = format!("{}/v1/sessions", self.base_url);
82        let response = self
83            .client
84            .post(&url)
85            .json(req)
86            .send()
87            .await
88            .map_err(|e| e.to_string())?;
89        self.handle_response(response).await
90    }
91
92    /// Create a checkpoint for a session
93    pub async fn create_checkpoint(
94        &self,
95        session_id: Uuid,
96        req: &CreateCheckpointRequest,
97    ) -> Result<CreateCheckpointResponse, String> {
98        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
99        let response = self
100            .client
101            .post(&url)
102            .json(req)
103            .send()
104            .await
105            .map_err(|e| e.to_string())?;
106        self.handle_response(response).await
107    }
108
109    /// List sessions
110    pub async fn list_sessions(
111        &self,
112        query: &ListSessionsQuery,
113    ) -> Result<ListSessionsResponse, String> {
114        let url = format!("{}/v1/sessions", self.base_url);
115        let response = self
116            .client
117            .get(&url)
118            .query(query)
119            .send()
120            .await
121            .map_err(|e| e.to_string())?;
122        self.handle_response(response).await
123    }
124
125    /// Get a session by ID
126    pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
127        let url = format!("{}/v1/sessions/{}", self.base_url, id);
128        let response = self
129            .client
130            .get(&url)
131            .send()
132            .await
133            .map_err(|e| e.to_string())?;
134        self.handle_response(response).await
135    }
136
137    /// Update a session
138    pub async fn update_session(
139        &self,
140        id: Uuid,
141        req: &UpdateSessionRequest,
142    ) -> Result<UpdateSessionResponse, String> {
143        let url = format!("{}/v1/sessions/{}", self.base_url, id);
144        let response = self
145            .client
146            .patch(&url)
147            .json(req)
148            .send()
149            .await
150            .map_err(|e| e.to_string())?;
151        self.handle_response(response).await
152    }
153
154    /// Delete a session
155    pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
156        let url = format!("{}/v1/sessions/{}", self.base_url, id);
157        let response = self
158            .client
159            .delete(&url)
160            .send()
161            .await
162            .map_err(|e| e.to_string())?;
163        self.handle_response_no_body(response).await
164    }
165
166    /// List checkpoints for a session
167    pub async fn list_checkpoints(
168        &self,
169        session_id: Uuid,
170        query: &ListCheckpointsQuery,
171    ) -> Result<ListCheckpointsResponse, String> {
172        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
173        let response = self
174            .client
175            .get(&url)
176            .query(query)
177            .send()
178            .await
179            .map_err(|e| e.to_string())?;
180        self.handle_response(response).await
181    }
182
183    /// Get a checkpoint by ID
184    pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
185        let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
186        let response = self
187            .client
188            .get(&url)
189            .send()
190            .await
191            .map_err(|e| e.to_string())?;
192        self.handle_response(response).await
193    }
194
195    // =========================================================================
196    // Cancel API
197    // =========================================================================
198
199    /// Cancel an active inference request
200    pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
201        let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
202        let response = self
203            .client
204            .post(&url)
205            .send()
206            .await
207            .map_err(|e| e.to_string())?;
208        self.handle_response_no_body(response).await
209    }
210
211    // =========================================================================
212    // Account APIs
213    // =========================================================================
214
215    /// Get the current user's account info
216    pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
217        let url = format!("{}/v1/account", self.base_url);
218        let response = self
219            .client
220            .get(&url)
221            .send()
222            .await
223            .map_err(|e| e.to_string())?;
224        self.handle_response(response).await
225    }
226
227    /// Get billing info for a user
228    pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
229        let url = format!("{}/v2/{}/billing", self.base_url, username);
230        let response = self
231            .client
232            .get(&url)
233            .send()
234            .await
235            .map_err(|e| e.to_string())?;
236        self.handle_response(response).await
237    }
238
239    // =========================================================================
240    // Rulebook APIs
241    // =========================================================================
242
243    /// List all rulebooks
244    pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
245        let url = format!("{}/v1/rules", self.base_url);
246        let response = self
247            .client
248            .get(&url)
249            .send()
250            .await
251            .map_err(|e| e.to_string())?;
252
253        let response = self.handle_response_error(response).await?;
254        let value: Value = response.json().await.map_err(|e| e.to_string())?;
255
256        match serde_json::from_value::<ListRulebooksResponse>(value) {
257            Ok(response) => Ok(response.results),
258            Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
259        }
260    }
261
262    /// Get a rulebook by URI
263    pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
264        let url = format!("{}/v1/rules/{}", self.base_url, uri);
265        let response = self
266            .client
267            .get(&url)
268            .send()
269            .await
270            .map_err(|e| e.to_string())?;
271        self.handle_response(response).await
272    }
273
274    /// Create a new rulebook
275    pub async fn create_rulebook(
276        &self,
277        input: &CreateRuleBookInput,
278    ) -> Result<CreateRuleBookResponse, String> {
279        let url = format!("{}/v1/rules", self.base_url);
280        let response = self
281            .client
282            .post(&url)
283            .json(input)
284            .send()
285            .await
286            .map_err(|e| e.to_string())?;
287        self.handle_response(response).await
288    }
289
290    /// Delete a rulebook
291    pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
292        let url = format!("{}/v1/rules/{}", self.base_url, uri);
293        let response = self
294            .client
295            .delete(&url)
296            .send()
297            .await
298            .map_err(|e| e.to_string())?;
299        self.handle_response_no_body(response).await
300    }
301
302    // =========================================================================
303    // MCP Tool APIs
304    // =========================================================================
305
306    /// Search documentation
307    pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
308        self.call_mcp_tool(&ToolsCallParams {
309            name: "search_docs".to_string(),
310            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
311        })
312        .await
313    }
314
315    /// Search memory
316    pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
317        self.call_mcp_tool(&ToolsCallParams {
318            name: "search_memory".to_string(),
319            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
320        })
321        .await
322    }
323
324    /// Memorize a session checkpoint (extract memory)
325    pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
326        let url = format!(
327            "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
328            self.base_url, checkpoint_id
329        );
330        let response = self
331            .client
332            .post(&url)
333            .send()
334            .await
335            .map_err(|e| e.to_string())?;
336        self.handle_response_no_body(response).await
337    }
338
339    /// Read Slack messages from a channel
340    pub async fn slack_read_messages(
341        &self,
342        req: &SlackReadMessagesRequest,
343    ) -> Result<Vec<Content>, String> {
344        self.call_mcp_tool(&ToolsCallParams {
345            name: "slack_read_messages".to_string(),
346            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
347        })
348        .await
349    }
350
351    /// Read Slack thread replies
352    pub async fn slack_read_replies(
353        &self,
354        req: &SlackReadRepliesRequest,
355    ) -> Result<Vec<Content>, String> {
356        self.call_mcp_tool(&ToolsCallParams {
357            name: "slack_read_replies".to_string(),
358            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
359        })
360        .await
361    }
362
363    /// Send a Slack message
364    pub async fn slack_send_message(
365        &self,
366        req: &SlackSendMessageRequest,
367    ) -> Result<Vec<Content>, String> {
368        self.call_mcp_tool(&ToolsCallParams {
369            name: "slack_send_message".to_string(),
370            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
371        })
372        .await
373    }
374
375    // =========================================================================
376    // Helper Methods
377    // =========================================================================
378
379    /// Call an MCP tool via JSON-RPC
380    async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
381        let url = format!("{}/v1/mcp", self.base_url);
382        let body = json!({
383            "jsonrpc": "2.0",
384            "id": 1,
385            "method": "tools/call",
386            "params": params
387        });
388
389        let response = self
390            .client
391            .post(&url)
392            .json(&body)
393            .send()
394            .await
395            .map_err(|e| e.to_string())?;
396
397        let resp: Value = self.handle_response(response).await?;
398
399        // Extract result.content from JSON-RPC response
400        if let Some(result) = resp.get("result")
401            && let Some(content) = result.get("content")
402        {
403            let content: Vec<Content> =
404                serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
405            return Ok(content);
406        }
407
408        // Check for error
409        if let Some(error) = resp.get("error") {
410            let msg = error
411                .get("message")
412                .and_then(|m| m.as_str())
413                .unwrap_or("Unknown error");
414            return Err(msg.to_string());
415        }
416
417        Err("Invalid MCP response format".to_string())
418    }
419
420    /// Handle response and parse JSON
421    async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
422        let response = self.handle_response_error(response).await?;
423        response.json().await.map_err(|e| e.to_string())
424    }
425
426    /// Handle response without body
427    async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
428        self.handle_response_error(response).await?;
429        Ok(())
430    }
431
432    /// Handle response errors
433    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
434        if response.status().is_success() {
435            return Ok(response);
436        }
437
438        let status = response.status();
439        let error_body = response.text().await.unwrap_or_default();
440
441        // Try to parse as API error
442        if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
443            // Special handling for API limit exceeded
444            if api_error.error.key == "EXCEEDED_API_LIMIT" {
445                return Err(format!(
446                    "{}. You can top up your billing at https://stakpak.dev/settings/billing",
447                    api_error.error.message
448                ));
449            }
450            return Err(api_error.error.message);
451        }
452
453        Err(format!("API error {}: {}", status, error_body))
454    }
455}
456
457// =============================================================================
458// Builder helpers for creating sessions and checkpoints
459// =============================================================================
460
461impl CreateSessionRequest {
462    /// Create a new session request with initial state
463    pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
464        Self {
465            title: title.into(),
466            visibility: Some(SessionVisibility::Private),
467            cwd: None,
468            state,
469        }
470    }
471
472    /// Set the working directory
473    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
474        self.cwd = Some(cwd.into());
475        self
476    }
477
478    /// Set visibility
479    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
480        self.visibility = Some(visibility);
481        self
482    }
483}
484
485impl CreateCheckpointRequest {
486    /// Create a new checkpoint request
487    pub fn new(state: CheckpointState) -> Self {
488        Self {
489            state,
490            parent_id: None,
491        }
492    }
493
494    /// Set the parent checkpoint ID (for branching)
495    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
496        self.parent_id = Some(parent_id);
497        self
498    }
499}