Skip to main content

systemprompt_api/routes/stream/
mod.rs

1use axum::Router;
2use axum::extract::Extension;
3use axum::response::IntoResponse;
4use axum::response::sse::Sse;
5use axum::routing::get;
6use std::convert::Infallible;
7use std::sync::{Arc, LazyLock};
8use systemprompt_agent::services::ContextProviderService;
9use systemprompt_events::{
10    A2A_BROADCASTER, AGUI_BROADCASTER, Broadcaster, ConnectionGuard, GenericBroadcaster, ToSse,
11    standard_keep_alive,
12};
13use systemprompt_models::RequestContext;
14use systemprompt_runtime::AppContext;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18pub mod contexts;
19
20#[derive(Clone, Debug)]
21pub struct StreamState {
22    pub context_provider: Arc<ContextProviderService>,
23}
24
25pub fn stream_router(ctx: &AppContext) -> anyhow::Result<Router> {
26    let context_provider = ContextProviderService::new(ctx.db_pool())?;
27    let state = StreamState {
28        context_provider: Arc::new(context_provider),
29    };
30
31    Ok(Router::new()
32        .route("/contexts", get(contexts::stream_context_state))
33        .route("/agui", get(stream_agui_events))
34        .route("/a2a", get(stream_a2a_events))
35        .with_state(state))
36}
37
38pub async fn stream_a2a_events(
39    Extension(request_context): Extension<RequestContext>,
40) -> impl IntoResponse {
41    create_sse_stream(request_context, &A2A_BROADCASTER, "A2A").await
42}
43
44pub async fn stream_agui_events(
45    Extension(request_context): Extension<RequestContext>,
46) -> impl IntoResponse {
47    create_sse_stream(request_context, &AGUI_BROADCASTER, "AgUI").await
48}
49
50#[derive(Debug)]
51pub struct StreamWithGuard<E: ToSse + Clone + Send + Sync + 'static> {
52    stream: UnboundedReceiverStream<Result<axum::response::sse::Event, Infallible>>,
53    _cleanup_guard: ConnectionGuard<E>,
54}
55
56impl<E: ToSse + Clone + Send + Sync + 'static> StreamWithGuard<E> {
57    pub const fn new(
58        stream: UnboundedReceiverStream<Result<axum::response::sse::Event, Infallible>>,
59        cleanup_guard: ConnectionGuard<E>,
60    ) -> Self {
61        Self {
62            stream,
63            _cleanup_guard: cleanup_guard,
64        }
65    }
66}
67
68impl<E: ToSse + Clone + Send + Sync + 'static> futures_util::Stream for StreamWithGuard<E> {
69    type Item = Result<axum::response::sse::Event, Infallible>;
70
71    fn poll_next(
72        mut self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74    ) -> std::task::Poll<Option<Self::Item>> {
75        std::pin::Pin::new(&mut self.stream).poll_next(cx)
76    }
77}
78
79pub async fn create_sse_stream<E: ToSse + Clone + Send + Sync + 'static>(
80    request_context: RequestContext,
81    broadcaster: &'static LazyLock<GenericBroadcaster<E>>,
82    stream_name: &str,
83) -> impl IntoResponse {
84    let user_id = request_context.user_id().clone();
85    let user_id_str = user_id.to_string();
86    let conn_id = uuid::Uuid::new_v4().to_string();
87
88    tracing::info!(user_id = %user_id_str, conn_id = %conn_id, stream = %stream_name, "SSE stream opened");
89
90    let (tx, rx) = mpsc::unbounded_channel();
91
92    broadcaster.register(&user_id, &conn_id, tx.clone()).await;
93
94    let cleanup_guard = ConnectionGuard::new(broadcaster, user_id, conn_id.clone());
95    let stream = UnboundedReceiverStream::new(rx);
96    let stream_with_guard = StreamWithGuard::<E>::new(stream, cleanup_guard);
97
98    tracing::info!(user_id = %user_id_str, conn_id = %conn_id, stream = %stream_name, "SSE stream ready");
99
100    Sse::new(stream_with_guard)
101        .keep_alive(standard_keep_alive())
102        .into_response()
103}