Skip to main content

systemprompt_cli/commands/core/files/
upload.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Result, anyhow};
4use base64::Engine;
5use base64::engine::general_purpose::STANDARD;
6use clap::Args;
7use sha2::{Digest, Sha256};
8use systemprompt_files::{FileUploadRequest, FileUploadService, FilesConfig};
9use systemprompt_identifiers::{ContextId, SessionId, UserId};
10use systemprompt_runtime::AppContext;
11use tokio::fs;
12
13use super::types::FileUploadOutput;
14use crate::CliConfig;
15use crate::shared::CommandOutput;
16
17#[derive(Debug, Clone, Args)]
18pub struct UploadArgs {
19    #[arg(help = "Path to file to upload")]
20    pub file_path: PathBuf,
21
22    #[arg(long, help = "Context ID (required)")]
23    pub context: String,
24
25    #[arg(long, help = "User ID")]
26    pub user: Option<String>,
27
28    #[arg(long, help = "Session ID")]
29    pub session: Option<String>,
30
31    #[arg(long, help = "Mark as AI-generated content")]
32    pub ai: bool,
33}
34
35pub async fn execute(args: UploadArgs, _config: &CliConfig) -> Result<CommandOutput> {
36    let ctx = AppContext::new().await?;
37    let files_config = FilesConfig::get()?;
38    let service = FileUploadService::new(ctx.db_pool(), files_config.clone())?;
39
40    if !service.is_enabled() {
41        return Err(anyhow!("File uploads are disabled in configuration"));
42    }
43
44    let file_path = args
45        .file_path
46        .canonicalize()
47        .map_err(|e| anyhow!("File not found: {} - {}", args.file_path.display(), e))?;
48
49    let bytes = fs::read(&file_path).await?;
50    let bytes_base64 = STANDARD.encode(&bytes);
51    let digest = Sha256::digest(&bytes);
52    let checksum_sha256 = digest.iter().fold(String::with_capacity(64), |mut acc, b| {
53        acc.push_str(&format!("{b:02x}"));
54        acc
55    });
56    let size_bytes = bytes.len() as i64;
57
58    let mime_type = detect_mime_type(&file_path);
59    let filename = file_path
60        .file_name()
61        .and_then(|n| n.to_str())
62        .map(String::from);
63
64    let context_id = ContextId::new(args.context);
65
66    let request = FileUploadRequest {
67        name: filename,
68        mime_type: mime_type.clone(),
69        bytes_base64,
70        context_id,
71        user_id: args.user.map(UserId::new),
72        session_id: args.session.map(SessionId::new),
73        trace_id: None,
74    };
75
76    let result = service.upload_file(request).await?;
77
78    let output = FileUploadOutput {
79        file_id: result.file_id,
80        path: result.path,
81        public_url: result.public_url,
82        size_bytes,
83        mime_type,
84        checksum_sha256,
85    };
86
87    Ok(CommandOutput::card_value("File Uploaded", &output))
88}
89
90const EXTENSION_MIME_TABLE: &[(&[&str], &str)] = &[
91    (&["jpg", "jpeg"], "image/jpeg"),
92    (&["png"], "image/png"),
93    (&["gif"], "image/gif"),
94    (&["webp"], "image/webp"),
95    (&["svg"], "image/svg+xml"),
96    (&["bmp"], "image/bmp"),
97    (&["tiff", "tif"], "image/tiff"),
98    (&["ico"], "image/x-icon"),
99    (&["pdf"], "application/pdf"),
100    (&["doc"], "application/msword"),
101    (
102        &["docx"],
103        "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
104    ),
105    (&["xls"], "application/vnd.ms-excel"),
106    (
107        &["xlsx"],
108        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
109    ),
110    (&["ppt"], "application/vnd.ms-powerpoint"),
111    (
112        &["pptx"],
113        "application/vnd.openxmlformats-officedocument.presentationml.presentation",
114    ),
115    (&["txt"], "text/plain"),
116    (&["csv"], "text/csv"),
117    (&["md"], "text/markdown"),
118    (&["html", "htm"], "text/html"),
119    (&["json"], "application/json"),
120    (&["xml"], "application/xml"),
121    (&["rtf"], "application/rtf"),
122    (&["mp3"], "audio/mpeg"),
123    (&["wav"], "audio/wav"),
124    (&["ogg"], "audio/ogg"),
125    (&["aac"], "audio/aac"),
126    (&["flac"], "audio/flac"),
127    (&["m4a"], "audio/mp4"),
128    (&["mp4"], "video/mp4"),
129    (&["webm"], "video/webm"),
130    (&["mov"], "video/quicktime"),
131    (&["avi"], "video/x-msvideo"),
132    (&["mkv"], "video/x-matroska"),
133];
134
135pub fn detect_mime_type(path: &Path) -> String {
136    let extension = path
137        .extension()
138        .and_then(|e| e.to_str())
139        .map(str::to_lowercase);
140    let Some(ext) = extension.as_deref() else {
141        return "application/octet-stream".to_owned();
142    };
143    EXTENSION_MIME_TABLE
144        .iter()
145        .find(|(exts, _)| exts.contains(&ext))
146        .map_or("application/octet-stream", |(_, mime)| *mime)
147        .to_owned()
148}