Skip to main content

somatize_worker/
env_manager.rs

1//! Python environment manager: creates and maintains isolated venvs/conda envs
2//! per pipeline, with incremental dependency updates.
3
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9
10/// Environment type preference.
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12#[serde(rename_all = "snake_case")]
13pub enum EnvType {
14    #[default]
15    Venv,
16    Conda,
17}
18
19/// Lockfile: tracks what's installed in an environment.
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct EnvLockfile {
22    pub packages: HashMap<String, String>, // name → version
23    pub requirements_hash: String,
24    pub env_type: EnvType,
25    pub python_version: String,
26}
27
28/// Manages isolated Python environments for pipeline execution.
29pub struct EnvManager {
30    base_dir: PathBuf,
31    env_type: EnvType,
32}
33
34impl EnvManager {
35    pub fn new(base_dir: impl Into<PathBuf>, env_type: EnvType) -> Self {
36        let base = base_dir.into();
37        std::fs::create_dir_all(&base).ok();
38        Self {
39            base_dir: base,
40            env_type,
41        }
42    }
43
44    /// Get or create an environment for a pipeline.
45    /// Returns the path to the Python binary.
46    pub fn ensure_env(&self, pipeline_id: &str, requirements: &str) -> Result<PathBuf, String> {
47        let req_hash = Self::hash_requirements(requirements);
48        let env_dir = self.base_dir.join(format!("env-{pipeline_id}"));
49        let lockfile_path = env_dir.join("lockfile.json");
50
51        // Check if env exists and is up to date
52        if env_dir.exists()
53            && let Ok(lockfile) = self.read_lockfile(&lockfile_path)
54        {
55            if lockfile.requirements_hash == req_hash {
56                // Env is up to date, just return python path
57                tracing::info!("Reusing env for pipeline {pipeline_id} (hash match)");
58                return self.python_path(&env_dir);
59            }
60
61            // Requirements changed — do incremental update
62            tracing::info!("Updating env for pipeline {pipeline_id} (requirements changed)");
63            self.incremental_update(&env_dir, requirements, &lockfile)?;
64            self.write_lockfile(&lockfile_path, requirements, &req_hash)?;
65            return self.python_path(&env_dir);
66        }
67
68        // Create new environment
69        tracing::info!("Creating new env for pipeline {pipeline_id}");
70        self.create_env(&env_dir)?;
71        self.install_requirements(&env_dir, requirements)?;
72        self.write_lockfile(&lockfile_path, requirements, &req_hash)?;
73
74        self.python_path(&env_dir)
75    }
76
77    /// Remove unused environments older than max_age.
78    pub fn cleanup(&self, max_age: std::time::Duration) -> usize {
79        let mut removed = 0;
80        if let Ok(entries) = std::fs::read_dir(&self.base_dir) {
81            for entry in entries.flatten() {
82                if let Ok(meta) = entry.metadata()
83                    && let Ok(modified) = meta.modified()
84                    && modified.elapsed().unwrap_or_default() > max_age
85                {
86                    let _ = std::fs::remove_dir_all(entry.path());
87                    removed += 1;
88                }
89            }
90        }
91        removed
92    }
93
94    // ── Internal ──
95
96    fn create_env(&self, env_dir: &Path) -> Result<(), String> {
97        match self.env_type {
98            EnvType::Venv => {
99                let output = Command::new("python3")
100                    .args(["-m", "venv", &env_dir.to_string_lossy()])
101                    .output()
102                    .map_err(|e| format!("Failed to create venv: {e}"))?;
103                if !output.status.success() {
104                    return Err(format!(
105                        "venv creation failed: {}",
106                        String::from_utf8_lossy(&output.stderr)
107                    ));
108                }
109            }
110            EnvType::Conda => {
111                let output = Command::new("conda")
112                    .args([
113                        "create",
114                        "-p",
115                        &env_dir.to_string_lossy(),
116                        "python=3.11",
117                        "-y",
118                        "-q",
119                    ])
120                    .output()
121                    .map_err(|e| format!("Failed to create conda env: {e}"))?;
122                if !output.status.success() {
123                    return Err(format!(
124                        "conda create failed: {}",
125                        String::from_utf8_lossy(&output.stderr)
126                    ));
127                }
128            }
129        }
130        Ok(())
131    }
132
133    fn install_requirements(&self, env_dir: &Path, requirements: &str) -> Result<(), String> {
134        let pip = self.pip_path(env_dir);
135
136        // Write requirements to temp file
137        let req_file = env_dir.join("requirements.txt");
138        std::fs::write(&req_file, requirements)
139            .map_err(|e| format!("Failed to write requirements.txt: {e}"))?;
140
141        // Always ensure soma is installed
142        let _ = Command::new(&pip).args(["install", "soma"]).output();
143
144        let output = Command::new(&pip)
145            .args(["install", "-r", &req_file.to_string_lossy(), "-q"])
146            .output()
147            .map_err(|e| format!("pip install failed: {e}"))?;
148
149        if !output.status.success() {
150            return Err(format!(
151                "pip install failed:\n{}",
152                String::from_utf8_lossy(&output.stderr)
153            ));
154        }
155
156        Ok(())
157    }
158
159    fn incremental_update(
160        &self,
161        env_dir: &Path,
162        new_requirements: &str,
163        old_lockfile: &EnvLockfile,
164    ) -> Result<(), String> {
165        let new_packages = Self::parse_requirements(new_requirements);
166        let pip = self.pip_path(env_dir);
167
168        // Find packages to install/upgrade
169        let mut to_install = Vec::new();
170        for (name, version) in &new_packages {
171            match old_lockfile.packages.get(name) {
172                None => {
173                    // New package
174                    tracing::info!("  + {name}=={version}");
175                    to_install.push(format!("{name}=={version}"));
176                }
177                Some(old_ver) if old_ver != version => {
178                    // Version changed
179                    tracing::info!("  ↑ {name}: {old_ver} → {version}");
180                    to_install.push(format!("{name}=={version}"));
181                }
182                _ => {} // Same version, skip
183            }
184        }
185
186        // Find packages to remove
187        for name in old_lockfile.packages.keys() {
188            if !new_packages.contains_key(name) {
189                tracing::info!("  - {name}");
190                let _ = Command::new(&pip)
191                    .args(["uninstall", name, "-y", "-q"])
192                    .output();
193            }
194        }
195
196        // Install new/updated packages
197        if !to_install.is_empty() {
198            let output = Command::new(&pip)
199                .args(["install"])
200                .args(&to_install)
201                .arg("-q")
202                .output()
203                .map_err(|e| format!("pip install failed: {e}"))?;
204
205            if !output.status.success() {
206                return Err(format!(
207                    "pip install failed:\n{}",
208                    String::from_utf8_lossy(&output.stderr)
209                ));
210            }
211        }
212
213        Ok(())
214    }
215
216    fn python_path(&self, env_dir: &Path) -> Result<PathBuf, String> {
217        let path = env_dir.join("bin").join("python");
218        if path.exists() {
219            Ok(path)
220        } else {
221            Err(format!("Python not found at {}", path.display()))
222        }
223    }
224
225    fn pip_path(&self, env_dir: &Path) -> PathBuf {
226        env_dir.join("bin").join("pip")
227    }
228
229    fn hash_requirements(requirements: &str) -> String {
230        let mut hasher = Sha256::new();
231        // Normalize: sort lines, trim whitespace, ignore comments
232        let mut lines: Vec<&str> = requirements
233            .lines()
234            .map(|l| l.trim())
235            .filter(|l| !l.is_empty() && !l.starts_with('#'))
236            .collect();
237        lines.sort();
238        for line in &lines {
239            hasher.update(line.as_bytes());
240            hasher.update(b"\n");
241        }
242        hex::encode(hasher.finalize())
243    }
244
245    fn parse_requirements(requirements: &str) -> HashMap<String, String> {
246        let mut packages = HashMap::new();
247        for line in requirements.lines() {
248            let line = line.trim();
249            if line.is_empty() || line.starts_with('#') {
250                continue;
251            }
252            // Parse "package==version", "package>=version", "package"
253            let (name, version) = if let Some((n, v)) = line.split_once("==") {
254                (n.trim().to_lowercase(), v.trim().to_string())
255            } else if let Some((n, v)) = line.split_once(">=") {
256                (n.trim().to_lowercase(), format!(">={v}"))
257            } else if let Some((n, v)) = line.split_once("<=") {
258                (n.trim().to_lowercase(), format!("<={v}"))
259            } else {
260                (line.to_lowercase(), "latest".to_string())
261            };
262            packages.insert(name, version);
263        }
264        packages
265    }
266
267    fn read_lockfile(&self, path: &Path) -> Result<EnvLockfile, String> {
268        let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
269        serde_json::from_str(&content).map_err(|e| e.to_string())
270    }
271
272    fn write_lockfile(&self, path: &Path, requirements: &str, hash: &str) -> Result<(), String> {
273        let lockfile = EnvLockfile {
274            packages: Self::parse_requirements(requirements),
275            requirements_hash: hash.to_string(),
276            env_type: self.env_type.clone(),
277            python_version: "3.11".to_string(),
278        };
279        let json = serde_json::to_string_pretty(&lockfile).map_err(|e| e.to_string())?;
280        std::fs::write(path, json).map_err(|e| format!("Failed to write lockfile: {e}"))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn hash_requirements_stable() {
290        let r1 = "numpy==1.26\nscikit-learn==1.4\n";
291        let r2 = "scikit-learn==1.4\nnumpy==1.26\n"; // different order
292        assert_eq!(
293            EnvManager::hash_requirements(r1),
294            EnvManager::hash_requirements(r2)
295        );
296    }
297
298    #[test]
299    fn hash_requirements_ignores_comments() {
300        let r1 = "numpy==1.26\n# comment\nscikit-learn==1.4\n";
301        let r2 = "numpy==1.26\nscikit-learn==1.4\n";
302        assert_eq!(
303            EnvManager::hash_requirements(r1),
304            EnvManager::hash_requirements(r2)
305        );
306    }
307
308    #[test]
309    fn hash_changes_on_version_change() {
310        let r1 = "numpy==1.26\n";
311        let r2 = "numpy==1.27\n";
312        assert_ne!(
313            EnvManager::hash_requirements(r1),
314            EnvManager::hash_requirements(r2)
315        );
316    }
317
318    #[test]
319    fn parse_requirements_formats() {
320        let pkgs = EnvManager::parse_requirements("numpy==1.26\nsklearn>=1.4\npandas\n");
321        assert_eq!(pkgs["numpy"], "1.26");
322        assert_eq!(pkgs["sklearn"], ">=1.4");
323        assert_eq!(pkgs["pandas"], "latest");
324    }
325}