systemprompt_files/services/upload/
service.rs1use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use sha2::{Digest, Sha256};
11use std::path::PathBuf;
12use systemprompt_database::DbPool;
13use systemprompt_identifiers::{ContextId, FileId, UserId};
14use tokio::fs;
15use tokio::io::AsyncWriteExt;
16use uuid::Uuid;
17
18use crate::config::{FilePersistenceMode, FilesConfig};
19use crate::models::{FileChecksums, FileMetadata};
20use crate::repository::{FileRepository, InsertFileRequest};
21
22use super::error::FileUploadError;
23use super::request::{FileUploadRequest, UploadedFile};
24use super::validator::{FileCategory, FileValidator};
25
26struct StoredArtifact<'a> {
27 file_id: &'a FileId,
28 storage_path: &'a std::path::Path,
29 public_url: &'a str,
30 size_bytes: u64,
31 sha256: String,
32}
33
34#[derive(Debug, Clone)]
35pub struct FileUploadService {
36 files_config: FilesConfig,
37 file_repository: FileRepository,
38 validator: FileValidator,
39}
40
41impl FileUploadService {
42 pub fn new(db_pool: &DbPool, files_config: FilesConfig) -> Result<Self, FileUploadError> {
43 let upload_config = *files_config.upload();
44 let file_repository =
45 FileRepository::new(db_pool).map_err(|e| FileUploadError::Database(e.to_string()))?;
46 let validator = FileValidator::new(upload_config);
47
48 Ok(Self {
49 files_config,
50 file_repository,
51 validator,
52 })
53 }
54
55 pub const fn validator(&self) -> &FileValidator {
56 &self.validator
57 }
58
59 pub fn is_enabled(&self) -> bool {
60 let cfg = self.files_config.upload();
61 cfg.enabled && cfg.persistence_mode != FilePersistenceMode::Disabled
62 }
63
64 pub async fn upload_file(
65 &self,
66 request: FileUploadRequest,
67 ) -> Result<UploadedFile, FileUploadError> {
68 let upload_config = self.files_config.upload();
69
70 if upload_config.persistence_mode == FilePersistenceMode::Disabled {
71 return Err(FileUploadError::PersistenceDisabled);
72 }
73
74 let max_encoded_size = (upload_config.max_file_size_bytes as f64 * 1.34) as usize + 100;
75 if request.bytes_base64.len() > max_encoded_size {
76 return Err(FileUploadError::Base64TooLarge {
77 encoded_size: request.bytes_base64.len(),
78 });
79 }
80
81 let bytes = STANDARD.decode(&request.bytes_base64)?;
82 let size_bytes = bytes.len() as u64;
83
84 let category = self.validator.validate(&request.mime_type, size_bytes)?;
85
86 let file_id = FileId::new(Uuid::new_v4().to_string());
87 let extension = FileValidator::get_extension(&request.mime_type, request.name.as_deref());
88 let filename = format!("{}.{}", file_id.as_str(), extension);
89
90 let (storage_path, relative_path) = self.determine_storage_path(
91 &category,
92 &filename,
93 &request.context_id,
94 request.user_id.as_ref(),
95 )?;
96
97 if let Some(parent) = storage_path.parent() {
98 fs::create_dir_all(parent).await?;
99 }
100
101 let mut file = fs::File::create(&storage_path).await?;
102 file.write_all(&bytes).await?;
103 file.flush().await?;
104
105 let sha256 = hex::encode(Sha256::digest(&bytes));
106
107 let public_url = self.files_config.upload_url(&relative_path);
108
109 self.record_file(
110 StoredArtifact {
111 file_id: &file_id,
112 storage_path: &storage_path,
113 public_url: &public_url,
114 size_bytes,
115 sha256,
116 },
117 &request,
118 )
119 .await?;
120
121 Ok(UploadedFile {
122 file_id,
123 path: relative_path,
124 public_url,
125 size_bytes: size_bytes as i64,
126 })
127 }
128
129 async fn record_file(
130 &self,
131 artifact: StoredArtifact<'_>,
132 request: &FileUploadRequest,
133 ) -> Result<(), FileUploadError> {
134 let StoredArtifact {
135 file_id,
136 storage_path,
137 public_url,
138 size_bytes,
139 sha256,
140 } = artifact;
141
142 let metadata = FileMetadata::new().with_checksums(FileChecksums::new().with_sha256(sha256));
143
144 let metadata_json = serde_json::to_value(&metadata)
145 .map_err(|e| FileUploadError::Database(format!("Failed to serialize metadata: {e}")))?;
146
147 let mut insert_request = InsertFileRequest::new(
148 file_id.clone(),
149 storage_path.to_string_lossy().to_string(),
150 public_url.to_owned(),
151 request.mime_type.clone(),
152 )
153 .with_size(size_bytes as i64)
154 .with_metadata(metadata_json)
155 .with_context_id(request.context_id.clone());
156
157 if let Some(user_id) = request.user_id.clone() {
158 insert_request = insert_request.with_user_id(user_id);
159 }
160
161 if let Some(session_id) = request.session_id.clone() {
162 insert_request = insert_request.with_session_id(session_id);
163 }
164
165 if let Some(trace_id) = request.trace_id.clone() {
166 insert_request = insert_request.with_trace_id(trace_id);
167 }
168
169 if let Err(e) = self.file_repository.insert(insert_request).await {
170 if let Err(cleanup_err) = fs::remove_file(storage_path).await {
171 tracing::warn!(
172 path = %storage_path.display(),
173 error = %cleanup_err,
174 "Failed to clean up uploaded file after database error"
175 );
176 }
177 return Err(FileUploadError::Database(e.to_string()));
178 }
179
180 Ok(())
181 }
182
183 fn determine_storage_path(
184 &self,
185 category: &FileCategory,
186 filename: &str,
187 context_id: &ContextId,
188 user_id: Option<&UserId>,
189 ) -> Result<(PathBuf, String), FileUploadError> {
190 let base = self.files_config.uploads();
191 let upload_config = self.files_config.upload();
192
193 let context_str = context_id.as_str();
194 Self::validate_path_inputs(context_str, filename, user_id)?;
195
196 let (full_path, relative) = match upload_config.persistence_mode {
197 FilePersistenceMode::ContextScoped => {
198 let rel = format!(
199 "contexts/{}/{}/{}",
200 context_str,
201 category.storage_subdir(),
202 filename
203 );
204 (base.join(&rel), rel)
205 },
206 FilePersistenceMode::UserLibrary => {
207 let user_dir =
208 user_id.map_or_else(|| "anonymous".to_owned(), |u| u.as_str().to_owned());
209 let rel = format!(
210 "users/{}/{}/{}",
211 user_dir,
212 category.storage_subdir(),
213 filename
214 );
215 (base.join(&rel), rel)
216 },
217 FilePersistenceMode::Disabled => {
218 return Err(FileUploadError::PersistenceDisabled);
219 },
220 };
221
222 for component in std::path::Path::new(&relative).components() {
223 use std::path::Component;
224 match component {
225 Component::Normal(_) | Component::CurDir => {},
226 Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
227 return Err(FileUploadError::PathValidation(
228 "Resolved path contains traversal or absolute component".to_owned(),
229 ));
230 },
231 }
232 }
233
234 if !full_path.starts_with(&base) {
235 return Err(FileUploadError::PathValidation(
236 "Resolved path escapes upload directory".to_owned(),
237 ));
238 }
239
240 Ok((full_path, relative))
241 }
242
243 fn validate_path_inputs(
244 context_str: &str,
245 filename: &str,
246 user_id: Option<&UserId>,
247 ) -> Result<(), FileUploadError> {
248 if context_str.contains("..") || context_str.contains('\0') {
249 return Err(FileUploadError::PathValidation(
250 "Invalid context_id: contains path traversal sequence".to_owned(),
251 ));
252 }
253
254 if let Some(uid) = user_id {
255 let user_str = uid.as_str();
256 if user_str.contains("..") || user_str.contains('\0') {
257 return Err(FileUploadError::PathValidation(
258 "Invalid user_id: contains path traversal sequence".to_owned(),
259 ));
260 }
261 }
262
263 if filename.contains('\0')
264 || filename.contains('/')
265 || filename.contains('\\')
266 || filename == ".."
267 || filename == "."
268 || filename.is_empty()
269 {
270 return Err(FileUploadError::PathValidation(
271 "Invalid filename: must be a single path component".to_owned(),
272 ));
273 }
274
275 Ok(())
276 }
277}