Skip to main content

systemprompt_files/services/upload/
service.rs

1use base64::Engine;
2use base64::engine::general_purpose::STANDARD;
3use sha2::{Digest, Sha256};
4use std::path::PathBuf;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{ContextId, FileId, UserId};
7use tokio::fs;
8use tokio::io::AsyncWriteExt;
9use uuid::Uuid;
10
11use crate::config::{FilePersistenceMode, FilesConfig};
12use crate::models::{FileChecksums, FileMetadata};
13use crate::repository::{FileRepository, InsertFileRequest};
14
15use super::error::FileUploadError;
16use super::request::{FileUploadRequest, UploadedFile};
17use super::validator::{FileCategory, FileValidator};
18
19#[derive(Debug, Clone)]
20pub struct FileUploadService {
21    files_config: FilesConfig,
22    file_repository: FileRepository,
23    validator: FileValidator,
24}
25
26impl FileUploadService {
27    pub fn new(db_pool: &DbPool, files_config: FilesConfig) -> Result<Self, FileUploadError> {
28        let upload_config = *files_config.upload();
29        let file_repository =
30            FileRepository::new(db_pool).map_err(|e| FileUploadError::Database(e.to_string()))?;
31        let validator = FileValidator::new(upload_config);
32
33        Ok(Self {
34            files_config,
35            file_repository,
36            validator,
37        })
38    }
39
40    pub const fn validator(&self) -> &FileValidator {
41        &self.validator
42    }
43
44    pub fn is_enabled(&self) -> bool {
45        let cfg = self.files_config.upload();
46        cfg.enabled && cfg.persistence_mode != FilePersistenceMode::Disabled
47    }
48
49    pub async fn upload_file(
50        &self,
51        request: FileUploadRequest,
52    ) -> Result<UploadedFile, FileUploadError> {
53        let upload_config = self.files_config.upload();
54
55        if upload_config.persistence_mode == FilePersistenceMode::Disabled {
56            return Err(FileUploadError::PersistenceDisabled);
57        }
58
59        let max_encoded_size = (upload_config.max_file_size_bytes as f64 * 1.34) as usize + 100;
60        if request.bytes_base64.len() > max_encoded_size {
61            return Err(FileUploadError::Base64TooLarge {
62                encoded_size: request.bytes_base64.len(),
63            });
64        }
65
66        let bytes = STANDARD.decode(&request.bytes_base64)?;
67        let size_bytes = bytes.len() as u64;
68
69        let category = self.validator.validate(&request.mime_type, size_bytes)?;
70
71        let file_id = FileId::new(Uuid::new_v4().to_string());
72        let extension = FileValidator::get_extension(&request.mime_type, request.name.as_deref());
73        let filename = format!("{}.{}", file_id.as_str(), extension);
74
75        let (storage_path, relative_path) = self.determine_storage_path(
76            &category,
77            &filename,
78            &request.context_id,
79            request.user_id.as_ref(),
80        )?;
81
82        if let Some(parent) = storage_path.parent() {
83            fs::create_dir_all(parent).await?;
84        }
85
86        let mut file = fs::File::create(&storage_path).await?;
87        file.write_all(&bytes).await?;
88        file.flush().await?;
89
90        let sha256 = hex::encode(Sha256::digest(&bytes));
91
92        let public_url = self.files_config.upload_url(&relative_path);
93
94        let metadata = FileMetadata::new().with_checksums(FileChecksums::new().with_sha256(sha256));
95
96        let metadata_json = serde_json::to_value(&metadata)
97            .map_err(|e| FileUploadError::Database(format!("Failed to serialize metadata: {e}")))?;
98
99        let mut insert_request = InsertFileRequest::new(
100            file_id.clone(),
101            storage_path.to_string_lossy().to_string(),
102            public_url.clone(),
103            request.mime_type.clone(),
104        )
105        .with_size(size_bytes as i64)
106        .with_metadata(metadata_json)
107        .with_context_id(request.context_id.clone());
108
109        if let Some(user_id) = request.user_id {
110            insert_request = insert_request.with_user_id(user_id);
111        }
112
113        if let Some(session_id) = request.session_id {
114            insert_request = insert_request.with_session_id(session_id);
115        }
116
117        if let Some(trace_id) = request.trace_id {
118            insert_request = insert_request.with_trace_id(trace_id);
119        }
120
121        if let Err(e) = self.file_repository.insert(insert_request).await {
122            if let Err(cleanup_err) = fs::remove_file(&storage_path).await {
123                tracing::warn!(
124                    path = %storage_path.display(),
125                    error = %cleanup_err,
126                    "Failed to clean up uploaded file after database error"
127                );
128            }
129            return Err(FileUploadError::Database(e.to_string()));
130        }
131
132        Ok(UploadedFile {
133            file_id,
134            path: relative_path,
135            public_url,
136            size_bytes: size_bytes as i64,
137        })
138    }
139
140    fn determine_storage_path(
141        &self,
142        category: &FileCategory,
143        filename: &str,
144        context_id: &ContextId,
145        user_id: Option<&UserId>,
146    ) -> Result<(PathBuf, String), FileUploadError> {
147        let base = self.files_config.uploads();
148        let upload_config = self.files_config.upload();
149
150        let context_str = context_id.as_str();
151        if context_str.contains("..") || context_str.contains('\0') {
152            return Err(FileUploadError::PathValidation(
153                "Invalid context_id: contains path traversal sequence".to_string(),
154            ));
155        }
156
157        if let Some(uid) = user_id {
158            let user_str = uid.as_str();
159            if user_str.contains("..") || user_str.contains('\0') {
160                return Err(FileUploadError::PathValidation(
161                    "Invalid user_id: contains path traversal sequence".to_string(),
162                ));
163            }
164        }
165
166        let (full_path, relative) = match upload_config.persistence_mode {
167            FilePersistenceMode::ContextScoped => {
168                let rel = format!(
169                    "contexts/{}/{}/{}",
170                    context_str,
171                    category.storage_subdir(),
172                    filename
173                );
174                (base.join(&rel), rel)
175            },
176            FilePersistenceMode::UserLibrary => {
177                let user_dir =
178                    user_id.map_or_else(|| "anonymous".to_string(), |u| u.as_str().to_string());
179                let rel = format!(
180                    "users/{}/{}/{}",
181                    user_dir,
182                    category.storage_subdir(),
183                    filename
184                );
185                (base.join(&rel), rel)
186            },
187            FilePersistenceMode::Disabled => {
188                return Err(FileUploadError::PersistenceDisabled);
189            },
190        };
191
192        if !full_path.starts_with(&base) {
193            return Err(FileUploadError::PathValidation(
194                "Resolved path escapes upload directory".to_string(),
195            ));
196        }
197
198        Ok((full_path, relative))
199    }
200}