1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::schema::{Task, TaskError, TaskResult, TaskStatus};
10
11pub struct WasmWorkerConfig {
13 pub module_path: PathBuf,
14 pub worker_id: String,
15 pub supported_tasks: Vec<String>,
16 pub max_memory_pages: u32,
17 pub max_execution_time_ms: u64,
18 pub allowed_env: Vec<String>,
19}
20
21pub struct WasmTransport {
23 configs: Vec<WasmWorkerConfig>,
24 workers: Arc<RwLock<HashMap<String, WasmWorker>>>,
25 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
26}
27
28struct WasmWorker {
29 #[allow(dead_code)]
30 config: WasmWorkerConfig,
31 task_tx: mpsc::UnboundedSender<Task>,
32}
33
34impl WasmTransport {
35 pub fn new(
36 configs: Vec<WasmWorkerConfig>,
37 on_message: impl Fn(String, Message) + Send + Sync + 'static,
38 ) -> Self {
39 Self {
40 configs,
41 workers: Arc::new(RwLock::new(HashMap::new())),
42 on_message: Arc::new(on_message),
43 }
44 }
45}
46
47#[async_trait]
48impl Transport for WasmTransport {
49 async fn start(&self) -> Result<(), TransportError> {
50 for config in &self.configs {
51 let worker_id = config.worker_id.clone();
52 let on_message = self.on_message.clone();
53 let workers = self.workers.clone();
54
55 if !config.module_path.exists() {
57 return Err(TransportError::ConnectionFailed(format!(
58 "WASM module not found: {}",
59 config.module_path.display()
60 )));
61 }
62
63 let (task_tx, mut task_rx) = mpsc::unbounded_channel::<Task>();
64
65 let reg_msg = Message::WorkerRegister {
67 registration: super::WorkerRegistration {
68 worker_id: worker_id.clone(),
69 supported_tasks: config.supported_tasks.clone(),
70 max_concurrency: 1,
71 language: super::WorkerLanguage::Other("wasm".to_string()),
72 tags: None,
73 },
74 };
75 on_message(worker_id.clone(), reg_msg);
76
77 let wid = worker_id.clone();
79 let module_path = config.module_path.clone();
80 let max_time = config.max_execution_time_ms;
81
82 tokio::spawn(async move {
83 while let Some(task) = task_rx.recv().await {
84 let start = std::time::Instant::now();
85 let task_id = task.id;
86
87 let task_json = serde_json::to_string(&task).unwrap_or_default();
90
91 let result = tokio::time::timeout(
92 std::time::Duration::from_millis(max_time),
93 execute_wasm_module(&module_path, &task_json),
94 )
95 .await;
96
97 let duration_ms = start.elapsed().as_millis() as u64;
98
99 let task_result = match result {
100 Ok(Ok(output)) => {
101 match serde_json::from_str::<serde_json::Value>(&output) {
102 Ok(payload) => TaskResult {
103 task_id,
104 status: TaskStatus::Completed,
105 payload: Some(payload),
106 error: None,
107 duration_ms,
108 worker_id: wid.clone(),
109 },
110 Err(e) => TaskResult {
111 task_id,
112 status: TaskStatus::Failed,
113 payload: None,
114 error: Some(TaskError {
115 code: "PARSE_ERROR".to_string(),
116 message: format!("Failed to parse WASM output: {}", e),
117 retryable: false,
118 }),
119 duration_ms,
120 worker_id: wid.clone(),
121 },
122 }
123 }
124 Ok(Err(e)) => TaskResult {
125 task_id,
126 status: TaskStatus::Failed,
127 payload: None,
128 error: Some(TaskError {
129 code: "EXECUTION_ERROR".to_string(),
130 message: e,
131 retryable: true,
132 }),
133 duration_ms,
134 worker_id: wid.clone(),
135 },
136 Err(_) => TaskResult {
137 task_id,
138 status: TaskStatus::TimedOut,
139 payload: None,
140 error: Some(TaskError {
141 code: "TIMEOUT".to_string(),
142 message: format!("WASM execution exceeded {}ms", max_time),
143 retryable: true,
144 }),
145 duration_ms,
146 worker_id: wid.clone(),
147 },
148 };
149
150 on_message(
151 wid.clone(),
152 Message::TaskResult {
153 result: task_result,
154 },
155 );
156 }
157 });
158
159 workers.write().await.insert(
160 worker_id.clone(),
161 WasmWorker {
162 config: WasmWorkerConfig {
163 module_path: config.module_path.clone(),
164 worker_id: config.worker_id.clone(),
165 supported_tasks: config.supported_tasks.clone(),
166 max_memory_pages: config.max_memory_pages,
167 max_execution_time_ms: config.max_execution_time_ms,
168 allowed_env: config.allowed_env.clone(),
169 },
170 task_tx,
171 },
172 );
173
174 tracing::info!(
175 worker_id = %worker_id,
176 module = %config.module_path.display(),
177 "WASM worker registered"
178 );
179 }
180
181 Ok(())
182 }
183
184 async fn stop(&self) -> Result<(), TransportError> {
185 self.workers.write().await.clear();
186 Ok(())
187 }
188
189 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
190 let workers = self.workers.read().await;
191 let worker = workers
192 .get(worker_id)
193 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
194
195 if let Message::TaskDispatch { task } = message {
196 worker
197 .task_tx
198 .send(task)
199 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
200 }
201
202 Ok(())
203 }
204
205 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
206 if let Message::TaskDispatch { ref task } = message {
207 let workers = self.workers.read().await;
208 for (_, worker) in workers.iter() {
209 let _ = worker.task_tx.send(task.clone());
210 }
211 }
212 Ok(())
213 }
214}
215
216async fn execute_wasm_module(
217 module_path: &std::path::Path,
218 input_json: &str,
219) -> Result<String, String> {
220 use tokio::io::AsyncWriteExt;
221
222 let mut child = Command::new("wasmtime")
223 .args(["run", &module_path.to_string_lossy()])
224 .stdin(std::process::Stdio::piped())
225 .stdout(std::process::Stdio::piped())
226 .stderr(std::process::Stdio::piped())
227 .spawn()
228 .map_err(|e| format!("Failed to spawn wasmtime: {}", e))?;
229
230 if let Some(mut stdin) = child.stdin.take() {
232 stdin
233 .write_all(input_json.as_bytes())
234 .await
235 .map_err(|e| format!("Failed to write to wasmtime stdin: {}", e))?;
236 drop(stdin);
237 }
238
239 let output = child
240 .wait_with_output()
241 .await
242 .map_err(|e| format!("wasmtime execution failed: {}", e))?;
243
244 if !output.status.success() {
245 let stderr = String::from_utf8_lossy(&output.stderr);
246 return Err(format!("WASM module failed: {}", stderr));
247 }
248
249 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
250}