Skip to main content

systemprompt_agent/services/
message.rs

1use anyhow::{Result, anyhow};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::models::a2a::{Message, Part, TextPart};
6use crate::repository::context::message::PersistMessageWithTxParams;
7use crate::repository::task::TaskRepository;
8use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
9use systemprompt_identifiers::{ContextId, TaskId};
10use systemprompt_models::RequestContext;
11
12pub struct PersistMessageInTxParams<'a> {
13    pub tx: &'a mut dyn DatabaseTransaction,
14    pub message: &'a Message,
15    pub task_id: &'a TaskId,
16    pub context_id: &'a ContextId,
17    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
18    pub session_id: &'a systemprompt_identifiers::SessionId,
19    pub trace_id: &'a systemprompt_identifiers::TraceId,
20}
21
22impl std::fmt::Debug for PersistMessageInTxParams<'_> {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("PersistMessageInTxParams")
25            .field("message", &self.message)
26            .field("task_id", &self.task_id)
27            .field("context_id", &self.context_id)
28            .field("user_id", &self.user_id)
29            .field("session_id", &self.session_id)
30            .field("trace_id", &self.trace_id)
31            .finish_non_exhaustive()
32    }
33}
34
35#[derive(Debug)]
36pub struct PersistMessagesParams<'a> {
37    pub task_id: &'a TaskId,
38    pub context_id: &'a ContextId,
39    pub messages: Vec<Message>,
40    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
41    pub session_id: &'a systemprompt_identifiers::SessionId,
42    pub trace_id: &'a systemprompt_identifiers::TraceId,
43}
44
45#[derive(Debug)]
46pub struct CreateToolExecutionMessageParams<'a> {
47    pub task_id: &'a TaskId,
48    pub context_id: &'a ContextId,
49    pub tool_name: &'a str,
50    pub tool_args: &'a serde_json::Value,
51    pub request_context: &'a RequestContext,
52}
53
54pub struct MessageService {
55    task_repo: TaskRepository,
56}
57
58impl std::fmt::Debug for MessageService {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("MessageService").finish_non_exhaustive()
61    }
62}
63
64impl MessageService {
65    pub fn new(db_pool: &DbPool) -> Result<Self> {
66        Ok(Self {
67            task_repo: TaskRepository::new(db_pool)?,
68        })
69    }
70
71    pub async fn persist_message_in_tx(&self, params: PersistMessageInTxParams<'_>) -> Result<i32> {
72        let PersistMessageInTxParams {
73            tx,
74            message,
75            task_id,
76            context_id,
77            user_id,
78            session_id,
79            trace_id,
80        } = params;
81        let sequence_number = self
82            .task_repo
83            .get_next_sequence_number_in_tx(tx, task_id)
84            .await?;
85
86        self.task_repo
87            .persist_message_with_tx(PersistMessageWithTxParams {
88                tx,
89                message,
90                task_id,
91                context_id,
92                sequence_number,
93                user_id,
94                session_id,
95                trace_id,
96            })
97            .await
98            .map_err(|e| anyhow!("Failed to persist message: {}", e))?;
99
100        tracing::info!(
101            message_id = %message.id,
102            task_id = %task_id,
103            sequence_number = sequence_number,
104            "Message persisted"
105        );
106
107        Ok(sequence_number)
108    }
109
110    pub async fn persist_messages(&self, params: PersistMessagesParams<'_>) -> Result<Vec<i32>> {
111        let PersistMessagesParams {
112            task_id,
113            context_id,
114            messages,
115            user_id,
116            session_id,
117            trace_id,
118        } = params;
119
120        if messages.is_empty() {
121            return Ok(Vec::new());
122        }
123
124        let mut tx = self
125            .task_repo
126            .db_pool()
127            .as_ref()
128            .begin_transaction()
129            .await?;
130        let mut sequence_numbers = Vec::new();
131
132        tracing::info!(
133            task_id = %task_id,
134            message_count = messages.len(),
135            "Persisting multiple messages"
136        );
137
138        for message in messages {
139            let seq = self
140                .persist_message_in_tx(PersistMessageInTxParams {
141                    tx: &mut *tx,
142                    message: &message,
143                    task_id,
144                    context_id,
145                    user_id,
146                    session_id,
147                    trace_id,
148                })
149                .await?;
150            sequence_numbers.push(seq);
151        }
152
153        tx.commit().await?;
154
155        tracing::info!(
156            task_id = %task_id,
157            sequence_numbers = ?sequence_numbers,
158            "Messages persisted successfully"
159        );
160
161        Ok(sequence_numbers)
162    }
163
164    pub async fn create_tool_execution_message(
165        &self,
166        params: CreateToolExecutionMessageParams<'_>,
167    ) -> Result<(String, i32)> {
168        let CreateToolExecutionMessageParams {
169            task_id,
170            context_id,
171            tool_name,
172            tool_args,
173            request_context,
174        } = params;
175        let message_id = Uuid::new_v4().to_string();
176
177        let tool_args_display =
178            serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());
179
180        let timestamp = chrono::Utc::now().to_rfc3339();
181
182        let message = Message {
183            role: "user".to_string(),
184            id: message_id.clone().into(),
185            task_id: Some(task_id.clone()),
186            context_id: context_id.clone(),
187            kind: "message".to_string(),
188            parts: vec![Part::Text(TextPart {
189                text: format!(
190                    "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
191                    tool_name,
192                    tool_args_display,
193                    task_id.as_str(),
194                    timestamp
195                ),
196            })],
197            metadata: Some(json!({
198                "source": "mcp_direct_call",
199                "tool_name": tool_name,
200                "is_synthetic": true,
201                "tool_args": tool_args,
202                "execution_timestamp": timestamp,
203            })),
204            extensions: None,
205            reference_task_ids: None,
206        };
207
208        let mut tx = self
209            .task_repo
210            .db_pool()
211            .as_ref()
212            .begin_transaction()
213            .await?;
214
215        let sequence_number = self
216            .persist_message_in_tx(PersistMessageInTxParams {
217                tx: &mut *tx,
218                message: &message,
219                task_id,
220                context_id,
221                user_id: Some(request_context.user_id()),
222                session_id: request_context.session_id(),
223                trace_id: request_context.trace_id(),
224            })
225            .await?;
226
227        tx.commit().await?;
228
229        tracing::info!(
230            message_id = %message_id,
231            task_id = %task_id,
232            tool_name = %tool_name,
233            sequence_number = sequence_number,
234            "Created synthetic tool execution message"
235        );
236
237        Ok((message_id, sequence_number))
238    }
239}