1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::schema::{Task, TaskError, TaskResult, TaskStatus};
10
11pub struct WasmWorkerConfig {
13 pub module_path: PathBuf,
14 pub worker_id: String,
15 pub supported_tasks: Vec<String>,
16 pub max_memory_pages: u32,
17 pub max_execution_time_ms: u64,
18 pub allowed_env: Vec<String>,
19}
20
21pub struct WasmTransport {
23 configs: Vec<WasmWorkerConfig>,
24 workers: Arc<RwLock<HashMap<String, WasmWorker>>>,
25 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
26}
27
28struct WasmWorker {
29 #[allow(dead_code)]
30 config: WasmWorkerConfig,
31 task_tx: mpsc::UnboundedSender<Task>,
32}
33
34impl WasmTransport {
35 pub fn new(
36 configs: Vec<WasmWorkerConfig>,
37 on_message: impl Fn(String, Message) + Send + Sync + 'static,
38 ) -> Self {
39 Self {
40 configs,
41 workers: Arc::new(RwLock::new(HashMap::new())),
42 on_message: Arc::new(on_message),
43 }
44 }
45}
46
47#[async_trait]
48impl Transport for WasmTransport {
49 async fn start(&self) -> Result<(), TransportError> {
50 for config in &self.configs {
51 let worker_id = config.worker_id.clone();
52 let on_message = self.on_message.clone();
53 let workers = self.workers.clone();
54
55 if !config.module_path.exists() {
57 return Err(TransportError::ConnectionFailed(format!(
58 "WASM module not found: {}",
59 config.module_path.display()
60 )));
61 }
62
63 let (task_tx, mut task_rx) = mpsc::unbounded_channel::<Task>();
64
65 let reg_msg = Message::WorkerRegister {
67 registration: super::WorkerRegistration {
68 worker_id: worker_id.clone(),
69 supported_tasks: config.supported_tasks.clone(),
70 max_concurrency: 1,
71 language: super::WorkerLanguage::Other("wasm".to_string()),
72 },
73 };
74 on_message(worker_id.clone(), reg_msg);
75
76 let wid = worker_id.clone();
78 let module_path = config.module_path.clone();
79 let max_time = config.max_execution_time_ms;
80
81 tokio::spawn(async move {
82 while let Some(task) = task_rx.recv().await {
83 let start = std::time::Instant::now();
84 let task_id = task.id;
85
86 let task_json = serde_json::to_string(&task).unwrap_or_default();
89
90 let result = tokio::time::timeout(
91 std::time::Duration::from_millis(max_time),
92 execute_wasm_module(&module_path, &task_json),
93 )
94 .await;
95
96 let duration_ms = start.elapsed().as_millis() as u64;
97
98 let task_result = match result {
99 Ok(Ok(output)) => {
100 match serde_json::from_str::<serde_json::Value>(&output) {
101 Ok(payload) => TaskResult {
102 task_id,
103 status: TaskStatus::Completed,
104 payload: Some(payload),
105 error: None,
106 duration_ms,
107 worker_id: wid.clone(),
108 },
109 Err(e) => TaskResult {
110 task_id,
111 status: TaskStatus::Failed,
112 payload: None,
113 error: Some(TaskError {
114 code: "PARSE_ERROR".to_string(),
115 message: format!("Failed to parse WASM output: {}", e),
116 retryable: false,
117 }),
118 duration_ms,
119 worker_id: wid.clone(),
120 },
121 }
122 }
123 Ok(Err(e)) => TaskResult {
124 task_id,
125 status: TaskStatus::Failed,
126 payload: None,
127 error: Some(TaskError {
128 code: "EXECUTION_ERROR".to_string(),
129 message: e,
130 retryable: true,
131 }),
132 duration_ms,
133 worker_id: wid.clone(),
134 },
135 Err(_) => TaskResult {
136 task_id,
137 status: TaskStatus::TimedOut,
138 payload: None,
139 error: Some(TaskError {
140 code: "TIMEOUT".to_string(),
141 message: format!("WASM execution exceeded {}ms", max_time),
142 retryable: true,
143 }),
144 duration_ms,
145 worker_id: wid.clone(),
146 },
147 };
148
149 on_message(
150 wid.clone(),
151 Message::TaskResult {
152 result: task_result,
153 },
154 );
155 }
156 });
157
158 workers.write().await.insert(
159 worker_id.clone(),
160 WasmWorker {
161 config: WasmWorkerConfig {
162 module_path: config.module_path.clone(),
163 worker_id: config.worker_id.clone(),
164 supported_tasks: config.supported_tasks.clone(),
165 max_memory_pages: config.max_memory_pages,
166 max_execution_time_ms: config.max_execution_time_ms,
167 allowed_env: config.allowed_env.clone(),
168 },
169 task_tx,
170 },
171 );
172
173 tracing::info!(
174 worker_id = %worker_id,
175 module = %config.module_path.display(),
176 "WASM worker registered"
177 );
178 }
179
180 Ok(())
181 }
182
183 async fn stop(&self) -> Result<(), TransportError> {
184 self.workers.write().await.clear();
185 Ok(())
186 }
187
188 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
189 let workers = self.workers.read().await;
190 let worker = workers
191 .get(worker_id)
192 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
193
194 if let Message::TaskDispatch { task } = message {
195 worker
196 .task_tx
197 .send(task)
198 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
199 }
200
201 Ok(())
202 }
203
204 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
205 if let Message::TaskDispatch { ref task } = message {
206 let workers = self.workers.read().await;
207 for (_, worker) in workers.iter() {
208 let _ = worker.task_tx.send(task.clone());
209 }
210 }
211 Ok(())
212 }
213}
214
215async fn execute_wasm_module(
216 module_path: &std::path::Path,
217 input_json: &str,
218) -> Result<String, String> {
219 use tokio::io::AsyncWriteExt;
220
221 let mut child = Command::new("wasmtime")
222 .args(["run", &module_path.to_string_lossy()])
223 .stdin(std::process::Stdio::piped())
224 .stdout(std::process::Stdio::piped())
225 .stderr(std::process::Stdio::piped())
226 .spawn()
227 .map_err(|e| format!("Failed to spawn wasmtime: {}", e))?;
228
229 if let Some(mut stdin) = child.stdin.take() {
231 stdin
232 .write_all(input_json.as_bytes())
233 .await
234 .map_err(|e| format!("Failed to write to wasmtime stdin: {}", e))?;
235 drop(stdin);
236 }
237
238 let output = child
239 .wait_with_output()
240 .await
241 .map_err(|e| format!("wasmtime execution failed: {}", e))?;
242
243 if !output.status.success() {
244 let stderr = String::from_utf8_lossy(&output.stderr);
245 return Err(format!("WASM module failed: {}", stderr));
246 }
247
248 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
249}