Skip to main content

soli_proxy/app/
deployment.rs

1use anyhow::Result;
2use std::collections::HashSet;
3use std::path::PathBuf;
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use tokio::time::sleep;
7
8use super::AppInfo;
9
10/// Validate that a name is safe for use in filesystem paths.
11/// Rejects empty strings, path separators, "..", and control characters.
12fn validate_path_component(name: &str, label: &str) -> Result<()> {
13    if name.is_empty() {
14        anyhow::bail!("{} cannot be empty", label);
15    }
16    if name.contains('/') || name.contains('\\') || name.contains('\0') {
17        anyhow::bail!("{} contains invalid path characters: {:?}", label, name);
18    }
19    if name == "." || name == ".." || name.contains("..") {
20        anyhow::bail!("{} contains path traversal: {:?}", label, name);
21    }
22    if name.chars().any(|c| c.is_control()) {
23        anyhow::bail!("{} contains control characters: {:?}", label, name);
24    }
25    Ok(())
26}
27
28/// Parse a start script into a program and arguments without using a shell.
29/// Performs variable substitution for $PORT and $WORKERS.
30/// This avoids shell injection by never passing the script through `sh -c`.
31fn parse_start_command(script: &str, port: u16, workers: u16) -> Result<(String, Vec<String>)> {
32    let tokens: Vec<&str> = script.split_whitespace().collect();
33    if tokens.is_empty() {
34        anyhow::bail!("Start script is empty");
35    }
36
37    let port_str = port.to_string();
38    let workers_str = workers.to_string();
39
40    let program = tokens[0]
41        .replace("$PORT", &port_str)
42        .replace("$WORKERS", &workers_str);
43
44    let args: Vec<String> = tokens[1..]
45        .iter()
46        .map(|t| {
47            t.replace("$PORT", &port_str)
48                .replace("$WORKERS", &workers_str)
49        })
50        .collect();
51
52    Ok((program, args))
53}
54
55#[derive(Debug, Clone, PartialEq)]
56pub enum DeploymentStatus {
57    Idle,
58    Deploying,
59    RollingBack,
60    Failed(String),
61}
62
63pub struct DeploymentManager {
64    /// Per-app deployment locks: contains app names currently being deployed
65    deploying_apps: Arc<Mutex<HashSet<String>>>,
66    dev_mode: bool,
67    http_client: reqwest::Client,
68}
69
70impl Default for DeploymentManager {
71    fn default() -> Self {
72        Self::new(false)
73    }
74}
75
76impl DeploymentManager {
77    pub fn new(dev_mode: bool) -> Self {
78        let http_client = reqwest::Client::builder()
79            .timeout(Duration::from_secs(5))
80            .build()
81            .unwrap_or_else(|_| reqwest::Client::new());
82
83        Self {
84            deploying_apps: Arc::new(Mutex::new(HashSet::new())),
85            dev_mode,
86            http_client,
87        }
88    }
89
90    pub fn is_deploying(&self, app_name: &str) -> bool {
91        self.deploying_apps.lock().unwrap().contains(app_name)
92    }
93
94    /// Deploy an app to a slot. Returns the PID of the started process.
95    pub async fn deploy(&self, app: &AppInfo, slot: &str) -> Result<u32> {
96        {
97            let mut deploying = self.deploying_apps.lock().unwrap();
98            if deploying.contains(&app.config.name) {
99                anyhow::bail!("Deployment already in progress for {}", app.config.name);
100            }
101            deploying.insert(app.config.name.clone());
102        }
103
104        let deploying_apps = self.deploying_apps.clone();
105        let app_name = app.config.name.clone();
106        let _guard = scopeguard::guard((), move |_| {
107            deploying_apps.lock().unwrap().remove(&app_name);
108        });
109
110        tracing::info!(
111            "Starting deployment of {} to slot {}",
112            app.config.name,
113            slot
114        );
115
116        let pid = self.start_instance(app, slot).await?;
117
118        let healthy = self.wait_for_health(app, slot).await?;
119
120        if !healthy {
121            self.stop_instance(app, slot).await?;
122            anyhow::bail!("Health check failed for {} slot", slot);
123        }
124
125        tracing::info!("Health check passed for {} slot {}", app.config.name, slot);
126        Ok(pid)
127    }
128
129    async fn start_instance(&self, app: &AppInfo, slot: &str) -> Result<u32> {
130        // Validate slot and app name to prevent path traversal in log paths
131        if slot != "blue" && slot != "green" {
132            anyhow::bail!("Invalid slot name: {:?}", slot);
133        }
134        validate_path_component(&app.config.name, "App name")?;
135
136        let port = if slot == "blue" {
137            app.blue.port
138        } else {
139            app.green.port
140        };
141
142        let base_script = if let Some(ref script) = app.config.start_script {
143            script.clone()
144        } else if app.path.join("app").exists() && app.path.join("app/models").exists() {
145            "soli serve .".to_string()
146        } else {
147            anyhow::bail!("No start script configured for {}", app.config.name)
148        };
149
150        let script = if self.dev_mode && base_script.starts_with("soli ") {
151            format!("{} --dev", base_script)
152        } else {
153            base_script.clone()
154        };
155
156        let output_file = PathBuf::from(format!("run/logs/{}/{}.log", app.config.name, slot));
157        std::fs::create_dir_all(output_file.parent().unwrap())?;
158
159        let output = std::fs::File::create(&output_file)?;
160
161        // Parse the script into program + args instead of using `sh -c`
162        // to prevent shell injection from malicious app.infos files.
163        let (program, args) = parse_start_command(&script, port, app.config.workers)?;
164
165        let mut cmd = tokio::process::Command::new(&program);
166        cmd.args(&args)
167            .current_dir(&app.path)
168            .env("PATH", std::env::var("PATH").unwrap_or_default())
169            .env("PORT", port.to_string())
170            .env("WORKERS", app.config.workers.to_string())
171            .stdout(std::process::Stdio::from(output.try_clone()?))
172            .stderr(std::process::Stdio::from(output));
173
174        if let (Some(ref user), Some(ref group)) = (&app.config.user, &app.config.group) {
175            let uid = resolve_user(user)?;
176            let gid = resolve_group(group)?;
177            cmd.uid(uid).gid(gid);
178            tracing::info!(
179                "Running {} as user {} (uid: {}, gid: {})",
180                app.config.name,
181                user,
182                uid,
183                gid
184            );
185        } else if let Some(ref user) = &app.config.user {
186            let uid = resolve_user(user)?;
187            let gid = resolve_group(user)?;
188            cmd.uid(uid).gid(gid);
189            tracing::info!(
190                "Running {} as user {} (uid: {}, gid: {})",
191                app.config.name,
192                user,
193                uid,
194                gid
195            );
196        }
197
198        let cmd = unsafe {
199            cmd.pre_exec(|| {
200                libc::setsid();
201                Ok(())
202            })
203            .spawn()?
204        };
205
206        let pid = cmd.id().unwrap_or(0);
207        tracing::info!("Started {} slot {} with PID {}", app.config.name, slot, pid);
208
209        Ok(pid)
210    }
211
212    pub async fn stop_instance(&self, app: &AppInfo, slot: &str) -> Result<()> {
213        let pid = if slot == "blue" {
214            app.blue.pid
215        } else {
216            app.green.pid
217        };
218
219        if let Some(pid) = pid {
220            tracing::info!("Stopping {} slot {} (PID: {})", app.config.name, slot, pid);
221
222            #[cfg(unix)]
223            {
224                // Kill the entire process group (negative PID) so child processes are included
225                let pgid = format!("-{}", pid);
226
227                tokio::process::Command::new("kill")
228                    .arg("-TERM")
229                    .arg("--")
230                    .arg(&pgid)
231                    .output()
232                    .await?;
233
234                let timeout = app.config.graceful_timeout as u64;
235                let mut waited_ms = 0u64;
236                while waited_ms < timeout * 1000 {
237                    let output = tokio::process::Command::new("kill")
238                        .arg("-0")
239                        .arg(pid.to_string())
240                        .output()
241                        .await?;
242
243                    if !output.status.success() {
244                        tracing::info!("Process {} terminated gracefully", pid);
245                        return Ok(());
246                    }
247                    let delay = if waited_ms < 500 { 50 } else { 200 };
248                    sleep(Duration::from_millis(delay)).await;
249                    waited_ms += delay;
250                }
251
252                tracing::warn!("Force killing process group {}", pid);
253                tokio::process::Command::new("kill")
254                    .arg("-9")
255                    .arg("--")
256                    .arg(&pgid)
257                    .output()
258                    .await?;
259            }
260        }
261
262        Ok(())
263    }
264
265    async fn wait_for_health(&self, app: &AppInfo, slot: &str) -> Result<bool> {
266        let port = if slot == "blue" {
267            app.blue.port
268        } else {
269            app.green.port
270        };
271        let health_path = app.config.health_check.as_deref().unwrap_or("/health");
272
273        let url = format!("http://localhost:{}{}", port, health_path);
274        let timeout_secs = 30;
275
276        for i in 0..timeout_secs {
277            sleep(Duration::from_secs(1)).await;
278
279            match self.http_client.get(&url).send().await {
280                Ok(resp) if resp.status().is_success() => {
281                    tracing::info!(
282                        "Health check passed for {} slot {} after {}s",
283                        app.config.name,
284                        slot,
285                        i + 1
286                    );
287                    return Ok(true);
288                }
289                Ok(_) => {
290                    tracing::debug!(
291                        "Health check response for {} slot {}: {}",
292                        app.config.name,
293                        slot,
294                        i + 1
295                    );
296                }
297                Err(e) => {
298                    tracing::debug!(
299                        "Health check failed for {} slot {}: {} ({})",
300                        app.config.name,
301                        slot,
302                        e,
303                        i + 1
304                    );
305                }
306            }
307        }
308
309        Ok(false)
310    }
311
312    pub async fn switch_traffic(&self, app: &AppInfo, new_slot: &str) -> Result<()> {
313        tracing::info!(
314            "Switching traffic for {} to slot {}",
315            app.config.name,
316            new_slot
317        );
318
319        let old_slot = if new_slot == "blue" { "green" } else { "blue" };
320        self.stop_instance(app, old_slot).await?;
321
322        Ok(())
323    }
324
325    pub async fn rollback(&self, app: &AppInfo) -> Result<()> {
326        let target_slot = if app.current_slot == "blue" {
327            "green"
328        } else {
329            "blue"
330        };
331        self.deploy(app, target_slot).await?;
332        Ok(())
333    }
334
335    pub async fn get_deployment_log(&self, app_name: &str, slot: &str) -> Result<String> {
336        validate_path_component(app_name, "App name")?;
337        if slot != "blue" && slot != "green" {
338            anyhow::bail!("Invalid slot name: {:?}", slot);
339        }
340        let log_path = PathBuf::from(format!("run/logs/{}/{}.log", app_name, slot));
341        if log_path.exists() {
342            Ok(std::fs::read_to_string(&log_path)?)
343        } else {
344            Ok(String::new())
345        }
346    }
347}
348
349fn resolve_user(user: &str) -> Result<u32> {
350    use std::ffi::CString;
351    let c_user = CString::new(user)?;
352    let passwd = unsafe { libc::getpwnam(c_user.as_ptr()) };
353    if passwd.is_null() {
354        anyhow::bail!("User '{}' not found", user);
355    }
356    Ok(unsafe { (*passwd).pw_uid })
357}
358
359fn resolve_group(group: &str) -> Result<u32> {
360    use std::ffi::CString;
361    let c_group = CString::new(group)?;
362    let grp = unsafe { libc::getgrnam(c_group.as_ptr()) };
363    if grp.is_null() {
364        anyhow::bail!("Group '{}' not found", group);
365    }
366    Ok(unsafe { (*grp).gr_gid })
367}