1use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12#[serde(rename_all = "snake_case")]
13pub enum EnvType {
14 #[default]
15 Venv,
16 Conda,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
21pub struct EnvLockfile {
22 pub packages: HashMap<String, String>, pub requirements_hash: String,
24 pub env_type: EnvType,
25 pub python_version: String,
26}
27
28pub struct EnvManager {
30 base_dir: PathBuf,
31 env_type: EnvType,
32}
33
34impl EnvManager {
35 pub fn new(base_dir: impl Into<PathBuf>, env_type: EnvType) -> Self {
36 let base = base_dir.into();
37 std::fs::create_dir_all(&base).ok();
38 Self {
39 base_dir: base,
40 env_type,
41 }
42 }
43
44 pub fn ensure_env(&self, pipeline_id: &str, requirements: &str) -> Result<PathBuf, String> {
47 let req_hash = Self::hash_requirements(requirements);
48 let env_dir = self.base_dir.join(format!("env-{pipeline_id}"));
49 let lockfile_path = env_dir.join("lockfile.json");
50
51 if env_dir.exists()
53 && let Ok(lockfile) = self.read_lockfile(&lockfile_path)
54 {
55 if lockfile.requirements_hash == req_hash {
56 tracing::info!("Reusing env for pipeline {pipeline_id} (hash match)");
58 return self.python_path(&env_dir);
59 }
60
61 tracing::info!("Updating env for pipeline {pipeline_id} (requirements changed)");
63 self.incremental_update(&env_dir, requirements, &lockfile)?;
64 self.write_lockfile(&lockfile_path, requirements, &req_hash)?;
65 return self.python_path(&env_dir);
66 }
67
68 tracing::info!("Creating new env for pipeline {pipeline_id}");
70 self.create_env(&env_dir)?;
71 self.install_requirements(&env_dir, requirements)?;
72 self.write_lockfile(&lockfile_path, requirements, &req_hash)?;
73
74 self.python_path(&env_dir)
75 }
76
77 pub fn cleanup(&self, max_age: std::time::Duration) -> usize {
79 let mut removed = 0;
80 if let Ok(entries) = std::fs::read_dir(&self.base_dir) {
81 for entry in entries.flatten() {
82 if let Ok(meta) = entry.metadata()
83 && let Ok(modified) = meta.modified()
84 && modified.elapsed().unwrap_or_default() > max_age
85 {
86 let _ = std::fs::remove_dir_all(entry.path());
87 removed += 1;
88 }
89 }
90 }
91 removed
92 }
93
94 fn create_env(&self, env_dir: &Path) -> Result<(), String> {
97 match self.env_type {
98 EnvType::Venv => {
99 let output = Command::new("python3")
100 .args(["-m", "venv", &env_dir.to_string_lossy()])
101 .output()
102 .map_err(|e| format!("Failed to create venv: {e}"))?;
103 if !output.status.success() {
104 return Err(format!(
105 "venv creation failed: {}",
106 String::from_utf8_lossy(&output.stderr)
107 ));
108 }
109 }
110 EnvType::Conda => {
111 let output = Command::new("conda")
112 .args([
113 "create",
114 "-p",
115 &env_dir.to_string_lossy(),
116 "python=3.11",
117 "-y",
118 "-q",
119 ])
120 .output()
121 .map_err(|e| format!("Failed to create conda env: {e}"))?;
122 if !output.status.success() {
123 return Err(format!(
124 "conda create failed: {}",
125 String::from_utf8_lossy(&output.stderr)
126 ));
127 }
128 }
129 }
130 Ok(())
131 }
132
133 fn install_requirements(&self, env_dir: &Path, requirements: &str) -> Result<(), String> {
134 let pip = self.pip_path(env_dir);
135
136 let req_file = env_dir.join("requirements.txt");
138 std::fs::write(&req_file, requirements)
139 .map_err(|e| format!("Failed to write requirements.txt: {e}"))?;
140
141 let _ = Command::new(&pip).args(["install", "soma"]).output();
143
144 let output = Command::new(&pip)
145 .args(["install", "-r", &req_file.to_string_lossy(), "-q"])
146 .output()
147 .map_err(|e| format!("pip install failed: {e}"))?;
148
149 if !output.status.success() {
150 return Err(format!(
151 "pip install failed:\n{}",
152 String::from_utf8_lossy(&output.stderr)
153 ));
154 }
155
156 Ok(())
157 }
158
159 fn incremental_update(
160 &self,
161 env_dir: &Path,
162 new_requirements: &str,
163 old_lockfile: &EnvLockfile,
164 ) -> Result<(), String> {
165 let new_packages = Self::parse_requirements(new_requirements);
166 let pip = self.pip_path(env_dir);
167
168 let mut to_install = Vec::new();
170 for (name, version) in &new_packages {
171 match old_lockfile.packages.get(name) {
172 None => {
173 tracing::info!(" + {name}=={version}");
175 to_install.push(format!("{name}=={version}"));
176 }
177 Some(old_ver) if old_ver != version => {
178 tracing::info!(" ↑ {name}: {old_ver} → {version}");
180 to_install.push(format!("{name}=={version}"));
181 }
182 _ => {} }
184 }
185
186 for name in old_lockfile.packages.keys() {
188 if !new_packages.contains_key(name) {
189 tracing::info!(" - {name}");
190 let _ = Command::new(&pip)
191 .args(["uninstall", name, "-y", "-q"])
192 .output();
193 }
194 }
195
196 if !to_install.is_empty() {
198 let output = Command::new(&pip)
199 .args(["install"])
200 .args(&to_install)
201 .arg("-q")
202 .output()
203 .map_err(|e| format!("pip install failed: {e}"))?;
204
205 if !output.status.success() {
206 return Err(format!(
207 "pip install failed:\n{}",
208 String::from_utf8_lossy(&output.stderr)
209 ));
210 }
211 }
212
213 Ok(())
214 }
215
216 fn python_path(&self, env_dir: &Path) -> Result<PathBuf, String> {
217 let path = env_dir.join("bin").join("python");
218 if path.exists() {
219 Ok(path)
220 } else {
221 Err(format!("Python not found at {}", path.display()))
222 }
223 }
224
225 fn pip_path(&self, env_dir: &Path) -> PathBuf {
226 env_dir.join("bin").join("pip")
227 }
228
229 fn hash_requirements(requirements: &str) -> String {
230 let mut hasher = Sha256::new();
231 let mut lines: Vec<&str> = requirements
233 .lines()
234 .map(|l| l.trim())
235 .filter(|l| !l.is_empty() && !l.starts_with('#'))
236 .collect();
237 lines.sort();
238 for line in &lines {
239 hasher.update(line.as_bytes());
240 hasher.update(b"\n");
241 }
242 hex::encode(hasher.finalize())
243 }
244
245 fn parse_requirements(requirements: &str) -> HashMap<String, String> {
246 let mut packages = HashMap::new();
247 for line in requirements.lines() {
248 let line = line.trim();
249 if line.is_empty() || line.starts_with('#') {
250 continue;
251 }
252 let (name, version) = if let Some((n, v)) = line.split_once("==") {
254 (n.trim().to_lowercase(), v.trim().to_string())
255 } else if let Some((n, v)) = line.split_once(">=") {
256 (n.trim().to_lowercase(), format!(">={v}"))
257 } else if let Some((n, v)) = line.split_once("<=") {
258 (n.trim().to_lowercase(), format!("<={v}"))
259 } else {
260 (line.to_lowercase(), "latest".to_string())
261 };
262 packages.insert(name, version);
263 }
264 packages
265 }
266
267 fn read_lockfile(&self, path: &Path) -> Result<EnvLockfile, String> {
268 let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
269 serde_json::from_str(&content).map_err(|e| e.to_string())
270 }
271
272 fn write_lockfile(&self, path: &Path, requirements: &str, hash: &str) -> Result<(), String> {
273 let lockfile = EnvLockfile {
274 packages: Self::parse_requirements(requirements),
275 requirements_hash: hash.to_string(),
276 env_type: self.env_type.clone(),
277 python_version: "3.11".to_string(),
278 };
279 let json = serde_json::to_string_pretty(&lockfile).map_err(|e| e.to_string())?;
280 std::fs::write(path, json).map_err(|e| format!("Failed to write lockfile: {e}"))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn hash_requirements_stable() {
290 let r1 = "numpy==1.26\nscikit-learn==1.4\n";
291 let r2 = "scikit-learn==1.4\nnumpy==1.26\n"; assert_eq!(
293 EnvManager::hash_requirements(r1),
294 EnvManager::hash_requirements(r2)
295 );
296 }
297
298 #[test]
299 fn hash_requirements_ignores_comments() {
300 let r1 = "numpy==1.26\n# comment\nscikit-learn==1.4\n";
301 let r2 = "numpy==1.26\nscikit-learn==1.4\n";
302 assert_eq!(
303 EnvManager::hash_requirements(r1),
304 EnvManager::hash_requirements(r2)
305 );
306 }
307
308 #[test]
309 fn hash_changes_on_version_change() {
310 let r1 = "numpy==1.26\n";
311 let r2 = "numpy==1.27\n";
312 assert_ne!(
313 EnvManager::hash_requirements(r1),
314 EnvManager::hash_requirements(r2)
315 );
316 }
317
318 #[test]
319 fn parse_requirements_formats() {
320 let pkgs = EnvManager::parse_requirements("numpy==1.26\nsklearn>=1.4\npandas\n");
321 assert_eq!(pkgs["numpy"], "1.26");
322 assert_eq!(pkgs["sklearn"], ">=1.4");
323 assert_eq!(pkgs["pandas"], "latest");
324 }
325}