systemprompt_agent/services/
message.rs1use anyhow::{anyhow, Result};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, Part, TextPart};
6use crate::repository::task::TaskRepository;
7use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
8use systemprompt_identifiers::{ContextId, TaskId};
9use systemprompt_models::RequestContext;
10
11pub struct MessageService {
12 task_repo: TaskRepository,
13}
14
15impl std::fmt::Debug for MessageService {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.debug_struct("MessageService").finish_non_exhaustive()
18 }
19}
20
21impl MessageService {
22 pub fn new(db_pool: &DbPool) -> Result<Self> {
23 Ok(Self {
24 task_repo: TaskRepository::new(db_pool)?,
25 })
26 }
27
28 pub async fn persist_message_in_tx(
29 &self,
30 tx: &mut dyn DatabaseTransaction,
31 message: &Message,
32 task_id: &TaskId,
33 context_id: &ContextId,
34 user_id: Option<&systemprompt_identifiers::UserId>,
35 session_id: &systemprompt_identifiers::SessionId,
36 trace_id: &systemprompt_identifiers::TraceId,
37 ) -> Result<i32> {
38 let sequence_number = self
39 .task_repo
40 .get_next_sequence_number_in_tx(tx, task_id)
41 .await?;
42
43 self.task_repo
44 .persist_message_with_tx(
45 tx,
46 message,
47 task_id,
48 context_id,
49 sequence_number,
50 user_id,
51 session_id,
52 trace_id,
53 )
54 .await
55 .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
56
57 tracing::info!(
58 message_id = %message.id,
59 task_id = %task_id,
60 sequence_number = sequence_number,
61 "Message persisted"
62 );
63
64 Ok(sequence_number)
65 }
66
67 pub async fn persist_messages(
68 &self,
69 task_id: &TaskId,
70 context_id: &ContextId,
71 messages: Vec<Message>,
72 user_id: Option<&systemprompt_identifiers::UserId>,
73 session_id: &systemprompt_identifiers::SessionId,
74 trace_id: &systemprompt_identifiers::TraceId,
75 ) -> Result<Vec<i32>> {
76 if messages.is_empty() {
77 return Ok(Vec::new());
78 }
79
80 let mut tx = self
81 .task_repo
82 .db_pool()
83 .as_ref()
84 .begin_transaction()
85 .await?;
86 let mut sequence_numbers = Vec::new();
87
88 tracing::info!(
89 task_id = %task_id,
90 message_count = messages.len(),
91 "Persisting multiple messages"
92 );
93
94 for message in messages {
95 let seq = self
96 .persist_message_in_tx(
97 &mut *tx, &message, task_id, context_id, user_id, session_id, trace_id,
98 )
99 .await?;
100 sequence_numbers.push(seq);
101 }
102
103 tx.commit().await?;
104
105 tracing::info!(
106 task_id = %task_id,
107 sequence_numbers = ?sequence_numbers,
108 "Messages persisted successfully"
109 );
110
111 Ok(sequence_numbers)
112 }
113
114 pub async fn create_tool_execution_message(
115 &self,
116 task_id: &TaskId,
117 context_id: &ContextId,
118 tool_name: &str,
119 tool_args: &serde_json::Value,
120 request_context: &RequestContext,
121 ) -> Result<(String, i32)> {
122 let message_id = Uuid::new_v4().to_string();
123
124 let tool_args_display =
125 serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
126
127 let timestamp = chrono::Utc::now().to_rfc3339();
128
129 let message = Message {
130 role: "user".to_string(),
131 id: message_id.clone().into(),
132 task_id: Some(task_id.clone()),
133 context_id: context_id.clone(),
134 kind: "message".to_string(),
135 parts: vec![Part::Text(TextPart {
136 text: format!(
137 "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
138 tool_name,
139 tool_args_display,
140 task_id.as_str(),
141 timestamp
142 ),
143 })],
144 metadata: Some(json!({
145 "source": "mcp_direct_call",
146 "tool_name": tool_name,
147 "is_synthetic": true,
148 "tool_args": tool_args,
149 "execution_timestamp": timestamp,
150 })),
151 extensions: None,
152 reference_task_ids: None,
153 };
154
155 let mut tx = self
156 .task_repo
157 .db_pool()
158 .as_ref()
159 .begin_transaction()
160 .await?;
161
162 let sequence_number = self
163 .persist_message_in_tx(
164 &mut *tx,
165 &message,
166 task_id,
167 context_id,
168 Some(request_context.user_id()),
169 request_context.session_id(),
170 request_context.trace_id(),
171 )
172 .await?;
173
174 tx.commit().await?;
175
176 tracing::info!(
177 message_id = %message_id,
178 task_id = %task_id,
179 tool_name = %tool_name,
180 sequence_number = sequence_number,
181 "Created synthetic tool execution message"
182 );
183
184 Ok((message_id, sequence_number))
185 }
186}