systemprompt_agent/services/a2a_server/processing/
message_validation.rs1use anyhow::{anyhow, Result};
2use systemprompt_database::DbPool;
3use systemprompt_identifiers::{ContextId, TaskId};
4use systemprompt_models::RequestContext;
5
6use crate::models::{AgentRuntimeInfo, Message};
7use crate::repository::context::ContextRepository;
8use crate::services::registry::AgentRegistry;
9
10#[derive(Clone, Debug)]
11pub struct ValidatedMessageRequest {
12 pub message: Message,
13 pub agent_name: String,
14 pub context_id: ContextId,
15 pub task_id: TaskId,
16 pub agent_runtime: AgentRuntimeInfo,
17 pub has_tools: bool,
18}
19
20#[derive(Debug)]
21pub struct MessageValidationService {
22 db_pool: DbPool,
23}
24
25impl MessageValidationService {
26 pub const fn new(db_pool: DbPool) -> Self {
27 Self { db_pool }
28 }
29
30 pub async fn validate_message_request(
31 &self,
32 message: &Message,
33 agent_name: &str,
34 context: &RequestContext,
35 ) -> Result<ValidatedMessageRequest> {
36 self.validate_message_format(message)?;
37
38 let agent_runtime = self.load_agent_runtime(agent_name).await?;
39
40 self.validate_context_ownership(message, context).await?;
41
42 let task_id = self.determine_task_id(message);
43
44 let has_tools = !agent_runtime.mcp_servers.is_empty();
45
46 Ok(ValidatedMessageRequest {
47 message: message.clone(),
48 agent_name: agent_name.to_string(),
49 context_id: message.context_id.clone(),
50 task_id,
51 agent_runtime,
52 has_tools,
53 })
54 }
55
56 async fn load_agent_runtime(&self, agent_name: &str) -> Result<AgentRuntimeInfo> {
57 let registry = AgentRegistry::new().await?;
58 let agent_config = registry
59 .get_agent(agent_name)
60 .await
61 .map_err(|_| anyhow!("Agent not found: {}", agent_name))?;
62
63 Ok(agent_config.into())
64 }
65
66 async fn validate_context_ownership(
67 &self,
68 message: &Message,
69 context: &RequestContext,
70 ) -> Result<()> {
71 let context_repo = ContextRepository::new(self.db_pool.clone());
72
73 context_repo
74 .get_context(&message.context_id, context.user_id())
75 .await
76 .map_err(|e| {
77 anyhow!(
78 "Context validation failed - context_id: {}, user_id: {}, error: {}",
79 message.context_id,
80 context.user_id(),
81 e
82 )
83 })?;
84
85 Ok(())
86 }
87
88 fn validate_message_format(&self, message: &Message) -> Result<()> {
89 let has_text_part = message
90 .parts
91 .iter()
92 .any(|part| matches!(part, crate::models::Part::Text(_)));
93
94 if !has_text_part {
95 return Err(anyhow!("No text content found in message"));
96 }
97
98 Ok(())
99 }
100
101 fn determine_task_id(&self, message: &Message) -> TaskId {
102 message
103 .task_id
104 .clone()
105 .unwrap_or_else(|| TaskId::new(uuid::Uuid::new_v4().to_string()))
106 }
107}