Skip to main content

somatize_worker/
remote_executor.rs

1//! WebSocket-based RemoteExecutor — sends plans to workers and collects results.
2
3use somatize_core::error::{Result, SomaError};
4use somatize_core::filter::RemoteTarget;
5use somatize_core::value::Value;
6use somatize_runtime::executor::RemoteExecutor;
7
8use crate::protocol::*;
9use std::sync::RwLock;
10
11/// A remote executor that dispatches work to workers via WebSocket.
12///
13/// Workers are registered by address + optional token.
14/// When `execute_remote` is called, it finds a matching worker,
15/// connects via WS, sends the plan, and waits for the result.
16pub struct WsRemoteExecutor {
17    /// Registered workers: (address, token, tags)
18    workers: RwLock<Vec<WorkerEntry>>,
19}
20
21#[derive(Clone)]
22struct WorkerEntry {
23    address: String,
24    token: Option<String>,
25    tags: Vec<String>,
26}
27
28impl WsRemoteExecutor {
29    pub fn new() -> Self {
30        Self {
31            workers: RwLock::new(Vec::new()),
32        }
33    }
34
35    /// Register a worker endpoint.
36    pub fn add_worker(&self, address: impl Into<String>, token: Option<String>, tags: Vec<String>) {
37        let mut workers = self.workers.write().unwrap();
38        workers.push(WorkerEntry {
39            address: address.into(),
40            token,
41            tags,
42        });
43    }
44
45    /// Find a worker matching the given target.
46    fn find_worker(&self, target: &RemoteTarget) -> Option<WorkerEntry> {
47        let workers = self.workers.read().unwrap();
48        match target {
49            RemoteTarget::WorkerId(id) => workers.iter().find(|w| w.address.contains(id)).cloned(),
50            RemoteTarget::Tag(tag) => workers.iter().find(|w| w.tags.contains(tag)).cloned(),
51        }
52    }
53
54    /// Send a plan to a worker via WebSocket and wait for the result.
55    fn execute_on_worker(
56        &self,
57        worker: &WorkerEntry,
58        node_id: &str,
59        input: Option<&Value>,
60    ) -> Result<Value> {
61        let rt = tokio::runtime::Builder::new_current_thread()
62            .enable_all()
63            .build()
64            .map_err(|e| SomaError::Other(format!("tokio runtime: {e}")))?;
65
66        rt.block_on(async {
67            let url = if let Some(token) = &worker.token {
68                format!("{}/ws?token={}", worker.address, token)
69            } else {
70                format!("{}/ws", worker.address)
71            };
72
73            // No size limits — workers handle arbitrary payloads (datasets, model weights, etc.)
74            let mut ws_config =
75                tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default();
76            ws_config.max_message_size = None;
77            ws_config.max_frame_size = None;
78            let (mut ws, _) =
79                tokio_tungstenite::connect_async_with_config(&url, Some(ws_config), false)
80                    .await
81                    .map_err(|e| {
82                        SomaError::Other(format!("WS connect to {}: {e}", worker.address))
83                    })?;
84
85            use futures_util::{SinkExt, StreamExt};
86            use tokio_tungstenite::tungstenite::Message;
87
88            // Build a simple plan: execute this one node
89            let plan = SerializedPlan {
90                plan_id: format!("remote_{node_id}"),
91                plan: somatize_compiler::ExecutionPlan::Execute {
92                    node_id: node_id.to_string(),
93                },
94                input: input.map(|v| InputSource::Inline { value: v.clone() }),
95                filters: vec![],
96                mode: ExecutionMode::default(),
97                metadata: serde_json::json!({}),
98            };
99
100            let msg = CoordinatorToWorker::AssignPlan { plan };
101            let json = serde_json::to_string(&msg)
102                .map_err(|e| SomaError::Other(format!("serialize: {e}")))?;
103
104            ws.send(Message::Text(json.into()))
105                .await
106                .map_err(|e| SomaError::Other(format!("WS send: {e}")))?;
107
108            // Wait for result
109            while let Some(Ok(Message::Text(response))) = ws.next().await {
110                if let Ok(result) = serde_json::from_str::<WorkerToCoordinator>(&response) {
111                    match result {
112                        WorkerToCoordinator::PlanResult { result, .. } => match result {
113                            PlanResult::Success { output, .. } => {
114                                let _ = ws.close(None).await;
115                                let value = match output {
116                                    OutputDelivery::Inline { value } => value,
117                                    OutputDelivery::Reference { data_ref } => {
118                                        // Download from worker via HTTP
119                                        let http_addr = worker
120                                            .address
121                                            .replace("ws://", "http://")
122                                            .replace("wss://", "https://");
123                                        let url = format!("{http_addr}/download");
124                                        let ref_json =
125                                            serde_json::to_string(&data_ref).map_err(|e| {
126                                                SomaError::Other(format!("serialize data_ref: {e}"))
127                                            })?;
128                                        let client = reqwest::blocking::Client::new();
129                                        let mut req = client.get(&url).query(&[("ref", &ref_json)]);
130                                        if let Some(token) = &worker.token {
131                                            req = req.query(&[("token", token)]);
132                                        }
133                                        let resp = req.send().map_err(|e| {
134                                            SomaError::Other(format!("HTTP download: {e}"))
135                                        })?;
136                                        if !resp.status().is_success() {
137                                            return Err(SomaError::Other(format!(
138                                                "download failed: {}",
139                                                resp.status()
140                                            )));
141                                        }
142                                        let bytes = resp.bytes().map_err(|e| {
143                                            SomaError::Other(format!("read response: {e}"))
144                                        })?;
145                                        rmp_serde::from_slice(&bytes).or_else(|_| {
146                                            serde_json::from_slice(&bytes).map_err(|e| {
147                                                SomaError::Other(format!(
148                                                    "deserialize download: {e}"
149                                                ))
150                                            })
151                                        })?
152                                    }
153                                };
154                                return Ok(value);
155                            }
156                            PlanResult::Failed { error, .. } => {
157                                let _ = ws.close(None).await;
158                                return Err(SomaError::Execution {
159                                    node_id: node_id.to_string(),
160                                    message: error,
161                                });
162                            }
163                        },
164                        // Skip progress/event messages
165                        _ => continue,
166                    }
167                }
168            }
169
170            let _ = ws.close(None).await;
171            Err(SomaError::Other(format!(
172                "worker {} closed without result",
173                worker.address
174            )))
175        })
176    }
177
178    pub fn has_workers(&self) -> bool {
179        !self.workers.read().unwrap().is_empty()
180    }
181}
182
183impl Default for WsRemoteExecutor {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl RemoteExecutor for WsRemoteExecutor {
190    fn execute_remote(
191        &self,
192        node_id: &str,
193        target: &RemoteTarget,
194        input: Option<&Value>,
195    ) -> Result<Value> {
196        let worker = self
197            .find_worker(target)
198            .ok_or_else(|| SomaError::Other(format!("no worker found for target {target:?}")))?;
199
200        tracing::info!(
201            "Dispatching node '{node_id}' to worker at {}",
202            worker.address
203        );
204
205        self.execute_on_worker(&worker, node_id, input)
206    }
207}