Skip to main content

stakpak_api/stakpak/
client.rs

1//! StakpakApiClient implementation
2//!
3//! Provides access to Stakpak's non-inference APIs.
4
5use super::{
6    CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7    CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8    ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9    StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, knowledge::AccountCacheState,
10    models::*,
11};
12use crate::models::{
13    CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
14    ListRulebooksResponse, RuleBook,
15};
16use reqwest::{Response, header};
17use rmcp::model::Content;
18use serde::de::DeserializeOwned;
19use serde_json::{Value, json};
20use stakpak_shared::models::billing::BillingResponse;
21use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use uuid::Uuid;
25
26/// Client for Stakpak's non-inference APIs
27#[derive(Clone, Debug)]
28pub struct StakpakApiClient {
29    pub(super) client: reqwest::Client,
30    pub(super) base_url: String,
31    pub(super) account_name: Arc<Mutex<AccountCacheState>>,
32}
33
34/// API error response format
35#[derive(Debug, serde::Deserialize)]
36pub(super) struct ApiError {
37    pub(super) error: ApiErrorDetail,
38}
39
40#[derive(Debug, serde::Deserialize)]
41pub(super) struct ApiErrorDetail {
42    pub(super) key: String,
43    pub(super) message: String,
44}
45
46impl StakpakApiClient {
47    /// Create a new StakpakApiClient
48    pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
49        if config.api_key.is_empty() {
50            return Err("Stakpak API key is required".to_string());
51        }
52
53        let mut headers = header::HeaderMap::new();
54        headers.insert(
55            header::AUTHORIZATION,
56            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
57                .map_err(|e| e.to_string())?,
58        );
59        headers.insert(
60            header::USER_AGENT,
61            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
62                .map_err(|e| e.to_string())?,
63        );
64
65        let client = create_tls_client(
66            TlsClientConfig::default()
67                .with_headers(headers)
68                .with_timeout(std::time::Duration::from_secs(300)),
69        )?;
70
71        Ok(Self {
72            client,
73            base_url: config.api_endpoint.clone(),
74            account_name: Arc::new(Mutex::new(AccountCacheState::Unknown)),
75        })
76    }
77
78    // =========================================================================
79    // Session APIs - New /v1/sessions endpoints
80    // =========================================================================
81
82    /// Create a new session
83    pub async fn create_session(
84        &self,
85        req: &CreateSessionRequest,
86    ) -> Result<CreateSessionResponse, String> {
87        let url = format!("{}/v1/sessions", self.base_url);
88        let response = self
89            .client
90            .post(&url)
91            .json(req)
92            .send()
93            .await
94            .map_err(|e| e.to_string())?;
95        self.handle_response(response).await
96    }
97
98    /// Create a checkpoint for a session
99    pub async fn create_checkpoint(
100        &self,
101        session_id: Uuid,
102        req: &CreateCheckpointRequest,
103    ) -> Result<CreateCheckpointResponse, String> {
104        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
105        let response = self
106            .client
107            .post(&url)
108            .json(req)
109            .send()
110            .await
111            .map_err(|e| e.to_string())?;
112        self.handle_response(response).await
113    }
114
115    /// List sessions
116    pub async fn list_sessions(
117        &self,
118        query: &ListSessionsQuery,
119    ) -> Result<ListSessionsResponse, String> {
120        let url = format!("{}/v1/sessions", self.base_url);
121        let response = self
122            .client
123            .get(&url)
124            .query(query)
125            .send()
126            .await
127            .map_err(|e| e.to_string())?;
128        self.handle_response(response).await
129    }
130
131    /// Get a session by ID
132    pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
133        let url = format!("{}/v1/sessions/{}", self.base_url, id);
134        let response = self
135            .client
136            .get(&url)
137            .send()
138            .await
139            .map_err(|e| e.to_string())?;
140        self.handle_response(response).await
141    }
142
143    /// Update a session
144    pub async fn update_session(
145        &self,
146        id: Uuid,
147        req: &UpdateSessionRequest,
148    ) -> Result<UpdateSessionResponse, String> {
149        let url = format!("{}/v1/sessions/{}", self.base_url, id);
150        let response = self
151            .client
152            .patch(&url)
153            .json(req)
154            .send()
155            .await
156            .map_err(|e| e.to_string())?;
157        self.handle_response(response).await
158    }
159
160    /// Delete a session
161    pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
162        let url = format!("{}/v1/sessions/{}", self.base_url, id);
163        let response = self
164            .client
165            .delete(&url)
166            .send()
167            .await
168            .map_err(|e| e.to_string())?;
169        self.handle_response_no_body(response).await
170    }
171
172    /// List checkpoints for a session
173    pub async fn list_checkpoints(
174        &self,
175        session_id: Uuid,
176        query: &ListCheckpointsQuery,
177    ) -> Result<ListCheckpointsResponse, String> {
178        let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
179        let response = self
180            .client
181            .get(&url)
182            .query(query)
183            .send()
184            .await
185            .map_err(|e| e.to_string())?;
186        self.handle_response(response).await
187    }
188
189    /// Get a checkpoint by ID
190    pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
191        let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
192        let response = self
193            .client
194            .get(&url)
195            .send()
196            .await
197            .map_err(|e| e.to_string())?;
198        self.handle_response(response).await
199    }
200
201    // =========================================================================
202    // Cancel API
203    // =========================================================================
204
205    /// Cancel an active inference request
206    pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
207        let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
208        let response = self
209            .client
210            .post(&url)
211            .send()
212            .await
213            .map_err(|e| e.to_string())?;
214        self.handle_response_no_body(response).await
215    }
216
217    // =========================================================================
218    // Account APIs
219    // =========================================================================
220
221    /// Get the current user's account info
222    pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
223        let url = format!("{}/v1/account", self.base_url);
224        let response = self
225            .client
226            .get(&url)
227            .send()
228            .await
229            .map_err(|e| e.to_string())?;
230        self.handle_response(response).await
231    }
232
233    /// Get billing info for a user
234    pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
235        let url = format!("{}/v2/{}/billing", self.base_url, username);
236        let response = self
237            .client
238            .get(&url)
239            .send()
240            .await
241            .map_err(|e| e.to_string())?;
242        self.handle_response(response).await
243    }
244
245    // =========================================================================
246    // Rulebook APIs
247    // =========================================================================
248
249    /// List all rulebooks
250    pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
251        let url = format!("{}/v1/rules", self.base_url);
252        let response = self
253            .client
254            .get(&url)
255            .send()
256            .await
257            .map_err(|e| e.to_string())?;
258
259        let response = self.handle_response_error(response).await?;
260        let value: Value = response.json().await.map_err(|e| e.to_string())?;
261
262        match serde_json::from_value::<ListRulebooksResponse>(value) {
263            Ok(response) => Ok(response.results),
264            Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
265        }
266    }
267
268    /// Get a rulebook by URI
269    pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
270        let encoded_uri = urlencoding::encode(uri);
271        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
272        let response = self
273            .client
274            .get(&url)
275            .send()
276            .await
277            .map_err(|e| e.to_string())?;
278        self.handle_response(response).await
279    }
280
281    /// Create a new rulebook
282    pub async fn create_rulebook(
283        &self,
284        input: &CreateRuleBookInput,
285    ) -> Result<CreateRuleBookResponse, String> {
286        let url = format!("{}/v1/rules", self.base_url);
287        let response = self
288            .client
289            .post(&url)
290            .json(input)
291            .send()
292            .await
293            .map_err(|e| e.to_string())?;
294        self.handle_response(response).await
295    }
296
297    /// Delete a rulebook
298    pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
299        let encoded_uri = urlencoding::encode(uri);
300        let url = format!("{}/v1/rules/{}", self.base_url, encoded_uri);
301        let response = self
302            .client
303            .delete(&url)
304            .send()
305            .await
306            .map_err(|e| e.to_string())?;
307        self.handle_response_no_body(response).await
308    }
309
310    // =========================================================================
311    // MCP Tool APIs
312    // =========================================================================
313
314    /// Search documentation
315    pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
316        self.call_mcp_tool(&ToolsCallParams {
317            name: "search_docs".to_string(),
318            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
319        })
320        .await
321    }
322
323    /// Search memory
324    pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
325        self.call_mcp_tool(&ToolsCallParams {
326            name: "search_memory".to_string(),
327            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
328        })
329        .await
330    }
331
332    /// Memorize a session checkpoint (extract memory)
333    pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
334        let url = format!(
335            "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
336            self.base_url, checkpoint_id
337        );
338        let response = self
339            .client
340            .post(&url)
341            .send()
342            .await
343            .map_err(|e| e.to_string())?;
344        self.handle_response_no_body(response).await
345    }
346
347    /// Read Slack messages from a channel
348    pub async fn slack_read_messages(
349        &self,
350        req: &SlackReadMessagesRequest,
351    ) -> Result<Vec<Content>, String> {
352        self.call_mcp_tool(&ToolsCallParams {
353            name: "slack_read_messages".to_string(),
354            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
355        })
356        .await
357    }
358
359    /// Read Slack thread replies
360    pub async fn slack_read_replies(
361        &self,
362        req: &SlackReadRepliesRequest,
363    ) -> Result<Vec<Content>, String> {
364        self.call_mcp_tool(&ToolsCallParams {
365            name: "slack_read_replies".to_string(),
366            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
367        })
368        .await
369    }
370
371    /// Send a Slack message
372    pub async fn slack_send_message(
373        &self,
374        req: &SlackSendMessageRequest,
375    ) -> Result<Vec<Content>, String> {
376        self.call_mcp_tool(&ToolsCallParams {
377            name: "slack_send_message".to_string(),
378            arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
379        })
380        .await
381    }
382
383    // =========================================================================
384    // Helper Methods
385    // =========================================================================
386
387    /// Call an MCP tool via JSON-RPC
388    async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
389        let url = format!("{}/v1/mcp", self.base_url);
390        let body = json!({
391            "jsonrpc": "2.0",
392            "id": 1,
393            "method": "tools/call",
394            "params": params
395        });
396
397        let response = self
398            .client
399            .post(&url)
400            .json(&body)
401            .send()
402            .await
403            .map_err(|e| e.to_string())?;
404
405        let resp: Value = self.handle_response(response).await?;
406
407        // Extract result.content from JSON-RPC response
408        if let Some(result) = resp.get("result")
409            && let Some(content) = result.get("content")
410        {
411            let content: Vec<Content> =
412                serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
413            return Ok(content);
414        }
415
416        // Check for error
417        if let Some(error) = resp.get("error") {
418            let msg = error
419                .get("message")
420                .and_then(|m| m.as_str())
421                .unwrap_or("Unknown error");
422            return Err(msg.to_string());
423        }
424
425        Err("Invalid MCP response format".to_string())
426    }
427
428    /// Handle response and parse JSON
429    async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
430        let response = self.handle_response_error(response).await?;
431        let url = response.url().to_string();
432        let status = response.status();
433        let body = response.text().await.map_err(|e| {
434            format!(
435                "Failed to read response body from {} (status {}): {}",
436                url, status, e
437            )
438        })?;
439        serde_json::from_str(&body).map_err(|e| {
440            // Truncate body to avoid flooding the error message
441            let truncated_body: String = body.chars().take(500).collect();
442            format!(
443                "Failed to decode response from {} (status {}): {} | body: {}",
444                url, status, e, truncated_body
445            )
446        })
447    }
448
449    /// Handle response without body
450    async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
451        self.handle_response_error(response).await?;
452        Ok(())
453    }
454
455    /// Handle response errors
456    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
457        if response.status().is_success() {
458            return Ok(response);
459        }
460
461        let status = response.status();
462        let error_body = response.text().await.unwrap_or_default();
463
464        // Try to parse as API error
465        if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
466            // Special handling for API limit exceeded
467            if api_error.error.key == "EXCEEDED_API_LIMIT" {
468                return Err(format!(
469                    "{}. You can top up your billing at https://stakpak.dev/settings/billing",
470                    api_error.error.message
471                ));
472            }
473            return Err(api_error.error.message);
474        }
475
476        Err(format!("API error {}: {}", status, error_body))
477    }
478}
479
480// =============================================================================
481// Builder helpers for creating sessions and checkpoints
482// =============================================================================
483
484impl CreateSessionRequest {
485    /// Create a new session request with initial state
486    pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
487        Self {
488            title: title.into(),
489            visibility: Some(SessionVisibility::Private),
490            cwd: None,
491            state,
492        }
493    }
494
495    /// Set the working directory
496    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
497        self.cwd = Some(cwd.into());
498        self
499    }
500
501    /// Set visibility
502    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
503        self.visibility = Some(visibility);
504        self
505    }
506}
507
508impl CreateCheckpointRequest {
509    /// Create a new checkpoint request
510    pub fn new(state: CheckpointState) -> Self {
511        Self {
512            state,
513            parent_id: None,
514        }
515    }
516
517    /// Set the parent checkpoint ID (for branching)
518    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
519        self.parent_id = Some(parent_id);
520        self
521    }
522}