Skip to main content

systemprompt_api/routes/stream/
mod.rs

1//! Server-sent event stream routes for live A2A, `AgUI`, and context-state
2//! feeds.
3//!
4//! Each route opens a per-user SSE connection backed by a broadcaster from
5//! `systemprompt_events`. [`create_sse_stream`] registers the connection (with
6//! a per-user cap), wraps the receiver in a [`StreamWithGuard`] so the
7//! [`ConnectionGuard`] deregisters it on drop, and emits keep-alive frames.
8
9use axum::Router;
10use axum::extract::Extension;
11use axum::response::IntoResponse;
12use axum::response::sse::Sse;
13use axum::routing::get;
14use std::convert::Infallible;
15use std::sync::{Arc, LazyLock};
16use systemprompt_agent::services::ContextProviderService;
17use systemprompt_events::{
18    A2A_BROADCASTER, AGUI_BROADCASTER, Broadcaster, ConnectionGuard, GenericBroadcaster, ToSse,
19    standard_keep_alive,
20};
21use systemprompt_models::RequestContext;
22use systemprompt_runtime::AppContext;
23use tokio::sync::mpsc;
24use tokio_stream::wrappers::ReceiverStream;
25
26pub mod contexts;
27
28#[derive(Clone, Debug)]
29pub struct StreamState {
30    pub context_provider: Arc<ContextProviderService>,
31}
32
33pub fn stream_router(ctx: &AppContext) -> anyhow::Result<Router> {
34    let context_provider = ContextProviderService::new(ctx.db_pool())?;
35    let state = StreamState {
36        context_provider: Arc::new(context_provider),
37    };
38
39    Ok(Router::new()
40        .route("/contexts", get(contexts::stream_context_state))
41        .route("/agui", get(stream_agui_events))
42        .route("/a2a", get(stream_a2a_events))
43        .with_state(state))
44}
45
46pub async fn stream_a2a_events(
47    Extension(request_context): Extension<RequestContext>,
48) -> impl IntoResponse {
49    create_sse_stream(request_context, &A2A_BROADCASTER, "A2A").await
50}
51
52pub async fn stream_agui_events(
53    Extension(request_context): Extension<RequestContext>,
54) -> impl IntoResponse {
55    create_sse_stream(request_context, &AGUI_BROADCASTER, "AgUI").await
56}
57
58#[derive(Debug)]
59pub struct StreamWithGuard<E: ToSse + Clone + Send + Sync + 'static> {
60    stream: ReceiverStream<Result<axum::response::sse::Event, Infallible>>,
61    _cleanup_guard: ConnectionGuard<E>,
62}
63
64impl<E: ToSse + Clone + Send + Sync + 'static> StreamWithGuard<E> {
65    pub const fn new(
66        stream: ReceiverStream<Result<axum::response::sse::Event, Infallible>>,
67        cleanup_guard: ConnectionGuard<E>,
68    ) -> Self {
69        Self {
70            stream,
71            _cleanup_guard: cleanup_guard,
72        }
73    }
74}
75
76impl<E: ToSse + Clone + Send + Sync + 'static> futures_util::Stream for StreamWithGuard<E> {
77    type Item = Result<axum::response::sse::Event, Infallible>;
78
79    fn poll_next(
80        mut self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82    ) -> std::task::Poll<Option<Self::Item>> {
83        std::pin::Pin::new(&mut self.stream).poll_next(cx)
84    }
85}
86
87pub async fn create_sse_stream<E: ToSse + Clone + Send + Sync + 'static>(
88    request_context: RequestContext,
89    broadcaster: &'static LazyLock<GenericBroadcaster<E>>,
90    stream_name: &str,
91) -> impl IntoResponse {
92    let user_id = request_context.user_id().clone();
93    let user_id_str = user_id.to_string();
94    let conn_id = systemprompt_identifiers::ConnectionId::generate();
95    let conn_id_str = conn_id.as_str().to_owned();
96
97    tracing::info!(user_id = %user_id_str, conn_id = %conn_id_str, stream = %stream_name, "SSE stream opened");
98
99    let (tx, rx) = mpsc::channel(1024);
100
101    if !broadcaster.register(&user_id, &conn_id, tx.clone()).await {
102        tracing::warn!(user_id = %user_id_str, stream = %stream_name, "SSE stream rejected: per-user connection cap reached");
103        return http::StatusCode::TOO_MANY_REQUESTS.into_response();
104    }
105
106    let cleanup_guard = ConnectionGuard::new(broadcaster, user_id, conn_id);
107    let stream = ReceiverStream::new(rx);
108    let stream_with_guard = StreamWithGuard::<E>::new(stream, cleanup_guard);
109
110    tracing::info!(user_id = %user_id_str, conn_id = %conn_id_str, stream = %stream_name, "SSE stream ready");
111
112    Sse::new(stream_with_guard)
113        .keep_alive(standard_keep_alive())
114        .into_response()
115}