Skip to main content

systemprompt_files/services/upload/
service.rs

1use anyhow::Result;
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD;
4use sha2::{Digest, Sha256};
5use std::path::PathBuf;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, FileId, UserId};
8use tokio::fs;
9use tokio::io::AsyncWriteExt;
10use uuid::Uuid;
11
12use crate::config::{FilePersistenceMode, FilesConfig};
13use crate::models::{FileChecksums, FileMetadata};
14use crate::repository::InsertFileRequest;
15use crate::services::FileService;
16
17use super::error::FileUploadError;
18use super::request::{FileUploadRequest, UploadedFile};
19use super::validator::{FileCategory, FileValidator};
20
21#[derive(Debug, Clone)]
22pub struct FileUploadService {
23    files_config: FilesConfig,
24    file_service: FileService,
25    validator: FileValidator,
26}
27
28impl FileUploadService {
29    pub fn new(db_pool: &DbPool, files_config: FilesConfig) -> Result<Self, FileUploadError> {
30        let upload_config = *files_config.upload();
31        let file_service =
32            FileService::new(db_pool).map_err(|e| FileUploadError::Database(e.to_string()))?;
33        let validator = FileValidator::new(upload_config);
34
35        Ok(Self {
36            files_config,
37            file_service,
38            validator,
39        })
40    }
41
42    pub const fn validator(&self) -> &FileValidator {
43        &self.validator
44    }
45
46    pub fn is_enabled(&self) -> bool {
47        let cfg = self.files_config.upload();
48        cfg.enabled && cfg.persistence_mode != FilePersistenceMode::Disabled
49    }
50
51    pub async fn upload_file(
52        &self,
53        request: FileUploadRequest,
54    ) -> Result<UploadedFile, FileUploadError> {
55        let upload_config = self.files_config.upload();
56
57        if upload_config.persistence_mode == FilePersistenceMode::Disabled {
58            return Err(FileUploadError::PersistenceDisabled);
59        }
60
61        let max_encoded_size = (upload_config.max_file_size_bytes as f64 * 1.34) as usize + 100;
62        if request.bytes_base64.len() > max_encoded_size {
63            return Err(FileUploadError::Base64TooLarge {
64                encoded_size: request.bytes_base64.len(),
65            });
66        }
67
68        let bytes = STANDARD.decode(&request.bytes_base64)?;
69        let size_bytes = bytes.len() as u64;
70
71        let category = self.validator.validate(&request.mime_type, size_bytes)?;
72
73        let file_id = FileId::new(Uuid::new_v4().to_string());
74        let extension = FileValidator::get_extension(&request.mime_type, request.name.as_deref());
75        let filename = format!("{}.{}", file_id.as_str(), extension);
76
77        let (storage_path, relative_path) = self.determine_storage_path(
78            &category,
79            &filename,
80            &request.context_id,
81            request.user_id.as_ref(),
82        )?;
83
84        if let Some(parent) = storage_path.parent() {
85            fs::create_dir_all(parent).await?;
86        }
87
88        let mut file = fs::File::create(&storage_path).await?;
89        file.write_all(&bytes).await?;
90        file.flush().await?;
91
92        let sha256 = hex::encode(Sha256::digest(&bytes));
93
94        let public_url = self.files_config.upload_url(&relative_path);
95
96        let metadata = FileMetadata::new().with_checksums(FileChecksums::new().with_sha256(sha256));
97
98        let metadata_json = serde_json::to_value(&metadata)
99            .map_err(|e| FileUploadError::Database(format!("Failed to serialize metadata: {e}")))?;
100
101        let mut insert_request = InsertFileRequest::new(
102            file_id.clone(),
103            storage_path.to_string_lossy().to_string(),
104            public_url.clone(),
105            request.mime_type.clone(),
106        )
107        .with_size(size_bytes as i64)
108        .with_metadata(metadata_json)
109        .with_context_id(request.context_id.clone());
110
111        if let Some(user_id) = request.user_id {
112            insert_request = insert_request.with_user_id(user_id);
113        }
114
115        if let Some(session_id) = request.session_id {
116            insert_request = insert_request.with_session_id(session_id);
117        }
118
119        if let Some(trace_id) = request.trace_id {
120            insert_request = insert_request.with_trace_id(trace_id);
121        }
122
123        if let Err(e) = self.file_service.insert(insert_request).await {
124            if let Err(cleanup_err) = fs::remove_file(&storage_path).await {
125                tracing::warn!(
126                    path = %storage_path.display(),
127                    error = %cleanup_err,
128                    "Failed to clean up uploaded file after database error"
129                );
130            }
131            return Err(FileUploadError::Database(e.to_string()));
132        }
133
134        Ok(UploadedFile {
135            file_id,
136            path: relative_path,
137            public_url,
138            size_bytes: size_bytes as i64,
139        })
140    }
141
142    fn determine_storage_path(
143        &self,
144        category: &FileCategory,
145        filename: &str,
146        context_id: &ContextId,
147        user_id: Option<&UserId>,
148    ) -> Result<(PathBuf, String), FileUploadError> {
149        let base = self.files_config.uploads();
150        let upload_config = self.files_config.upload();
151
152        let context_str = context_id.as_str();
153        if context_str.contains("..") || context_str.contains('\0') {
154            return Err(FileUploadError::PathValidation(
155                "Invalid context_id: contains path traversal sequence".to_string(),
156            ));
157        }
158
159        if let Some(uid) = user_id {
160            let user_str = uid.as_str();
161            if user_str.contains("..") || user_str.contains('\0') {
162                return Err(FileUploadError::PathValidation(
163                    "Invalid user_id: contains path traversal sequence".to_string(),
164                ));
165            }
166        }
167
168        let (full_path, relative) = match upload_config.persistence_mode {
169            FilePersistenceMode::ContextScoped => {
170                let rel = format!(
171                    "contexts/{}/{}/{}",
172                    context_str,
173                    category.storage_subdir(),
174                    filename
175                );
176                (base.join(&rel), rel)
177            },
178            FilePersistenceMode::UserLibrary => {
179                let user_dir =
180                    user_id.map_or_else(|| "anonymous".to_string(), |u| u.as_str().to_string());
181                let rel = format!(
182                    "users/{}/{}/{}",
183                    user_dir,
184                    category.storage_subdir(),
185                    filename
186                );
187                (base.join(&rel), rel)
188            },
189            FilePersistenceMode::Disabled => {
190                return Err(FileUploadError::PersistenceDisabled);
191            },
192        };
193
194        if !full_path.starts_with(&base) {
195            return Err(FileUploadError::PathValidation(
196                "Resolved path escapes upload directory".to_string(),
197            ));
198        }
199
200        Ok((full_path, relative))
201    }
202}