Skip to main content

systemprompt_files/repository/file/
mod.rs

1mod request;
2mod stats;
3
4pub use request::InsertFileRequest;
5pub use stats::FileStats;
6
7use std::sync::Arc;
8
9use anyhow::{Context, Result};
10use chrono::Utc;
11use sqlx::PgPool;
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{ContextId, FileId, SessionId, TraceId, UserId};
14
15use crate::models::{File, FileMetadata};
16
17#[derive(Debug, Clone)]
18pub struct FileRepository {
19    pub(crate) pool: Arc<PgPool>,
20    write_pool: Arc<PgPool>,
21}
22
23impl FileRepository {
24    pub fn new(db: &DbPool) -> Result<Self> {
25        let pool = db.pool_arc()?;
26        let write_pool = db.write_pool_arc()?;
27        Ok(Self { pool, write_pool })
28    }
29
30    pub async fn insert(&self, request: InsertFileRequest) -> Result<FileId> {
31        let id_uuid = uuid::Uuid::parse_str(request.id.as_str())
32            .with_context(|| format!("Invalid UUID for file id: {}", request.id.as_str()))?;
33        let now = Utc::now();
34
35        let user_id_str = request.user_id.as_ref().map(UserId::as_str);
36        let session_id_str = request.session_id.as_ref().map(SessionId::as_str);
37        let trace_id_str = request.trace_id.as_ref().map(TraceId::as_str);
38        let context_id_str = request.context_id.as_ref().map(ContextId::as_str);
39
40        sqlx::query_as!(
41            File,
42            r#"
43            INSERT INTO files (id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id, session_id, trace_id, context_id, created_at, updated_at)
44            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12)
45            ON CONFLICT (path) DO UPDATE SET
46                public_url = EXCLUDED.public_url,
47                mime_type = EXCLUDED.mime_type,
48                size_bytes = EXCLUDED.size_bytes,
49                ai_content = EXCLUDED.ai_content,
50                metadata = EXCLUDED.metadata,
51                updated_at = EXCLUDED.updated_at
52            RETURNING id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
53            "#,
54            id_uuid,
55            request.path,
56            request.public_url,
57            request.mime_type,
58            request.size_bytes,
59            request.ai_content,
60            request.metadata,
61            user_id_str,
62            session_id_str,
63            trace_id_str,
64            context_id_str,
65            now
66        )
67        .fetch_one(&*self.write_pool)
68        .await
69        .with_context(|| {
70            format!(
71                "Failed to insert file (id: {}, path: {}, url: {})",
72                request.id.as_str(),
73                request.path,
74                request.public_url
75            )
76        })?;
77
78        Ok(request.id)
79    }
80
81    pub async fn insert_file(&self, file: &File) -> Result<FileId> {
82        let file_id = FileId::new(file.id.to_string());
83
84        let mut request = InsertFileRequest::new(
85            file_id.clone(),
86            file.path.clone(),
87            file.public_url.clone(),
88            file.mime_type.clone(),
89        )
90        .with_ai_content(file.ai_content)
91        .with_metadata(file.metadata.clone());
92
93        if let Some(size) = file.size_bytes {
94            request = request.with_size(size);
95        }
96
97        if let Some(ref user_id) = file.user_id {
98            request = request.with_user_id(user_id.clone());
99        }
100
101        if let Some(ref session_id) = file.session_id {
102            request = request.with_session_id(session_id.clone());
103        }
104
105        if let Some(ref trace_id) = file.trace_id {
106            request = request.with_trace_id(trace_id.clone());
107        }
108
109        if let Some(ref context_id) = file.context_id {
110            request = request.with_context_id(context_id.clone());
111        }
112
113        self.insert(request).await
114    }
115
116    pub async fn find_by_id(&self, id: &FileId) -> Result<Option<File>> {
117        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
118
119        sqlx::query_as!(
120            File,
121            r#"
122            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
123            FROM files
124            WHERE id = $1 AND deleted_at IS NULL
125            "#,
126            id_uuid
127        )
128        .fetch_optional(&*self.pool)
129        .await
130        .context(format!("Failed to find file by id: {id}"))
131    }
132
133    pub async fn find_by_path(&self, path: &str) -> Result<Option<File>> {
134        sqlx::query_as!(
135            File,
136            r#"
137            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
138            FROM files
139            WHERE path = $1 AND deleted_at IS NULL
140            "#,
141            path
142        )
143        .fetch_optional(&*self.pool)
144        .await
145        .context(format!("Failed to find file by path: {path}"))
146    }
147
148    pub async fn list_by_user(
149        &self,
150        user_id: &UserId,
151        limit: i64,
152        offset: i64,
153    ) -> Result<Vec<File>> {
154        let user_id_str = user_id.as_str();
155        sqlx::query_as!(
156            File,
157            r#"
158            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
159            FROM files
160            WHERE user_id = $1 AND deleted_at IS NULL
161            ORDER BY created_at DESC
162            LIMIT $2 OFFSET $3
163            "#,
164            user_id_str,
165            limit,
166            offset
167        )
168        .fetch_all(&*self.pool)
169        .await
170        .context(format!("Failed to list files for user: {user_id}"))
171    }
172
173    pub async fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<File>> {
174        sqlx::query_as!(
175            File,
176            r#"
177            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
178            FROM files
179            WHERE deleted_at IS NULL
180            ORDER BY created_at DESC
181            LIMIT $1 OFFSET $2
182            "#,
183            limit,
184            offset
185        )
186        .fetch_all(&*self.pool)
187        .await
188        .context("Failed to list all files")
189    }
190
191    pub async fn delete(&self, id: &FileId) -> Result<()> {
192        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
193
194        sqlx::query!(
195            r#"
196            DELETE FROM files
197            WHERE id = $1
198            "#,
199            id_uuid
200        )
201        .execute(&*self.write_pool)
202        .await
203        .context(format!("Failed to delete file: {id}"))?;
204
205        Ok(())
206    }
207
208    pub async fn update_metadata(&self, id: &FileId, metadata: &FileMetadata) -> Result<()> {
209        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
210        let metadata_json = serde_json::to_value(metadata)?;
211        let now = Utc::now();
212
213        sqlx::query!(
214            r#"
215            UPDATE files
216            SET metadata = $1, updated_at = $2
217            WHERE id = $3
218            "#,
219            metadata_json,
220            now,
221            id_uuid
222        )
223        .execute(&*self.write_pool)
224        .await
225        .context(format!("Failed to update metadata for file: {id}"))?;
226
227        Ok(())
228    }
229
230    pub async fn search_by_path(&self, query: &str, limit: i64) -> Result<Vec<File>> {
231        let pattern = format!("%{query}%");
232        sqlx::query_as!(
233            File,
234            r#"
235            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata,
236                   user_id as "user_id: UserId", session_id as "session_id: SessionId",
237                   trace_id as "trace_id: TraceId", context_id as "context_id: ContextId",
238                   created_at, updated_at, deleted_at
239            FROM files
240            WHERE path ILIKE $1 AND deleted_at IS NULL
241            ORDER BY created_at DESC
242            LIMIT $2
243            "#,
244            pattern,
245            limit
246        )
247        .fetch_all(&*self.pool)
248        .await
249        .context(format!("Failed to search files by path: {query}"))
250    }
251}