1use crate::stakpak::{self as stakpak_api, StakpakApiClient, StakpakApiConfig};
6use crate::storage::{
7 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
8 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
9 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
10 StorageError, UpdateSessionRequest,
11};
12use async_trait::async_trait;
13use uuid::Uuid;
14
15#[derive(Clone)]
17pub struct StakpakStorage {
18 client: StakpakApiClient,
19}
20
21impl StakpakStorage {
22 pub fn new(api_key: &str, api_endpoint: &str) -> Result<Self, StorageError> {
24 let client = StakpakApiClient::new(&StakpakApiConfig {
25 api_key: api_key.to_string(),
26 api_endpoint: api_endpoint.to_string(),
27 })
28 .map_err(StorageError::Connection)?;
29
30 Ok(Self { client })
31 }
32
33 pub fn client(&self) -> &StakpakApiClient {
35 &self.client
36 }
37}
38
39#[async_trait]
40impl SessionStorage for StakpakStorage {
41 async fn list_sessions(
42 &self,
43 query: &ListSessionsQuery,
44 ) -> Result<ListSessionsResult, StorageError> {
45 let api_query = stakpak_api::ListSessionsQuery {
46 limit: query.limit,
47 offset: query.offset,
48 search: query.search.clone(),
49 status: query.status.map(|s| match s {
50 SessionStatus::Active => "ACTIVE".to_string(),
51 SessionStatus::Deleted => "DELETED".to_string(),
52 }),
53 visibility: query.visibility.map(|v| match v {
54 SessionVisibility::Private => "PRIVATE".to_string(),
55 SessionVisibility::Public => "PUBLIC".to_string(),
56 }),
57 };
58
59 let response = self
60 .client
61 .list_sessions(&api_query)
62 .await
63 .map_err(map_api_error)?;
64
65 Ok(ListSessionsResult {
66 sessions: response
67 .sessions
68 .into_iter()
69 .map(|s| SessionSummary {
70 id: s.id,
71 title: s.title,
72 visibility: match s.visibility {
73 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
74 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
75 },
76 status: match s.status {
77 stakpak_api::SessionStatus::Active => SessionStatus::Active,
78 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
79 },
80 cwd: s.cwd,
81 created_at: s.created_at,
82 updated_at: s.updated_at,
83 message_count: s.message_count,
84 active_checkpoint_id: Some(s.active_checkpoint_id),
85 last_message_at: s.last_message_at,
86 })
87 .collect(),
88 total: None,
89 })
90 }
91
92 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
93 let response = self
94 .client
95 .get_session(session_id)
96 .await
97 .map_err(map_api_error)?;
98 let s = response.session;
99
100 Ok(Session {
101 id: s.id,
102 title: s.title,
103 visibility: match s.visibility {
104 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
105 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
106 },
107 status: match s.status {
108 stakpak_api::SessionStatus::Active => SessionStatus::Active,
109 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
110 },
111 cwd: s.cwd,
112 created_at: s.created_at,
113 updated_at: s.updated_at,
114 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
115 id: c.id,
116 session_id: c.session_id,
117 parent_id: c.parent_id,
118 state: CheckpointState {
119 messages: c.state.messages,
120 metadata: c.state.metadata,
121 },
122 created_at: c.created_at,
123 updated_at: c.updated_at,
124 }),
125 })
126 }
127
128 async fn create_session(
129 &self,
130 request: &CreateSessionRequest,
131 ) -> Result<CreateSessionResult, StorageError> {
132 let api_request = stakpak_api::CreateSessionRequest {
133 title: request.title.clone(),
134 visibility: Some(match request.visibility {
135 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
136 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
137 }),
138 cwd: request.cwd.clone(),
139 state: stakpak_api::CheckpointState {
140 messages: request.initial_state.messages.clone(),
141 metadata: request.initial_state.metadata.clone(),
142 },
143 };
144
145 let response = self
146 .client
147 .create_session(&api_request)
148 .await
149 .map_err(map_api_error)?;
150
151 Ok(CreateSessionResult {
152 session_id: response.session_id,
153 checkpoint: Checkpoint {
154 id: response.checkpoint.id,
155 session_id: response.checkpoint.session_id,
156 parent_id: response.checkpoint.parent_id,
157 state: CheckpointState {
158 messages: response.checkpoint.state.messages,
159 metadata: response.checkpoint.state.metadata,
160 },
161 created_at: response.checkpoint.created_at,
162 updated_at: response.checkpoint.updated_at,
163 },
164 })
165 }
166
167 async fn update_session(
168 &self,
169 session_id: Uuid,
170 request: &UpdateSessionRequest,
171 ) -> Result<Session, StorageError> {
172 let api_request = stakpak_api::UpdateSessionRequest {
173 title: request.title.clone(),
174 visibility: request.visibility.map(|v| match v {
175 SessionVisibility::Private => stakpak_api::SessionVisibility::Private,
176 SessionVisibility::Public => stakpak_api::SessionVisibility::Public,
177 }),
178 };
179
180 let response = self
181 .client
182 .update_session(session_id, &api_request)
183 .await
184 .map_err(map_api_error)?;
185 let s = response.session;
186
187 Ok(Session {
188 id: s.id,
189 title: s.title,
190 visibility: match s.visibility {
191 stakpak_api::SessionVisibility::Private => SessionVisibility::Private,
192 stakpak_api::SessionVisibility::Public => SessionVisibility::Public,
193 },
194 status: match s.status {
195 stakpak_api::SessionStatus::Active => SessionStatus::Active,
196 stakpak_api::SessionStatus::Deleted => SessionStatus::Deleted,
197 },
198 cwd: s.cwd,
199 created_at: s.created_at,
200 updated_at: s.updated_at,
201 active_checkpoint: s.active_checkpoint.map(|c| Checkpoint {
202 id: c.id,
203 session_id: c.session_id,
204 parent_id: c.parent_id,
205 state: CheckpointState {
206 messages: c.state.messages,
207 metadata: c.state.metadata,
208 },
209 created_at: c.created_at,
210 updated_at: c.updated_at,
211 }),
212 })
213 }
214
215 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
216 self.client
217 .delete_session(session_id)
218 .await
219 .map_err(map_api_error)
220 }
221
222 async fn list_checkpoints(
223 &self,
224 session_id: Uuid,
225 query: &ListCheckpointsQuery,
226 ) -> Result<ListCheckpointsResult, StorageError> {
227 let api_query = stakpak_api::ListCheckpointsQuery {
228 limit: query.limit,
229 offset: query.offset,
230 include_state: query.include_state,
231 };
232
233 let response = self
234 .client
235 .list_checkpoints(session_id, &api_query)
236 .await
237 .map_err(map_api_error)?;
238
239 Ok(ListCheckpointsResult {
240 checkpoints: response
241 .checkpoints
242 .into_iter()
243 .map(|c| CheckpointSummary {
244 id: c.id,
245 session_id: c.session_id,
246 parent_id: c.parent_id,
247 message_count: c.message_count,
248 created_at: c.created_at,
249 updated_at: c.updated_at,
250 })
251 .collect(),
252 total: None,
253 })
254 }
255
256 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
257 let response = self
258 .client
259 .get_checkpoint(checkpoint_id)
260 .await
261 .map_err(map_api_error)?;
262 let c = response.checkpoint;
263
264 Ok(Checkpoint {
265 id: c.id,
266 session_id: c.session_id,
267 parent_id: c.parent_id,
268 state: CheckpointState {
269 messages: c.state.messages,
270 metadata: c.state.metadata,
271 },
272 created_at: c.created_at,
273 updated_at: c.updated_at,
274 })
275 }
276
277 async fn create_checkpoint(
278 &self,
279 session_id: Uuid,
280 request: &CreateCheckpointRequest,
281 ) -> Result<Checkpoint, StorageError> {
282 let api_request = stakpak_api::CreateCheckpointRequest {
283 state: stakpak_api::CheckpointState {
284 messages: request.state.messages.clone(),
285 metadata: request.state.metadata.clone(),
286 },
287 parent_id: request.parent_id,
288 };
289
290 let response = self
291 .client
292 .create_checkpoint(session_id, &api_request)
293 .await
294 .map_err(map_api_error)?;
295
296 Ok(Checkpoint {
297 id: response.checkpoint.id,
298 session_id: response.checkpoint.session_id,
299 parent_id: response.checkpoint.parent_id,
300 state: CheckpointState {
301 messages: response.checkpoint.state.messages,
302 metadata: response.checkpoint.state.metadata,
303 },
304 created_at: response.checkpoint.created_at,
305 updated_at: response.checkpoint.updated_at,
306 })
307 }
308}
309
310fn map_api_error(error: String) -> StorageError {
312 if error.contains("not found") || error.contains("Not found") {
313 StorageError::NotFound(error)
314 } else if error.contains("unauthorized")
315 || error.contains("Unauthorized")
316 || error.contains("401")
317 {
318 StorageError::Unauthorized(error)
319 } else if error.contains("rate limit") || error.contains("Rate limit") || error.contains("429")
320 {
321 StorageError::RateLimited(error)
322 } else if error.contains("invalid") || error.contains("Invalid") || error.contains("400") {
323 StorageError::InvalidRequest(error)
324 } else {
325 StorageError::Internal(error)
326 }
327}