1mod request;
2mod stats;
3
4pub use request::InsertFileRequest;
5pub use stats::FileStats;
6
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10use chrono::Utc;
11use sqlx::PgPool;
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{ContextId, FileId, SessionId, TraceId, UserId};
14
15use crate::models::{File, FileMetadata};
16
17#[derive(Debug, Clone)]
18pub struct FileRepository {
19 pub(crate) pool: Arc<PgPool>,
20}
21
22impl FileRepository {
23 pub fn new(db: &DbPool) -> Result<Self> {
24 let pool = db.pool_arc()?;
25 Ok(Self { pool })
26 }
27
28 pub async fn insert(&self, request: InsertFileRequest) -> Result<FileId> {
29 let id_uuid = uuid::Uuid::parse_str(request.id.as_str())
30 .with_context(|| format!("Invalid UUID for file id: {}", request.id.as_str()))?;
31 let now = Utc::now();
32
33 let user_id_str = request.user_id.as_ref().map(UserId::as_str);
34 let session_id_str = request.session_id.as_ref().map(SessionId::as_str);
35 let trace_id_str = request.trace_id.as_ref().map(TraceId::as_str);
36 let context_id_str = request.context_id.as_ref().map(ContextId::as_str);
37
38 sqlx::query_as!(
39 File,
40 r#"
41 INSERT INTO files (id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id, session_id, trace_id, context_id, created_at, updated_at)
42 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12)
43 ON CONFLICT (path) DO UPDATE SET
44 public_url = EXCLUDED.public_url,
45 mime_type = EXCLUDED.mime_type,
46 size_bytes = EXCLUDED.size_bytes,
47 ai_content = EXCLUDED.ai_content,
48 metadata = EXCLUDED.metadata,
49 updated_at = EXCLUDED.updated_at
50 RETURNING id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
51 "#,
52 id_uuid,
53 request.path,
54 request.public_url,
55 request.mime_type,
56 request.size_bytes,
57 request.ai_content,
58 request.metadata,
59 user_id_str,
60 session_id_str,
61 trace_id_str,
62 context_id_str,
63 now
64 )
65 .fetch_one(self.pool.as_ref())
66 .await
67 .with_context(|| {
68 format!(
69 "Failed to insert file (id: {}, path: {}, url: {})",
70 request.id.as_str(),
71 request.path,
72 request.public_url
73 )
74 })?;
75
76 Ok(request.id)
77 }
78
79 pub async fn insert_file(&self, file: &File) -> Result<FileId> {
80 let file_id = FileId::new(file.id.to_string());
81
82 let mut request = InsertFileRequest::new(
83 file_id.clone(),
84 file.path.clone(),
85 file.public_url.clone(),
86 file.mime_type.clone(),
87 )
88 .with_ai_content(file.ai_content)
89 .with_metadata(file.metadata.clone());
90
91 if let Some(size) = file.size_bytes {
92 request = request.with_size(size);
93 }
94
95 if let Some(ref user_id) = file.user_id {
96 request = request.with_user_id(user_id.clone());
97 }
98
99 if let Some(ref session_id) = file.session_id {
100 request = request.with_session_id(session_id.clone());
101 }
102
103 if let Some(ref trace_id) = file.trace_id {
104 request = request.with_trace_id(trace_id.clone());
105 }
106
107 if let Some(ref context_id) = file.context_id {
108 request = request.with_context_id(context_id.clone());
109 }
110
111 self.insert(request).await
112 }
113
114 pub async fn find_by_id(&self, id: &FileId) -> Result<Option<File>> {
115 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
116
117 sqlx::query_as!(
118 File,
119 r#"
120 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
121 FROM files
122 WHERE id = $1 AND deleted_at IS NULL
123 "#,
124 id_uuid
125 )
126 .fetch_optional(self.pool.as_ref())
127 .await
128 .context(format!("Failed to find file by id: {id}"))
129 }
130
131 pub async fn find_by_path(&self, path: &str) -> Result<Option<File>> {
132 sqlx::query_as!(
133 File,
134 r#"
135 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
136 FROM files
137 WHERE path = $1 AND deleted_at IS NULL
138 "#,
139 path
140 )
141 .fetch_optional(self.pool.as_ref())
142 .await
143 .context(format!("Failed to find file by path: {path}"))
144 }
145
146 pub async fn list_by_user(
147 &self,
148 user_id: &UserId,
149 limit: i64,
150 offset: i64,
151 ) -> Result<Vec<File>> {
152 let user_id_str = user_id.as_str();
153 sqlx::query_as!(
154 File,
155 r#"
156 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
157 FROM files
158 WHERE user_id = $1 AND deleted_at IS NULL
159 ORDER BY created_at DESC
160 LIMIT $2 OFFSET $3
161 "#,
162 user_id_str,
163 limit,
164 offset
165 )
166 .fetch_all(self.pool.as_ref())
167 .await
168 .context(format!("Failed to list files for user: {user_id}"))
169 }
170
171 pub async fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<File>> {
172 sqlx::query_as!(
173 File,
174 r#"
175 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
176 FROM files
177 WHERE deleted_at IS NULL
178 ORDER BY created_at DESC
179 LIMIT $1 OFFSET $2
180 "#,
181 limit,
182 offset
183 )
184 .fetch_all(self.pool.as_ref())
185 .await
186 .context("Failed to list all files")
187 }
188
189 pub async fn delete(&self, id: &FileId) -> Result<()> {
190 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
191
192 sqlx::query!(
193 r#"
194 DELETE FROM files
195 WHERE id = $1
196 "#,
197 id_uuid
198 )
199 .execute(self.pool.as_ref())
200 .await
201 .context(format!("Failed to delete file: {id}"))?;
202
203 Ok(())
204 }
205
206 pub async fn update_metadata(&self, id: &FileId, metadata: &FileMetadata) -> Result<()> {
207 let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
208 let metadata_json = serde_json::to_value(metadata)?;
209 let now = Utc::now();
210
211 sqlx::query!(
212 r#"
213 UPDATE files
214 SET metadata = $1, updated_at = $2
215 WHERE id = $3
216 "#,
217 metadata_json,
218 now,
219 id_uuid
220 )
221 .execute(self.pool.as_ref())
222 .await
223 .context(format!("Failed to update metadata for file: {id}"))?;
224
225 Ok(())
226 }
227
228 pub async fn search_by_path(&self, query: &str, limit: i64) -> Result<Vec<File>> {
229 let pattern = format!("%{query}%");
230 sqlx::query_as!(
231 File,
232 r#"
233 SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata,
234 user_id as "user_id: UserId", session_id as "session_id: SessionId",
235 trace_id as "trace_id: TraceId", context_id as "context_id: ContextId",
236 created_at, updated_at, deleted_at
237 FROM files
238 WHERE path ILIKE $1 AND deleted_at IS NULL
239 ORDER BY created_at DESC
240 LIMIT $2
241 "#,
242 pattern,
243 limit
244 )
245 .fetch_all(self.pool.as_ref())
246 .await
247 .context(format!("Failed to search files by path: {query}"))
248 }
249}