1use base64::engine::{Engine, general_purpose::STANDARD};
8use somatize_core::cache::CacheKey;
9use somatize_core::error::{Result, SomaError};
10use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
11use somatize_core::value::Value;
12use std::collections::HashMap;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
15use std::sync::{Arc, Mutex};
16
17const DAEMON_SCRIPT: &str = r#"
19import json, sys, base64, cloudpickle, io, pickle
20
21filters = {}
22
23def _encode(obj):
24 """Encode a Python object to JSON-safe format."""
25 if obj is None:
26 return None
27 if isinstance(obj, (list, int, float, str, bool)):
28 return obj
29 if isinstance(obj, dict):
30 return {k: _encode(v) for k, v in obj.items()}
31 # Fall back to pickle + base64
32 return {"__pickle_b64__": base64.b64encode(pickle.dumps(obj)).decode()}
33
34def _decode(obj):
35 """Decode from JSON-safe format back to Python object."""
36 if obj is None:
37 return None
38 if isinstance(obj, dict):
39 if "__pickle_b64__" in obj:
40 return pickle.loads(base64.b64decode(obj["__pickle_b64__"]))
41 if "type" in obj and "data" in obj:
42 # Soma Value format
43 t, d = obj["type"], obj["data"]
44 if t == "Tensor":
45 return d.get("values", [])
46 if t == "Json":
47 return d
48 if t == "Empty":
49 return {}
50 if t == "Bytes":
51 return bytes(d)
52 return {k: _decode(v) for k, v in obj.items()}
53 if isinstance(obj, list):
54 return [_decode(v) for v in obj]
55 return obj
56
57for line in sys.stdin:
58 line = line.strip()
59 if not line:
60 continue
61 try:
62 cmd = json.loads(line)
63 except json.JSONDecodeError as e:
64 print(json.dumps({"ok": False, "error": f"invalid JSON: {e}"}), flush=True)
65 continue
66
67 try:
68 action = cmd.get("cmd", "")
69
70 if action == "LOAD":
71 for f in cmd["filters"]:
72 obj = cloudpickle.loads(base64.b64decode(f["pickle_b64"]))
73 filters[f["id"]] = {"obj": obj, "trainable": f.get("trainable", True)}
74 print(json.dumps({"ok": True}), flush=True)
75
76 elif action == "FIT":
77 f = filters[cmd["node_id"]]["obj"]
78 data = _decode(cmd.get("data"))
79 y = _decode(cmd.get("y"))
80 result = f.fit(data, y)
81 print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
82
83 elif action == "FORWARD":
84 f = filters[cmd["node_id"]]["obj"]
85 data = _decode(cmd.get("data"))
86 state = _decode(cmd.get("state", {}))
87 result = f.forward(data, state)
88 print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
89
90 elif action == "COMPOSITE_FORWARD":
91 node_ids = cmd["node_ids"]
92 data = _decode(cmd.get("data"))
93 try:
94 import torch
95 if isinstance(data, list):
96 x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
97 else:
98 x = data
99 except ImportError:
100 x = data
101
102 out = x
103 for nid in node_ids:
104 f = filters[nid]["obj"]
105 state = _decode(cmd.get("states", {}).get(nid, {}))
106 out = f.forward(out, state)
107
108 result = out.detach().tolist() if hasattr(out, 'detach') else out
109 print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
110
111 elif action == "COMPOSITE_FIT":
112 node_ids = cmd["node_ids"]
113 data = _decode(cmd.get("data"))
114 y = _decode(cmd.get("y"))
115
116 try:
117 import torch
118 if isinstance(data, list):
119 x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
120 else:
121 x = data
122 except ImportError:
123 x = data
124
125 out = x
126 for nid in node_ids:
127 f = filters[nid]["obj"]
128 out = f.forward(out, {})
129
130 # Backward
131 if y is not None and hasattr(out, 'backward'):
132 last = filters[node_ids[-1]]["obj"]
133 try:
134 import torch
135 if isinstance(y, list):
136 y_t = torch.tensor(y, dtype=torch.float32)
137 else:
138 y_t = y
139 if hasattr(last, 'loss_fn'):
140 loss = last.loss_fn(out, y_t)
141 else:
142 loss = torch.nn.functional.mse_loss(out, y_t)
143 loss.backward()
144 for nid in node_ids:
145 f = filters[nid]["obj"]
146 if hasattr(f, 'optimizer'):
147 f.optimizer.step()
148 f.optimizer.zero_grad()
149 except Exception:
150 pass
151
152 states = {}
153 for nid in node_ids:
154 f = filters[nid]["obj"]
155 if hasattr(f, 'state_dict'):
156 buf = io.BytesIO()
157 try:
158 import torch
159 torch.save(f.state_dict(), buf)
160 except ImportError:
161 buf.write(cloudpickle.dumps(f))
162 states[nid] = base64.b64encode(buf.getvalue()).decode()
163
164 result = out.detach().tolist() if hasattr(out, 'detach') else out
165 print(json.dumps({"ok": True, "result": _encode(result), "states": states}), flush=True)
166
167 elif action == "GET_STATE":
168 nid = cmd["node_id"]
169 f = filters[nid]["obj"]
170 buf = io.BytesIO()
171 if hasattr(f, 'state_dict'):
172 try:
173 import torch
174 torch.save(f.state_dict(), buf)
175 except ImportError:
176 buf.write(cloudpickle.dumps(f))
177 else:
178 buf.write(cloudpickle.dumps(f))
179 state_b64 = base64.b64encode(buf.getvalue()).decode()
180 print(json.dumps({"ok": True, "state_b64": state_b64}), flush=True)
181
182 elif action == "SET_STATE":
183 nid = cmd["node_id"]
184 f = filters[nid]["obj"]
185 state_bytes = base64.b64decode(cmd["state_b64"])
186 buf = io.BytesIO(state_bytes)
187 if hasattr(f, 'load_state_dict'):
188 try:
189 import torch
190 f.load_state_dict(torch.load(buf, weights_only=True))
191 except ImportError:
192 filters[nid]["obj"] = cloudpickle.loads(buf.read())
193 else:
194 filters[nid]["obj"] = cloudpickle.loads(buf.read())
195 print(json.dumps({"ok": True}), flush=True)
196
197 elif action == "GET_GRADIENTS":
198 nid = cmd["node_id"]
199 f = filters[nid]["obj"]
200 buf = io.BytesIO()
201 if hasattr(f, 'parameters'):
202 try:
203 import torch
204 grads = {name: p.grad.clone() for name, p in f.named_parameters() if p.grad is not None}
205 torch.save(grads, buf)
206 except ImportError:
207 pass
208 grad_b64 = base64.b64encode(buf.getvalue()).decode()
209 print(json.dumps({"ok": True, "gradients_b64": grad_b64}), flush=True)
210
211 elif action == "APPLY_GRADIENTS":
212 nid = cmd["node_id"]
213 f = filters[nid]["obj"]
214 grad_bytes = base64.b64decode(cmd["gradients_b64"])
215 if hasattr(f, 'named_parameters'):
216 try:
217 import torch
218 buf = io.BytesIO(grad_bytes)
219 grads = torch.load(buf, weights_only=True)
220 for name, p in f.named_parameters():
221 if name in grads:
222 p.grad = grads[name]
223 if hasattr(f, 'optimizer'):
224 f.optimizer.step()
225 except ImportError:
226 pass
227 print(json.dumps({"ok": True}), flush=True)
228
229 elif action == "SHUTDOWN":
230 print(json.dumps({"ok": True}), flush=True)
231 break
232
233 else:
234 print(json.dumps({"ok": False, "error": f"unknown command: {action}"}), flush=True)
235
236 except Exception as e:
237 import traceback
238 tb = traceback.format_exc()
239 print(json.dumps({"ok": False, "error": str(e), "traceback": tb}), flush=True)
240"#;
241
242pub struct PythonProcess {
244 child: Child,
245 stdin: BufWriter<ChildStdin>,
246 stdout: BufReader<ChildStdout>,
247 node_ids: Vec<String>,
248}
249
250impl PythonProcess {
251 pub fn spawn(
253 python_path: &str,
254 filters: &[(String, Vec<u8>, bool)], ) -> Result<Self> {
256 let mut child = Command::new(python_path)
257 .args(["-c", DAEMON_SCRIPT])
258 .stdin(Stdio::piped())
259 .stdout(Stdio::piped())
260 .stderr(Stdio::inherit()) .spawn()
262 .map_err(|e| SomaError::Other(format!("failed to spawn python: {e}")))?;
263
264 let stdin = BufWriter::new(
265 child
266 .stdin
267 .take()
268 .ok_or_else(|| SomaError::Other("no stdin".into()))?,
269 );
270 let stdout = BufReader::new(
271 child
272 .stdout
273 .take()
274 .ok_or_else(|| SomaError::Other("no stdout".into()))?,
275 );
276
277 let node_ids: Vec<String> = filters.iter().map(|(id, _, _)| id.clone()).collect();
278
279 let mut proc = Self {
280 child,
281 stdin,
282 stdout,
283 node_ids,
284 };
285
286 let filter_specs: Vec<serde_json::Value> = filters
288 .iter()
289 .map(|(id, pickled, trainable)| {
290 serde_json::json!({
291 "id": id,
292 "pickle_b64": STANDARD.encode(pickled),
293 "trainable": trainable,
294 })
295 })
296 .collect();
297
298 let resp = proc.send(serde_json::json!({
299 "cmd": "LOAD",
300 "filters": filter_specs,
301 }))?;
302
303 if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
304 let error = resp
305 .get("error")
306 .and_then(|e| e.as_str())
307 .unwrap_or("unknown error");
308 return Err(SomaError::Other(format!("LOAD failed: {error}")));
309 }
310
311 Ok(proc)
312 }
313
314 fn send(&mut self, cmd: serde_json::Value) -> Result<serde_json::Value> {
316 let action = cmd
317 .get("cmd")
318 .and_then(|c| c.as_str())
319 .unwrap_or("?")
320 .to_string();
321 let node_id = cmd
322 .get("node_id")
323 .and_then(|n| n.as_str())
324 .unwrap_or("")
325 .to_string();
326
327 tracing::debug!(action = %action, node_id = %node_id, "→ Python");
328 let start = std::time::Instant::now();
329
330 let line = serde_json::to_string(&cmd)
331 .map_err(|e| SomaError::Other(format!("serialize cmd: {e}")))?;
332
333 writeln!(self.stdin, "{line}")
334 .map_err(|e| SomaError::Other(format!("write to python stdin: {e}")))?;
335 self.stdin
336 .flush()
337 .map_err(|e| SomaError::Other(format!("flush stdin: {e}")))?;
338
339 let mut response = String::new();
340 self.stdout
341 .read_line(&mut response)
342 .map_err(|e| SomaError::Other(format!("read from python stdout: {e}")))?;
343
344 let duration_ms = start.elapsed().as_millis();
345
346 if response.is_empty() {
347 tracing::error!(action = %action, "Python process closed stdout (crashed?)");
348 return Err(SomaError::Other(
349 "python process closed stdout (crashed?)".into(),
350 ));
351 }
352
353 let parsed: serde_json::Value = serde_json::from_str(&response).map_err(|e| {
354 SomaError::Other(format!("parse python response: {e}\nraw: {response}"))
355 })?;
356
357 let ok = parsed.get("ok") == Some(&serde_json::Value::Bool(true));
358 if ok {
359 tracing::debug!(action = %action, node_id = %node_id, duration_ms, "← Python OK");
360 } else {
361 let error = parsed.get("error").and_then(|e| e.as_str()).unwrap_or("?");
362 let traceback = parsed
363 .get("traceback")
364 .and_then(|t| t.as_str())
365 .unwrap_or("");
366 tracing::error!(action = %action, node_id = %node_id, error, "Python filter error");
367 if !traceback.is_empty() {
368 tracing::error!("Python traceback:\n{traceback}");
369 }
370 }
371
372 Ok(parsed)
373 }
374
375 fn response_to_value(resp: &serde_json::Value) -> Result<Value> {
377 if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
378 let error = resp
379 .get("error")
380 .and_then(|e| e.as_str())
381 .unwrap_or("unknown error");
382 let traceback = resp.get("traceback").and_then(|t| t.as_str()).unwrap_or("");
383 return Err(SomaError::Other(format!(
384 "Python error: {error}\n{traceback}"
385 )));
386 }
387
388 if let Some(result) = resp.get("result") {
389 return Self::json_to_value(result);
390 }
391
392 Ok(Value::Empty)
393 }
394
395 fn json_to_value(v: &serde_json::Value) -> Result<Value> {
397 if v.is_null() {
398 return Ok(Value::Empty);
399 }
400 if let Some(arr) = v.as_array() {
401 let values: Vec<f64> = arr.iter().filter_map(|x| x.as_f64()).collect();
402 if values.len() == arr.len() && !values.is_empty() {
403 return Ok(Value::tensor(values.clone(), vec![values.len()]));
404 }
405 if let Some(first) = arr.first()
407 && first.is_array()
408 {
409 let rows = arr.len();
410 let cols = first.as_array().map(|a| a.len()).unwrap_or(0);
411 let flat: Vec<f64> = arr
412 .iter()
413 .filter_map(|row| row.as_array())
414 .flat_map(|row| row.iter().filter_map(|x| x.as_f64()))
415 .collect();
416 if flat.len() == rows * cols {
417 return Ok(Value::tensor(flat, vec![rows, cols]));
418 }
419 }
420 }
421 Ok(Value::Json(v.clone()))
422 }
423
424 fn value_to_json(v: &Value) -> serde_json::Value {
426 serde_json::to_value(v).unwrap_or(serde_json::Value::Null)
427 }
428
429 pub fn fit(&mut self, node_id: &str, data: &Value, y: Option<&Value>) -> Result<Value> {
432 let mut cmd = serde_json::json!({
433 "cmd": "FIT",
434 "node_id": node_id,
435 "data": Self::value_to_json(data),
436 });
437 if let Some(y_val) = y {
438 cmd["y"] = Self::value_to_json(y_val);
439 }
440 let resp = self.send(cmd)?;
441 Self::response_to_value(&resp)
442 }
443
444 pub fn forward(&mut self, node_id: &str, data: &Value, state: &Value) -> Result<Value> {
445 let resp = self.send(serde_json::json!({
446 "cmd": "FORWARD",
447 "node_id": node_id,
448 "data": Self::value_to_json(data),
449 "state": Self::value_to_json(state),
450 }))?;
451 Self::response_to_value(&resp)
452 }
453
454 pub fn composite_fit(
455 &mut self,
456 node_ids: &[String],
457 data: &Value,
458 y: Option<&Value>,
459 ) -> Result<(Value, HashMap<String, Value>)> {
460 let mut cmd = serde_json::json!({
461 "cmd": "COMPOSITE_FIT",
462 "node_ids": node_ids,
463 "data": Self::value_to_json(data),
464 });
465 if let Some(y_val) = y {
466 cmd["y"] = Self::value_to_json(y_val);
467 }
468 let resp = self.send(cmd)?;
469 let output = Self::response_to_value(&resp)?;
470
471 let mut states = HashMap::new();
472 if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
473 for (id, b64) in state_map {
474 if let Some(s) = b64.as_str() {
475 let bytes = STANDARD
476 .decode(s)
477 .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
478 states.insert(id.clone(), Value::Bytes(bytes));
479 }
480 }
481 }
482 Ok((output, states))
483 }
484
485 pub fn composite_forward(&mut self, node_ids: &[String], data: &Value) -> Result<Value> {
486 let resp = self.send(serde_json::json!({
487 "cmd": "COMPOSITE_FORWARD",
488 "node_ids": node_ids,
489 "data": Self::value_to_json(data),
490 }))?;
491 Self::response_to_value(&resp)
492 }
493
494 pub fn get_state(&mut self, node_id: &str) -> Result<Value> {
495 let resp = self.send(serde_json::json!({"cmd": "GET_STATE", "node_id": node_id}))?;
496 if let Some(b64) = resp.get("state_b64").and_then(|s| s.as_str()) {
497 let bytes = STANDARD
498 .decode(b64)
499 .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
500 Ok(Value::Bytes(bytes))
501 } else {
502 Self::response_to_value(&resp)
503 }
504 }
505
506 pub fn set_state(&mut self, node_id: &str, state: &Value) -> Result<()> {
507 let b64 = match state {
508 Value::Bytes(b) => STANDARD.encode(b),
509 _ => return Err(SomaError::Other("set_state expects Value::Bytes".into())),
510 };
511 let resp = self
512 .send(serde_json::json!({"cmd": "SET_STATE", "node_id": node_id, "state_b64": b64}))?;
513 if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
514 let error = resp.get("error").and_then(|e| e.as_str()).unwrap_or("?");
515 return Err(SomaError::Other(format!("set_state: {error}")));
516 }
517 Ok(())
518 }
519
520 pub fn get_gradients(&mut self, node_id: &str) -> Result<Value> {
521 let resp = self.send(serde_json::json!({"cmd": "GET_GRADIENTS", "node_id": node_id}))?;
522 if let Some(b64) = resp.get("gradients_b64").and_then(|s| s.as_str()) {
523 let bytes = STANDARD
524 .decode(b64)
525 .map_err(|e| SomaError::Other(format!("decode gradients: {e}")))?;
526 Ok(Value::Bytes(bytes))
527 } else {
528 Ok(Value::Empty)
529 }
530 }
531
532 pub fn apply_gradients(&mut self, node_id: &str, gradients: &Value) -> Result<()> {
533 let b64 = match gradients {
534 Value::Bytes(b) => STANDARD.encode(b),
535 _ => {
536 return Err(SomaError::Other(
537 "apply_gradients expects Value::Bytes".into(),
538 ));
539 }
540 };
541 self.send(
542 serde_json::json!({"cmd": "APPLY_GRADIENTS", "node_id": node_id, "gradients_b64": b64}),
543 )?;
544 Ok(())
545 }
546
547 pub fn shutdown(&mut self) {
548 let _ = self.send(serde_json::json!({"cmd": "SHUTDOWN"}));
549 }
550
551 pub fn node_ids(&self) -> &[String] {
552 &self.node_ids
553 }
554}
555
556impl Drop for PythonProcess {
557 fn drop(&mut self) {
558 self.shutdown();
559 let _ = self.child.kill();
560 let _ = self.child.wait();
561 }
562}
563
564pub struct SubprocessFilter {
569 process: Arc<Mutex<PythonProcess>>,
570 node_id: String,
571 trainable: bool,
572}
573
574impl SubprocessFilter {
575 pub fn new(process: Arc<Mutex<PythonProcess>>, node_id: String, trainable: bool) -> Self {
576 Self {
577 process,
578 node_id,
579 trainable,
580 }
581 }
582}
583
584impl Filter for SubprocessFilter {
585 fn config_hash(&self) -> CacheKey {
586 CacheKey::from_parts(&[self.node_id.as_bytes()])
587 }
588
589 fn fit(&self, x: &Value, y: Option<&Value>) -> Result<Value> {
590 self.process
591 .lock()
592 .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
593 .fit(&self.node_id, x, y)
594 }
595
596 fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
597 self.process
598 .lock()
599 .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
600 .forward(&self.node_id, x, state)
601 }
602
603 fn meta(&self) -> FilterMeta {
604 FilterMeta {
605 name: self.node_id.clone(),
606 kind: if self.trainable {
607 FilterKind::Trainable
608 } else {
609 FilterKind::Stateless
610 },
611 cacheable: true,
612 differentiable: self.trainable,
613 stream_mode: StreamMode::FixedState,
614 distribution: somatize_core::filter::Distribution::Local,
615 input_schema: None,
616 output_schema: None,
617 }
618 }
619
620 fn as_any(&self) -> &dyn std::any::Any {
621 self
622 }
623
624 fn composite_fit(
625 &self,
626 node_ids: &[String],
627 x: &Value,
628 y: Option<&Value>,
629 ) -> Option<Result<(Value, HashMap<String, Value>)>> {
630 tracing::info!(nodes = ?node_ids, "Composite fit via subprocess");
631 Some(
632 self.process
633 .lock()
634 .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))
635 .and_then(|mut proc| proc.composite_fit(node_ids, x, y)),
636 )
637 }
638}