1mod request;
2mod stats;
3
4pub use request::InsertFileRequest;
5pub use stats::FileStats;
6
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10use chrono::Utc;
11use sqlx::PgPool;
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{ContextId, FileId, SessionId, TraceId, UserId};
14
15use crate::models::{File, FileMetadata};
16
17#[derive(Debug, Clone)]
18pub struct FileRepository {
19 pub(crate) pool: Arc<PgPool>,
20 write_pool: Arc<PgPool>,
21}
22
23impl FileRepository {
24 pub fn new(db: &DbPool) -> Result<Self> {
25 let pool = db.pool_arc()?;
26 let write_pool = db.write_pool_arc()?;
27 Ok(Self { pool, write_pool })
28 }
29
30 pub async fn insert(&self, request: InsertFileRequest) -> Result<FileId> {
31 let id_uuid = uuid::Uuid::parse_str(request.id.as_str())
32 .with_context(|| format!("Invalid UUID for file id: {}", request.id.as_str()))?;
33 let now = Utc::now();
34
35 let user_id_str = request.user_id.as_ref().map(UserId::as_str);
36 let session_id_str = request.session_id.as_ref().map(SessionId::as_str);
37 let trace_id_str = request.trace_id.as_ref().map(TraceId::as_str);
38 let context_id_str = request.context_id.as_ref().map(ContextId::as_str);
39
40 sqlx::query_as!(
41 File,
42 r#"
43 INSERT INTO files (id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id, session_id, trace_id, context_id, created_at, updated_at)
44 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12)
45 ON CONFLICT (path) DO UPDATE SET
46 public_url = EXCLUDED.public_url,
47 mime_type = EXCLUDED.mime_type,
48 size_bytes = EXCLUDED.size_bytes,
49 ai_content = EXCLUDED.ai_content,
50 metadata = EXCLUDED.metadata,
51 updated_at = EXCLUDED.updated_at
52 RETURNING id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
53 "#,
54 id_uuid,
55 request.path,
56 request.public_url,
57 request.mime_type,
58 request.size_bytes,
59 request.ai_content,
60 request.metadata,
61 user_id_str,
62 session_id_str,
63 trace_id_str,
64 context_id_str,
65 now
66 )
67 .fetch_one(&*self.write_pool)
68 .await
69 .with_context(|| {
70 format!(
71 "Failed to insert file (id: {}, path: {}, url: {})",
72 request.id.as_str(),
73 request.path,
74 request.public_url
75 )
76 })?;
77
78 Ok(request.id)
79 }
80
81 pub async fn insert_file(&self, file: &File) -> Result<FileId> {
82 let file_id = FileId::new(file.id.to_string());
83
84 let mut request = InsertFileRequest::new(
85 file_id.clone(),
86 file.path.clone(),
87 file.public_url.clone(),
88 file.mime_type.clone(),
89 )
90 .with_ai_content(file.ai_content)
91 .with_metadata(file.metadata.clone());
92
93 if let Some(size) = file.size_bytes {
94 request = request.with_size(size);
95 }
96
97 if let Some(ref user_id) = file.user_id {
98 request = request.with_user_id(user_id.clone());
99 }
100
101 if let Some(ref session_id) = file.session_id {
102 request = request.with_session_id(session_id.clone());
103 }
104
105 if let Some(ref trace_id) = file.trace_id {
106 request = request.with_trace_id(trace_id.clone());
107 }
108
109 if let Some(ref context_id) = file.context_id {
110 request = request.with_context_id(context_id.clone());
111 }
112
113 self.insert(request).await
114 }
115
116 pub async fn find_by_id(&self, id: &FileId) -> Result<Option<File>> {
117 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
118
119 sqlx::query_as!(
120 File,
121 r#"
122 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
123 FROM files
124 WHERE id = $1 AND deleted_at IS NULL
125 "#,
126 id_uuid
127 )
128 .fetch_optional(&*self.pool)
129 .await
130 .context(format!("Failed to find file by id: {id}"))
131 }
132
133 pub async fn find_by_path(&self, path: &str) -> Result<Option<File>> {
134 sqlx::query_as!(
135 File,
136 r#"
137 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
138 FROM files
139 WHERE path = $1 AND deleted_at IS NULL
140 "#,
141 path
142 )
143 .fetch_optional(&*self.pool)
144 .await
145 .context(format!("Failed to find file by path: {path}"))
146 }
147
148 pub async fn list_by_user(
149 &self,
150 user_id: &UserId,
151 limit: i64,
152 offset: i64,
153 ) -> Result<Vec<File>> {
154 let user_id_str = user_id.as_str();
155 sqlx::query_as!(
156 File,
157 r#"
158 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
159 FROM files
160 WHERE user_id = $1 AND deleted_at IS NULL
161 ORDER BY created_at DESC
162 LIMIT $2 OFFSET $3
163 "#,
164 user_id_str,
165 limit,
166 offset
167 )
168 .fetch_all(&*self.pool)
169 .await
170 .context(format!("Failed to list files for user: {user_id}"))
171 }
172
173 pub async fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<File>> {
174 sqlx::query_as!(
175 File,
176 r#"
177 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
178 FROM files
179 WHERE deleted_at IS NULL
180 ORDER BY created_at DESC
181 LIMIT $1 OFFSET $2
182 "#,
183 limit,
184 offset
185 )
186 .fetch_all(&*self.pool)
187 .await
188 .context("Failed to list all files")
189 }
190
191 pub async fn delete(&self, id: &FileId) -> Result<()> {
192 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
193
194 sqlx::query!(
195 r#"
196 DELETE FROM files
197 WHERE id = $1
198 "#,
199 id_uuid
200 )
201 .execute(&*self.write_pool)
202 .await
203 .context(format!("Failed to delete file: {id}"))?;
204
205 Ok(())
206 }
207
208 pub async fn update_metadata(&self, id: &FileId, metadata: &FileMetadata) -> Result<()> {
209 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
210 let metadata_json = serde_json::to_value(metadata)?;
211 let now = Utc::now();
212
213 sqlx::query!(
214 r#"
215 UPDATE files
216 SET metadata = $1, updated_at = $2
217 WHERE id = $3
218 "#,
219 metadata_json,
220 now,
221 id_uuid
222 )
223 .execute(&*self.write_pool)
224 .await
225 .context(format!("Failed to update metadata for file: {id}"))?;
226
227 Ok(())
228 }
229
230 pub async fn search_by_path(&self, query: &str, limit: i64) -> Result<Vec<File>> {
231 let pattern = format!("%{query}%");
232 sqlx::query_as!(
233 File,
234 r#"
235 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata,
236 user_id as "user_id: UserId", session_id as "session_id: SessionId",
237 trace_id as "trace_id: TraceId", context_id as "context_id: ContextId",
238 created_at, updated_at, deleted_at
239 FROM files
240 WHERE path ILIKE $1 AND deleted_at IS NULL
241 ORDER BY created_at DESC
242 LIMIT $2
243 "#,
244 pattern,
245 limit
246 )
247 .fetch_all(&*self.pool)
248 .await
249 .context(format!("Failed to search files by path: {query}"))
250 }
251}