systemprompt_files/services/upload/
service.rs1use anyhow::Result;
2use base64::engine::general_purpose::STANDARD;
3use base64::Engine;
4use sha2::{Digest, Sha256};
5use std::path::PathBuf;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, FileId, UserId};
8use tokio::fs;
9use tokio::io::AsyncWriteExt;
10use uuid::Uuid;
11
12use crate::config::{FilePersistenceMode, FilesConfig};
13use crate::models::{FileChecksums, FileMetadata};
14use crate::repository::InsertFileRequest;
15use crate::services::FileService;
16
17use super::error::FileUploadError;
18use super::request::{FileUploadRequest, UploadedFile};
19use super::validator::{FileCategory, FileValidator};
20
21#[derive(Debug, Clone)]
22pub struct FileUploadService {
23 files_config: FilesConfig,
24 file_service: FileService,
25 validator: FileValidator,
26}
27
28impl FileUploadService {
29 pub fn new(db_pool: &DbPool, files_config: FilesConfig) -> Result<Self, FileUploadError> {
30 let upload_config = *files_config.upload();
31 let file_service =
32 FileService::new(db_pool).map_err(|e| FileUploadError::Database(e.to_string()))?;
33 let validator = FileValidator::new(upload_config);
34
35 Ok(Self {
36 files_config,
37 file_service,
38 validator,
39 })
40 }
41
42 pub const fn validator(&self) -> &FileValidator {
43 &self.validator
44 }
45
46 pub fn is_enabled(&self) -> bool {
47 let cfg = self.files_config.upload();
48 cfg.enabled && cfg.persistence_mode != FilePersistenceMode::Disabled
49 }
50
51 pub async fn upload_file(
52 &self,
53 request: FileUploadRequest,
54 ) -> Result<UploadedFile, FileUploadError> {
55 let upload_config = self.files_config.upload();
56
57 if upload_config.persistence_mode == FilePersistenceMode::Disabled {
58 return Err(FileUploadError::PersistenceDisabled);
59 }
60
61 let max_encoded_size = (upload_config.max_file_size_bytes as f64 * 1.34) as usize + 100;
62 if request.bytes_base64.len() > max_encoded_size {
63 return Err(FileUploadError::Base64TooLarge {
64 encoded_size: request.bytes_base64.len(),
65 });
66 }
67
68 let bytes = STANDARD.decode(&request.bytes_base64)?;
69 let size_bytes = bytes.len() as u64;
70
71 let category = self.validator.validate(&request.mime_type, size_bytes)?;
72
73 let file_id = FileId::new(Uuid::new_v4().to_string());
74 let extension = FileValidator::get_extension(&request.mime_type, request.name.as_deref());
75 let filename = format!("{}.{}", file_id.as_str(), extension);
76
77 let (storage_path, relative_path) = self.determine_storage_path(
78 &category,
79 &filename,
80 &request.context_id,
81 request.user_id.as_ref(),
82 )?;
83
84 if let Some(parent) = storage_path.parent() {
85 fs::create_dir_all(parent).await?;
86 }
87
88 let mut file = fs::File::create(&storage_path).await?;
89 file.write_all(&bytes).await?;
90 file.flush().await?;
91
92 let sha256 = hex::encode(Sha256::digest(&bytes));
93
94 let public_url = self.files_config.upload_url(&relative_path);
95
96 let metadata = FileMetadata::new().with_checksums(FileChecksums::new().with_sha256(sha256));
97
98 let metadata_json = serde_json::to_value(&metadata)
99 .map_err(|e| FileUploadError::Database(format!("Failed to serialize metadata: {e}")))?;
100
101 let mut insert_request = InsertFileRequest::new(
102 file_id.clone(),
103 storage_path.to_string_lossy().to_string(),
104 public_url.clone(),
105 request.mime_type.clone(),
106 )
107 .with_size(size_bytes as i64)
108 .with_metadata(metadata_json)
109 .with_context_id(request.context_id.clone());
110
111 if let Some(user_id) = request.user_id {
112 insert_request = insert_request.with_user_id(user_id);
113 }
114
115 if let Some(session_id) = request.session_id {
116 insert_request = insert_request.with_session_id(session_id);
117 }
118
119 if let Some(trace_id) = request.trace_id {
120 insert_request = insert_request.with_trace_id(trace_id);
121 }
122
123 if let Err(e) = self.file_service.insert(insert_request).await {
124 let _ = fs::remove_file(&storage_path).await;
125 return Err(FileUploadError::Database(e.to_string()));
126 }
127
128 Ok(UploadedFile {
129 file_id,
130 path: relative_path,
131 public_url,
132 size_bytes: size_bytes as i64,
133 })
134 }
135
136 fn determine_storage_path(
137 &self,
138 category: &FileCategory,
139 filename: &str,
140 context_id: &ContextId,
141 user_id: Option<&UserId>,
142 ) -> Result<(PathBuf, String), FileUploadError> {
143 let base = self.files_config.uploads();
144 let upload_config = self.files_config.upload();
145
146 let context_str = context_id.as_str();
147 if context_str.contains("..") || context_str.contains('\0') {
148 return Err(FileUploadError::PathValidation(
149 "Invalid context_id: contains path traversal sequence".to_string(),
150 ));
151 }
152
153 if let Some(uid) = user_id {
154 let user_str = uid.as_str();
155 if user_str.contains("..") || user_str.contains('\0') {
156 return Err(FileUploadError::PathValidation(
157 "Invalid user_id: contains path traversal sequence".to_string(),
158 ));
159 }
160 }
161
162 let (full_path, relative) = match upload_config.persistence_mode {
163 FilePersistenceMode::ContextScoped => {
164 let rel = format!(
165 "contexts/{}/{}/{}",
166 context_str,
167 category.storage_subdir(),
168 filename
169 );
170 (base.join(&rel), rel)
171 },
172 FilePersistenceMode::UserLibrary => {
173 let user_dir =
174 user_id.map_or_else(|| "anonymous".to_string(), |u| u.as_str().to_string());
175 let rel = format!(
176 "users/{}/{}/{}",
177 user_dir,
178 category.storage_subdir(),
179 filename
180 );
181 (base.join(&rel), rel)
182 },
183 FilePersistenceMode::Disabled => {
184 return Err(FileUploadError::PersistenceDisabled);
185 },
186 };
187
188 if !full_path.starts_with(&base) {
189 return Err(FileUploadError::PathValidation(
190 "Resolved path escapes upload directory".to_string(),
191 ));
192 }
193
194 Ok((full_path, relative))
195 }
196}