Skip to main content

systemprompt_agent/services/a2a_server/processing/
message_validation.rs

1use anyhow::{anyhow, Result};
2use systemprompt_database::DbPool;
3use systemprompt_identifiers::{ContextId, TaskId};
4use systemprompt_models::RequestContext;
5
6use crate::models::{AgentRuntimeInfo, Message};
7use crate::repository::context::ContextRepository;
8use crate::services::registry::AgentRegistry;
9
10#[derive(Clone, Debug)]
11pub struct ValidatedMessageRequest {
12    pub message: Message,
13    pub agent_name: String,
14    pub context_id: ContextId,
15    pub task_id: TaskId,
16    pub agent_runtime: AgentRuntimeInfo,
17    pub has_tools: bool,
18}
19
20#[derive(Debug)]
21pub struct MessageValidationService {
22    db_pool: DbPool,
23}
24
25impl MessageValidationService {
26    pub const fn new(db_pool: DbPool) -> Self {
27        Self { db_pool }
28    }
29
30    pub async fn validate_message_request(
31        &self,
32        message: &Message,
33        agent_name: &str,
34        context: &RequestContext,
35    ) -> Result<ValidatedMessageRequest> {
36        self.validate_message_format(message)?;
37
38        let agent_runtime = self.load_agent_runtime(agent_name).await?;
39
40        self.validate_context_ownership(message, context).await?;
41
42        let task_id = self.determine_task_id(message);
43
44        let has_tools = !agent_runtime.mcp_servers.is_empty();
45
46        Ok(ValidatedMessageRequest {
47            message: message.clone(),
48            agent_name: agent_name.to_string(),
49            context_id: message.context_id.clone(),
50            task_id,
51            agent_runtime,
52            has_tools,
53        })
54    }
55
56    async fn load_agent_runtime(&self, agent_name: &str) -> Result<AgentRuntimeInfo> {
57        let registry = AgentRegistry::new().await?;
58        let agent_config = registry
59            .get_agent(agent_name)
60            .await
61            .map_err(|_| anyhow!("Agent not found: {}", agent_name))?;
62
63        Ok(agent_config.into())
64    }
65
66    async fn validate_context_ownership(
67        &self,
68        message: &Message,
69        context: &RequestContext,
70    ) -> Result<()> {
71        let context_repo = ContextRepository::new(self.db_pool.clone());
72
73        context_repo
74            .get_context(&message.context_id, context.user_id())
75            .await
76            .map_err(|e| {
77                anyhow!(
78                    "Context validation failed - context_id: {}, user_id: {}, error: {}",
79                    message.context_id,
80                    context.user_id(),
81                    e
82                )
83            })?;
84
85        Ok(())
86    }
87
88    fn validate_message_format(&self, message: &Message) -> Result<()> {
89        let has_text_part = message
90            .parts
91            .iter()
92            .any(|part| matches!(part, crate::models::Part::Text(_)));
93
94        if !has_text_part {
95            return Err(anyhow!("No text content found in message"));
96        }
97
98        Ok(())
99    }
100
101    fn determine_task_id(&self, message: &Message) -> TaskId {
102        message
103            .task_id
104            .clone()
105            .unwrap_or_else(|| TaskId::new(uuid::Uuid::new_v4().to_string()))
106    }
107}