Skip to main content

systemprompt_agent/services/a2a_server/streaming/
messages.rs

1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use systemprompt_models::RequestContext;
5use tokio_stream::wrappers::UnboundedReceiverStream;
6
7use crate::models::a2a::Message;
8use crate::models::a2a::jsonrpc::NumberOrString;
9use crate::models::a2a::protocol::PushNotificationConfig;
10use crate::services::a2a_server::handlers::AgentHandlerState;
11use crate::services::a2a_server::processing::message::ProcessMessageStreamParams;
12
13use super::event_loop::{ProcessEventsParams, handle_stream_creation_error, process_events};
14use super::initialization::setup_stream;
15use super::types::StreamInput;
16use super::webhook_client::WebhookContext;
17
18pub struct CreateSseStreamParams {
19    pub message: Message,
20    pub agent_name: String,
21    pub state: Arc<AgentHandlerState>,
22    pub request_id: NumberOrString,
23    pub context: RequestContext,
24    pub callback_config: Option<PushNotificationConfig>,
25}
26
27impl std::fmt::Debug for CreateSseStreamParams {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("CreateSseStreamParams")
30            .field("message", &self.message)
31            .field("agent_name", &self.agent_name)
32            .field("request_id", &self.request_id)
33            .field("context", &self.context)
34            .field("callback_config", &self.callback_config)
35            .finish_non_exhaustive()
36    }
37}
38
39pub async fn create_sse_stream(params: CreateSseStreamParams) -> UnboundedReceiverStream<Event> {
40    let CreateSseStreamParams {
41        message,
42        agent_name,
43        state,
44        request_id,
45        context,
46        callback_config,
47    } = params;
48    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
49
50    tracing::info!("create_sse_stream() called - spawning tokio task");
51
52    let input = StreamInput {
53        message,
54        agent_name,
55        state,
56        request_id,
57        context,
58        callback_config,
59    };
60
61    tokio::spawn(async move {
62        tracing::info!("Inside tokio::spawn - task execution started");
63
64        let Ok(setup) = setup_stream(input, &tx).await else {
65            return;
66        };
67
68        tracing::info!(agent = %setup.agent_name, "Starting message stream processing for agent");
69
70        match setup
71            .processor
72            .process_message_stream(ProcessMessageStreamParams {
73                a2a_message: &setup.message,
74                agent_runtime: &setup.agent_runtime,
75                agent_name: &setup.agent_name,
76                context: &setup.context,
77                task_id: setup.task_id.clone(),
78            })
79            .await
80        {
81            Ok(chunk_rx) => {
82                let params = ProcessEventsParams {
83                    tx,
84                    chunk_rx,
85                    task_id: setup.task_id,
86                    context_id: setup.context_id,
87                    message_id: setup.message_id,
88                    original_message: setup.message,
89                    agent_name: setup.agent_name,
90                    context: setup.context,
91                    task_repo: setup.task_repo,
92                    processor: setup.processor,
93                    request_id: setup.request_id,
94                };
95                process_events(params).await;
96            },
97            Err(e) => {
98                let webhook_context = WebhookContext::new(
99                    setup.context.user_id().clone(),
100                    setup.context.auth_token().as_str(),
101                );
102                handle_stream_creation_error(
103                    &webhook_context,
104                    e,
105                    &setup.task_id,
106                    &setup.context_id,
107                    &setup.task_repo,
108                )
109                .await;
110            },
111        }
112    });
113
114    UnboundedReceiverStream::new(rx)
115}