Skip to main content

stakpak_api/stakpak/
storage.rs

1//! Stakpak API storage implementation
2//!
3//! Implements SessionStorage using Stakpak's /v1/sessions API.
4
5use crate::stakpak::{self as stakpak_api, StakpakApiClient, StakpakApiConfig};
6use crate::storage::{
7    BackendInfo, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
8    CreateSessionRequest, CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult,
9    ListSessionsQuery, ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary,
10    SessionVisibility, StorageError, UpdateSessionRequest,
11};
12use async_trait::async_trait;
13use uuid::Uuid;
14
15/// Stakpak API storage implementation
16#[derive(Clone)]
17pub struct StakpakStorage {
18    client: StakpakApiClient,
19    backend_info: BackendInfo,
20}
21
22impl StakpakStorage {
23    /// Create a new Stakpak storage client
24    pub fn new(api_key: &str, api_endpoint: &str) -> Result<Self, StorageError> {
25        Self::new_with_profile(api_key, api_endpoint, None)
26    }
27
28    pub fn new_with_profile(
29        api_key: &str,
30        api_endpoint: &str,
31        profile: Option<String>,
32    ) -> Result<Self, StorageError> {
33        let client = StakpakApiClient::new(&StakpakApiConfig {
34            api_key: api_key.to_string(),
35            api_endpoint: api_endpoint.to_string(),
36        })
37        .map_err(StorageError::Connection)?;
38
39        Ok(Self {
40            client,
41            backend_info: BackendInfo::stakpak_api(profile, api_endpoint.to_string()),
42        })
43    }
44
45    /// Get the underlying API client
46    pub fn client(&self) -> &StakpakApiClient {
47        &self.client
48    }
49}
50
51#[async_trait]
52impl SessionStorage for StakpakStorage {
53    fn backend_info(&self) -> BackendInfo {
54        self.backend_info.clone()
55    }
56
57    async fn list_sessions(
58        &self,
59        query: &ListSessionsQuery,
60    ) -> Result<ListSessionsResult, StorageError> {
61        let api_query = stakpak_api::ListSessionsQuery {
62            limit: query.limit,
63            offset: query.offset,
64            search: query.search.clone(),
65            status: query.status.map(|s| match s {
66                SessionStatus::Active => "ACTIVE".to_string(),
67                SessionStatus::Deleted => "DELETED".to_string(),
68            }),
69            visibility: query.visibility.map(|v| match v {
70                SessionVisibility::Private => "PRIVATE".to_string(),
71                SessionVisibility::Public => "PUBLIC".to_string(),
72            }),
73        };
74
75        let response = self
76            .client
77            .list_sessions(&api_query)
78            .await
79            .map_err(map_api_error)?;
80
81        Ok(ListSessionsResult {
82            sessions: response
83                .sessions
84                .into_iter()
85                .map(|s| SessionSummary {
86                    id: s.id,
87                    title: s.title,
88                    visibility: match s.visibility {
89                        stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
90                        stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
91                    },
92                    status: match s.status {
93                        stakpak_api::SessionStatus::Active => SessionStatus::Active,
94                        stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
95                    },
96                    cwd: s.cwd,
97                    created_at: s.created_at,
98                    updated_at: s.updated_at,
99                    message_count: s.message_count,
100                    active_checkpoint_id: Some(s.active_checkpoint_id),
101                    last_message_at: s.last_message_at,
102                })
103                .collect(),
104            total: None,
105        })
106    }
107
108    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
109        let response = self
110            .client
111            .get_session(session_id)
112            .await
113            .map_err(map_api_error)?;
114        let s = response.session;
115
116        Ok(Session {
117            id: s.id,
118            title: s.title,
119            visibility: match s.visibility {
120                stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
121                stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
122            },
123            status: match s.status {
124                stakpak_api::SessionStatus::Active => SessionStatus::Active,
125                stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
126            },
127            cwd: s.cwd,
128            created_at: s.created_at,
129            updated_at: s.updated_at,
130            active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
131                id: c.id,
132                session_id: c.session_id,
133                parent_id: c.parent_id,
134                state: CheckpointState {
135                    messages: c.state.messages,
136                    metadata: c.state.metadata,
137                },
138                created_at: c.created_at,
139                updated_at: c.updated_at,
140            }),
141        })
142    }
143
144    async fn create_session(
145        &self,
146        request: &CreateSessionRequest,
147    ) -> Result<CreateSessionResult, StorageError> {
148        let api_request = stakpak_api::CreateSessionRequest {
149            title: request.title.clone(),
150            visibility: Some(match request.visibility {
151                SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
152                SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
153            }),
154            cwd: request.cwd.clone(),
155            state: stakpak_api::CheckpointState {
156                messages: request.initial_state.messages.clone(),
157                metadata: request.initial_state.metadata.clone(),
158            },
159        };
160
161        let response = self
162            .client
163            .create_session(&api_request)
164            .await
165            .map_err(map_api_error)?;
166
167        Ok(CreateSessionResult {
168            session_id: response.session_id,
169            checkpoint: Checkpoint {
170                id: response.checkpoint.id,
171                session_id: response.checkpoint.session_id,
172                parent_id: response.checkpoint.parent_id,
173                state: CheckpointState {
174                    messages: response.checkpoint.state.messages,
175                    metadata: response.checkpoint.state.metadata,
176                },
177                created_at: response.checkpoint.created_at,
178                updated_at: response.checkpoint.updated_at,
179            },
180        })
181    }
182
183    async fn update_session(
184        &self,
185        session_id: Uuid,
186        request: &UpdateSessionRequest,
187    ) -> Result<Session, StorageError> {
188        let api_request = stakpak_api::UpdateSessionRequest {
189            title: request.title.clone(),
190            visibility: request.visibility.map(|v| match v {
191                SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
192                SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
193            }),
194        };
195
196        let response = self
197            .client
198            .update_session(session_id, &api_request)
199            .await
200            .map_err(map_api_error)?;
201        let s = response.session;
202
203        Ok(Session {
204            id: s.id,
205            title: s.title,
206            visibility: match s.visibility {
207                stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
208                stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
209            },
210            status: match s.status {
211                stakpak_api::SessionStatus::Active => SessionStatus::Active,
212                stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
213            },
214            cwd: s.cwd,
215            created_at: s.created_at,
216            updated_at: s.updated_at,
217            active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
218                id: c.id,
219                session_id: c.session_id,
220                parent_id: c.parent_id,
221                state: CheckpointState {
222                    messages: c.state.messages,
223                    metadata: c.state.metadata,
224                },
225                created_at: c.created_at,
226                updated_at: c.updated_at,
227            }),
228        })
229    }
230
231    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
232        self.client
233            .delete_session(session_id)
234            .await
235            .map_err(map_api_error)
236    }
237
238    async fn list_checkpoints(
239        &self,
240        session_id: Uuid,
241        query: &ListCheckpointsQuery,
242    ) -> Result<ListCheckpointsResult, StorageError> {
243        let api_query = stakpak_api::ListCheckpointsQuery {
244            limit: query.limit,
245            offset: query.offset,
246            include_state: query.include_state,
247        };
248
249        let response = self
250            .client
251            .list_checkpoints(session_id, &api_query)
252            .await
253            .map_err(map_api_error)?;
254
255        Ok(ListCheckpointsResult {
256            checkpoints: response
257                .checkpoints
258                .into_iter()
259                .map(|c| CheckpointSummary {
260                    id: c.id,
261                    session_id: c.session_id,
262                    parent_id: c.parent_id,
263                    message_count: c.message_count,
264                    created_at: c.created_at,
265                    updated_at: c.updated_at,
266                })
267                .collect(),
268            total: None,
269        })
270    }
271
272    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
273        let response = self
274            .client
275            .get_checkpoint(checkpoint_id)
276            .await
277            .map_err(map_api_error)?;
278        let c = response.checkpoint;
279
280        Ok(Checkpoint {
281            id: c.id,
282            session_id: c.session_id,
283            parent_id: c.parent_id,
284            state: CheckpointState {
285                messages: c.state.messages,
286                metadata: c.state.metadata,
287            },
288            created_at: c.created_at,
289            updated_at: c.updated_at,
290        })
291    }
292
293    async fn create_checkpoint(
294        &self,
295        session_id: Uuid,
296        request: &CreateCheckpointRequest,
297    ) -> Result<Checkpoint, StorageError> {
298        let api_request = stakpak_api::CreateCheckpointRequest {
299            state: stakpak_api::CheckpointState {
300                messages: request.state.messages.clone(),
301                metadata: request.state.metadata.clone(),
302            },
303            parent_id: request.parent_id,
304        };
305
306        let response = self
307            .client
308            .create_checkpoint(session_id, &api_request)
309            .await
310            .map_err(map_api_error)?;
311
312        Ok(Checkpoint {
313            id: response.checkpoint.id,
314            session_id: response.checkpoint.session_id,
315            parent_id: response.checkpoint.parent_id,
316            state: CheckpointState {
317                messages: response.checkpoint.state.messages,
318                metadata: response.checkpoint.state.metadata,
319            },
320            created_at: response.checkpoint.created_at,
321            updated_at: response.checkpoint.updated_at,
322        })
323    }
324}
325
326/// Map API error strings to StorageError
327fn map_api_error(error: String) -> StorageError {
328    if error.contains("not found") || error.contains("Not found") {
329        StorageError::NotFound(error)
330    } else if error.contains("unauthorized")
331        || error.contains("Unauthorized")
332        || error.contains("401")
333    {
334        StorageError::Unauthorized(error)
335    } else if error.contains("rate limit") || error.contains("Rate limit") || error.contains("429")
336    {
337        StorageError::RateLimited(error)
338    } else if error.contains("invalid") || error.contains("Invalid") || error.contains("400") {
339        StorageError::InvalidRequest(error)
340    } else {
341        StorageError::Internal(error)
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use axum::{Json, Router, extract::Path, routing::get};
348    use chrono::Utc;
349    use serde_json::json;
350    use tokio::net::TcpListener;
351
352    use super::*;
353
354    #[tokio::test]
355    async fn listed_remote_session_id_is_fetchable_via_get_session() {
356        let session_id = Uuid::new_v4();
357        let checkpoint_id = Uuid::new_v4();
358        let now = Utc::now().to_rfc3339();
359        let list_body = json!({
360            "sessions": [
361                {
362                    "id": session_id,
363                    "title": "Round Trip",
364                    "visibility": "PRIVATE",
365                    "status": "ACTIVE",
366                    "cwd": "/tmp/project",
367                    "created_at": now,
368                    "updated_at": now,
369                    "message_count": 1,
370                    "active_checkpoint_id": checkpoint_id,
371                    "last_message_at": now
372                }
373            ]
374        });
375        let show_body = json!({
376            "session": {
377                "id": session_id,
378                "title": "Round Trip",
379                "visibility": "PRIVATE",
380                "status": "ACTIVE",
381                "cwd": "/tmp/project",
382                "created_at": now,
383                "updated_at": now,
384                "deleted_at": null,
385                "active_checkpoint": {
386                    "id": checkpoint_id,
387                    "session_id": session_id,
388                    "parent_id": null,
389                    "state": {
390                        "messages": [
391                            {
392                                "role": "user",
393                                "content": "hi"
394                            }
395                        ],
396                        "metadata": null
397                    },
398                    "created_at": now,
399                    "updated_at": now
400                }
401            }
402        });
403
404        let app = Router::new()
405            .route(
406                "/v1/sessions",
407                get({
408                    let list_body = list_body.clone();
409                    move || {
410                        let list_body = list_body.clone();
411                        async move { Json(list_body) }
412                    }
413                }),
414            )
415            .route(
416                "/v1/sessions/{id}",
417                get({
418                    let show_body = show_body.clone();
419                    move |Path(id): Path<Uuid>| {
420                        let show_body = show_body.clone();
421                        async move {
422                            assert_eq!(id, session_id, "show should request listed session id");
423                            Json(show_body)
424                        }
425                    }
426                }),
427            );
428
429        let listener = TcpListener::bind("127.0.0.1:0")
430            .await
431            .expect("bind test listener");
432        let addr = listener.local_addr().expect("local addr");
433        let server = tokio::spawn(async move {
434            axum::serve(listener, app).await.expect("serve test app");
435        });
436
437        let storage = StakpakStorage::new("test-key", &format!("http://{addr}"))
438            .expect("storage should build");
439        let listed = storage
440            .list_sessions(&ListSessionsQuery::new().with_limit(10))
441            .await
442            .expect("list sessions should succeed");
443        let first_id = listed.sessions.first().expect("session from list").id;
444        let fetched = storage
445            .get_session(first_id)
446            .await
447            .expect("get_session should accept listed id");
448
449        assert_eq!(fetched.id, session_id);
450        assert_eq!(fetched.title, "Round Trip");
451        assert_eq!(fetched.cwd.as_deref(), Some("/tmp/project"));
452
453        server.abort();
454        if let Err(join_err) = server.await
455            && !join_err.is_cancelled()
456        {
457            panic!("server task failed: {join_err}");
458        }
459    }
460}