systemprompt_api/routes/stream/
mod.rs1use axum::Router;
10use axum::extract::Extension;
11use axum::response::IntoResponse;
12use axum::response::sse::Sse;
13use axum::routing::get;
14use std::convert::Infallible;
15use std::sync::{Arc, LazyLock};
16use systemprompt_agent::services::ContextProviderService;
17use systemprompt_events::{
18 A2A_BROADCASTER, AGUI_BROADCASTER, Broadcaster, ConnectionGuard, GenericBroadcaster, ToSse,
19 standard_keep_alive,
20};
21use systemprompt_models::RequestContext;
22use systemprompt_runtime::AppContext;
23use tokio::sync::mpsc;
24use tokio_stream::wrappers::ReceiverStream;
25
26pub mod contexts;
27
28#[derive(Clone, Debug)]
29pub struct StreamState {
30 pub context_provider: Arc<ContextProviderService>,
31}
32
33pub fn stream_router(ctx: &AppContext) -> anyhow::Result<Router> {
34 let context_provider = ContextProviderService::new(ctx.db_pool())?;
35 let state = StreamState {
36 context_provider: Arc::new(context_provider),
37 };
38
39 Ok(Router::new()
40 .route("/contexts", get(contexts::stream_context_state))
41 .route("/agui", get(stream_agui_events))
42 .route("/a2a", get(stream_a2a_events))
43 .with_state(state))
44}
45
46pub async fn stream_a2a_events(
47 Extension(request_context): Extension<RequestContext>,
48) -> impl IntoResponse {
49 create_sse_stream(request_context, &A2A_BROADCASTER, "A2A").await
50}
51
52pub async fn stream_agui_events(
53 Extension(request_context): Extension<RequestContext>,
54) -> impl IntoResponse {
55 create_sse_stream(request_context, &AGUI_BROADCASTER, "AgUI").await
56}
57
58#[derive(Debug)]
59pub struct StreamWithGuard<E: ToSse + Clone + Send + Sync + 'static> {
60 stream: ReceiverStream<Result<axum::response::sse::Event, Infallible>>,
61 _cleanup_guard: ConnectionGuard<E>,
62}
63
64impl<E: ToSse + Clone + Send + Sync + 'static> StreamWithGuard<E> {
65 pub const fn new(
66 stream: ReceiverStream<Result<axum::response::sse::Event, Infallible>>,
67 cleanup_guard: ConnectionGuard<E>,
68 ) -> Self {
69 Self {
70 stream,
71 _cleanup_guard: cleanup_guard,
72 }
73 }
74}
75
76impl<E: ToSse + Clone + Send + Sync + 'static> futures_util::Stream for StreamWithGuard<E> {
77 type Item = Result<axum::response::sse::Event, Infallible>;
78
79 fn poll_next(
80 mut self: std::pin::Pin<&mut Self>,
81 cx: &mut std::task::Context<'_>,
82 ) -> std::task::Poll<Option<Self::Item>> {
83 std::pin::Pin::new(&mut self.stream).poll_next(cx)
84 }
85}
86
87pub async fn create_sse_stream<E: ToSse + Clone + Send + Sync + 'static>(
88 request_context: RequestContext,
89 broadcaster: &'static LazyLock<GenericBroadcaster<E>>,
90 stream_name: &str,
91) -> impl IntoResponse {
92 let user_id = request_context.user_id().clone();
93 let user_id_str = user_id.to_string();
94 let conn_id = systemprompt_identifiers::ConnectionId::generate();
95 let conn_id_str = conn_id.as_str().to_owned();
96
97 tracing::info!(user_id = %user_id_str, conn_id = %conn_id_str, stream = %stream_name, "SSE stream opened");
98
99 let (tx, rx) = mpsc::channel(1024);
100
101 if !broadcaster.register(&user_id, &conn_id, tx.clone()).await {
102 tracing::warn!(user_id = %user_id_str, stream = %stream_name, "SSE stream rejected: per-user connection cap reached");
103 return http::StatusCode::TOO_MANY_REQUESTS.into_response();
104 }
105
106 let cleanup_guard = ConnectionGuard::new(broadcaster, user_id, conn_id);
107 let stream = ReceiverStream::new(rx);
108 let stream_with_guard = StreamWithGuard::<E>::new(stream, cleanup_guard);
109
110 tracing::info!(user_id = %user_id_str, conn_id = %conn_id_str, stream = %stream_name, "SSE stream ready");
111
112 Sse::new(stream_with_guard)
113 .keep_alive(standard_keep_alive())
114 .into_response()
115}