systemprompt_agent/services/a2a_server/streaming/
messages.rs1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use systemprompt_models::RequestContext;
5use tokio_stream::wrappers::UnboundedReceiverStream;
6
7use crate::models::a2a::Message;
8use crate::models::a2a::jsonrpc::NumberOrString;
9use crate::models::a2a::protocol::PushNotificationConfig;
10use crate::services::a2a_server::handlers::AgentHandlerState;
11use crate::services::a2a_server::processing::message::ProcessMessageStreamParams;
12
13use super::event_loop::{ProcessEventsParams, handle_stream_creation_error, process_events};
14use super::initialization::setup_stream;
15use super::types::StreamInput;
16use super::webhook_client::WebhookContext;
17
18pub struct CreateSseStreamParams {
19 pub message: Message,
20 pub agent_name: String,
21 pub state: Arc<AgentHandlerState>,
22 pub request_id: NumberOrString,
23 pub context: RequestContext,
24 pub callback_config: Option<PushNotificationConfig>,
25}
26
27impl std::fmt::Debug for CreateSseStreamParams {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("CreateSseStreamParams")
30 .field("message", &self.message)
31 .field("agent_name", &self.agent_name)
32 .field("request_id", &self.request_id)
33 .field("context", &self.context)
34 .field("callback_config", &self.callback_config)
35 .finish_non_exhaustive()
36 }
37}
38
39pub async fn create_sse_stream(params: CreateSseStreamParams) -> UnboundedReceiverStream<Event> {
40 let CreateSseStreamParams {
41 message,
42 agent_name,
43 state,
44 request_id,
45 context,
46 callback_config,
47 } = params;
48 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
49
50 tracing::info!("create_sse_stream() called - spawning tokio task");
51
52 let input = StreamInput {
53 message,
54 agent_name,
55 state,
56 request_id,
57 context,
58 callback_config,
59 };
60
61 tokio::spawn(async move {
62 tracing::info!("Inside tokio::spawn - task execution started");
63
64 let Ok(setup) = setup_stream(input, &tx).await else {
65 return;
66 };
67
68 tracing::info!(agent = %setup.agent_name, "Starting message stream processing for agent");
69
70 match setup
71 .processor
72 .process_message_stream(ProcessMessageStreamParams {
73 a2a_message: &setup.message,
74 agent_runtime: &setup.agent_runtime,
75 agent_name: &setup.agent_name,
76 context: &setup.context,
77 task_id: setup.task_id.clone(),
78 })
79 .await
80 {
81 Ok(chunk_rx) => {
82 let params = ProcessEventsParams {
83 tx,
84 chunk_rx,
85 task_id: setup.task_id,
86 context_id: setup.context_id,
87 message_id: setup.message_id,
88 original_message: setup.message,
89 agent_name: setup.agent_name,
90 context: setup.context,
91 task_repo: setup.task_repo,
92 processor: setup.processor,
93 request_id: setup.request_id,
94 };
95 process_events(params).await;
96 },
97 Err(e) => {
98 let webhook_context = WebhookContext::new(
99 setup.context.user_id().clone(),
100 setup.context.auth_token().as_str(),
101 );
102 handle_stream_creation_error(
103 &webhook_context,
104 e,
105 &setup.task_id,
106 &setup.context_id,
107 &setup.task_repo,
108 )
109 .await;
110 },
111 }
112 });
113
114 UnboundedReceiverStream::new(rx)
115}