Skip to main content

somatize_worker/
python_process.rs

1//! Python subprocess — persistent daemon for filter execution.
2//!
3//! Spawns a Python child process that loads filters via cloudpickle
4//! and executes fit/forward commands via stdin/stdout JSON Lines.
5//! The GIL is completely isolated from the Rust process — no segfaults.
6
7use base64::engine::{Engine, general_purpose::STANDARD};
8use somatize_core::cache::CacheKey;
9use somatize_core::error::{Result, SomaError};
10use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
11use somatize_core::value::Value;
12use std::collections::HashMap;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
15use std::sync::{Arc, Mutex};
16
17/// The Python daemon script, embedded as a Rust string.
18const DAEMON_SCRIPT: &str = r#"
19import json, sys, base64, cloudpickle, io, pickle
20
21filters = {}
22
23def _encode(obj):
24    """Encode a Python object to JSON-safe format."""
25    if obj is None:
26        return None
27    if isinstance(obj, (list, int, float, str, bool)):
28        return obj
29    if isinstance(obj, dict):
30        return {k: _encode(v) for k, v in obj.items()}
31    # Fall back to pickle + base64
32    return {"__pickle_b64__": base64.b64encode(pickle.dumps(obj)).decode()}
33
34def _decode(obj):
35    """Decode from JSON-safe format back to Python object."""
36    if obj is None:
37        return None
38    if isinstance(obj, dict):
39        if "__pickle_b64__" in obj:
40            return pickle.loads(base64.b64decode(obj["__pickle_b64__"]))
41        if "type" in obj and "data" in obj:
42            # Soma Value format
43            t, d = obj["type"], obj["data"]
44            if t == "Tensor":
45                return d.get("values", [])
46            if t == "Json":
47                return d
48            if t == "Empty":
49                return {}
50            if t == "Bytes":
51                return bytes(d)
52        return {k: _decode(v) for k, v in obj.items()}
53    if isinstance(obj, list):
54        return [_decode(v) for v in obj]
55    return obj
56
57for line in sys.stdin:
58    line = line.strip()
59    if not line:
60        continue
61    try:
62        cmd = json.loads(line)
63    except json.JSONDecodeError as e:
64        print(json.dumps({"ok": False, "error": f"invalid JSON: {e}"}), flush=True)
65        continue
66
67    try:
68        action = cmd.get("cmd", "")
69
70        if action == "LOAD":
71            for f in cmd["filters"]:
72                obj = cloudpickle.loads(base64.b64decode(f["pickle_b64"]))
73                filters[f["id"]] = {"obj": obj, "trainable": f.get("trainable", True)}
74            print(json.dumps({"ok": True}), flush=True)
75
76        elif action == "FIT":
77            f = filters[cmd["node_id"]]["obj"]
78            data = _decode(cmd.get("data"))
79            y = _decode(cmd.get("y"))
80            result = f.fit(data, y)
81            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
82
83        elif action == "FORWARD":
84            f = filters[cmd["node_id"]]["obj"]
85            data = _decode(cmd.get("data"))
86            state = _decode(cmd.get("state", {}))
87            result = f.forward(data, state)
88            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
89
90        elif action == "COMPOSITE_FORWARD":
91            node_ids = cmd["node_ids"]
92            data = _decode(cmd.get("data"))
93            try:
94                import torch
95                if isinstance(data, list):
96                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
97                else:
98                    x = data
99            except ImportError:
100                x = data
101
102            out = x
103            for nid in node_ids:
104                f = filters[nid]["obj"]
105                state = _decode(cmd.get("states", {}).get(nid, {}))
106                out = f.forward(out, state)
107
108            result = out.detach().tolist() if hasattr(out, 'detach') else out
109            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
110
111        elif action == "COMPOSITE_FIT":
112            node_ids = cmd["node_ids"]
113            data = _decode(cmd.get("data"))
114            y = _decode(cmd.get("y"))
115
116            try:
117                import torch
118                if isinstance(data, list):
119                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
120                else:
121                    x = data
122            except ImportError:
123                x = data
124
125            out = x
126            for nid in node_ids:
127                f = filters[nid]["obj"]
128                out = f.forward(out, {})
129
130            # Backward
131            if y is not None and hasattr(out, 'backward'):
132                last = filters[node_ids[-1]]["obj"]
133                try:
134                    import torch
135                    if isinstance(y, list):
136                        y_t = torch.tensor(y, dtype=torch.float32)
137                    else:
138                        y_t = y
139                    if hasattr(last, 'loss_fn'):
140                        loss = last.loss_fn(out, y_t)
141                    else:
142                        loss = torch.nn.functional.mse_loss(out, y_t)
143                    loss.backward()
144                    for nid in node_ids:
145                        f = filters[nid]["obj"]
146                        if hasattr(f, 'optimizer'):
147                            f.optimizer.step()
148                            f.optimizer.zero_grad()
149                except Exception:
150                    pass
151
152            states = {}
153            for nid in node_ids:
154                f = filters[nid]["obj"]
155                if hasattr(f, 'state_dict'):
156                    buf = io.BytesIO()
157                    try:
158                        import torch
159                        torch.save(f.state_dict(), buf)
160                    except ImportError:
161                        buf.write(cloudpickle.dumps(f))
162                    states[nid] = base64.b64encode(buf.getvalue()).decode()
163
164            result = out.detach().tolist() if hasattr(out, 'detach') else out
165            print(json.dumps({"ok": True, "result": _encode(result), "states": states}), flush=True)
166
167        elif action == "GET_STATE":
168            nid = cmd["node_id"]
169            f = filters[nid]["obj"]
170            buf = io.BytesIO()
171            if hasattr(f, 'state_dict'):
172                try:
173                    import torch
174                    torch.save(f.state_dict(), buf)
175                except ImportError:
176                    buf.write(cloudpickle.dumps(f))
177            else:
178                buf.write(cloudpickle.dumps(f))
179            state_b64 = base64.b64encode(buf.getvalue()).decode()
180            print(json.dumps({"ok": True, "state_b64": state_b64}), flush=True)
181
182        elif action == "SET_STATE":
183            nid = cmd["node_id"]
184            f = filters[nid]["obj"]
185            state_bytes = base64.b64decode(cmd["state_b64"])
186            buf = io.BytesIO(state_bytes)
187            if hasattr(f, 'load_state_dict'):
188                try:
189                    import torch
190                    f.load_state_dict(torch.load(buf, weights_only=True))
191                except ImportError:
192                    filters[nid]["obj"] = cloudpickle.loads(buf.read())
193            else:
194                filters[nid]["obj"] = cloudpickle.loads(buf.read())
195            print(json.dumps({"ok": True}), flush=True)
196
197        elif action == "GET_GRADIENTS":
198            nid = cmd["node_id"]
199            f = filters[nid]["obj"]
200            buf = io.BytesIO()
201            if hasattr(f, 'parameters'):
202                try:
203                    import torch
204                    grads = {name: p.grad.clone() for name, p in f.named_parameters() if p.grad is not None}
205                    torch.save(grads, buf)
206                except ImportError:
207                    pass
208            grad_b64 = base64.b64encode(buf.getvalue()).decode()
209            print(json.dumps({"ok": True, "gradients_b64": grad_b64}), flush=True)
210
211        elif action == "APPLY_GRADIENTS":
212            nid = cmd["node_id"]
213            f = filters[nid]["obj"]
214            grad_bytes = base64.b64decode(cmd["gradients_b64"])
215            if hasattr(f, 'named_parameters'):
216                try:
217                    import torch
218                    buf = io.BytesIO(grad_bytes)
219                    grads = torch.load(buf, weights_only=True)
220                    for name, p in f.named_parameters():
221                        if name in grads:
222                            p.grad = grads[name]
223                    if hasattr(f, 'optimizer'):
224                        f.optimizer.step()
225                except ImportError:
226                    pass
227            print(json.dumps({"ok": True}), flush=True)
228
229        elif action == "SHUTDOWN":
230            print(json.dumps({"ok": True}), flush=True)
231            break
232
233        else:
234            print(json.dumps({"ok": False, "error": f"unknown command: {action}"}), flush=True)
235
236    except Exception as e:
237        import traceback
238        tb = traceback.format_exc()
239        print(json.dumps({"ok": False, "error": str(e), "traceback": tb}), flush=True)
240"#;
241
242/// A persistent Python child process that executes filter commands.
243pub struct PythonProcess {
244    child: Child,
245    stdin: BufWriter<ChildStdin>,
246    stdout: BufReader<ChildStdout>,
247    node_ids: Vec<String>,
248}
249
250impl PythonProcess {
251    /// Spawn a Python daemon and load filters into it.
252    pub fn spawn(
253        python_path: &str,
254        filters: &[(String, Vec<u8>, bool)], // (node_id, pickled_bytes, trainable)
255    ) -> Result<Self> {
256        let mut child = Command::new(python_path)
257            .args(["-c", DAEMON_SCRIPT])
258            .stdin(Stdio::piped())
259            .stdout(Stdio::piped())
260            .stderr(Stdio::inherit()) // Python stderr → worker stderr (for logs/tracing)
261            .spawn()
262            .map_err(|e| SomaError::Other(format!("failed to spawn python: {e}")))?;
263
264        let stdin = BufWriter::new(
265            child
266                .stdin
267                .take()
268                .ok_or_else(|| SomaError::Other("no stdin".into()))?,
269        );
270        let stdout = BufReader::new(
271            child
272                .stdout
273                .take()
274                .ok_or_else(|| SomaError::Other("no stdout".into()))?,
275        );
276
277        let node_ids: Vec<String> = filters.iter().map(|(id, _, _)| id.clone()).collect();
278
279        let mut proc = Self {
280            child,
281            stdin,
282            stdout,
283            node_ids,
284        };
285
286        // Send LOAD command with all filters
287        let filter_specs: Vec<serde_json::Value> = filters
288            .iter()
289            .map(|(id, pickled, trainable)| {
290                serde_json::json!({
291                    "id": id,
292                    "pickle_b64": STANDARD.encode(pickled),
293                    "trainable": trainable,
294                })
295            })
296            .collect();
297
298        let resp = proc.send(serde_json::json!({
299            "cmd": "LOAD",
300            "filters": filter_specs,
301        }))?;
302
303        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
304            let error = resp
305                .get("error")
306                .and_then(|e| e.as_str())
307                .unwrap_or("unknown error");
308            return Err(SomaError::Other(format!("LOAD failed: {error}")));
309        }
310
311        Ok(proc)
312    }
313
314    /// Send a JSON command and read the JSON response.
315    fn send(&mut self, cmd: serde_json::Value) -> Result<serde_json::Value> {
316        let action = cmd
317            .get("cmd")
318            .and_then(|c| c.as_str())
319            .unwrap_or("?")
320            .to_string();
321        let node_id = cmd
322            .get("node_id")
323            .and_then(|n| n.as_str())
324            .unwrap_or("")
325            .to_string();
326
327        tracing::debug!(action = %action, node_id = %node_id, "→ Python");
328        let start = std::time::Instant::now();
329
330        let line = serde_json::to_string(&cmd)
331            .map_err(|e| SomaError::Other(format!("serialize cmd: {e}")))?;
332
333        writeln!(self.stdin, "{line}")
334            .map_err(|e| SomaError::Other(format!("write to python stdin: {e}")))?;
335        self.stdin
336            .flush()
337            .map_err(|e| SomaError::Other(format!("flush stdin: {e}")))?;
338
339        let mut response = String::new();
340        self.stdout
341            .read_line(&mut response)
342            .map_err(|e| SomaError::Other(format!("read from python stdout: {e}")))?;
343
344        let duration_ms = start.elapsed().as_millis();
345
346        if response.is_empty() {
347            tracing::error!(action = %action, "Python process closed stdout (crashed?)");
348            return Err(SomaError::Other(
349                "python process closed stdout (crashed?)".into(),
350            ));
351        }
352
353        let parsed: serde_json::Value = serde_json::from_str(&response).map_err(|e| {
354            SomaError::Other(format!("parse python response: {e}\nraw: {response}"))
355        })?;
356
357        let ok = parsed.get("ok") == Some(&serde_json::Value::Bool(true));
358        if ok {
359            tracing::debug!(action = %action, node_id = %node_id, duration_ms, "← Python OK");
360        } else {
361            let error = parsed.get("error").and_then(|e| e.as_str()).unwrap_or("?");
362            let traceback = parsed
363                .get("traceback")
364                .and_then(|t| t.as_str())
365                .unwrap_or("");
366            tracing::error!(action = %action, node_id = %node_id, error, "Python filter error");
367            if !traceback.is_empty() {
368                tracing::error!("Python traceback:\n{traceback}");
369            }
370        }
371
372        Ok(parsed)
373    }
374
375    /// Convert a response to a Value, handling errors.
376    fn response_to_value(resp: &serde_json::Value) -> Result<Value> {
377        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
378            let error = resp
379                .get("error")
380                .and_then(|e| e.as_str())
381                .unwrap_or("unknown error");
382            let traceback = resp.get("traceback").and_then(|t| t.as_str()).unwrap_or("");
383            return Err(SomaError::Other(format!(
384                "Python error: {error}\n{traceback}"
385            )));
386        }
387
388        if let Some(result) = resp.get("result") {
389            return Self::json_to_value(result);
390        }
391
392        Ok(Value::Empty)
393    }
394
395    /// Convert a JSON value to a Soma Value.
396    fn json_to_value(v: &serde_json::Value) -> Result<Value> {
397        if v.is_null() {
398            return Ok(Value::Empty);
399        }
400        if let Some(arr) = v.as_array() {
401            let values: Vec<f64> = arr.iter().filter_map(|x| x.as_f64()).collect();
402            if values.len() == arr.len() && !values.is_empty() {
403                return Ok(Value::tensor(values.clone(), vec![values.len()]));
404            }
405            // Could be nested array
406            if let Some(first) = arr.first()
407                && first.is_array()
408            {
409                let rows = arr.len();
410                let cols = first.as_array().map(|a| a.len()).unwrap_or(0);
411                let flat: Vec<f64> = arr
412                    .iter()
413                    .filter_map(|row| row.as_array())
414                    .flat_map(|row| row.iter().filter_map(|x| x.as_f64()))
415                    .collect();
416                if flat.len() == rows * cols {
417                    return Ok(Value::tensor(flat, vec![rows, cols]));
418                }
419            }
420        }
421        Ok(Value::Json(v.clone()))
422    }
423
424    /// Encode a Value to JSON for the Python process.
425    fn value_to_json(v: &Value) -> serde_json::Value {
426        serde_json::to_value(v).unwrap_or(serde_json::Value::Null)
427    }
428
429    // ── Public API ──
430
431    pub fn fit(&mut self, node_id: &str, data: &Value, y: Option<&Value>) -> Result<Value> {
432        let mut cmd = serde_json::json!({
433            "cmd": "FIT",
434            "node_id": node_id,
435            "data": Self::value_to_json(data),
436        });
437        if let Some(y_val) = y {
438            cmd["y"] = Self::value_to_json(y_val);
439        }
440        let resp = self.send(cmd)?;
441        Self::response_to_value(&resp)
442    }
443
444    pub fn forward(&mut self, node_id: &str, data: &Value, state: &Value) -> Result<Value> {
445        let resp = self.send(serde_json::json!({
446            "cmd": "FORWARD",
447            "node_id": node_id,
448            "data": Self::value_to_json(data),
449            "state": Self::value_to_json(state),
450        }))?;
451        Self::response_to_value(&resp)
452    }
453
454    pub fn composite_fit(
455        &mut self,
456        node_ids: &[String],
457        data: &Value,
458        y: Option<&Value>,
459    ) -> Result<(Value, HashMap<String, Value>)> {
460        let mut cmd = serde_json::json!({
461            "cmd": "COMPOSITE_FIT",
462            "node_ids": node_ids,
463            "data": Self::value_to_json(data),
464        });
465        if let Some(y_val) = y {
466            cmd["y"] = Self::value_to_json(y_val);
467        }
468        let resp = self.send(cmd)?;
469        let output = Self::response_to_value(&resp)?;
470
471        let mut states = HashMap::new();
472        if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
473            for (id, b64) in state_map {
474                if let Some(s) = b64.as_str() {
475                    let bytes = STANDARD
476                        .decode(s)
477                        .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
478                    states.insert(id.clone(), Value::Bytes(bytes));
479                }
480            }
481        }
482        Ok((output, states))
483    }
484
485    pub fn composite_forward(&mut self, node_ids: &[String], data: &Value) -> Result<Value> {
486        let resp = self.send(serde_json::json!({
487            "cmd": "COMPOSITE_FORWARD",
488            "node_ids": node_ids,
489            "data": Self::value_to_json(data),
490        }))?;
491        Self::response_to_value(&resp)
492    }
493
494    pub fn get_state(&mut self, node_id: &str) -> Result<Value> {
495        let resp = self.send(serde_json::json!({"cmd": "GET_STATE", "node_id": node_id}))?;
496        if let Some(b64) = resp.get("state_b64").and_then(|s| s.as_str()) {
497            let bytes = STANDARD
498                .decode(b64)
499                .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
500            Ok(Value::Bytes(bytes))
501        } else {
502            Self::response_to_value(&resp)
503        }
504    }
505
506    pub fn set_state(&mut self, node_id: &str, state: &Value) -> Result<()> {
507        let b64 = match state {
508            Value::Bytes(b) => STANDARD.encode(b),
509            _ => return Err(SomaError::Other("set_state expects Value::Bytes".into())),
510        };
511        let resp = self
512            .send(serde_json::json!({"cmd": "SET_STATE", "node_id": node_id, "state_b64": b64}))?;
513        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
514            let error = resp.get("error").and_then(|e| e.as_str()).unwrap_or("?");
515            return Err(SomaError::Other(format!("set_state: {error}")));
516        }
517        Ok(())
518    }
519
520    pub fn get_gradients(&mut self, node_id: &str) -> Result<Value> {
521        let resp = self.send(serde_json::json!({"cmd": "GET_GRADIENTS", "node_id": node_id}))?;
522        if let Some(b64) = resp.get("gradients_b64").and_then(|s| s.as_str()) {
523            let bytes = STANDARD
524                .decode(b64)
525                .map_err(|e| SomaError::Other(format!("decode gradients: {e}")))?;
526            Ok(Value::Bytes(bytes))
527        } else {
528            Ok(Value::Empty)
529        }
530    }
531
532    pub fn apply_gradients(&mut self, node_id: &str, gradients: &Value) -> Result<()> {
533        let b64 = match gradients {
534            Value::Bytes(b) => STANDARD.encode(b),
535            _ => {
536                return Err(SomaError::Other(
537                    "apply_gradients expects Value::Bytes".into(),
538                ));
539            }
540        };
541        self.send(
542            serde_json::json!({"cmd": "APPLY_GRADIENTS", "node_id": node_id, "gradients_b64": b64}),
543        )?;
544        Ok(())
545    }
546
547    pub fn shutdown(&mut self) {
548        let _ = self.send(serde_json::json!({"cmd": "SHUTDOWN"}));
549    }
550
551    pub fn node_ids(&self) -> &[String] {
552        &self.node_ids
553    }
554}
555
556impl Drop for PythonProcess {
557    fn drop(&mut self) {
558        self.shutdown();
559        let _ = self.child.kill();
560        let _ = self.child.wait();
561    }
562}
563
564// ── SubprocessFilter: implements Filter trait via PythonProcess ──
565
566/// A filter that delegates to a shared PythonProcess via stdin/stdout.
567/// Multiple SubprocessFilters can share the same process (Arc<Mutex>).
568pub struct SubprocessFilter {
569    process: Arc<Mutex<PythonProcess>>,
570    node_id: String,
571    trainable: bool,
572}
573
574impl SubprocessFilter {
575    pub fn new(process: Arc<Mutex<PythonProcess>>, node_id: String, trainable: bool) -> Self {
576        Self {
577            process,
578            node_id,
579            trainable,
580        }
581    }
582}
583
584impl Filter for SubprocessFilter {
585    fn config_hash(&self) -> CacheKey {
586        CacheKey::from_parts(&[self.node_id.as_bytes()])
587    }
588
589    fn fit(&self, x: &Value, y: Option<&Value>) -> Result<Value> {
590        self.process
591            .lock()
592            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
593            .fit(&self.node_id, x, y)
594    }
595
596    fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
597        self.process
598            .lock()
599            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
600            .forward(&self.node_id, x, state)
601    }
602
603    fn meta(&self) -> FilterMeta {
604        FilterMeta {
605            name: self.node_id.clone(),
606            kind: if self.trainable {
607                FilterKind::Trainable
608            } else {
609                FilterKind::Stateless
610            },
611            cacheable: true,
612            differentiable: self.trainable,
613            stream_mode: StreamMode::FixedState,
614            distribution: somatize_core::filter::Distribution::Local,
615            input_schema: None,
616            output_schema: None,
617        }
618    }
619
620    fn as_any(&self) -> &dyn std::any::Any {
621        self
622    }
623
624    fn composite_fit(
625        &self,
626        node_ids: &[String],
627        x: &Value,
628        y: Option<&Value>,
629    ) -> Option<Result<(Value, HashMap<String, Value>)>> {
630        tracing::info!(nodes = ?node_ids, "Composite fit via subprocess");
631        Some(
632            self.process
633                .lock()
634                .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))
635                .and_then(|mut proc| proc.composite_fit(node_ids, x, y)),
636        )
637    }
638}