systemprompt_api/routes/stream/
mod.rs1use axum::extract::Extension;
2use axum::response::sse::Sse;
3use axum::response::IntoResponse;
4use axum::routing::get;
5use axum::Router;
6use std::convert::Infallible;
7use std::sync::{Arc, LazyLock};
8use systemprompt_agent::services::ContextProviderService;
9use systemprompt_events::{
10 standard_keep_alive, Broadcaster, ConnectionGuard, GenericBroadcaster, ToSse, A2A_BROADCASTER,
11 AGUI_BROADCASTER,
12};
13use systemprompt_models::RequestContext;
14use systemprompt_runtime::AppContext;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18pub mod contexts;
19
20#[derive(Clone, Debug)]
21pub struct StreamState {
22 pub context_provider: Arc<ContextProviderService>,
23}
24
25pub fn stream_router(ctx: &AppContext) -> Router {
26 let state = StreamState {
27 context_provider: Arc::new(
28 ContextProviderService::new(ctx.db_pool())
29 .expect("Failed to create ContextProviderService"),
30 ),
31 };
32
33 Router::new()
34 .route("/contexts", get(contexts::stream_context_state))
35 .route("/agui", get(stream_agui_events))
36 .route("/a2a", get(stream_a2a_events))
37 .with_state(state)
38}
39
40pub async fn stream_a2a_events(
41 Extension(request_context): Extension<RequestContext>,
42) -> impl IntoResponse {
43 create_sse_stream(&request_context, &A2A_BROADCASTER, "A2A").await
44}
45
46pub async fn stream_agui_events(
47 Extension(request_context): Extension<RequestContext>,
48) -> impl IntoResponse {
49 create_sse_stream(&request_context, &AGUI_BROADCASTER, "AgUI").await
50}
51
52#[derive(Debug)]
53pub struct StreamWithGuard<E: ToSse + Clone + Send + Sync + 'static> {
54 stream: UnboundedReceiverStream<Result<axum::response::sse::Event, Infallible>>,
55 _cleanup_guard: ConnectionGuard<E>,
56}
57
58impl<E: ToSse + Clone + Send + Sync + 'static> StreamWithGuard<E> {
59 pub fn new(
60 stream: UnboundedReceiverStream<Result<axum::response::sse::Event, Infallible>>,
61 cleanup_guard: ConnectionGuard<E>,
62 ) -> Self {
63 Self {
64 stream,
65 _cleanup_guard: cleanup_guard,
66 }
67 }
68}
69
70impl<E: ToSse + Clone + Send + Sync + 'static> futures_util::Stream for StreamWithGuard<E> {
71 type Item = Result<axum::response::sse::Event, Infallible>;
72
73 fn poll_next(
74 mut self: std::pin::Pin<&mut Self>,
75 cx: &mut std::task::Context<'_>,
76 ) -> std::task::Poll<Option<Self::Item>> {
77 std::pin::Pin::new(&mut self.stream).poll_next(cx)
78 }
79}
80
81pub async fn create_sse_stream<E: ToSse + Clone + Send + Sync + 'static>(
82 request_context: &RequestContext,
83 broadcaster: &'static LazyLock<GenericBroadcaster<E>>,
84 stream_name: &str,
85) -> impl IntoResponse {
86 let user_id = request_context.user_id().clone();
87 let user_id_str = user_id.to_string();
88 let conn_id = uuid::Uuid::new_v4().to_string();
89
90 tracing::info!(user_id = %user_id_str, conn_id = %conn_id, stream = %stream_name, "SSE stream opened");
91
92 let (tx, rx) = mpsc::unbounded_channel();
93
94 broadcaster.register(&user_id, &conn_id, tx.clone()).await;
95
96 let cleanup_guard = ConnectionGuard::new(broadcaster, user_id, conn_id.clone());
97 let stream = UnboundedReceiverStream::new(rx);
98 let stream_with_guard = StreamWithGuard::<E>::new(stream, cleanup_guard);
99
100 tracing::info!(user_id = %user_id_str, conn_id = %conn_id, stream = %stream_name, "SSE stream ready");
101
102 Sse::new(stream_with_guard)
103 .keep_alive(standard_keep_alive())
104 .into_response()
105}