1use crate::stakpak::{self as stakpak_api, StakpakApiClient, StakpakApiConfig};
6use crate::storage::{
7 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
8 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
9 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
10 StorageError, UpdateSessionRequest,
11};
12use async_trait::async_trait;
13use uuid::Uuid;
14
15#[derive(Clone)]
17pub struct StakpakStorage {
18 client: StakpakApiClient,
19}
20
21impl StakpakStorage {
22 pub fn new(api_key: &str, api_endpoint: &str) -> Result<Self, StorageError> {
24 let client = StakpakApiClient::new(&StakpakApiConfig {
25 api_key: api_key.to_string(),
26 api_endpoint: api_endpoint.to_string(),
27 })
28 .map_err(StorageError::Connection)?;
29
30 Ok(Self { client })
31 }
32
33 pub fn client(&self) -> &StakpakApiClient {
35 &self.client
36 }
37}
38
39#[async_trait]
40impl SessionStorage for StakpakStorage {
41 async fn list_sessions(
42 &self,
43 query: &ListSessionsQuery,
44 ) -> Result<ListSessionsResult, StorageError> {
45 let api_query = stakpak_api::ListSessionsQuery {
46 limit: query.limit,
47 offset: query.offset,
48 search: query.search.clone(),
49 status: query.status.map(|s| match s {
50 SessionStatus::Active => "ACTIVE".to_string(),
51 SessionStatus::Deleted => "DELETED".to_string(),
52 }),
53 visibility: query.visibility.map(|v| match v {
54 SessionVisibility::Private => "PRIVATE".to_string(),
55 SessionVisibility::Public => "PUBLIC".to_string(),
56 }),
57 };
58
59 let response = self
60 .client
61 .list_sessions(&api_query)
62 .await
63 .map_err(map_api_error)?;
64
65 Ok(ListSessionsResult {
66 sessions: response
67 .sessions
68 .into_iter()
69 .map(|s| SessionSummary {
70 id: s.id,
71 title: s.title,
72 visibility: match s.visibility {
73 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
74 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
75 },
76 status: match s.status {
77 stakpak_api::SessionStatus::Active => SessionStatus::Active,
78 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
79 },
80 cwd: s.cwd,
81 created_at: s.created_at,
82 updated_at: s.updated_at,
83 message_count: s.message_count,
84 active_checkpoint_id: Some(s.active_checkpoint_id),
85 last_message_at: s.last_message_at,
86 })
87 .collect(),
88 total: None,
89 })
90 }
91
92 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
93 let response = self
94 .client
95 .get_session(session_id)
96 .await
97 .map_err(map_api_error)?;
98 let s = response.session;
99
100 Ok(Session {
101 id: s.id,
102 title: s.title,
103 visibility: match s.visibility {
104 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
105 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
106 },
107 status: match s.status {
108 stakpak_api::SessionStatus::Active => SessionStatus::Active,
109 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
110 },
111 cwd: s.cwd,
112 created_at: s.created_at,
113 updated_at: s.updated_at,
114 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
115 id: c.id,
116 session_id: c.session_id,
117 parent_id: c.parent_id,
118 state: CheckpointState {
119 messages: c.state.messages,
120 },
121 created_at: c.created_at,
122 updated_at: c.updated_at,
123 }),
124 })
125 }
126
127 async fn create_session(
128 &self,
129 request: &CreateSessionRequest,
130 ) -> Result<CreateSessionResult, StorageError> {
131 let api_request = stakpak_api::CreateSessionRequest {
132 title: request.title.clone(),
133 visibility: Some(match request.visibility {
134 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
135 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
136 }),
137 cwd: request.cwd.clone(),
138 state: stakpak_api::CheckpointState {
139 messages: request.initial_state.messages.clone(),
140 },
141 };
142
143 let response = self
144 .client
145 .create_session(&api_request)
146 .await
147 .map_err(map_api_error)?;
148
149 Ok(CreateSessionResult {
150 session_id: response.session_id,
151 checkpoint: Checkpoint {
152 id: response.checkpoint.id,
153 session_id: response.checkpoint.session_id,
154 parent_id: response.checkpoint.parent_id,
155 state: CheckpointState {
156 messages: response.checkpoint.state.messages,
157 },
158 created_at: response.checkpoint.created_at,
159 updated_at: response.checkpoint.updated_at,
160 },
161 })
162 }
163
164 async fn update_session(
165 &self,
166 session_id: Uuid,
167 request: &UpdateSessionRequest,
168 ) -> Result<Session, StorageError> {
169 let api_request = stakpak_api::UpdateSessionRequest {
170 title: request.title.clone(),
171 visibility: request.visibility.map(|v| match v {
172 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
173 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
174 }),
175 };
176
177 let response = self
178 .client
179 .update_session(session_id, &api_request)
180 .await
181 .map_err(map_api_error)?;
182 let s = response.session;
183
184 Ok(Session {
185 id: s.id,
186 title: s.title,
187 visibility: match s.visibility {
188 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
189 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
190 },
191 status: match s.status {
192 stakpak_api::SessionStatus::Active => SessionStatus::Active,
193 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
194 },
195 cwd: s.cwd,
196 created_at: s.created_at,
197 updated_at: s.updated_at,
198 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
199 id: c.id,
200 session_id: c.session_id,
201 parent_id: c.parent_id,
202 state: CheckpointState {
203 messages: c.state.messages,
204 },
205 created_at: c.created_at,
206 updated_at: c.updated_at,
207 }),
208 })
209 }
210
211 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
212 self.client
213 .delete_session(session_id)
214 .await
215 .map_err(map_api_error)
216 }
217
218 async fn list_checkpoints(
219 &self,
220 session_id: Uuid,
221 query: &ListCheckpointsQuery,
222 ) -> Result<ListCheckpointsResult, StorageError> {
223 let api_query = stakpak_api::ListCheckpointsQuery {
224 limit: query.limit,
225 offset: query.offset,
226 include_state: query.include_state,
227 };
228
229 let response = self
230 .client
231 .list_checkpoints(session_id, &api_query)
232 .await
233 .map_err(map_api_error)?;
234
235 Ok(ListCheckpointsResult {
236 checkpoints: response
237 .checkpoints
238 .into_iter()
239 .map(|c| CheckpointSummary {
240 id: c.id,
241 session_id: c.session_id,
242 parent_id: c.parent_id,
243 message_count: c.message_count,
244 created_at: c.created_at,
245 updated_at: c.updated_at,
246 })
247 .collect(),
248 total: None,
249 })
250 }
251
252 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
253 let response = self
254 .client
255 .get_checkpoint(checkpoint_id)
256 .await
257 .map_err(map_api_error)?;
258 let c = response.checkpoint;
259
260 Ok(Checkpoint {
261 id: c.id,
262 session_id: c.session_id,
263 parent_id: c.parent_id,
264 state: CheckpointState {
265 messages: c.state.messages,
266 },
267 created_at: c.created_at,
268 updated_at: c.updated_at,
269 })
270 }
271
272 async fn create_checkpoint(
273 &self,
274 session_id: Uuid,
275 request: &CreateCheckpointRequest,
276 ) -> Result<Checkpoint, StorageError> {
277 let api_request = stakpak_api::CreateCheckpointRequest {
278 state: stakpak_api::CheckpointState {
279 messages: request.state.messages.clone(),
280 },
281 parent_id: request.parent_id,
282 };
283
284 let response = self
285 .client
286 .create_checkpoint(session_id, &api_request)
287 .await
288 .map_err(map_api_error)?;
289
290 Ok(Checkpoint {
291 id: response.checkpoint.id,
292 session_id: response.checkpoint.session_id,
293 parent_id: response.checkpoint.parent_id,
294 state: CheckpointState {
295 messages: response.checkpoint.state.messages,
296 },
297 created_at: response.checkpoint.created_at,
298 updated_at: response.checkpoint.updated_at,
299 })
300 }
301}
302
303fn map_api_error(error: String) -> StorageError {
305 if error.contains("not found") || error.contains("Not found") {
306 StorageError::NotFound(error)
307 } else if error.contains("unauthorized")
308 || error.contains("Unauthorized")
309 || error.contains("401")
310 {
311 StorageError::Unauthorized(error)
312 } else if error.contains("rate limit") || error.contains("Rate limit") || error.contains("429")
313 {
314 StorageError::RateLimited(error)
315 } else if error.contains("invalid") || error.contains("Invalid") || error.contains("400") {
316 StorageError::InvalidRequest(error)
317 } else {
318 StorageError::Internal(error)
319 }
320}