Skip to main content

stakpak_api/stakpak/
client.rs

1//! StakpakApiClient implementation
2//!
3//! Provides access to Stakpak's non-inference APIs.
4
5use super::{
6    CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7    CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8    ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9    StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, models::*,
10};
11use crate::models::{
12    CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
13    ListRulebooksResponse, RuleBook,
14};
15use reqwest::{Response, header};
16use rmcp::model::Content;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19use stakpak_shared::models::billing::BillingResponse;
20use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
21use uuid::Uuid;
22
23/// Client for Stakpak's non-inference APIs
24#[derive(Clone, Debug)]
25pub struct StakpakApiClient {
26    client: reqwest::Client,
27    base_url: String,
28}
29
30/// API error response format
31#[derive(Debug, serde::Deserialize)]
32struct ApiError {
33    error: ApiErrorDetail,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ApiErrorDetail {
38    key: String,
39    message: String,
40}
41
42impl StakpakApiClient {
43    /// Create a new StakpakApiClient
44    pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
45        if config.api_key.is_empty() {
46            return Err("Stakpak API key is required".to_string());
47        }
48
49        let mut headers = header::HeaderMap::new();
50        headers.insert(
51            header::AUTHORIZATION,
52            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
53                .map_err(|e| e.to_string())?,
54        );
55        headers.insert(
56            header::USER_AGENT,
57            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
58                .map_err(|e| e.to_string())?,
59        );
60
61        let client = create_tls_client(
62            TlsClientConfig::default()
63                .with_headers(headers)
64                .with_timeout(std::time::Duration::from_secs(300)),
65        )?;
66
67        Ok(Self {
68            client,
69            base_url: config.api_endpoint.clone(),
70        })
71    }
72
73    // =========================================================================
74    // Session APIs - New /v1/sessions endpoints
75    // =========================================================================
76
77    /// Create a new session
78    pub async fn create_session(
79        &self,
80        req: &CreateSessionRequest,
81    ) -> Result<CreateSessionResponse, String> {
82        let url = format!("{}/v1/sessions", self.base_url);
83        let response = self
84            .client
85            .post(&url)
86            .json(req)
87            .send()
88            .await
89            .map_err(|e| e.to_string())?;
90        self.handle_response(response).await
91    }
92
93    /// Create a checkpoint for a session
94    pub async fn create_checkpoint(
95        &self,
96        session_id: Uuid,
97        req: &CreateCheckpointRequest,
98    ) -> Result<CreateCheckpointResponse, String> {
99        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
100        let response = self
101            .client
102            .post(&url)
103            .json(req)
104            .send()
105            .await
106            .map_err(|e| e.to_string())?;
107        self.handle_response(response).await
108    }
109
110    /// List sessions
111    pub async fn list_sessions(
112        &self,
113        query: &ListSessionsQuery,
114    ) -> Result<ListSessionsResponse, String> {
115        let url = format!("{}/v1/sessions", self.base_url);
116        let response = self
117            .client
118            .get(&url)
119            .query(query)
120            .send()
121            .await
122            .map_err(|e| e.to_string())?;
123        self.handle_response(response).await
124    }
125
126    /// Get a session by ID
127    pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
128        let url = format!("{}/v1/sessions/{}", self.base_url, id);
129        let response = self
130            .client
131            .get(&url)
132            .send()
133            .await
134            .map_err(|e| e.to_string())?;
135        self.handle_response(response).await
136    }
137
138    /// Update a session
139    pub async fn update_session(
140        &self,
141        id: Uuid,
142        req: &UpdateSessionRequest,
143    ) -> Result<UpdateSessionResponse, String> {
144        let url = format!("{}/v1/sessions/{}", self.base_url, id);
145        let response = self
146            .client
147            .patch(&url)
148            .json(req)
149            .send()
150            .await
151            .map_err(|e| e.to_string())?;
152        self.handle_response(response).await
153    }
154
155    /// Delete a session
156    pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
157        let url = format!("{}/v1/sessions/{}", self.base_url, id);
158        let response = self
159            .client
160            .delete(&url)
161            .send()
162            .await
163            .map_err(|e| e.to_string())?;
164        self.handle_response_no_body(response).await
165    }
166
167    /// List checkpoints for a session
168    pub async fn list_checkpoints(
169        &self,
170        session_id: Uuid,
171        query: &ListCheckpointsQuery,
172    ) -> Result<ListCheckpointsResponse, String> {
173        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
174        let response = self
175            .client
176            .get(&url)
177            .query(query)
178            .send()
179            .await
180            .map_err(|e| e.to_string())?;
181        self.handle_response(response).await
182    }
183
184    /// Get a checkpoint by ID
185    pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
186        let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
187        let response = self
188            .client
189            .get(&url)
190            .send()
191            .await
192            .map_err(|e| e.to_string())?;
193        self.handle_response(response).await
194    }
195
196    // =========================================================================
197    // Cancel API
198    // =========================================================================
199
200    /// Cancel an active inference request
201    pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
202        let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
203        let response = self
204            .client
205            .post(&url)
206            .send()
207            .await
208            .map_err(|e| e.to_string())?;
209        self.handle_response_no_body(response).await
210    }
211
212    // =========================================================================
213    // Account APIs
214    // =========================================================================
215
216    /// Get the current user's account info
217    pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
218        let url = format!("{}/v1/account", self.base_url);
219        let response = self
220            .client
221            .get(&url)
222            .send()
223            .await
224            .map_err(|e| e.to_string())?;
225        self.handle_response(response).await
226    }
227
228    /// Get billing info for a user
229    pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
230        let url = format!("{}/v2/{}/billing", self.base_url, username);
231        let response = self
232            .client
233            .get(&url)
234            .send()
235            .await
236            .map_err(|e| e.to_string())?;
237        self.handle_response(response).await
238    }
239
240    // =========================================================================
241    // Rulebook APIs
242    // =========================================================================
243
244    /// List all rulebooks
245    pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
246        let url = format!("{}/v1/rules", self.base_url);
247        let response = self
248            .client
249            .get(&url)
250            .send()
251            .await
252            .map_err(|e| e.to_string())?;
253
254        let response = self.handle_response_error(response).await?;
255        let value: Value = response.json().await.map_err(|e| e.to_string())?;
256
257        match serde_json::from_value::<ListRulebooksResponse>(value) {
258            Ok(response) => Ok(response.results),
259            Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
260        }
261    }
262
263    /// Get a rulebook by URI
264    pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
265        let encoded_uri = urlencoding::encode(uri);
266        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
267        let response = self
268            .client
269            .get(&url)
270            .send()
271            .await
272            .map_err(|e| e.to_string())?;
273        self.handle_response(response).await
274    }
275
276    /// Create a new rulebook
277    pub async fn create_rulebook(
278        &self,
279        input: &CreateRuleBookInput,
280    ) -> Result<CreateRuleBookResponse, String> {
281        let url = format!("{}/v1/rules", self.base_url);
282        let response = self
283            .client
284            .post(&url)
285            .json(input)
286            .send()
287            .await
288            .map_err(|e| e.to_string())?;
289        self.handle_response(response).await
290    }
291
292    /// Delete a rulebook
293    pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
294        let encoded_uri = urlencoding::encode(uri);
295        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
296        let response = self
297            .client
298            .delete(&url)
299            .send()
300            .await
301            .map_err(|e| e.to_string())?;
302        self.handle_response_no_body(response).await
303    }
304
305    // =========================================================================
306    // MCP Tool APIs
307    // =========================================================================
308
309    /// Search documentation
310    pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
311        self.call_mcp_tool(&ToolsCallParams {
312            name: "search_docs".to_string(),
313            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
314        })
315        .await
316    }
317
318    /// Search memory
319    pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
320        self.call_mcp_tool(&ToolsCallParams {
321            name: "search_memory".to_string(),
322            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
323        })
324        .await
325    }
326
327    /// Memorize a session checkpoint (extract memory)
328    pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
329        let url = format!(
330            "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
331            self.base_url, checkpoint_id
332        );
333        let response = self
334            .client
335            .post(&url)
336            .send()
337            .await
338            .map_err(|e| e.to_string())?;
339        self.handle_response_no_body(response).await
340    }
341
342    /// Read Slack messages from a channel
343    pub async fn slack_read_messages(
344        &self,
345        req: &SlackReadMessagesRequest,
346    ) -> Result<Vec<Content>, String> {
347        self.call_mcp_tool(&ToolsCallParams {
348            name: "slack_read_messages".to_string(),
349            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
350        })
351        .await
352    }
353
354    /// Read Slack thread replies
355    pub async fn slack_read_replies(
356        &self,
357        req: &SlackReadRepliesRequest,
358    ) -> Result<Vec<Content>, String> {
359        self.call_mcp_tool(&ToolsCallParams {
360            name: "slack_read_replies".to_string(),
361            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
362        })
363        .await
364    }
365
366    /// Send a Slack message
367    pub async fn slack_send_message(
368        &self,
369        req: &SlackSendMessageRequest,
370    ) -> Result<Vec<Content>, String> {
371        self.call_mcp_tool(&ToolsCallParams {
372            name: "slack_send_message".to_string(),
373            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
374        })
375        .await
376    }
377
378    // =========================================================================
379    // Helper Methods
380    // =========================================================================
381
382    /// Call an MCP tool via JSON-RPC
383    async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
384        let url = format!("{}/v1/mcp", self.base_url);
385        let body = json!({
386            "jsonrpc": "2.0",
387            "id": 1,
388            "method": "tools/call",
389            "params": params
390        });
391
392        let response = self
393            .client
394            .post(&url)
395            .json(&body)
396            .send()
397            .await
398            .map_err(|e| e.to_string())?;
399
400        let resp: Value = self.handle_response(response).await?;
401
402        // Extract result.content from JSON-RPC response
403        if let Some(result) = resp.get("result")
404            && let Some(content) = result.get("content")
405        {
406            let content: Vec<Content> =
407                serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
408            return Ok(content);
409        }
410
411        // Check for error
412        if let Some(error) = resp.get("error") {
413            let msg = error
414                .get("message")
415                .and_then(|m| m.as_str())
416                .unwrap_or("Unknown error");
417            return Err(msg.to_string());
418        }
419
420        Err("Invalid MCP response format".to_string())
421    }
422
423    /// Handle response and parse JSON
424    async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
425        let response = self.handle_response_error(response).await?;
426        response.json().await.map_err(|e| e.to_string())
427    }
428
429    /// Handle response without body
430    async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
431        self.handle_response_error(response).await?;
432        Ok(())
433    }
434
435    /// Handle response errors
436    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
437        if response.status().is_success() {
438            return Ok(response);
439        }
440
441        let status = response.status();
442        let error_body = response.text().await.unwrap_or_default();
443
444        // Try to parse as API error
445        if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
446            // Special handling for API limit exceeded
447            if api_error.error.key == "EXCEEDED_API_LIMIT" {
448                return Err(format!(
449                    "{}. You can top up your billing at https://stakpak.dev/settings/billing",
450                    api_error.error.message
451                ));
452            }
453            return Err(api_error.error.message);
454        }
455
456        Err(format!("API error {}: {}", status, error_body))
457    }
458}
459
460// =============================================================================
461// Builder helpers for creating sessions and checkpoints
462// =============================================================================
463
464impl CreateSessionRequest {
465    /// Create a new session request with initial state
466    pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
467        Self {
468            title: title.into(),
469            visibility: Some(SessionVisibility::Private),
470            cwd: None,
471            state,
472        }
473    }
474
475    /// Set the working directory
476    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
477        self.cwd = Some(cwd.into());
478        self
479    }
480
481    /// Set visibility
482    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
483        self.visibility = Some(visibility);
484        self
485    }
486}
487
488impl CreateCheckpointRequest {
489    /// Create a new checkpoint request
490    pub fn new(state: CheckpointState) -> Self {
491        Self {
492            state,
493            parent_id: None,
494        }
495    }
496
497    /// Set the parent checkpoint ID (for branching)
498    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
499        self.parent_id = Some(parent_id);
500        self
501    }
502}