1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use async_stream::stream;
8use axum::extract::{Path, State};
9use axum::http::StatusCode;
10use axum::response::sse::{Event, KeepAlive, Sse};
11use axum::response::IntoResponse;
12use axum::routing::{get, post};
13use axum::{Json, Router};
14use runtime::{ConversationMessage, Session as RuntimeSession};
15use serde::{Deserialize, Serialize};
16use tokio::sync::{broadcast, RwLock};
17
18pub type SessionId = String;
19pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>;
20
21const BROADCAST_CAPACITY: usize = 64;
22
23#[derive(Clone)]
24pub struct AppState {
25 sessions: SessionStore,
26 next_session_id: Arc<AtomicU64>,
27}
28
29impl AppState {
30 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 sessions: Arc::new(RwLock::new(HashMap::new())),
34 next_session_id: Arc::new(AtomicU64::new(1)),
35 }
36 }
37
38 fn allocate_session_id(&self) -> SessionId {
39 let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
40 format!("session-{id}")
41 }
42}
43
44impl Default for AppState {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50#[derive(Clone)]
51pub struct Session {
52 pub id: SessionId,
53 pub created_at: u64,
54 pub conversation: RuntimeSession,
55 events: broadcast::Sender<SessionEvent>,
56}
57
58impl Session {
59 fn new(id: SessionId) -> Self {
60 let (events, _) = broadcast::channel(BROADCAST_CAPACITY);
61 Self {
62 id,
63 created_at: unix_timestamp_millis(),
64 conversation: RuntimeSession::new(),
65 events,
66 }
67 }
68
69 fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
70 self.events.subscribe()
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75#[serde(tag = "type", rename_all = "snake_case")]
76enum SessionEvent {
77 Snapshot {
78 session_id: SessionId,
79 session: RuntimeSession,
80 },
81 Message {
82 session_id: SessionId,
83 message: ConversationMessage,
84 },
85}
86
87impl SessionEvent {
88 fn event_name(&self) -> &'static str {
89 match self {
90 Self::Snapshot { .. } => "snapshot",
91 Self::Message { .. } => "message",
92 }
93 }
94
95 fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
96 Ok(Event::default()
97 .event(self.event_name())
98 .data(serde_json::to_string(self)?))
99 }
100}
101
102#[derive(Debug, Serialize)]
103struct ErrorResponse {
104 error: String,
105}
106
107type ApiError = (StatusCode, Json<ErrorResponse>);
108type ApiResult<T> = Result<T, ApiError>;
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
111pub struct CreateSessionResponse {
112 pub session_id: SessionId,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
116pub struct SessionSummary {
117 pub id: SessionId,
118 pub created_at: u64,
119 pub message_count: usize,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
123pub struct ListSessionsResponse {
124 pub sessions: Vec<SessionSummary>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
128pub struct SessionDetailsResponse {
129 pub id: SessionId,
130 pub created_at: u64,
131 pub session: RuntimeSession,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
135pub struct SendMessageRequest {
136 pub message: String,
137}
138
139pub fn app(state: AppState) -> Router {
140 Router::new()
141 .route("/sessions", post(create_session).get(list_sessions))
142 .route("/sessions/{id}", get(get_session))
143 .route("/sessions/{id}/events", get(stream_session_events))
144 .route("/sessions/{id}/message", post(send_message))
145 .with_state(state)
146}
147
148async fn create_session(
149 State(state): State<AppState>,
150) -> (StatusCode, Json<CreateSessionResponse>) {
151 let session_id = state.allocate_session_id();
152 let session = Session::new(session_id.clone());
153
154 state
155 .sessions
156 .write()
157 .await
158 .insert(session_id.clone(), session);
159
160 (
161 StatusCode::CREATED,
162 Json(CreateSessionResponse { session_id }),
163 )
164}
165
166async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> {
167 let sessions = state.sessions.read().await;
168 let mut summaries = sessions
169 .values()
170 .map(|session| SessionSummary {
171 id: session.id.clone(),
172 created_at: session.created_at,
173 message_count: session.conversation.messages.len(),
174 })
175 .collect::<Vec<_>>();
176 summaries.sort_by(|left, right| left.id.cmp(&right.id));
177
178 Json(ListSessionsResponse {
179 sessions: summaries,
180 })
181}
182
183async fn get_session(
184 State(state): State<AppState>,
185 Path(id): Path<SessionId>,
186) -> ApiResult<Json<SessionDetailsResponse>> {
187 let sessions = state.sessions.read().await;
188 let session = sessions
189 .get(&id)
190 .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
191
192 Ok(Json(SessionDetailsResponse {
193 id: session.id.clone(),
194 created_at: session.created_at,
195 session: session.conversation.clone(),
196 }))
197}
198
199async fn send_message(
200 State(state): State<AppState>,
201 Path(id): Path<SessionId>,
202 Json(payload): Json<SendMessageRequest>,
203) -> ApiResult<StatusCode> {
204 let message = ConversationMessage::user_text(payload.message);
205 let broadcaster = {
206 let mut sessions = state.sessions.write().await;
207 let session = sessions
208 .get_mut(&id)
209 .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
210 session.conversation.messages.push(message.clone());
211 session.events.clone()
212 };
213
214 let _ = broadcaster.send(SessionEvent::Message {
215 session_id: id,
216 message,
217 });
218
219 Ok(StatusCode::NO_CONTENT)
220}
221
222async fn stream_session_events(
223 State(state): State<AppState>,
224 Path(id): Path<SessionId>,
225) -> ApiResult<impl IntoResponse> {
226 let (snapshot, mut receiver) = {
227 let sessions = state.sessions.read().await;
228 let session = sessions
229 .get(&id)
230 .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
231 (
232 SessionEvent::Snapshot {
233 session_id: session.id.clone(),
234 session: session.conversation.clone(),
235 },
236 session.subscribe(),
237 )
238 };
239
240 let stream = stream! {
241 if let Ok(event) = snapshot.to_sse_event() {
242 yield Ok::<Event, Infallible>(event);
243 }
244
245 loop {
246 match receiver.recv().await {
247 Ok(event) => {
248 if let Ok(sse_event) = event.to_sse_event() {
249 yield Ok::<Event, Infallible>(sse_event);
250 }
251 }
252 Err(broadcast::error::RecvError::Lagged(_)) => {},
253 Err(broadcast::error::RecvError::Closed) => break,
254 }
255 }
256 };
257
258 Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
259}
260
261fn unix_timestamp_millis() -> u64 {
262 u64::try_from(
263 SystemTime::now()
264 .duration_since(UNIX_EPOCH)
265 .expect("system time should be after epoch")
266 .as_millis(),
267 )
268 .unwrap_or(u64::MAX)
269}
270
271fn not_found(message: String) -> ApiError {
272 (
273 StatusCode::NOT_FOUND,
274 Json(ErrorResponse { error: message }),
275 )
276}
277
278#[cfg(test)]
279mod tests {
280 use super::{
281 app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse,
282 };
283 use reqwest::Client;
284 use std::net::SocketAddr;
285 use std::time::Duration;
286 use tokio::net::TcpListener;
287 use tokio::task::JoinHandle;
288 use tokio::time::timeout;
289
290 struct TestServer {
291 address: SocketAddr,
292 handle: JoinHandle<()>,
293 }
294
295 impl TestServer {
296 async fn spawn() -> Self {
297 let listener = TcpListener::bind("127.0.0.1:0")
298 .await
299 .expect("test listener should bind");
300 let address = listener
301 .local_addr()
302 .expect("listener should report local address");
303 let handle = tokio::spawn(async move {
304 axum::serve(listener, app(AppState::default()))
305 .await
306 .expect("server should run");
307 });
308
309 Self { address, handle }
310 }
311
312 fn url(&self, path: &str) -> String {
313 format!("http://{}{}", self.address, path)
314 }
315 }
316
317 impl Drop for TestServer {
318 fn drop(&mut self) {
319 self.handle.abort();
320 }
321 }
322
323 async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse {
324 client
325 .post(server.url("/sessions"))
326 .send()
327 .await
328 .expect("create request should succeed")
329 .error_for_status()
330 .expect("create request should return success")
331 .json::<CreateSessionResponse>()
332 .await
333 .expect("create response should parse")
334 }
335
336 async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String {
337 loop {
338 if let Some(index) = buffer.find("\n\n") {
339 let frame = buffer[..index].to_string();
340 let remainder = buffer[index + 2..].to_string();
341 *buffer = remainder;
342 return frame;
343 }
344
345 let next_chunk = timeout(Duration::from_secs(5), response.chunk())
346 .await
347 .expect("SSE stream should yield within timeout")
348 .expect("SSE stream should remain readable")
349 .expect("SSE stream should stay open");
350 buffer.push_str(&String::from_utf8_lossy(&next_chunk));
351 }
352 }
353
354 #[tokio::test]
355 async fn creates_and_lists_sessions() {
356 let server = TestServer::spawn().await;
357 let client = Client::new();
358
359 let created = create_session(&client, &server).await;
361
362 let sessions = client
364 .get(server.url("/sessions"))
365 .send()
366 .await
367 .expect("list request should succeed")
368 .error_for_status()
369 .expect("list request should return success")
370 .json::<ListSessionsResponse>()
371 .await
372 .expect("list response should parse");
373 let details = client
374 .get(server.url(&format!("/sessions/{}", created.session_id)))
375 .send()
376 .await
377 .expect("details request should succeed")
378 .error_for_status()
379 .expect("details request should return success")
380 .json::<SessionDetailsResponse>()
381 .await
382 .expect("details response should parse");
383
384 assert_eq!(created.session_id, "session-1");
386 assert_eq!(sessions.sessions.len(), 1);
387 assert_eq!(sessions.sessions[0].id, created.session_id);
388 assert_eq!(sessions.sessions[0].message_count, 0);
389 assert_eq!(details.id, "session-1");
390 assert!(details.session.messages.is_empty());
391 }
392
393 #[tokio::test]
394 async fn streams_message_events_and_persists_message_flow() {
395 let server = TestServer::spawn().await;
396 let client = Client::new();
397
398 let created = create_session(&client, &server).await;
400 let mut response = client
401 .get(server.url(&format!("/sessions/{}/events", created.session_id)))
402 .send()
403 .await
404 .expect("events request should succeed")
405 .error_for_status()
406 .expect("events request should return success");
407 let mut buffer = String::new();
408 let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await;
409
410 let send_status = client
412 .post(server.url(&format!("/sessions/{}/message", created.session_id)))
413 .json(&super::SendMessageRequest {
414 message: "hello from test".to_string(),
415 })
416 .send()
417 .await
418 .expect("message request should succeed")
419 .status();
420 let message_frame = next_sse_frame(&mut response, &mut buffer).await;
421 let details = client
422 .get(server.url(&format!("/sessions/{}", created.session_id)))
423 .send()
424 .await
425 .expect("details request should succeed")
426 .error_for_status()
427 .expect("details request should return success")
428 .json::<SessionDetailsResponse>()
429 .await
430 .expect("details response should parse");
431
432 assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT);
434 assert!(snapshot_frame.contains("event: snapshot"));
435 assert!(snapshot_frame.contains("\"session_id\":\"session-1\""));
436 assert!(message_frame.contains("event: message"));
437 assert!(message_frame.contains("hello from test"));
438 assert_eq!(details.session.messages.len(), 1);
439 assert_eq!(
440 details.session.messages[0],
441 runtime::ConversationMessage::user_text("hello from test")
442 );
443 }
444}