Skip to main content

stakpak_api/stakpak/
storage.rs

1//! Stakpak API storage implementation
2//!
3//! Implements SessionStorage using Stakpak's /v1/sessions API.
4
5use crate::stakpak::{self as stakpak_api, StakpakApiClient, StakpakApiConfig};
6use crate::storage::{
7    Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
8    CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
9    ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
10    StorageError, UpdateSessionRequest,
11};
12use async_trait::async_trait;
13use uuid::Uuid;
14
15/// Stakpak API storage implementation
16#[derive(Clone)]
17pub struct StakpakStorage {
18    client: StakpakApiClient,
19}
20
21impl StakpakStorage {
22    /// Create a new Stakpak storage client
23    pub fn new(api_key: &str, api_endpoint: &str) -> Result<Self, StorageError> {
24        let client = StakpakApiClient::new(&StakpakApiConfig {
25            api_key: api_key.to_string(),
26            api_endpoint: api_endpoint.to_string(),
27        })
28        .map_err(StorageError::Connection)?;
29
30        Ok(Self { client })
31    }
32
33    /// Get the underlying API client
34    pub fn client(&self) -> &StakpakApiClient {
35        &self.client
36    }
37}
38
39#[async_trait]
40impl SessionStorage for StakpakStorage {
41    async fn list_sessions(
42        &self,
43        query: &ListSessionsQuery,
44    ) -> Result<ListSessionsResult, StorageError> {
45        let api_query = stakpak_api::ListSessionsQuery {
46            limit: query.limit,
47            offset: query.offset,
48            search: query.search.clone(),
49            status: query.status.map(|s| match s {
50                SessionStatus::Active => "ACTIVE".to_string(),
51                SessionStatus::Deleted => "DELETED".to_string(),
52            }),
53            visibility: query.visibility.map(|v| match v {
54                SessionVisibility::Private => "PRIVATE".to_string(),
55                SessionVisibility::Public => "PUBLIC".to_string(),
56            }),
57        };
58
59        let response = self
60            .client
61            .list_sessions(&api_query)
62            .await
63            .map_err(map_api_error)?;
64
65        Ok(ListSessionsResult {
66            sessions: response
67                .sessions
68                .into_iter()
69                .map(|s| SessionSummary {
70                    id: s.id,
71                    title: s.title,
72                    visibility: match s.visibility {
73                        stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
74                        stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
75                    },
76                    status: match s.status {
77                        stakpak_api::SessionStatus::Active => SessionStatus::Active,
78                        stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
79                    },
80                    cwd: s.cwd,
81                    created_at: s.created_at,
82                    updated_at: s.updated_at,
83                    message_count: s.message_count,
84                    active_checkpoint_id: Some(s.active_checkpoint_id),
85                    last_message_at: s.last_message_at,
86                })
87                .collect(),
88            total: None,
89        })
90    }
91
92    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
93        let response = self
94            .client
95            .get_session(session_id)
96            .await
97            .map_err(map_api_error)?;
98        let s = response.session;
99
100        Ok(Session {
101            id: s.id,
102            title: s.title,
103            visibility: match s.visibility {
104                stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
105                stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
106            },
107            status: match s.status {
108                stakpak_api::SessionStatus::Active => SessionStatus::Active,
109                stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
110            },
111            cwd: s.cwd,
112            created_at: s.created_at,
113            updated_at: s.updated_at,
114            active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
115                id: c.id,
116                session_id: c.session_id,
117                parent_id: c.parent_id,
118                state: CheckpointState {
119                    messages: c.state.messages,
120                    metadata: c.state.metadata,
121                },
122                created_at: c.created_at,
123                updated_at: c.updated_at,
124            }),
125        })
126    }
127
128    async fn create_session(
129        &self,
130        request: &CreateSessionRequest,
131    ) -> Result<CreateSessionResult, StorageError> {
132        let api_request = stakpak_api::CreateSessionRequest {
133            title: request.title.clone(),
134            visibility: Some(match request.visibility {
135                SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
136                SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
137            }),
138            cwd: request.cwd.clone(),
139            state: stakpak_api::CheckpointState {
140                messages: request.initial_state.messages.clone(),
141                metadata: request.initial_state.metadata.clone(),
142            },
143        };
144
145        let response = self
146            .client
147            .create_session(&api_request)
148            .await
149            .map_err(map_api_error)?;
150
151        Ok(CreateSessionResult {
152            session_id: response.session_id,
153            checkpoint: Checkpoint {
154                id: response.checkpoint.id,
155                session_id: response.checkpoint.session_id,
156                parent_id: response.checkpoint.parent_id,
157                state: CheckpointState {
158                    messages: response.checkpoint.state.messages,
159                    metadata: response.checkpoint.state.metadata,
160                },
161                created_at: response.checkpoint.created_at,
162                updated_at: response.checkpoint.updated_at,
163            },
164        })
165    }
166
167    async fn update_session(
168        &self,
169        session_id: Uuid,
170        request: &UpdateSessionRequest,
171    ) -> Result<Session, StorageError> {
172        let api_request = stakpak_api::UpdateSessionRequest {
173            title: request.title.clone(),
174            visibility: request.visibility.map(|v| match v {
175                SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
176                SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
177            }),
178        };
179
180        let response = self
181            .client
182            .update_session(session_id, &api_request)
183            .await
184            .map_err(map_api_error)?;
185        let s = response.session;
186
187        Ok(Session {
188            id: s.id,
189            title: s.title,
190            visibility: match s.visibility {
191                stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
192                stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
193            },
194            status: match s.status {
195                stakpak_api::SessionStatus::Active => SessionStatus::Active,
196                stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
197            },
198            cwd: s.cwd,
199            created_at: s.created_at,
200            updated_at: s.updated_at,
201            active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
202                id: c.id,
203                session_id: c.session_id,
204                parent_id: c.parent_id,
205                state: CheckpointState {
206                    messages: c.state.messages,
207                    metadata: c.state.metadata,
208                },
209                created_at: c.created_at,
210                updated_at: c.updated_at,
211            }),
212        })
213    }
214
215    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
216        self.client
217            .delete_session(session_id)
218            .await
219            .map_err(map_api_error)
220    }
221
222    async fn list_checkpoints(
223        &self,
224        session_id: Uuid,
225        query: &ListCheckpointsQuery,
226    ) -> Result<ListCheckpointsResult, StorageError> {
227        let api_query = stakpak_api::ListCheckpointsQuery {
228            limit: query.limit,
229            offset: query.offset,
230            include_state: query.include_state,
231        };
232
233        let response = self
234            .client
235            .list_checkpoints(session_id, &api_query)
236            .await
237            .map_err(map_api_error)?;
238
239        Ok(ListCheckpointsResult {
240            checkpoints: response
241                .checkpoints
242                .into_iter()
243                .map(|c| CheckpointSummary {
244                    id: c.id,
245                    session_id: c.session_id,
246                    parent_id: c.parent_id,
247                    message_count: c.message_count,
248                    created_at: c.created_at,
249                    updated_at: c.updated_at,
250                })
251                .collect(),
252            total: None,
253        })
254    }
255
256    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
257        let response = self
258            .client
259            .get_checkpoint(checkpoint_id)
260            .await
261            .map_err(map_api_error)?;
262        let c = response.checkpoint;
263
264        Ok(Checkpoint {
265            id: c.id,
266            session_id: c.session_id,
267            parent_id: c.parent_id,
268            state: CheckpointState {
269                messages: c.state.messages,
270                metadata: c.state.metadata,
271            },
272            created_at: c.created_at,
273            updated_at: c.updated_at,
274        })
275    }
276
277    async fn create_checkpoint(
278        &self,
279        session_id: Uuid,
280        request: &CreateCheckpointRequest,
281    ) -> Result<Checkpoint, StorageError> {
282        let api_request = stakpak_api::CreateCheckpointRequest {
283            state: stakpak_api::CheckpointState {
284                messages: request.state.messages.clone(),
285                metadata: request.state.metadata.clone(),
286            },
287            parent_id: request.parent_id,
288        };
289
290        let response = self
291            .client
292            .create_checkpoint(session_id, &api_request)
293            .await
294            .map_err(map_api_error)?;
295
296        Ok(Checkpoint {
297            id: response.checkpoint.id,
298            session_id: response.checkpoint.session_id,
299            parent_id: response.checkpoint.parent_id,
300            state: CheckpointState {
301                messages: response.checkpoint.state.messages,
302                metadata: response.checkpoint.state.metadata,
303            },
304            created_at: response.checkpoint.created_at,
305            updated_at: response.checkpoint.updated_at,
306        })
307    }
308}
309
310/// Map API error strings to StorageError
311fn map_api_error(error: String) -> StorageError {
312    if error.contains("not found") || error.contains("Not found") {
313        StorageError::NotFound(error)
314    } else if error.contains("unauthorized")
315        || error.contains("Unauthorized")
316        || error.contains("401")
317    {
318        StorageError::Unauthorized(error)
319    } else if error.contains("rate limit") || error.contains("Rate limit") || error.contains("429")
320    {
321        StorageError::RateLimited(error)
322    } else if error.contains("invalid") || error.contains("Invalid") || error.contains("400") {
323        StorageError::InvalidRequest(error)
324    } else {
325        StorageError::Internal(error)
326    }
327}