Skip to main content

rust_pipe/transport/
wasm.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::schema::{Task, TaskError, TaskResult, TaskStatus};
10
11/// Configuration for a WASM module worker.
12pub struct WasmWorkerConfig {
13    pub module_path: PathBuf,
14    pub worker_id: String,
15    pub supported_tasks: Vec<String>,
16    pub max_memory_pages: u32,
17    pub max_execution_time_ms: u64,
18    pub allowed_env: Vec<String>,
19}
20
21/// Transport that executes tasks in sandboxed WebAssembly modules.
22pub struct WasmTransport {
23    configs: Vec<WasmWorkerConfig>,
24    workers: Arc<RwLock<HashMap<String, WasmWorker>>>,
25    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
26}
27
28struct WasmWorker {
29    #[allow(dead_code)]
30    config: WasmWorkerConfig,
31    task_tx: mpsc::UnboundedSender<Task>,
32}
33
34impl WasmTransport {
35    pub fn new(
36        configs: Vec<WasmWorkerConfig>,
37        on_message: impl Fn(String, Message) + Send + Sync + 'static,
38    ) -> Self {
39        Self {
40            configs,
41            workers: Arc::new(RwLock::new(HashMap::new())),
42            on_message: Arc::new(on_message),
43        }
44    }
45}
46
47#[async_trait]
48impl Transport for WasmTransport {
49    async fn start(&self) -> Result<(), TransportError> {
50        for config in &self.configs {
51            let worker_id = config.worker_id.clone();
52            let on_message = self.on_message.clone();
53            let workers = self.workers.clone();
54
55            // Verify module exists
56            if !config.module_path.exists() {
57                return Err(TransportError::ConnectionFailed(format!(
58                    "WASM module not found: {}",
59                    config.module_path.display()
60                )));
61            }
62
63            let (task_tx, mut task_rx) = mpsc::unbounded_channel::<Task>();
64
65            // Register worker
66            let reg_msg = Message::WorkerRegister {
67                registration: super::WorkerRegistration {
68                    worker_id: worker_id.clone(),
69                    supported_tasks: config.supported_tasks.clone(),
70                    max_concurrency: 1,
71                    language: super::WorkerLanguage::Other("wasm".to_string()),
72                    tags: None,
73                },
74            };
75            on_message(worker_id.clone(), reg_msg);
76
77            // WASM execution loop
78            let wid = worker_id.clone();
79            let module_path = config.module_path.clone();
80            let max_time = config.max_execution_time_ms;
81
82            tokio::spawn(async move {
83                while let Some(task) = task_rx.recv().await {
84                    let start = std::time::Instant::now();
85                    let task_id = task.id;
86
87                    // Execute WASM module via wasmtime CLI
88                    // The WASM module reads task JSON from stdin, writes result JSON to stdout
89                    let task_json = serde_json::to_string(&task).unwrap_or_default();
90
91                    let result = tokio::time::timeout(
92                        std::time::Duration::from_millis(max_time),
93                        execute_wasm_module(&module_path, &task_json),
94                    )
95                    .await;
96
97                    let duration_ms = start.elapsed().as_millis() as u64;
98
99                    let task_result = match result {
100                        Ok(Ok(output)) => {
101                            match serde_json::from_str::<serde_json::Value>(&output) {
102                                Ok(payload) => TaskResult {
103                                    task_id,
104                                    status: TaskStatus::Completed,
105                                    payload: Some(payload),
106                                    error: None,
107                                    duration_ms,
108                                    worker_id: wid.clone(),
109                                },
110                                Err(e) => TaskResult {
111                                    task_id,
112                                    status: TaskStatus::Failed,
113                                    payload: None,
114                                    error: Some(TaskError {
115                                        code: "PARSE_ERROR".to_string(),
116                                        message: format!("Failed to parse WASM output: {}", e),
117                                        retryable: false,
118                                    }),
119                                    duration_ms,
120                                    worker_id: wid.clone(),
121                                },
122                            }
123                        }
124                        Ok(Err(e)) => TaskResult {
125                            task_id,
126                            status: TaskStatus::Failed,
127                            payload: None,
128                            error: Some(TaskError {
129                                code: "EXECUTION_ERROR".to_string(),
130                                message: e,
131                                retryable: true,
132                            }),
133                            duration_ms,
134                            worker_id: wid.clone(),
135                        },
136                        Err(_) => TaskResult {
137                            task_id,
138                            status: TaskStatus::TimedOut,
139                            payload: None,
140                            error: Some(TaskError {
141                                code: "TIMEOUT".to_string(),
142                                message: format!("WASM execution exceeded {}ms", max_time),
143                                retryable: true,
144                            }),
145                            duration_ms,
146                            worker_id: wid.clone(),
147                        },
148                    };
149
150                    on_message(
151                        wid.clone(),
152                        Message::TaskResult {
153                            result: task_result,
154                        },
155                    );
156                }
157            });
158
159            workers.write().await.insert(
160                worker_id.clone(),
161                WasmWorker {
162                    config: WasmWorkerConfig {
163                        module_path: config.module_path.clone(),
164                        worker_id: config.worker_id.clone(),
165                        supported_tasks: config.supported_tasks.clone(),
166                        max_memory_pages: config.max_memory_pages,
167                        max_execution_time_ms: config.max_execution_time_ms,
168                        allowed_env: config.allowed_env.clone(),
169                    },
170                    task_tx,
171                },
172            );
173
174            tracing::info!(
175                worker_id = %worker_id,
176                module = %config.module_path.display(),
177                "WASM worker registered"
178            );
179        }
180
181        Ok(())
182    }
183
184    async fn stop(&self) -> Result<(), TransportError> {
185        self.workers.write().await.clear();
186        Ok(())
187    }
188
189    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
190        let workers = self.workers.read().await;
191        let worker = workers
192            .get(worker_id)
193            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
194
195        if let Message::TaskDispatch { task } = message {
196            worker
197                .task_tx
198                .send(task)
199                .map_err(|e| TransportError::SendFailed(e.to_string()))?;
200        }
201
202        Ok(())
203    }
204
205    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
206        if let Message::TaskDispatch { ref task } = message {
207            let workers = self.workers.read().await;
208            for (_, worker) in workers.iter() {
209                let _ = worker.task_tx.send(task.clone());
210            }
211        }
212        Ok(())
213    }
214}
215
216async fn execute_wasm_module(
217    module_path: &std::path::Path,
218    input_json: &str,
219) -> Result<String, String> {
220    use tokio::io::AsyncWriteExt;
221
222    let mut child = Command::new("wasmtime")
223        .args(["run", &module_path.to_string_lossy()])
224        .stdin(std::process::Stdio::piped())
225        .stdout(std::process::Stdio::piped())
226        .stderr(std::process::Stdio::piped())
227        .spawn()
228        .map_err(|e| format!("Failed to spawn wasmtime: {}", e))?;
229
230    // Write task JSON to the module's stdin
231    if let Some(mut stdin) = child.stdin.take() {
232        stdin
233            .write_all(input_json.as_bytes())
234            .await
235            .map_err(|e| format!("Failed to write to wasmtime stdin: {}", e))?;
236        drop(stdin);
237    }
238
239    let output = child
240        .wait_with_output()
241        .await
242        .map_err(|e| format!("wasmtime execution failed: {}", e))?;
243
244    if !output.status.success() {
245        let stderr = String::from_utf8_lossy(&output.stderr);
246        return Err(format!("WASM module failed: {}", stderr));
247    }
248
249    Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
250}