Skip to main content

systemprompt_agent/services/a2a_server/processing/message/
mod.rs

1mod message_handler;
2mod persistence;
3mod stream_processor;
4
5pub use stream_processor::StreamProcessor;
6
7use anyhow::{Result, anyhow};
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11use crate::models::AgentRuntimeInfo;
12use crate::models::a2a::{Artifact, Message, Task};
13use systemprompt_models::{AiProvider, CallToolResult, ToolCall};
14
15#[derive(Debug)]
16pub enum StreamEvent {
17    Text(String),
18    ToolCallStarted(ToolCall),
19    ToolResult {
20        call_id: String,
21        result: CallToolResult,
22    },
23    ExecutionStepUpdate {
24        step: crate::models::ExecutionStep,
25    },
26    Complete {
27        full_text: String,
28        artifacts: Vec<Artifact>,
29    },
30    Error(String),
31}
32use crate::repository::context::ContextRepository;
33use crate::repository::execution::ExecutionStepRepository;
34use crate::repository::task::TaskRepository;
35use crate::services::{ContextService, SkillService};
36use systemprompt_database::DbPool;
37use systemprompt_identifiers::TaskId;
38use systemprompt_models::RequestContext;
39
40#[derive(Debug)]
41pub struct PersistCompletedTaskOnProcessorParams<'a> {
42    pub task: &'a Task,
43    pub user_message: &'a Message,
44    pub agent_message: &'a Message,
45    pub context: &'a RequestContext,
46    pub agent_name: &'a str,
47    pub artifacts_already_published: bool,
48}
49
50#[derive(Debug)]
51pub struct ProcessMessageStreamParams<'a> {
52    pub a2a_message: &'a Message,
53    pub agent_runtime: &'a AgentRuntimeInfo,
54    pub agent_name: &'a str,
55    pub context: &'a RequestContext,
56    pub task_id: TaskId,
57}
58
59pub struct MessageProcessor {
60    db_pool: DbPool,
61    ai_service: Arc<dyn AiProvider>,
62    task_repo: TaskRepository,
63    context_repo: ContextRepository,
64    context_service: ContextService,
65    skill_service: Arc<SkillService>,
66    execution_step_repo: Arc<ExecutionStepRepository>,
67}
68
69impl std::fmt::Debug for MessageProcessor {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("MessageProcessor")
72            .field("ai_service", &"<Arc<dyn AiProvider>>")
73            .finish()
74    }
75}
76
77impl MessageProcessor {
78    pub fn new(db_pool: &DbPool, ai_service: Arc<dyn AiProvider>) -> Result<Self> {
79        let task_repo = TaskRepository::new(db_pool)?;
80        let context_repo = ContextRepository::new(db_pool)?;
81        let context_service = ContextService::new(db_pool)?;
82        let skill_service = Arc::new(SkillService::new(db_pool)?);
83        let execution_step_repo = Arc::new(ExecutionStepRepository::new(db_pool)?);
84
85        Ok(Self {
86            db_pool: Arc::clone(db_pool),
87            ai_service,
88            task_repo,
89            context_repo,
90            context_service,
91            skill_service,
92            execution_step_repo,
93        })
94    }
95
96    pub async fn load_agent_runtime(&self, agent_name: &str) -> Result<AgentRuntimeInfo> {
97        use crate::services::registry::AgentRegistry;
98
99        let registry = AgentRegistry::new()?;
100        let agent_config = registry
101            .get_agent(agent_name)
102            .await
103            .map_err(|_| anyhow!("Agent not found"))?;
104
105        Ok(agent_config.into())
106    }
107
108    pub async fn persist_completed_task(
109        &self,
110        params: PersistCompletedTaskOnProcessorParams<'_>,
111    ) -> Result<Task> {
112        persistence::persist_completed_task(persistence::PersistCompletedTaskParams {
113            task: params.task,
114            user_message: params.user_message,
115            agent_message: params.agent_message,
116            context: params.context,
117            task_repo: &self.task_repo,
118            db_pool: &self.db_pool,
119            artifacts_already_published: params.artifacts_already_published,
120        })
121        .await
122    }
123
124    pub async fn process_message_stream(
125        &self,
126        params: ProcessMessageStreamParams<'_>,
127    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
128        let stream_processor = StreamProcessor {
129            ai_service: Arc::clone(&self.ai_service),
130            context_service: self.context_service.clone(),
131            skill_service: Arc::clone(&self.skill_service),
132            execution_step_repo: Arc::clone(&self.execution_step_repo),
133        };
134
135        stream_processor.process_message_stream(params).await
136    }
137}