Skip to main content

rust_pipe/transport/
wasm.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::schema::{Task, TaskError, TaskResult, TaskStatus};
10
11/// Configuration for a WASM module worker.
12pub struct WasmWorkerConfig {
13    pub module_path: PathBuf,
14    pub worker_id: String,
15    pub supported_tasks: Vec<String>,
16    pub max_memory_pages: u32,
17    pub max_execution_time_ms: u64,
18    pub allowed_env: Vec<String>,
19}
20
21/// Transport that executes tasks in sandboxed WebAssembly modules.
22pub struct WasmTransport {
23    configs: Vec<WasmWorkerConfig>,
24    workers: Arc<RwLock<HashMap<String, WasmWorker>>>,
25    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
26}
27
28struct WasmWorker {
29    #[allow(dead_code)]
30    config: WasmWorkerConfig,
31    task_tx: mpsc::UnboundedSender<Task>,
32}
33
34impl WasmTransport {
35    pub fn new(
36        configs: Vec<WasmWorkerConfig>,
37        on_message: impl Fn(String, Message) + Send + Sync + 'static,
38    ) -> Self {
39        Self {
40            configs,
41            workers: Arc::new(RwLock::new(HashMap::new())),
42            on_message: Arc::new(on_message),
43        }
44    }
45}
46
47#[async_trait]
48impl Transport for WasmTransport {
49    async fn start(&self) -> Result<(), TransportError> {
50        for config in &self.configs {
51            let worker_id = config.worker_id.clone();
52            let on_message = self.on_message.clone();
53            let workers = self.workers.clone();
54
55            // Verify module exists
56            if !config.module_path.exists() {
57                return Err(TransportError::ConnectionFailed(format!(
58                    "WASM module not found: {}",
59                    config.module_path.display()
60                )));
61            }
62
63            let (task_tx, mut task_rx) = mpsc::unbounded_channel::<Task>();
64
65            // Register worker
66            let reg_msg = Message::WorkerRegister {
67                registration: super::WorkerRegistration {
68                    worker_id: worker_id.clone(),
69                    supported_tasks: config.supported_tasks.clone(),
70                    max_concurrency: 1,
71                    language: super::WorkerLanguage::Other("wasm".to_string()),
72                },
73            };
74            on_message(worker_id.clone(), reg_msg);
75
76            // WASM execution loop
77            let wid = worker_id.clone();
78            let module_path = config.module_path.clone();
79            let max_time = config.max_execution_time_ms;
80
81            tokio::spawn(async move {
82                while let Some(task) = task_rx.recv().await {
83                    let start = std::time::Instant::now();
84                    let task_id = task.id;
85
86                    // Execute WASM module via wasmtime CLI
87                    // The WASM module reads task JSON from stdin, writes result JSON to stdout
88                    let task_json = serde_json::to_string(&task).unwrap_or_default();
89
90                    let result = tokio::time::timeout(
91                        std::time::Duration::from_millis(max_time),
92                        execute_wasm_module(&module_path, &task_json),
93                    )
94                    .await;
95
96                    let duration_ms = start.elapsed().as_millis() as u64;
97
98                    let task_result = match result {
99                        Ok(Ok(output)) => {
100                            match serde_json::from_str::<serde_json::Value>(&output) {
101                                Ok(payload) => TaskResult {
102                                    task_id,
103                                    status: TaskStatus::Completed,
104                                    payload: Some(payload),
105                                    error: None,
106                                    duration_ms,
107                                    worker_id: wid.clone(),
108                                },
109                                Err(e) => TaskResult {
110                                    task_id,
111                                    status: TaskStatus::Failed,
112                                    payload: None,
113                                    error: Some(TaskError {
114                                        code: "PARSE_ERROR".to_string(),
115                                        message: format!("Failed to parse WASM output: {}", e),
116                                        retryable: false,
117                                    }),
118                                    duration_ms,
119                                    worker_id: wid.clone(),
120                                },
121                            }
122                        }
123                        Ok(Err(e)) => TaskResult {
124                            task_id,
125                            status: TaskStatus::Failed,
126                            payload: None,
127                            error: Some(TaskError {
128                                code: "EXECUTION_ERROR".to_string(),
129                                message: e,
130                                retryable: true,
131                            }),
132                            duration_ms,
133                            worker_id: wid.clone(),
134                        },
135                        Err(_) => TaskResult {
136                            task_id,
137                            status: TaskStatus::TimedOut,
138                            payload: None,
139                            error: Some(TaskError {
140                                code: "TIMEOUT".to_string(),
141                                message: format!("WASM execution exceeded {}ms", max_time),
142                                retryable: true,
143                            }),
144                            duration_ms,
145                            worker_id: wid.clone(),
146                        },
147                    };
148
149                    on_message(
150                        wid.clone(),
151                        Message::TaskResult {
152                            result: task_result,
153                        },
154                    );
155                }
156            });
157
158            workers.write().await.insert(
159                worker_id.clone(),
160                WasmWorker {
161                    config: WasmWorkerConfig {
162                        module_path: config.module_path.clone(),
163                        worker_id: config.worker_id.clone(),
164                        supported_tasks: config.supported_tasks.clone(),
165                        max_memory_pages: config.max_memory_pages,
166                        max_execution_time_ms: config.max_execution_time_ms,
167                        allowed_env: config.allowed_env.clone(),
168                    },
169                    task_tx,
170                },
171            );
172
173            tracing::info!(
174                worker_id = %worker_id,
175                module = %config.module_path.display(),
176                "WASM worker registered"
177            );
178        }
179
180        Ok(())
181    }
182
183    async fn stop(&self) -> Result<(), TransportError> {
184        self.workers.write().await.clear();
185        Ok(())
186    }
187
188    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
189        let workers = self.workers.read().await;
190        let worker = workers
191            .get(worker_id)
192            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
193
194        if let Message::TaskDispatch { task } = message {
195            worker
196                .task_tx
197                .send(task)
198                .map_err(|e| TransportError::SendFailed(e.to_string()))?;
199        }
200
201        Ok(())
202    }
203
204    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
205        if let Message::TaskDispatch { ref task } = message {
206            let workers = self.workers.read().await;
207            for (_, worker) in workers.iter() {
208                let _ = worker.task_tx.send(task.clone());
209            }
210        }
211        Ok(())
212    }
213}
214
215async fn execute_wasm_module(
216    module_path: &std::path::Path,
217    input_json: &str,
218) -> Result<String, String> {
219    use tokio::io::AsyncWriteExt;
220
221    let mut child = Command::new("wasmtime")
222        .args(["run", &module_path.to_string_lossy()])
223        .stdin(std::process::Stdio::piped())
224        .stdout(std::process::Stdio::piped())
225        .stderr(std::process::Stdio::piped())
226        .spawn()
227        .map_err(|e| format!("Failed to spawn wasmtime: {}", e))?;
228
229    // Write task JSON to the module's stdin
230    if let Some(mut stdin) = child.stdin.take() {
231        stdin
232            .write_all(input_json.as_bytes())
233            .await
234            .map_err(|e| format!("Failed to write to wasmtime stdin: {}", e))?;
235        drop(stdin);
236    }
237
238    let output = child
239        .wait_with_output()
240        .await
241        .map_err(|e| format!("wasmtime execution failed: {}", e))?;
242
243    if !output.status.success() {
244        let stderr = String::from_utf8_lossy(&output.stderr);
245        return Err(format!("WASM module failed: {}", stderr));
246    }
247
248    Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
249}