Skip to main content

somatize_worker/
python_process.rs

1//! Python subprocess — persistent daemon for filter execution.
2//!
3//! Spawns a Python child process that loads filters via cloudpickle
4//! and executes fit/forward commands via stdin/stdout JSON Lines.
5//! The GIL is completely isolated from the Rust process — no segfaults.
6
7use base64::engine::{Engine, general_purpose::STANDARD};
8use somatize_core::cache::CacheKey;
9use somatize_core::error::{Result, SomaError};
10use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
11use somatize_core::value::Value;
12use std::collections::HashMap;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
15use std::sync::{Arc, Mutex};
16
17/// The Python daemon script, embedded as a Rust string.
18const DAEMON_SCRIPT: &str = r#"
19import json, sys, base64, cloudpickle, io, pickle
20
21filters = {}
22
23def _encode(obj):
24    """Encode a Python object to JSON-safe format."""
25    if obj is None:
26        return None
27    if isinstance(obj, (list, int, float, str, bool)):
28        return obj
29    if isinstance(obj, dict):
30        return {k: _encode(v) for k, v in obj.items()}
31    # Fall back to pickle + base64
32    return {"__pickle_b64__": base64.b64encode(pickle.dumps(obj)).decode()}
33
34def _decode(obj):
35    """Decode from JSON-safe format back to Python object."""
36    if obj is None:
37        return None
38    if isinstance(obj, dict):
39        if "__pickle_b64__" in obj:
40            return pickle.loads(base64.b64decode(obj["__pickle_b64__"]))
41        if "type" in obj and "data" in obj:
42            # Soma Value format
43            t, d = obj["type"], obj["data"]
44            if t == "Tensor":
45                return d.get("values", [])
46            if t == "Json":
47                return d
48            if t == "Empty":
49                return {}
50            if t == "Bytes":
51                return bytes(d)
52            if t == "Object":
53                return pickle.loads(bytes(d))
54        return {k: _decode(v) for k, v in obj.items()}
55    if isinstance(obj, list):
56        return [_decode(v) for v in obj]
57    return obj
58
59for line in sys.stdin:
60    line = line.strip()
61    if not line:
62        continue
63    try:
64        cmd = json.loads(line)
65    except json.JSONDecodeError as e:
66        print(json.dumps({"ok": False, "error": f"invalid JSON: {e}"}), flush=True)
67        continue
68
69    try:
70        action = cmd.get("cmd", "")
71
72        if action == "LOAD":
73            for f in cmd["filters"]:
74                obj = cloudpickle.loads(base64.b64decode(f["pickle_b64"]))
75                filters[f["id"]] = {"obj": obj, "trainable": f.get("trainable", True)}
76            print(json.dumps({"ok": True}), flush=True)
77
78        elif action == "FIT":
79            f = filters[cmd["node_id"]]["obj"]
80            data = _decode(cmd.get("data"))
81            y = _decode(cmd.get("y"))
82            result = f.fit(data, y)
83            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
84
85        elif action == "FORWARD":
86            f = filters[cmd["node_id"]]["obj"]
87            data = _decode(cmd.get("data"))
88            state = _decode(cmd.get("state", {}))
89            result = f.forward(data, state)
90            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
91
92        elif action == "COMPOSITE_FORWARD":
93            node_ids = cmd["node_ids"]
94            data = _decode(cmd.get("data"))
95            try:
96                import torch
97                if isinstance(data, list):
98                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
99                else:
100                    x = data
101            except ImportError:
102                x = data
103
104            out = x
105            for nid in node_ids:
106                f = filters[nid]["obj"]
107                state = _decode(cmd.get("states", {}).get(nid, {}))
108                out = f.forward(out, state)
109
110            result = out.detach().tolist() if hasattr(out, 'detach') else out
111            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
112
113        elif action == "COMPOSITE_FIT":
114            node_ids = cmd["node_ids"]
115            data = _decode(cmd.get("data"))
116            y = _decode(cmd.get("y"))
117
118            # Step 1: fit each trainable filter to get states
119            fit_states = {}
120            fit_input = data
121            for nid in node_ids:
122                f = filters[nid]["obj"]
123                if filters[nid].get("trainable", True):
124                    state = f.fit(fit_input, y)
125                    fit_states[nid] = state
126                else:
127                    fit_states[nid] = {}
128                # Forward to propagate output to next filter
129                fit_input = f.forward(fit_input, fit_states[nid])
130
131            # Step 2: forward with autograd if torch available
132            try:
133                import torch
134                if isinstance(data, list):
135                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
136                else:
137                    x = data
138            except ImportError:
139                x = data
140
141            out = x
142            for nid in node_ids:
143                f = filters[nid]["obj"]
144                out = f.forward(out, fit_states.get(nid, {}))
145
146            # Backward
147            if y is not None and hasattr(out, 'backward'):
148                last = filters[node_ids[-1]]["obj"]
149                try:
150                    import torch
151                    if isinstance(y, list):
152                        y_t = torch.tensor(y, dtype=torch.float32)
153                    else:
154                        y_t = y
155                    if hasattr(last, 'loss_fn'):
156                        loss = last.loss_fn(out, y_t)
157                    else:
158                        loss = torch.nn.functional.mse_loss(out, y_t)
159                    loss.backward()
160                    for nid in node_ids:
161                        f = filters[nid]["obj"]
162                        if hasattr(f, 'optimizer'):
163                            f.optimizer.step()
164                            f.optimizer.zero_grad()
165                except Exception:
166                    pass
167
168            states = {}
169            for nid in node_ids:
170                f = filters[nid]["obj"]
171                if hasattr(f, 'state_dict'):
172                    buf = io.BytesIO()
173                    try:
174                        import torch
175                        torch.save(f.state_dict(), buf)
176                    except ImportError:
177                        buf.write(cloudpickle.dumps(f))
178                    states[nid] = base64.b64encode(buf.getvalue()).decode()
179
180            result = out.detach().tolist() if hasattr(out, 'detach') else out
181            print(json.dumps({"ok": True, "result": _encode(result), "states": states}), flush=True)
182
183        elif action == "GET_STATE":
184            nid = cmd["node_id"]
185            f = filters[nid]["obj"]
186            buf = io.BytesIO()
187            if hasattr(f, 'state_dict'):
188                try:
189                    import torch
190                    torch.save(f.state_dict(), buf)
191                except ImportError:
192                    buf.write(cloudpickle.dumps(f))
193            else:
194                buf.write(cloudpickle.dumps(f))
195            state_b64 = base64.b64encode(buf.getvalue()).decode()
196            print(json.dumps({"ok": True, "state_b64": state_b64}), flush=True)
197
198        elif action == "SET_STATE":
199            nid = cmd["node_id"]
200            f = filters[nid]["obj"]
201            state_bytes = base64.b64decode(cmd["state_b64"])
202            buf = io.BytesIO(state_bytes)
203            if hasattr(f, 'load_state_dict'):
204                try:
205                    import torch
206                    f.load_state_dict(torch.load(buf, weights_only=True))
207                except ImportError:
208                    filters[nid]["obj"] = cloudpickle.loads(buf.read())
209            else:
210                filters[nid]["obj"] = cloudpickle.loads(buf.read())
211            print(json.dumps({"ok": True}), flush=True)
212
213        elif action == "GET_GRADIENTS":
214            nid = cmd["node_id"]
215            f = filters[nid]["obj"]
216            buf = io.BytesIO()
217            if hasattr(f, 'parameters'):
218                try:
219                    import torch
220                    grads = {name: p.grad.clone() for name, p in f.named_parameters() if p.grad is not None}
221                    torch.save(grads, buf)
222                except ImportError:
223                    pass
224            grad_b64 = base64.b64encode(buf.getvalue()).decode()
225            print(json.dumps({"ok": True, "gradients_b64": grad_b64}), flush=True)
226
227        elif action == "APPLY_GRADIENTS":
228            nid = cmd["node_id"]
229            f = filters[nid]["obj"]
230            grad_bytes = base64.b64decode(cmd["gradients_b64"])
231            if hasattr(f, 'named_parameters'):
232                try:
233                    import torch
234                    buf = io.BytesIO(grad_bytes)
235                    grads = torch.load(buf, weights_only=True)
236                    for name, p in f.named_parameters():
237                        if name in grads:
238                            p.grad = grads[name]
239                    if hasattr(f, 'optimizer'):
240                        f.optimizer.step()
241                except ImportError:
242                    pass
243            print(json.dumps({"ok": True}), flush=True)
244
245        elif action == "BATCHED_FIT":
246            # Process dataset in batches — model loaded ONCE, batches processed in loop
247            node_ids = cmd["node_ids"]
248            data = _decode(cmd.get("data"))
249            y = _decode(cmd.get("y"))
250            batch_size = cmd.get("batch_size", 32)
251
252            # Find the list dimension to batch on
253            if isinstance(data, dict):
254                list_keys = [k for k, v in data.items() if isinstance(v, list)]
255                total = len(data[list_keys[0]]) if list_keys else 0
256            elif isinstance(data, list):
257                total = len(data)
258            else:
259                total = 0
260
261            all_states = {}
262            n_batches = (total + batch_size - 1) // batch_size if total > 0 else 1
263
264            for b in range(n_batches):
265                start = b * batch_size
266                end = min(start + batch_size, total)
267
268                # Slice the batch
269                if isinstance(data, dict):
270                    batch = {}
271                    for k, v in data.items():
272                        if isinstance(v, list):
273                            batch[k] = v[start:end]
274                        else:
275                            batch[k] = v
276                elif isinstance(data, list):
277                    batch = data[start:end]
278                else:
279                    batch = data
280
281                y_batch = None
282                if y is not None:
283                    if isinstance(y, list):
284                        y_batch = y[start:end]
285                    elif isinstance(y, dict):
286                        y_batch = {k: (v[start:end] if isinstance(v, list) else v) for k, v in y.items()}
287                    else:
288                        y_batch = y
289
290                # Fit + forward for this batch through all filters
291                batch_input = batch
292                for nid in node_ids:
293                    f = filters[nid]["obj"]
294                    if filters[nid].get("trainable", True):
295                        state = f.fit(batch_input, y_batch)
296                        all_states[nid] = state
297                    else:
298                        if nid not in all_states:
299                            all_states[nid] = {}
300                    batch_input = f.forward(batch_input, all_states.get(nid, {}))
301
302                import sys
303                print(f"    Batch {b+1}/{n_batches} complete", file=sys.stderr, flush=True)
304
305            # Encode final states
306            encoded_states = {}
307            for nid, state in all_states.items():
308                encoded_states[nid] = _encode(state)
309
310            result = _encode(batch_input) if batch_input is not None else None
311            print(json.dumps({"ok": True, "result": result, "states": encoded_states}), flush=True)
312
313        elif action == "SHUTDOWN":
314            print(json.dumps({"ok": True}), flush=True)
315            break
316
317        else:
318            print(json.dumps({"ok": False, "error": f"unknown command: {action}"}), flush=True)
319
320    except Exception as e:
321        import traceback
322        tb = traceback.format_exc()
323        print(json.dumps({"ok": False, "error": str(e), "traceback": tb}), flush=True)
324"#;
325
326/// A persistent Python child process that executes filter commands.
327pub struct PythonProcess {
328    child: Child,
329    stdin: BufWriter<ChildStdin>,
330    stdout: BufReader<ChildStdout>,
331    node_ids: Vec<String>,
332}
333
334impl PythonProcess {
335    /// Spawn a Python daemon and load filters into it.
336    pub fn spawn(
337        python_path: &str,
338        filters: &[(String, Vec<u8>, bool)], // (node_id, pickled_bytes, trainable)
339    ) -> Result<Self> {
340        let mut child = Command::new(python_path)
341            .args(["-c", DAEMON_SCRIPT])
342            .stdin(Stdio::piped())
343            .stdout(Stdio::piped())
344            .stderr(Stdio::inherit()) // Python stderr → worker stderr (for logs/tracing)
345            .spawn()
346            .map_err(|e| SomaError::Other(format!("failed to spawn python: {e}")))?;
347
348        let stdin = BufWriter::new(
349            child
350                .stdin
351                .take()
352                .ok_or_else(|| SomaError::Other("no stdin".into()))?,
353        );
354        let stdout = BufReader::new(
355            child
356                .stdout
357                .take()
358                .ok_or_else(|| SomaError::Other("no stdout".into()))?,
359        );
360
361        let node_ids: Vec<String> = filters.iter().map(|(id, _, _)| id.clone()).collect();
362
363        let mut proc = Self {
364            child,
365            stdin,
366            stdout,
367            node_ids,
368        };
369
370        // Send LOAD command with all filters
371        let filter_specs: Vec<serde_json::Value> = filters
372            .iter()
373            .map(|(id, pickled, trainable)| {
374                serde_json::json!({
375                    "id": id,
376                    "pickle_b64": STANDARD.encode(pickled),
377                    "trainable": trainable,
378                })
379            })
380            .collect();
381
382        let resp = proc.send(serde_json::json!({
383            "cmd": "LOAD",
384            "filters": filter_specs,
385        }))?;
386
387        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
388            let error = resp
389                .get("error")
390                .and_then(|e| e.as_str())
391                .unwrap_or("unknown error");
392            return Err(SomaError::Other(format!("LOAD failed: {error}")));
393        }
394
395        Ok(proc)
396    }
397
398    /// Send a JSON command and read the JSON response.
399    fn send(&mut self, cmd: serde_json::Value) -> Result<serde_json::Value> {
400        let action = cmd
401            .get("cmd")
402            .and_then(|c| c.as_str())
403            .unwrap_or("?")
404            .to_string();
405        let node_id = cmd
406            .get("node_id")
407            .and_then(|n| n.as_str())
408            .unwrap_or("")
409            .to_string();
410
411        tracing::debug!(action = %action, node_id = %node_id, "→ Python");
412        let start = std::time::Instant::now();
413
414        let line = serde_json::to_string(&cmd)
415            .map_err(|e| SomaError::Other(format!("serialize cmd: {e}")))?;
416
417        writeln!(self.stdin, "{line}")
418            .map_err(|e| SomaError::Other(format!("write to python stdin: {e}")))?;
419        self.stdin
420            .flush()
421            .map_err(|e| SomaError::Other(format!("flush stdin: {e}")))?;
422
423        let mut response = String::new();
424        self.stdout
425            .read_line(&mut response)
426            .map_err(|e| SomaError::Other(format!("read from python stdout: {e}")))?;
427
428        let duration_ms = start.elapsed().as_millis();
429
430        if response.is_empty() {
431            tracing::error!(action = %action, "Python process closed stdout (crashed?)");
432            return Err(SomaError::Other(
433                "python process closed stdout (crashed?)".into(),
434            ));
435        }
436
437        let parsed: serde_json::Value = serde_json::from_str(&response).map_err(|e| {
438            SomaError::Other(format!("parse python response: {e}\nraw: {response}"))
439        })?;
440
441        let ok = parsed.get("ok") == Some(&serde_json::Value::Bool(true));
442        if ok {
443            tracing::debug!(action = %action, node_id = %node_id, duration_ms, "← Python OK");
444        } else {
445            let error = parsed.get("error").and_then(|e| e.as_str()).unwrap_or("?");
446            let traceback = parsed
447                .get("traceback")
448                .and_then(|t| t.as_str())
449                .unwrap_or("");
450            tracing::error!(action = %action, node_id = %node_id, error, "Python filter error");
451            if !traceback.is_empty() {
452                tracing::error!("Python traceback:\n{traceback}");
453            }
454        }
455
456        Ok(parsed)
457    }
458
459    /// Convert a response to a Value, handling errors.
460    fn response_to_value(resp: &serde_json::Value) -> Result<Value> {
461        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
462            let error = resp
463                .get("error")
464                .and_then(|e| e.as_str())
465                .unwrap_or("unknown error");
466            let traceback = resp.get("traceback").and_then(|t| t.as_str()).unwrap_or("");
467            return Err(SomaError::Other(format!(
468                "Python error: {error}\n{traceback}"
469            )));
470        }
471
472        if let Some(result) = resp.get("result") {
473            return Self::json_to_value(result);
474        }
475
476        Ok(Value::Empty)
477    }
478
479    /// Convert a JSON value to a Soma Value.
480    fn json_to_value(v: &serde_json::Value) -> Result<Value> {
481        if v.is_null() {
482            return Ok(Value::Empty);
483        }
484        if let Some(arr) = v.as_array() {
485            let values: Vec<f64> = arr.iter().filter_map(|x| x.as_f64()).collect();
486            if values.len() == arr.len() && !values.is_empty() {
487                return Ok(Value::tensor(values.clone(), vec![values.len()]));
488            }
489            // Could be nested array
490            if let Some(first) = arr.first()
491                && first.is_array()
492            {
493                let rows = arr.len();
494                let cols = first.as_array().map(|a| a.len()).unwrap_or(0);
495                let flat: Vec<f64> = arr
496                    .iter()
497                    .filter_map(|row| row.as_array())
498                    .flat_map(|row| row.iter().filter_map(|x| x.as_f64()))
499                    .collect();
500                if flat.len() == rows * cols {
501                    return Ok(Value::tensor(flat, vec![rows, cols]));
502                }
503            }
504        }
505        Ok(Value::json(v.clone()))
506    }
507
508    /// Encode a Value to JSON for the Python process.
509    fn value_to_json(v: &Value) -> serde_json::Value {
510        serde_json::to_value(v).unwrap_or(serde_json::Value::Null)
511    }
512
513    // ── Public API ──
514
515    pub fn fit(&mut self, node_id: &str, data: &Value, y: Option<&Value>) -> Result<Value> {
516        let mut cmd = serde_json::json!({
517            "cmd": "FIT",
518            "node_id": node_id,
519            "data": Self::value_to_json(data),
520        });
521        if let Some(y_val) = y {
522            cmd["y"] = Self::value_to_json(y_val);
523        }
524        let resp = self.send(cmd)?;
525        Self::response_to_value(&resp)
526    }
527
528    pub fn forward(&mut self, node_id: &str, data: &Value, state: &Value) -> Result<Value> {
529        let resp = self.send(serde_json::json!({
530            "cmd": "FORWARD",
531            "node_id": node_id,
532            "data": Self::value_to_json(data),
533            "state": Self::value_to_json(state),
534        }))?;
535        Self::response_to_value(&resp)
536    }
537
538    pub fn composite_fit(
539        &mut self,
540        node_ids: &[String],
541        data: &Value,
542        y: Option<&Value>,
543    ) -> Result<(Value, HashMap<String, Value>)> {
544        let mut cmd = serde_json::json!({
545            "cmd": "COMPOSITE_FIT",
546            "node_ids": node_ids,
547            "data": Self::value_to_json(data),
548        });
549        if let Some(y_val) = y {
550            cmd["y"] = Self::value_to_json(y_val);
551        }
552        let resp = self.send(cmd)?;
553        let output = Self::response_to_value(&resp)?;
554
555        let mut states = HashMap::new();
556        if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
557            for (id, b64) in state_map {
558                if let Some(s) = b64.as_str() {
559                    let bytes = STANDARD
560                        .decode(s)
561                        .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
562                    states.insert(id.clone(), Value::bytes(bytes));
563                }
564            }
565        }
566        Ok((output, states))
567    }
568
569    /// Batched fit: send full dataset + batch_size, daemon splits internally.
570    /// Model loaded ONCE, batches processed in a loop.
571    pub fn batched_fit(
572        &mut self,
573        node_ids: &[String],
574        data: &Value,
575        y: Option<&Value>,
576        batch_size: usize,
577    ) -> Result<(Value, HashMap<String, Value>)> {
578        let mut cmd = serde_json::json!({
579            "cmd": "BATCHED_FIT",
580            "node_ids": node_ids,
581            "data": Self::value_to_json(data),
582            "batch_size": batch_size,
583        });
584        if let Some(y_val) = y {
585            cmd["y"] = Self::value_to_json(y_val);
586        }
587        let resp = self.send(cmd)?;
588        let output = Self::response_to_value(&resp)?;
589
590        let mut states = HashMap::new();
591        if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
592            for (id, val) in state_map {
593                if let Ok(v) = Self::json_to_value(val) {
594                    states.insert(id.clone(), v);
595                }
596            }
597        }
598        Ok((output, states))
599    }
600
601    pub fn composite_forward(&mut self, node_ids: &[String], data: &Value) -> Result<Value> {
602        let resp = self.send(serde_json::json!({
603            "cmd": "COMPOSITE_FORWARD",
604            "node_ids": node_ids,
605            "data": Self::value_to_json(data),
606        }))?;
607        Self::response_to_value(&resp)
608    }
609
610    pub fn get_state(&mut self, node_id: &str) -> Result<Value> {
611        let resp = self.send(serde_json::json!({"cmd": "GET_STATE", "node_id": node_id}))?;
612        if let Some(b64) = resp.get("state_b64").and_then(|s| s.as_str()) {
613            let bytes = STANDARD
614                .decode(b64)
615                .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
616            Ok(Value::bytes(bytes))
617        } else {
618            Self::response_to_value(&resp)
619        }
620    }
621
622    pub fn set_state(&mut self, node_id: &str, state: &Value) -> Result<()> {
623        let b64 = match state {
624            Value::Bytes(b) => STANDARD.encode(b.as_slice()),
625            _ => return Err(SomaError::Other("set_state expects Value::Bytes".into())),
626        };
627        let resp = self
628            .send(serde_json::json!({"cmd": "SET_STATE", "node_id": node_id, "state_b64": b64}))?;
629        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
630            let error = resp.get("error").and_then(|e| e.as_str()).unwrap_or("?");
631            return Err(SomaError::Other(format!("set_state: {error}")));
632        }
633        Ok(())
634    }
635
636    pub fn get_gradients(&mut self, node_id: &str) -> Result<Value> {
637        let resp = self.send(serde_json::json!({"cmd": "GET_GRADIENTS", "node_id": node_id}))?;
638        if let Some(b64) = resp.get("gradients_b64").and_then(|s| s.as_str()) {
639            let bytes = STANDARD
640                .decode(b64)
641                .map_err(|e| SomaError::Other(format!("decode gradients: {e}")))?;
642            Ok(Value::bytes(bytes))
643        } else {
644            Ok(Value::Empty)
645        }
646    }
647
648    pub fn apply_gradients(&mut self, node_id: &str, gradients: &Value) -> Result<()> {
649        let b64 = match gradients {
650            Value::Bytes(b) => STANDARD.encode(b.as_slice()),
651            _ => {
652                return Err(SomaError::Other(
653                    "apply_gradients expects Value::Bytes".into(),
654                ));
655            }
656        };
657        self.send(
658            serde_json::json!({"cmd": "APPLY_GRADIENTS", "node_id": node_id, "gradients_b64": b64}),
659        )?;
660        Ok(())
661    }
662
663    pub fn shutdown(&mut self) {
664        let _ = self.send(serde_json::json!({"cmd": "SHUTDOWN"}));
665    }
666
667    pub fn node_ids(&self) -> &[String] {
668        &self.node_ids
669    }
670}
671
672impl Drop for PythonProcess {
673    fn drop(&mut self) {
674        self.shutdown();
675        let _ = self.child.kill();
676        let _ = self.child.wait();
677    }
678}
679
680// ── SubprocessFilter: implements Filter trait via PythonProcess ──
681
682/// A filter that delegates to a shared PythonProcess via stdin/stdout.
683/// Multiple SubprocessFilters can share the same process (Arc<Mutex>).
684pub struct SubprocessFilter {
685    pub(crate) process: Arc<Mutex<PythonProcess>>,
686    node_id: String,
687    trainable: bool,
688}
689
690impl SubprocessFilter {
691    pub fn new(process: Arc<Mutex<PythonProcess>>, node_id: String, trainable: bool) -> Self {
692        Self {
693            process,
694            node_id,
695            trainable,
696        }
697    }
698}
699
700impl Filter for SubprocessFilter {
701    fn config_hash(&self) -> CacheKey {
702        CacheKey::from_parts(&[self.node_id.as_bytes()])
703    }
704
705    fn fit(&self, x: &Value, y: Option<&Value>) -> Result<Value> {
706        self.process
707            .lock()
708            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
709            .fit(&self.node_id, x, y)
710    }
711
712    fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
713        self.process
714            .lock()
715            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
716            .forward(&self.node_id, x, state)
717    }
718
719    fn meta(&self) -> FilterMeta {
720        FilterMeta {
721            name: self.node_id.clone(),
722            kind: if self.trainable {
723                FilterKind::Trainable
724            } else {
725                FilterKind::Stateless
726            },
727            cacheable: true,
728            differentiable: self.trainable,
729            stream_mode: StreamMode::FixedState,
730            distribution: somatize_core::filter::Distribution::Local,
731            input_schema: None,
732            output_schema: None,
733        }
734    }
735
736    fn as_any(&self) -> &dyn std::any::Any {
737        self
738    }
739
740    fn composite_fit(
741        &self,
742        peers: &[(String, std::sync::Arc<dyn somatize_core::filter::Filter>)],
743        x: &Value,
744        y: Option<&Value>,
745    ) -> Option<Result<(Value, HashMap<String, Value>)>> {
746        // Subprocess transport serialises the node_ids only — other filters
747        // aren't shipped; the worker already has them deserialised from the
748        // preceding prepare step.
749        let node_ids: Vec<String> = peers.iter().map(|(id, _)| id.clone()).collect();
750        tracing::info!(nodes = ?node_ids, "Composite fit via subprocess");
751        Some(
752            self.process
753                .lock()
754                .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))
755                .and_then(|mut proc| proc.composite_fit(&node_ids, x, y)),
756        )
757    }
758}