1use crate::stakpak::{self as stakpak_api, StakpakApiClient, StakpakApiConfig};
6use crate::storage::{
7 BackendInfo, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
8 CreateSessionRequest, CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult,
9 ListSessionsQuery, ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary,
10 SessionVisibility, StorageError, UpdateSessionRequest,
11};
12use async_trait::async_trait;
13use uuid::Uuid;
14
15#[derive(Clone)]
17pub struct StakpakStorage {
18 client: StakpakApiClient,
19 backend_info: BackendInfo,
20}
21
22impl StakpakStorage {
23 pub fn new(api_key: &str, api_endpoint: &str) -> Result<Self, StorageError> {
25 Self::new_with_profile(api_key, api_endpoint, None)
26 }
27
28 pub fn new_with_profile(
29 api_key: &str,
30 api_endpoint: &str,
31 profile: Option<String>,
32 ) -> Result<Self, StorageError> {
33 let client = StakpakApiClient::new(&StakpakApiConfig {
34 api_key: api_key.to_string(),
35 api_endpoint: api_endpoint.to_string(),
36 })
37 .map_err(StorageError::Connection)?;
38
39 Ok(Self {
40 client,
41 backend_info: BackendInfo::stakpak_api(profile, api_endpoint.to_string()),
42 })
43 }
44
45 pub fn client(&self) -> &StakpakApiClient {
47 &self.client
48 }
49}
50
51#[async_trait]
52impl SessionStorage for StakpakStorage {
53 fn backend_info(&self) -> BackendInfo {
54 self.backend_info.clone()
55 }
56
57 async fn list_sessions(
58 &self,
59 query: &ListSessionsQuery,
60 ) -> Result<ListSessionsResult, StorageError> {
61 let api_query = stakpak_api::ListSessionsQuery {
62 limit: query.limit,
63 offset: query.offset,
64 search: query.search.clone(),
65 status: query.status.map(|s| match s {
66 SessionStatus::Active => "ACTIVE".to_string(),
67 SessionStatus::Deleted => "DELETED".to_string(),
68 }),
69 visibility: query.visibility.map(|v| match v {
70 SessionVisibility::Private => "PRIVATE".to_string(),
71 SessionVisibility::Public => "PUBLIC".to_string(),
72 }),
73 };
74
75 let response = self
76 .client
77 .list_sessions(&api_query)
78 .await
79 .map_err(map_api_error)?;
80
81 Ok(ListSessionsResult {
82 sessions: response
83 .sessions
84 .into_iter()
85 .map(|s| SessionSummary {
86 id: s.id,
87 title: s.title,
88 visibility: match s.visibility {
89 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
90 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
91 },
92 status: match s.status {
93 stakpak_api::SessionStatus::Active => SessionStatus::Active,
94 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
95 },
96 cwd: s.cwd,
97 created_at: s.created_at,
98 updated_at: s.updated_at,
99 message_count: s.message_count,
100 active_checkpoint_id: Some(s.active_checkpoint_id),
101 last_message_at: s.last_message_at,
102 })
103 .collect(),
104 total: None,
105 })
106 }
107
108 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
109 let response = self
110 .client
111 .get_session(session_id)
112 .await
113 .map_err(map_api_error)?;
114 let s = response.session;
115
116 Ok(Session {
117 id: s.id,
118 title: s.title,
119 visibility: match s.visibility {
120 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
121 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
122 },
123 status: match s.status {
124 stakpak_api::SessionStatus::Active => SessionStatus::Active,
125 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
126 },
127 cwd: s.cwd,
128 created_at: s.created_at,
129 updated_at: s.updated_at,
130 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
131 id: c.id,
132 session_id: c.session_id,
133 parent_id: c.parent_id,
134 state: CheckpointState {
135 messages: c.state.messages,
136 metadata: c.state.metadata,
137 },
138 created_at: c.created_at,
139 updated_at: c.updated_at,
140 }),
141 })
142 }
143
144 async fn create_session(
145 &self,
146 request: &CreateSessionRequest,
147 ) -> Result<CreateSessionResult, StorageError> {
148 let api_request = stakpak_api::CreateSessionRequest {
149 title: request.title.clone(),
150 visibility: Some(match request.visibility {
151 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
152 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
153 }),
154 cwd: request.cwd.clone(),
155 state: stakpak_api::CheckpointState {
156 messages: request.initial_state.messages.clone(),
157 metadata: request.initial_state.metadata.clone(),
158 },
159 };
160
161 let response = self
162 .client
163 .create_session(&api_request)
164 .await
165 .map_err(map_api_error)?;
166
167 Ok(CreateSessionResult {
168 session_id: response.session_id,
169 checkpoint: Checkpoint {
170 id: response.checkpoint.id,
171 session_id: response.checkpoint.session_id,
172 parent_id: response.checkpoint.parent_id,
173 state: CheckpointState {
174 messages: response.checkpoint.state.messages,
175 metadata: response.checkpoint.state.metadata,
176 },
177 created_at: response.checkpoint.created_at,
178 updated_at: response.checkpoint.updated_at,
179 },
180 })
181 }
182
183 async fn update_session(
184 &self,
185 session_id: Uuid,
186 request: &UpdateSessionRequest,
187 ) -> Result<Session, StorageError> {
188 let api_request = stakpak_api::UpdateSessionRequest {
189 title: request.title.clone(),
190 visibility: request.visibility.map(|v| match v {
191 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
192 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
193 }),
194 };
195
196 let response = self
197 .client
198 .update_session(session_id, &api_request)
199 .await
200 .map_err(map_api_error)?;
201 let s = response.session;
202
203 Ok(Session {
204 id: s.id,
205 title: s.title,
206 visibility: match s.visibility {
207 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
208 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
209 },
210 status: match s.status {
211 stakpak_api::SessionStatus::Active => SessionStatus::Active,
212 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
213 },
214 cwd: s.cwd,
215 created_at: s.created_at,
216 updated_at: s.updated_at,
217 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
218 id: c.id,
219 session_id: c.session_id,
220 parent_id: c.parent_id,
221 state: CheckpointState {
222 messages: c.state.messages,
223 metadata: c.state.metadata,
224 },
225 created_at: c.created_at,
226 updated_at: c.updated_at,
227 }),
228 })
229 }
230
231 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
232 self.client
233 .delete_session(session_id)
234 .await
235 .map_err(map_api_error)
236 }
237
238 async fn list_checkpoints(
239 &self,
240 session_id: Uuid,
241 query: &ListCheckpointsQuery,
242 ) -> Result<ListCheckpointsResult, StorageError> {
243 let api_query = stakpak_api::ListCheckpointsQuery {
244 limit: query.limit,
245 offset: query.offset,
246 include_state: query.include_state,
247 };
248
249 let response = self
250 .client
251 .list_checkpoints(session_id, &api_query)
252 .await
253 .map_err(map_api_error)?;
254
255 Ok(ListCheckpointsResult {
256 checkpoints: response
257 .checkpoints
258 .into_iter()
259 .map(|c| CheckpointSummary {
260 id: c.id,
261 session_id: c.session_id,
262 parent_id: c.parent_id,
263 message_count: c.message_count,
264 created_at: c.created_at,
265 updated_at: c.updated_at,
266 })
267 .collect(),
268 total: None,
269 })
270 }
271
272 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
273 let response = self
274 .client
275 .get_checkpoint(checkpoint_id)
276 .await
277 .map_err(map_api_error)?;
278 let c = response.checkpoint;
279
280 Ok(Checkpoint {
281 id: c.id,
282 session_id: c.session_id,
283 parent_id: c.parent_id,
284 state: CheckpointState {
285 messages: c.state.messages,
286 metadata: c.state.metadata,
287 },
288 created_at: c.created_at,
289 updated_at: c.updated_at,
290 })
291 }
292
293 async fn create_checkpoint(
294 &self,
295 session_id: Uuid,
296 request: &CreateCheckpointRequest,
297 ) -> Result<Checkpoint, StorageError> {
298 let api_request = stakpak_api::CreateCheckpointRequest {
299 state: stakpak_api::CheckpointState {
300 messages: request.state.messages.clone(),
301 metadata: request.state.metadata.clone(),
302 },
303 parent_id: request.parent_id,
304 };
305
306 let response = self
307 .client
308 .create_checkpoint(session_id, &api_request)
309 .await
310 .map_err(map_api_error)?;
311
312 Ok(Checkpoint {
313 id: response.checkpoint.id,
314 session_id: response.checkpoint.session_id,
315 parent_id: response.checkpoint.parent_id,
316 state: CheckpointState {
317 messages: response.checkpoint.state.messages,
318 metadata: response.checkpoint.state.metadata,
319 },
320 created_at: response.checkpoint.created_at,
321 updated_at: response.checkpoint.updated_at,
322 })
323 }
324}
325
326fn map_api_error(error: String) -> StorageError {
328 if error.contains("not found") || error.contains("Not found") {
329 StorageError::NotFound(error)
330 } else if error.contains("unauthorized")
331 || error.contains("Unauthorized")
332 || error.contains("401")
333 {
334 StorageError::Unauthorized(error)
335 } else if error.contains("rate limit") || error.contains("Rate limit") || error.contains("429")
336 {
337 StorageError::RateLimited(error)
338 } else if error.contains("invalid") || error.contains("Invalid") || error.contains("400") {
339 StorageError::InvalidRequest(error)
340 } else {
341 StorageError::Internal(error)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use axum::{Json, Router, extract::Path, routing::get};
348 use chrono::Utc;
349 use serde_json::json;
350 use tokio::net::TcpListener;
351
352 use super::*;
353
354 #[tokio::test]
355 async fn listed_remote_session_id_is_fetchable_via_get_session() {
356 let session_id = Uuid::new_v4();
357 let checkpoint_id = Uuid::new_v4();
358 let now = Utc::now().to_rfc3339();
359 let list_body = json!({
360 "sessions": [
361 {
362 "id": session_id,
363 "title": "Round Trip",
364 "visibility": "PRIVATE",
365 "status": "ACTIVE",
366 "cwd": "/tmp/project",
367 "created_at": now,
368 "updated_at": now,
369 "message_count": 1,
370 "active_checkpoint_id": checkpoint_id,
371 "last_message_at": now
372 }
373 ]
374 });
375 let show_body = json!({
376 "session": {
377 "id": session_id,
378 "title": "Round Trip",
379 "visibility": "PRIVATE",
380 "status": "ACTIVE",
381 "cwd": "/tmp/project",
382 "created_at": now,
383 "updated_at": now,
384 "deleted_at": null,
385 "active_checkpoint": {
386 "id": checkpoint_id,
387 "session_id": session_id,
388 "parent_id": null,
389 "state": {
390 "messages": [
391 {
392 "role": "user",
393 "content": "hi"
394 }
395 ],
396 "metadata": null
397 },
398 "created_at": now,
399 "updated_at": now
400 }
401 }
402 });
403
404 let app = Router::new()
405 .route(
406 "/v1/sessions",
407 get({
408 let list_body = list_body.clone();
409 move || {
410 let list_body = list_body.clone();
411 async move { Json(list_body) }
412 }
413 }),
414 )
415 .route(
416 "/v1/sessions/{id}",
417 get({
418 let show_body = show_body.clone();
419 move |Path(id): Path<Uuid>| {
420 let show_body = show_body.clone();
421 async move {
422 assert_eq!(id, session_id, "show should request listed session id");
423 Json(show_body)
424 }
425 }
426 }),
427 );
428
429 let listener = TcpListener::bind("127.0.0.1:0")
430 .await
431 .expect("bind test listener");
432 let addr = listener.local_addr().expect("local addr");
433 let server = tokio::spawn(async move {
434 axum::serve(listener, app).await.expect("serve test app");
435 });
436
437 let storage = StakpakStorage::new("test-key", &format!("http://{addr}"))
438 .expect("storage should build");
439 let listed = storage
440 .list_sessions(&ListSessionsQuery::new().with_limit(10))
441 .await
442 .expect("list sessions should succeed");
443 let first_id = listed.sessions.first().expect("session from list").id;
444 let fetched = storage
445 .get_session(first_id)
446 .await
447 .expect("get_session should accept listed id");
448
449 assert_eq!(fetched.id, session_id);
450 assert_eq!(fetched.title, "Round Trip");
451 assert_eq!(fetched.cwd.as_deref(), Some("/tmp/project"));
452
453 server.abort();
454 if let Err(join_err) = server.await
455 && !join_err.is_cancelled()
456 {
457 panic!("server task failed: {join_err}");
458 }
459 }
460}