systemprompt_agent/repository/context/message/
mod.rs1mod parts;
2mod persistence;
3mod queries;
4
5use sqlx::PgPool;
6use std::sync::Arc;
7use systemprompt_database::DbPool;
8use systemprompt_identifiers::{ContextId, SessionId, TaskId, TraceId, UserId};
9use systemprompt_traits::RepositoryError;
10
11use crate::models::a2a::Message;
12
13pub use parts::{get_message_parts, FileUploadContext};
14pub use persistence::{persist_message_sqlx, persist_message_with_tx};
15pub use queries::{
16 get_messages_by_context, get_messages_by_task, get_next_sequence_number,
17 get_next_sequence_number_in_tx, get_next_sequence_number_sqlx,
18};
19
20#[derive(Debug, Clone)]
21pub struct MessageRepository {
22 pool: Arc<PgPool>,
23}
24
25impl MessageRepository {
26 pub fn new(db_pool: DbPool) -> Result<Self, RepositoryError> {
27 let pool = db_pool.as_ref().get_postgres_pool().ok_or_else(|| {
28 RepositoryError::InvalidData("PostgreSQL pool not available".to_string())
29 })?;
30 Ok(Self { pool })
31 }
32
33 pub async fn get_messages_by_task(
34 &self,
35 task_id: &TaskId,
36 ) -> Result<Vec<Message>, RepositoryError> {
37 get_messages_by_task(&self.pool, task_id).await
38 }
39
40 pub async fn get_messages_by_context(
41 &self,
42 context_id: &ContextId,
43 ) -> Result<Vec<Message>, RepositoryError> {
44 get_messages_by_context(&self.pool, context_id).await
45 }
46
47 pub async fn get_next_sequence_number(&self, task_id: &TaskId) -> Result<i32, RepositoryError> {
48 get_next_sequence_number(&self.pool, task_id).await
49 }
50
51 pub async fn persist_message_sqlx(
52 &self,
53 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
54 message: &Message,
55 task_id: &TaskId,
56 context_id: &ContextId,
57 sequence_number: i32,
58 user_id: Option<&UserId>,
59 session_id: &SessionId,
60 trace_id: &TraceId,
61 upload_ctx: Option<&FileUploadContext<'_>>,
62 ) -> Result<(), RepositoryError> {
63 persist_message_sqlx(
64 tx,
65 message,
66 task_id,
67 context_id,
68 sequence_number,
69 user_id,
70 session_id,
71 trace_id,
72 upload_ctx,
73 )
74 .await
75 }
76
77 pub async fn persist_message_with_tx(
78 &self,
79 tx: &mut dyn systemprompt_database::DatabaseTransaction,
80 message: &Message,
81 task_id: &TaskId,
82 context_id: &ContextId,
83 sequence_number: i32,
84 user_id: Option<&UserId>,
85 session_id: &SessionId,
86 trace_id: &TraceId,
87 ) -> Result<(), RepositoryError> {
88 persist_message_with_tx(
89 tx,
90 message,
91 task_id,
92 context_id,
93 sequence_number,
94 user_id,
95 session_id,
96 trace_id,
97 )
98 .await
99 }
100}