Skip to main content

systemprompt_files/services/upload/
service.rs

1//! [`FileUploadService`]: decode, validate, store, and record uploads.
2//!
3//! Decodes base64 payloads, enforces upload policy via [`FileValidator`],
4//! writes bytes to a traversal-checked storage path derived from the
5//! persistence mode, and records the file through [`FileRepository`], cleaning
6//! up the on-disk artefact if the database write fails.
7
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use sha2::{Digest, Sha256};
11use std::path::PathBuf;
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{ContextId, FileId, UserId};
14use tokio::fs;
15use tokio::io::AsyncWriteExt;
16use uuid::Uuid;
17
18use crate::config::{FilePersistenceMode, FilesConfig};
19use crate::models::{FileChecksums, FileMetadata};
20use crate::repository::{FileRepository, InsertFileRequest};
21
22use super::error::FileUploadError;
23use super::request::{FileUploadRequest, UploadedFile};
24use super::validator::{FileCategory, FileValidator};
25
26struct StoredArtifact<'a> {
27    file_id: &'a FileId,
28    storage_path: &'a std::path::Path,
29    public_url: &'a str,
30    size_bytes: u64,
31    sha256: String,
32}
33
34#[derive(Debug, Clone)]
35pub struct FileUploadService {
36    files_config: FilesConfig,
37    file_repository: FileRepository,
38    validator: FileValidator,
39}
40
41impl FileUploadService {
42    pub fn new(db_pool: &DbPool, files_config: FilesConfig) -> Result<Self, FileUploadError> {
43        let upload_config = *files_config.upload();
44        let file_repository =
45            FileRepository::new(db_pool).map_err(|e| FileUploadError::Database(e.to_string()))?;
46        let validator = FileValidator::new(upload_config);
47
48        Ok(Self {
49            files_config,
50            file_repository,
51            validator,
52        })
53    }
54
55    pub const fn validator(&self) -> &FileValidator {
56        &self.validator
57    }
58
59    pub fn is_enabled(&self) -> bool {
60        let cfg = self.files_config.upload();
61        cfg.enabled && cfg.persistence_mode != FilePersistenceMode::Disabled
62    }
63
64    pub async fn upload_file(
65        &self,
66        request: FileUploadRequest,
67    ) -> Result<UploadedFile, FileUploadError> {
68        let upload_config = self.files_config.upload();
69
70        if upload_config.persistence_mode == FilePersistenceMode::Disabled {
71            return Err(FileUploadError::PersistenceDisabled);
72        }
73
74        let max_encoded_size = (upload_config.max_file_size_bytes as f64 * 1.34) as usize + 100;
75        if request.bytes_base64.len() > max_encoded_size {
76            return Err(FileUploadError::Base64TooLarge {
77                encoded_size: request.bytes_base64.len(),
78            });
79        }
80
81        let bytes = STANDARD.decode(&request.bytes_base64)?;
82        let size_bytes = bytes.len() as u64;
83
84        let category = self.validator.validate(&request.mime_type, size_bytes)?;
85
86        let file_id = FileId::new(Uuid::new_v4().to_string());
87        let extension = FileValidator::get_extension(&request.mime_type, request.name.as_deref());
88        let filename = format!("{}.{}", file_id.as_str(), extension);
89
90        let (storage_path, relative_path) = self.determine_storage_path(
91            &category,
92            &filename,
93            &request.context_id,
94            request.user_id.as_ref(),
95        )?;
96
97        if let Some(parent) = storage_path.parent() {
98            fs::create_dir_all(parent).await?;
99        }
100
101        let mut file = fs::File::create(&storage_path).await?;
102        file.write_all(&bytes).await?;
103        file.flush().await?;
104
105        let sha256 = hex::encode(Sha256::digest(&bytes));
106
107        let public_url = self.files_config.upload_url(&relative_path);
108
109        self.record_file(
110            StoredArtifact {
111                file_id: &file_id,
112                storage_path: &storage_path,
113                public_url: &public_url,
114                size_bytes,
115                sha256,
116            },
117            &request,
118        )
119        .await?;
120
121        Ok(UploadedFile {
122            file_id,
123            path: relative_path,
124            public_url,
125            size_bytes: size_bytes as i64,
126        })
127    }
128
129    async fn record_file(
130        &self,
131        artifact: StoredArtifact<'_>,
132        request: &FileUploadRequest,
133    ) -> Result<(), FileUploadError> {
134        let StoredArtifact {
135            file_id,
136            storage_path,
137            public_url,
138            size_bytes,
139            sha256,
140        } = artifact;
141
142        let metadata = FileMetadata::new().with_checksums(FileChecksums::new().with_sha256(sha256));
143
144        let metadata_json = serde_json::to_value(&metadata)
145            .map_err(|e| FileUploadError::Database(format!("Failed to serialize metadata: {e}")))?;
146
147        let mut insert_request = InsertFileRequest::new(
148            file_id.clone(),
149            storage_path.to_string_lossy().to_string(),
150            public_url.to_owned(),
151            request.mime_type.clone(),
152        )
153        .with_size(size_bytes as i64)
154        .with_metadata(metadata_json)
155        .with_context_id(request.context_id.clone());
156
157        if let Some(user_id) = request.user_id.clone() {
158            insert_request = insert_request.with_user_id(user_id);
159        }
160
161        if let Some(session_id) = request.session_id.clone() {
162            insert_request = insert_request.with_session_id(session_id);
163        }
164
165        if let Some(trace_id) = request.trace_id.clone() {
166            insert_request = insert_request.with_trace_id(trace_id);
167        }
168
169        if let Err(e) = self.file_repository.insert(insert_request).await {
170            if let Err(cleanup_err) = fs::remove_file(storage_path).await {
171                tracing::warn!(
172                    path = %storage_path.display(),
173                    error = %cleanup_err,
174                    "Failed to clean up uploaded file after database error"
175                );
176            }
177            return Err(FileUploadError::Database(e.to_string()));
178        }
179
180        Ok(())
181    }
182
183    fn determine_storage_path(
184        &self,
185        category: &FileCategory,
186        filename: &str,
187        context_id: &ContextId,
188        user_id: Option<&UserId>,
189    ) -> Result<(PathBuf, String), FileUploadError> {
190        let base = self.files_config.uploads();
191        let upload_config = self.files_config.upload();
192
193        let context_str = context_id.as_str();
194        Self::validate_path_inputs(context_str, filename, user_id)?;
195
196        let (full_path, relative) = match upload_config.persistence_mode {
197            FilePersistenceMode::ContextScoped => {
198                let rel = format!(
199                    "contexts/{}/{}/{}",
200                    context_str,
201                    category.storage_subdir(),
202                    filename
203                );
204                (base.join(&rel), rel)
205            },
206            FilePersistenceMode::UserLibrary => {
207                let user_dir =
208                    user_id.map_or_else(|| "anonymous".to_owned(), |u| u.as_str().to_owned());
209                let rel = format!(
210                    "users/{}/{}/{}",
211                    user_dir,
212                    category.storage_subdir(),
213                    filename
214                );
215                (base.join(&rel), rel)
216            },
217            FilePersistenceMode::Disabled => {
218                return Err(FileUploadError::PersistenceDisabled);
219            },
220        };
221
222        for component in std::path::Path::new(&relative).components() {
223            use std::path::Component;
224            match component {
225                Component::Normal(_) | Component::CurDir => {},
226                Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
227                    return Err(FileUploadError::PathValidation(
228                        "Resolved path contains traversal or absolute component".to_owned(),
229                    ));
230                },
231            }
232        }
233
234        if !full_path.starts_with(&base) {
235            return Err(FileUploadError::PathValidation(
236                "Resolved path escapes upload directory".to_owned(),
237            ));
238        }
239
240        Ok((full_path, relative))
241    }
242
243    fn validate_path_inputs(
244        context_str: &str,
245        filename: &str,
246        user_id: Option<&UserId>,
247    ) -> Result<(), FileUploadError> {
248        if context_str.contains("..") || context_str.contains('\0') {
249            return Err(FileUploadError::PathValidation(
250                "Invalid context_id: contains path traversal sequence".to_owned(),
251            ));
252        }
253
254        if let Some(uid) = user_id {
255            let user_str = uid.as_str();
256            if user_str.contains("..") || user_str.contains('\0') {
257                return Err(FileUploadError::PathValidation(
258                    "Invalid user_id: contains path traversal sequence".to_owned(),
259                ));
260            }
261        }
262
263        if filename.contains('\0')
264            || filename.contains('/')
265            || filename.contains('\\')
266            || filename == ".."
267            || filename == "."
268            || filename.is_empty()
269        {
270            return Err(FileUploadError::PathValidation(
271                "Invalid filename: must be a single path component".to_owned(),
272            ));
273        }
274
275        Ok(())
276    }
277}