Skip to main content

somatize_worker/
python_process.rs

1//! Python subprocess — persistent daemon for filter execution.
2//!
3//! Spawns a Python child process that loads filters via cloudpickle
4//! and executes fit/forward commands via stdin/stdout JSON Lines.
5//! The GIL is completely isolated from the Rust process — no segfaults.
6
7use base64::engine::{Engine, general_purpose::STANDARD};
8use somatize_core::cache::CacheKey;
9use somatize_core::error::{Result, SomaError};
10use somatize_core::filter::{Filter, FilterKind, FilterMeta, StreamMode};
11use somatize_core::value::Value;
12use std::collections::HashMap;
13use std::io::{BufRead, BufReader, BufWriter, Write};
14use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
15use std::sync::{Arc, Mutex};
16
17/// The Python daemon script, embedded as a Rust string.
18const DAEMON_SCRIPT: &str = r#"
19import json, sys, base64, cloudpickle, io, pickle
20
21filters = {}
22
23def _encode(obj):
24    """Encode a Python object to JSON-safe format."""
25    if obj is None:
26        return None
27    if isinstance(obj, (list, int, float, str, bool)):
28        return obj
29    if isinstance(obj, dict):
30        return {k: _encode(v) for k, v in obj.items()}
31    # Fall back to pickle + base64
32    return {"__pickle_b64__": base64.b64encode(pickle.dumps(obj)).decode()}
33
34def _decode(obj):
35    """Decode from JSON-safe format back to Python object."""
36    if obj is None:
37        return None
38    if isinstance(obj, dict):
39        if "__pickle_b64__" in obj:
40            return pickle.loads(base64.b64decode(obj["__pickle_b64__"]))
41        if "type" in obj and "data" in obj:
42            # Soma Value format
43            t, d = obj["type"], obj["data"]
44            if t == "Tensor":
45                return d.get("values", [])
46            if t == "Json":
47                return d
48            if t == "Empty":
49                return {}
50            if t == "Bytes":
51                return bytes(d)
52        return {k: _decode(v) for k, v in obj.items()}
53    if isinstance(obj, list):
54        return [_decode(v) for v in obj]
55    return obj
56
57for line in sys.stdin:
58    line = line.strip()
59    if not line:
60        continue
61    try:
62        cmd = json.loads(line)
63    except json.JSONDecodeError as e:
64        print(json.dumps({"ok": False, "error": f"invalid JSON: {e}"}), flush=True)
65        continue
66
67    try:
68        action = cmd.get("cmd", "")
69
70        if action == "LOAD":
71            for f in cmd["filters"]:
72                obj = cloudpickle.loads(base64.b64decode(f["pickle_b64"]))
73                filters[f["id"]] = {"obj": obj, "trainable": f.get("trainable", True)}
74            print(json.dumps({"ok": True}), flush=True)
75
76        elif action == "FIT":
77            f = filters[cmd["node_id"]]["obj"]
78            data = _decode(cmd.get("data"))
79            y = _decode(cmd.get("y"))
80            result = f.fit(data, y)
81            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
82
83        elif action == "FORWARD":
84            f = filters[cmd["node_id"]]["obj"]
85            data = _decode(cmd.get("data"))
86            state = _decode(cmd.get("state", {}))
87            result = f.forward(data, state)
88            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
89
90        elif action == "COMPOSITE_FORWARD":
91            node_ids = cmd["node_ids"]
92            data = _decode(cmd.get("data"))
93            try:
94                import torch
95                if isinstance(data, list):
96                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
97                else:
98                    x = data
99            except ImportError:
100                x = data
101
102            out = x
103            for nid in node_ids:
104                f = filters[nid]["obj"]
105                state = _decode(cmd.get("states", {}).get(nid, {}))
106                out = f.forward(out, state)
107
108            result = out.detach().tolist() if hasattr(out, 'detach') else out
109            print(json.dumps({"ok": True, "result": _encode(result)}), flush=True)
110
111        elif action == "COMPOSITE_FIT":
112            node_ids = cmd["node_ids"]
113            data = _decode(cmd.get("data"))
114            y = _decode(cmd.get("y"))
115
116            # Step 1: fit each trainable filter to get states
117            fit_states = {}
118            fit_input = data
119            for nid in node_ids:
120                f = filters[nid]["obj"]
121                if filters[nid].get("trainable", True):
122                    state = f.fit(fit_input, y)
123                    fit_states[nid] = state
124                else:
125                    fit_states[nid] = {}
126                # Forward to propagate output to next filter
127                fit_input = f.forward(fit_input, fit_states[nid])
128
129            # Step 2: forward with autograd if torch available
130            try:
131                import torch
132                if isinstance(data, list):
133                    x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
134                else:
135                    x = data
136            except ImportError:
137                x = data
138
139            out = x
140            for nid in node_ids:
141                f = filters[nid]["obj"]
142                out = f.forward(out, fit_states.get(nid, {}))
143
144            # Backward
145            if y is not None and hasattr(out, 'backward'):
146                last = filters[node_ids[-1]]["obj"]
147                try:
148                    import torch
149                    if isinstance(y, list):
150                        y_t = torch.tensor(y, dtype=torch.float32)
151                    else:
152                        y_t = y
153                    if hasattr(last, 'loss_fn'):
154                        loss = last.loss_fn(out, y_t)
155                    else:
156                        loss = torch.nn.functional.mse_loss(out, y_t)
157                    loss.backward()
158                    for nid in node_ids:
159                        f = filters[nid]["obj"]
160                        if hasattr(f, 'optimizer'):
161                            f.optimizer.step()
162                            f.optimizer.zero_grad()
163                except Exception:
164                    pass
165
166            states = {}
167            for nid in node_ids:
168                f = filters[nid]["obj"]
169                if hasattr(f, 'state_dict'):
170                    buf = io.BytesIO()
171                    try:
172                        import torch
173                        torch.save(f.state_dict(), buf)
174                    except ImportError:
175                        buf.write(cloudpickle.dumps(f))
176                    states[nid] = base64.b64encode(buf.getvalue()).decode()
177
178            result = out.detach().tolist() if hasattr(out, 'detach') else out
179            print(json.dumps({"ok": True, "result": _encode(result), "states": states}), flush=True)
180
181        elif action == "GET_STATE":
182            nid = cmd["node_id"]
183            f = filters[nid]["obj"]
184            buf = io.BytesIO()
185            if hasattr(f, 'state_dict'):
186                try:
187                    import torch
188                    torch.save(f.state_dict(), buf)
189                except ImportError:
190                    buf.write(cloudpickle.dumps(f))
191            else:
192                buf.write(cloudpickle.dumps(f))
193            state_b64 = base64.b64encode(buf.getvalue()).decode()
194            print(json.dumps({"ok": True, "state_b64": state_b64}), flush=True)
195
196        elif action == "SET_STATE":
197            nid = cmd["node_id"]
198            f = filters[nid]["obj"]
199            state_bytes = base64.b64decode(cmd["state_b64"])
200            buf = io.BytesIO(state_bytes)
201            if hasattr(f, 'load_state_dict'):
202                try:
203                    import torch
204                    f.load_state_dict(torch.load(buf, weights_only=True))
205                except ImportError:
206                    filters[nid]["obj"] = cloudpickle.loads(buf.read())
207            else:
208                filters[nid]["obj"] = cloudpickle.loads(buf.read())
209            print(json.dumps({"ok": True}), flush=True)
210
211        elif action == "GET_GRADIENTS":
212            nid = cmd["node_id"]
213            f = filters[nid]["obj"]
214            buf = io.BytesIO()
215            if hasattr(f, 'parameters'):
216                try:
217                    import torch
218                    grads = {name: p.grad.clone() for name, p in f.named_parameters() if p.grad is not None}
219                    torch.save(grads, buf)
220                except ImportError:
221                    pass
222            grad_b64 = base64.b64encode(buf.getvalue()).decode()
223            print(json.dumps({"ok": True, "gradients_b64": grad_b64}), flush=True)
224
225        elif action == "APPLY_GRADIENTS":
226            nid = cmd["node_id"]
227            f = filters[nid]["obj"]
228            grad_bytes = base64.b64decode(cmd["gradients_b64"])
229            if hasattr(f, 'named_parameters'):
230                try:
231                    import torch
232                    buf = io.BytesIO(grad_bytes)
233                    grads = torch.load(buf, weights_only=True)
234                    for name, p in f.named_parameters():
235                        if name in grads:
236                            p.grad = grads[name]
237                    if hasattr(f, 'optimizer'):
238                        f.optimizer.step()
239                except ImportError:
240                    pass
241            print(json.dumps({"ok": True}), flush=True)
242
243        elif action == "BATCHED_FIT":
244            # Process dataset in batches — model loaded ONCE, batches processed in loop
245            node_ids = cmd["node_ids"]
246            data = _decode(cmd.get("data"))
247            y = _decode(cmd.get("y"))
248            batch_size = cmd.get("batch_size", 32)
249
250            # Find the list dimension to batch on
251            if isinstance(data, dict):
252                list_keys = [k for k, v in data.items() if isinstance(v, list)]
253                total = len(data[list_keys[0]]) if list_keys else 0
254            elif isinstance(data, list):
255                total = len(data)
256            else:
257                total = 0
258
259            all_states = {}
260            n_batches = (total + batch_size - 1) // batch_size if total > 0 else 1
261
262            for b in range(n_batches):
263                start = b * batch_size
264                end = min(start + batch_size, total)
265
266                # Slice the batch
267                if isinstance(data, dict):
268                    batch = {}
269                    for k, v in data.items():
270                        if isinstance(v, list):
271                            batch[k] = v[start:end]
272                        else:
273                            batch[k] = v
274                elif isinstance(data, list):
275                    batch = data[start:end]
276                else:
277                    batch = data
278
279                y_batch = None
280                if y is not None:
281                    if isinstance(y, list):
282                        y_batch = y[start:end]
283                    elif isinstance(y, dict):
284                        y_batch = {k: (v[start:end] if isinstance(v, list) else v) for k, v in y.items()}
285                    else:
286                        y_batch = y
287
288                # Fit + forward for this batch through all filters
289                batch_input = batch
290                for nid in node_ids:
291                    f = filters[nid]["obj"]
292                    if filters[nid].get("trainable", True):
293                        state = f.fit(batch_input, y_batch)
294                        all_states[nid] = state
295                    else:
296                        if nid not in all_states:
297                            all_states[nid] = {}
298                    batch_input = f.forward(batch_input, all_states.get(nid, {}))
299
300                import sys
301                print(f"    Batch {b+1}/{n_batches} complete", file=sys.stderr, flush=True)
302
303            # Encode final states
304            encoded_states = {}
305            for nid, state in all_states.items():
306                encoded_states[nid] = _encode(state)
307
308            result = _encode(batch_input) if batch_input is not None else None
309            print(json.dumps({"ok": True, "result": result, "states": encoded_states}), flush=True)
310
311        elif action == "SHUTDOWN":
312            print(json.dumps({"ok": True}), flush=True)
313            break
314
315        else:
316            print(json.dumps({"ok": False, "error": f"unknown command: {action}"}), flush=True)
317
318    except Exception as e:
319        import traceback
320        tb = traceback.format_exc()
321        print(json.dumps({"ok": False, "error": str(e), "traceback": tb}), flush=True)
322"#;
323
324/// A persistent Python child process that executes filter commands.
325pub struct PythonProcess {
326    child: Child,
327    stdin: BufWriter<ChildStdin>,
328    stdout: BufReader<ChildStdout>,
329    node_ids: Vec<String>,
330}
331
332impl PythonProcess {
333    /// Spawn a Python daemon and load filters into it.
334    pub fn spawn(
335        python_path: &str,
336        filters: &[(String, Vec<u8>, bool)], // (node_id, pickled_bytes, trainable)
337    ) -> Result<Self> {
338        let mut child = Command::new(python_path)
339            .args(["-c", DAEMON_SCRIPT])
340            .stdin(Stdio::piped())
341            .stdout(Stdio::piped())
342            .stderr(Stdio::inherit()) // Python stderr → worker stderr (for logs/tracing)
343            .spawn()
344            .map_err(|e| SomaError::Other(format!("failed to spawn python: {e}")))?;
345
346        let stdin = BufWriter::new(
347            child
348                .stdin
349                .take()
350                .ok_or_else(|| SomaError::Other("no stdin".into()))?,
351        );
352        let stdout = BufReader::new(
353            child
354                .stdout
355                .take()
356                .ok_or_else(|| SomaError::Other("no stdout".into()))?,
357        );
358
359        let node_ids: Vec<String> = filters.iter().map(|(id, _, _)| id.clone()).collect();
360
361        let mut proc = Self {
362            child,
363            stdin,
364            stdout,
365            node_ids,
366        };
367
368        // Send LOAD command with all filters
369        let filter_specs: Vec<serde_json::Value> = filters
370            .iter()
371            .map(|(id, pickled, trainable)| {
372                serde_json::json!({
373                    "id": id,
374                    "pickle_b64": STANDARD.encode(pickled),
375                    "trainable": trainable,
376                })
377            })
378            .collect();
379
380        let resp = proc.send(serde_json::json!({
381            "cmd": "LOAD",
382            "filters": filter_specs,
383        }))?;
384
385        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
386            let error = resp
387                .get("error")
388                .and_then(|e| e.as_str())
389                .unwrap_or("unknown error");
390            return Err(SomaError::Other(format!("LOAD failed: {error}")));
391        }
392
393        Ok(proc)
394    }
395
396    /// Send a JSON command and read the JSON response.
397    fn send(&mut self, cmd: serde_json::Value) -> Result<serde_json::Value> {
398        let action = cmd
399            .get("cmd")
400            .and_then(|c| c.as_str())
401            .unwrap_or("?")
402            .to_string();
403        let node_id = cmd
404            .get("node_id")
405            .and_then(|n| n.as_str())
406            .unwrap_or("")
407            .to_string();
408
409        tracing::debug!(action = %action, node_id = %node_id, "→ Python");
410        let start = std::time::Instant::now();
411
412        let line = serde_json::to_string(&cmd)
413            .map_err(|e| SomaError::Other(format!("serialize cmd: {e}")))?;
414
415        writeln!(self.stdin, "{line}")
416            .map_err(|e| SomaError::Other(format!("write to python stdin: {e}")))?;
417        self.stdin
418            .flush()
419            .map_err(|e| SomaError::Other(format!("flush stdin: {e}")))?;
420
421        let mut response = String::new();
422        self.stdout
423            .read_line(&mut response)
424            .map_err(|e| SomaError::Other(format!("read from python stdout: {e}")))?;
425
426        let duration_ms = start.elapsed().as_millis();
427
428        if response.is_empty() {
429            tracing::error!(action = %action, "Python process closed stdout (crashed?)");
430            return Err(SomaError::Other(
431                "python process closed stdout (crashed?)".into(),
432            ));
433        }
434
435        let parsed: serde_json::Value = serde_json::from_str(&response).map_err(|e| {
436            SomaError::Other(format!("parse python response: {e}\nraw: {response}"))
437        })?;
438
439        let ok = parsed.get("ok") == Some(&serde_json::Value::Bool(true));
440        if ok {
441            tracing::debug!(action = %action, node_id = %node_id, duration_ms, "← Python OK");
442        } else {
443            let error = parsed.get("error").and_then(|e| e.as_str()).unwrap_or("?");
444            let traceback = parsed
445                .get("traceback")
446                .and_then(|t| t.as_str())
447                .unwrap_or("");
448            tracing::error!(action = %action, node_id = %node_id, error, "Python filter error");
449            if !traceback.is_empty() {
450                tracing::error!("Python traceback:\n{traceback}");
451            }
452        }
453
454        Ok(parsed)
455    }
456
457    /// Convert a response to a Value, handling errors.
458    fn response_to_value(resp: &serde_json::Value) -> Result<Value> {
459        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
460            let error = resp
461                .get("error")
462                .and_then(|e| e.as_str())
463                .unwrap_or("unknown error");
464            let traceback = resp.get("traceback").and_then(|t| t.as_str()).unwrap_or("");
465            return Err(SomaError::Other(format!(
466                "Python error: {error}\n{traceback}"
467            )));
468        }
469
470        if let Some(result) = resp.get("result") {
471            return Self::json_to_value(result);
472        }
473
474        Ok(Value::Empty)
475    }
476
477    /// Convert a JSON value to a Soma Value.
478    fn json_to_value(v: &serde_json::Value) -> Result<Value> {
479        if v.is_null() {
480            return Ok(Value::Empty);
481        }
482        if let Some(arr) = v.as_array() {
483            let values: Vec<f64> = arr.iter().filter_map(|x| x.as_f64()).collect();
484            if values.len() == arr.len() && !values.is_empty() {
485                return Ok(Value::tensor(values.clone(), vec![values.len()]));
486            }
487            // Could be nested array
488            if let Some(first) = arr.first()
489                && first.is_array()
490            {
491                let rows = arr.len();
492                let cols = first.as_array().map(|a| a.len()).unwrap_or(0);
493                let flat: Vec<f64> = arr
494                    .iter()
495                    .filter_map(|row| row.as_array())
496                    .flat_map(|row| row.iter().filter_map(|x| x.as_f64()))
497                    .collect();
498                if flat.len() == rows * cols {
499                    return Ok(Value::tensor(flat, vec![rows, cols]));
500                }
501            }
502        }
503        Ok(Value::json(v.clone()))
504    }
505
506    /// Encode a Value to JSON for the Python process.
507    fn value_to_json(v: &Value) -> serde_json::Value {
508        serde_json::to_value(v).unwrap_or(serde_json::Value::Null)
509    }
510
511    // ── Public API ──
512
513    pub fn fit(&mut self, node_id: &str, data: &Value, y: Option<&Value>) -> Result<Value> {
514        let mut cmd = serde_json::json!({
515            "cmd": "FIT",
516            "node_id": node_id,
517            "data": Self::value_to_json(data),
518        });
519        if let Some(y_val) = y {
520            cmd["y"] = Self::value_to_json(y_val);
521        }
522        let resp = self.send(cmd)?;
523        Self::response_to_value(&resp)
524    }
525
526    pub fn forward(&mut self, node_id: &str, data: &Value, state: &Value) -> Result<Value> {
527        let resp = self.send(serde_json::json!({
528            "cmd": "FORWARD",
529            "node_id": node_id,
530            "data": Self::value_to_json(data),
531            "state": Self::value_to_json(state),
532        }))?;
533        Self::response_to_value(&resp)
534    }
535
536    pub fn composite_fit(
537        &mut self,
538        node_ids: &[String],
539        data: &Value,
540        y: Option<&Value>,
541    ) -> Result<(Value, HashMap<String, Value>)> {
542        let mut cmd = serde_json::json!({
543            "cmd": "COMPOSITE_FIT",
544            "node_ids": node_ids,
545            "data": Self::value_to_json(data),
546        });
547        if let Some(y_val) = y {
548            cmd["y"] = Self::value_to_json(y_val);
549        }
550        let resp = self.send(cmd)?;
551        let output = Self::response_to_value(&resp)?;
552
553        let mut states = HashMap::new();
554        if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
555            for (id, b64) in state_map {
556                if let Some(s) = b64.as_str() {
557                    let bytes = STANDARD
558                        .decode(s)
559                        .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
560                    states.insert(id.clone(), Value::bytes(bytes));
561                }
562            }
563        }
564        Ok((output, states))
565    }
566
567    /// Batched fit: send full dataset + batch_size, daemon splits internally.
568    /// Model loaded ONCE, batches processed in a loop.
569    pub fn batched_fit(
570        &mut self,
571        node_ids: &[String],
572        data: &Value,
573        y: Option<&Value>,
574        batch_size: usize,
575    ) -> Result<(Value, HashMap<String, Value>)> {
576        let mut cmd = serde_json::json!({
577            "cmd": "BATCHED_FIT",
578            "node_ids": node_ids,
579            "data": Self::value_to_json(data),
580            "batch_size": batch_size,
581        });
582        if let Some(y_val) = y {
583            cmd["y"] = Self::value_to_json(y_val);
584        }
585        let resp = self.send(cmd)?;
586        let output = Self::response_to_value(&resp)?;
587
588        let mut states = HashMap::new();
589        if let Some(state_map) = resp.get("states").and_then(|s| s.as_object()) {
590            for (id, val) in state_map {
591                if let Ok(v) = Self::json_to_value(val) {
592                    states.insert(id.clone(), v);
593                }
594            }
595        }
596        Ok((output, states))
597    }
598
599    pub fn composite_forward(&mut self, node_ids: &[String], data: &Value) -> Result<Value> {
600        let resp = self.send(serde_json::json!({
601            "cmd": "COMPOSITE_FORWARD",
602            "node_ids": node_ids,
603            "data": Self::value_to_json(data),
604        }))?;
605        Self::response_to_value(&resp)
606    }
607
608    pub fn get_state(&mut self, node_id: &str) -> Result<Value> {
609        let resp = self.send(serde_json::json!({"cmd": "GET_STATE", "node_id": node_id}))?;
610        if let Some(b64) = resp.get("state_b64").and_then(|s| s.as_str()) {
611            let bytes = STANDARD
612                .decode(b64)
613                .map_err(|e| SomaError::Other(format!("decode state: {e}")))?;
614            Ok(Value::bytes(bytes))
615        } else {
616            Self::response_to_value(&resp)
617        }
618    }
619
620    pub fn set_state(&mut self, node_id: &str, state: &Value) -> Result<()> {
621        let b64 = match state {
622            Value::Bytes(b) => STANDARD.encode(b.as_slice()),
623            _ => return Err(SomaError::Other("set_state expects Value::Bytes".into())),
624        };
625        let resp = self
626            .send(serde_json::json!({"cmd": "SET_STATE", "node_id": node_id, "state_b64": b64}))?;
627        if resp.get("ok") != Some(&serde_json::Value::Bool(true)) {
628            let error = resp.get("error").and_then(|e| e.as_str()).unwrap_or("?");
629            return Err(SomaError::Other(format!("set_state: {error}")));
630        }
631        Ok(())
632    }
633
634    pub fn get_gradients(&mut self, node_id: &str) -> Result<Value> {
635        let resp = self.send(serde_json::json!({"cmd": "GET_GRADIENTS", "node_id": node_id}))?;
636        if let Some(b64) = resp.get("gradients_b64").and_then(|s| s.as_str()) {
637            let bytes = STANDARD
638                .decode(b64)
639                .map_err(|e| SomaError::Other(format!("decode gradients: {e}")))?;
640            Ok(Value::bytes(bytes))
641        } else {
642            Ok(Value::Empty)
643        }
644    }
645
646    pub fn apply_gradients(&mut self, node_id: &str, gradients: &Value) -> Result<()> {
647        let b64 = match gradients {
648            Value::Bytes(b) => STANDARD.encode(b.as_slice()),
649            _ => {
650                return Err(SomaError::Other(
651                    "apply_gradients expects Value::Bytes".into(),
652                ));
653            }
654        };
655        self.send(
656            serde_json::json!({"cmd": "APPLY_GRADIENTS", "node_id": node_id, "gradients_b64": b64}),
657        )?;
658        Ok(())
659    }
660
661    pub fn shutdown(&mut self) {
662        let _ = self.send(serde_json::json!({"cmd": "SHUTDOWN"}));
663    }
664
665    pub fn node_ids(&self) -> &[String] {
666        &self.node_ids
667    }
668}
669
670impl Drop for PythonProcess {
671    fn drop(&mut self) {
672        self.shutdown();
673        let _ = self.child.kill();
674        let _ = self.child.wait();
675    }
676}
677
678// ── SubprocessFilter: implements Filter trait via PythonProcess ──
679
680/// A filter that delegates to a shared PythonProcess via stdin/stdout.
681/// Multiple SubprocessFilters can share the same process (Arc<Mutex>).
682pub struct SubprocessFilter {
683    pub(crate) process: Arc<Mutex<PythonProcess>>,
684    node_id: String,
685    trainable: bool,
686}
687
688impl SubprocessFilter {
689    pub fn new(process: Arc<Mutex<PythonProcess>>, node_id: String, trainable: bool) -> Self {
690        Self {
691            process,
692            node_id,
693            trainable,
694        }
695    }
696}
697
698impl Filter for SubprocessFilter {
699    fn config_hash(&self) -> CacheKey {
700        CacheKey::from_parts(&[self.node_id.as_bytes()])
701    }
702
703    fn fit(&self, x: &Value, y: Option<&Value>) -> Result<Value> {
704        self.process
705            .lock()
706            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
707            .fit(&self.node_id, x, y)
708    }
709
710    fn forward(&self, x: &Value, state: &Value) -> Result<Value> {
711        self.process
712            .lock()
713            .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))?
714            .forward(&self.node_id, x, state)
715    }
716
717    fn meta(&self) -> FilterMeta {
718        FilterMeta {
719            name: self.node_id.clone(),
720            kind: if self.trainable {
721                FilterKind::Trainable
722            } else {
723                FilterKind::Stateless
724            },
725            cacheable: true,
726            differentiable: self.trainable,
727            stream_mode: StreamMode::FixedState,
728            distribution: somatize_core::filter::Distribution::Local,
729            input_schema: None,
730            output_schema: None,
731        }
732    }
733
734    fn as_any(&self) -> &dyn std::any::Any {
735        self
736    }
737
738    fn composite_fit(
739        &self,
740        node_ids: &[String],
741        x: &Value,
742        y: Option<&Value>,
743    ) -> Option<Result<(Value, HashMap<String, Value>)>> {
744        tracing::info!(nodes = ?node_ids, "Composite fit via subprocess");
745        Some(
746            self.process
747                .lock()
748                .map_err(|e| SomaError::Other(format!("process mutex poisoned: {e}")))
749                .and_then(|mut proc| proc.composite_fit(node_ids, x, y)),
750        )
751    }
752}