Skip to main content

systemprompt_agent/services/a2a_server/processing/message/
mod.rs

1mod message_handler;
2mod persistence;
3mod stream_processor;
4
5pub use stream_processor::StreamProcessor;
6
7use anyhow::{anyhow, Result};
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11use crate::models::a2a::{Artifact, Message, Task};
12use crate::models::AgentRuntimeInfo;
13use systemprompt_models::{AiProvider, CallToolResult, ToolCall};
14
15#[derive(Debug)]
16pub enum StreamEvent {
17    Text(String),
18    ToolCallStarted(ToolCall),
19    ToolResult {
20        call_id: String,
21        result: CallToolResult,
22    },
23    ExecutionStepUpdate {
24        step: crate::models::ExecutionStep,
25    },
26    Complete {
27        full_text: String,
28        artifacts: Vec<Artifact>,
29    },
30    Error(String),
31}
32use crate::repository::context::ContextRepository;
33use crate::repository::execution::ExecutionStepRepository;
34use crate::repository::task::TaskRepository;
35use crate::services::{ContextService, SkillService};
36use systemprompt_database::DbPool;
37use systemprompt_identifiers::TaskId;
38use systemprompt_models::RequestContext;
39
40pub struct MessageProcessor {
41    db_pool: DbPool,
42    ai_service: Arc<dyn AiProvider>,
43    task_repo: TaskRepository,
44    context_repo: ContextRepository,
45    context_service: ContextService,
46    skill_service: Arc<SkillService>,
47    execution_step_repo: Arc<ExecutionStepRepository>,
48}
49
50impl std::fmt::Debug for MessageProcessor {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("MessageProcessor")
53            .field("ai_service", &"<Arc<dyn AiProvider>>")
54            .finish()
55    }
56}
57
58impl MessageProcessor {
59    pub fn new(db_pool: DbPool, ai_service: Arc<dyn AiProvider>) -> Result<Self> {
60        let task_repo = TaskRepository::new(db_pool.clone());
61        let context_repo = ContextRepository::new(db_pool.clone());
62        let context_service = ContextService::new(db_pool.clone());
63        let skill_service = Arc::new(SkillService::new(db_pool.clone()));
64        let execution_step_repo = Arc::new(ExecutionStepRepository::new(&db_pool)?);
65
66        Ok(Self {
67            db_pool,
68            ai_service,
69            task_repo,
70            context_repo,
71            context_service,
72            skill_service,
73            execution_step_repo,
74        })
75    }
76
77    pub async fn load_agent_runtime(&self, agent_name: &str) -> Result<AgentRuntimeInfo> {
78        use crate::services::registry::AgentRegistry;
79
80        let registry = AgentRegistry::new().await?;
81        let agent_config = registry
82            .get_agent(agent_name)
83            .await
84            .map_err(|_| anyhow!("Agent not found"))?;
85
86        Ok(agent_config.into())
87    }
88
89    pub async fn persist_completed_task(
90        &self,
91        task: &Task,
92        user_message: &Message,
93        agent_message: &Message,
94        context: &RequestContext,
95        _agent_name: &str,
96        artifacts_already_published: bool,
97    ) -> Result<Task> {
98        persistence::persist_completed_task(
99            task,
100            user_message,
101            agent_message,
102            context,
103            &self.task_repo,
104            &self.db_pool,
105            artifacts_already_published,
106        )
107        .await
108    }
109
110    pub async fn process_message_stream(
111        &self,
112        a2a_message: &Message,
113        agent_runtime: &AgentRuntimeInfo,
114        agent_name: &str,
115        context: &RequestContext,
116        task_id: TaskId,
117    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
118        let stream_processor = StreamProcessor {
119            ai_service: self.ai_service.clone(),
120            context_service: self.context_service.clone(),
121            skill_service: self.skill_service.clone(),
122            execution_step_repo: self.execution_step_repo.clone(),
123        };
124
125        stream_processor
126            .process_message_stream(a2a_message, agent_runtime, agent_name, context, task_id)
127            .await
128    }
129}