1use std::{
7 num::NonZeroUsize,
8 sync::{
9 atomic::{AtomicU64, Ordering},
10 Arc,
11 },
12 time::Duration,
13};
14
15use lru::LruCache;
16use tokio::sync::oneshot;
17use tracing::{debug, error};
18use wasmtime::{
19 component::{Component, Linker, ResourceTable},
20 Config, Engine, InstanceAllocationStrategy, PoolingAllocationConfig, Store, StoreLimitsBuilder,
21};
22use wasmtime_wasi::WasiCtx;
23
24const EPOCH_INTERVAL_MS: u64 = 100;
28
29use crate::{
30 config::WasmRuntimeConfig,
31 errors::{Result, WasmError, WasmRuntimeError},
32 module::{MiddlewareAttachPoint, WasmModuleAttachPoint},
33 spec::Smg,
34 types::{WasiState, WasmComponentInput, WasmComponentOutput},
35};
36
37pub struct WasmRuntime {
38 config: WasmRuntimeConfig,
39 thread_pool: Arc<WasmThreadPool>,
40 total_executions: AtomicU64,
42 successful_executions: AtomicU64,
43 failed_executions: AtomicU64,
44 total_execution_time_ms: AtomicU64,
45 max_execution_time_ms: AtomicU64,
46}
47
48pub struct WasmThreadPool {
49 sender: async_channel::Sender<WasmTask>,
50 receiver: async_channel::Receiver<WasmTask>,
51 workers: Vec<std::thread::JoinHandle<()>>,
52 total_tasks: AtomicU64,
54 completed_tasks: AtomicU64,
55 failed_tasks: AtomicU64,
56}
57
58pub enum WasmTask {
59 ExecuteComponent {
60 sha256_hash: [u8; 32],
63 wasm_bytes: Arc<Vec<u8>>,
66 attach_point: WasmModuleAttachPoint,
67 input: WasmComponentInput,
68 response: oneshot::Sender<Result<WasmComponentOutput>>,
69 },
70}
71
72impl WasmRuntime {
73 pub fn new(config: WasmRuntimeConfig) -> Self {
74 let thread_pool = Arc::new(WasmThreadPool::new(config.clone()));
75
76 Self {
77 config,
78 thread_pool,
79 total_executions: AtomicU64::new(0),
80 successful_executions: AtomicU64::new(0),
81 failed_executions: AtomicU64::new(0),
82 total_execution_time_ms: AtomicU64::new(0),
83 max_execution_time_ms: AtomicU64::new(0),
84 }
85 }
86
87 pub fn with_default_config() -> Self {
88 Self::new(WasmRuntimeConfig::default())
89 }
90
91 pub fn get_config(&self) -> &WasmRuntimeConfig {
92 &self.config
93 }
94
95 pub fn get_cpu_info() -> (usize, usize) {
97 let cpu_count = std::thread::available_parallelism()
98 .map(|n| n.get())
99 .unwrap_or(4);
100 let max_recommended = cpu_count.max(1);
101 (cpu_count, max_recommended)
102 }
103
104 pub fn get_thread_pool_info(&self) -> (usize, usize) {
106 let (_cpu_count, max_recommended) = Self::get_cpu_info();
107 let current_workers = self.thread_pool.workers.len();
108 (current_workers, max_recommended)
109 }
110
111 pub async fn execute_component_async(
113 &self,
114 sha256_hash: [u8; 32],
115 wasm_bytes: Arc<Vec<u8>>,
116 attach_point: WasmModuleAttachPoint,
117 input: WasmComponentInput,
118 ) -> Result<WasmComponentOutput> {
119 let start_time = std::time::Instant::now();
120 let (response_tx, response_rx) = oneshot::channel();
121
122 let task = WasmTask::ExecuteComponent {
123 sha256_hash,
124 wasm_bytes,
125 attach_point,
126 input,
127 response: response_tx,
128 };
129
130 self.thread_pool.sender.send(task).await.map_err(|e| {
131 WasmRuntimeError::CallFailed(format!("Failed to send task to thread pool: {e}"))
132 })?;
133
134 let result = response_rx.await.map_err(|e| {
135 WasmRuntimeError::CallFailed(format!(
136 "Failed to receive response from thread pool: {e}"
137 ))
138 })?;
139
140 let execution_time_ms = start_time.elapsed().as_millis() as u64;
141 self.total_executions.fetch_add(1, Ordering::Relaxed);
142 self.total_execution_time_ms
143 .fetch_add(execution_time_ms, Ordering::Relaxed);
144 self.max_execution_time_ms
146 .fetch_max(execution_time_ms, Ordering::Relaxed);
147
148 if result.is_ok() {
149 self.successful_executions.fetch_add(1, Ordering::Relaxed);
150 } else {
151 self.failed_executions.fetch_add(1, Ordering::Relaxed);
152 }
153
154 result
155 }
156
157 pub fn get_metrics(&self) -> (u64, u64, u64, u64, u64) {
159 (
160 self.total_executions.load(Ordering::Relaxed),
161 self.successful_executions.load(Ordering::Relaxed),
162 self.failed_executions.load(Ordering::Relaxed),
163 self.total_execution_time_ms.load(Ordering::Relaxed),
164 self.max_execution_time_ms.load(Ordering::Relaxed),
165 )
166 }
167}
168
169fn map_wasm_error(e: wasmtime::Error, timeout_ms: u64) -> WasmError {
171 if e.downcast_ref::<wasmtime::Trap>() == Some(&wasmtime::Trap::Interrupt) {
174 WasmError::from(WasmRuntimeError::Timeout(timeout_ms))
175 } else {
176 WasmError::from(WasmRuntimeError::CallFailed(e.to_string()))
177 }
178}
179
180impl WasmThreadPool {
181 pub fn new(config: WasmRuntimeConfig) -> Self {
182 let (sender, receiver) = async_channel::unbounded();
183
184 let mut workers = Vec::new();
185 let max_workers = std::thread::available_parallelism()
187 .map(|n| n.get())
188 .unwrap_or(4)
189 .max(1);
190 let num_workers = config.thread_pool_size.clamp(1, max_workers);
191
192 debug!(
193 target: "smg::wasm::runtime",
194 "Initializing WASM runtime with {} workers",
195 num_workers
196 );
197
198 for worker_id in 0..num_workers {
199 let receiver = receiver.clone();
200 let config = config.clone();
201
202 let worker = std::thread::spawn(move || {
203 let rt = match tokio::runtime::Runtime::new() {
205 Ok(rt) => rt,
206 Err(e) => {
207 error!(
208 target: "smg::wasm::runtime",
209 worker_id = worker_id,
210 "Failed to create tokio runtime: {}",
211 e
212 );
213 return;
214 }
215 };
216
217 rt.block_on(async {
218 Self::worker_loop(worker_id, receiver, config).await;
219 });
220 });
221
222 workers.push(worker);
223 }
224
225 Self {
226 sender,
227 receiver,
228 workers,
229 total_tasks: AtomicU64::new(0),
230 completed_tasks: AtomicU64::new(0),
231 failed_tasks: AtomicU64::new(0),
232 }
233 }
234
235 pub fn get_metrics(&self) -> (u64, u64, u64) {
237 (
238 self.total_tasks.load(Ordering::Relaxed),
239 self.completed_tasks.load(Ordering::Relaxed),
240 self.failed_tasks.load(Ordering::Relaxed),
241 )
242 }
243
244 async fn worker_loop(
245 worker_id: usize,
246 receiver: async_channel::Receiver<WasmTask>,
247 config: WasmRuntimeConfig,
248 ) {
249 debug!(
250 target: "smg::wasm::runtime",
251 worker_id = worker_id,
252 thread_id = ?std::thread::current().id(),
253 "Worker started"
254 );
255
256 let mut pool_config = PoolingAllocationConfig::default();
257 let max_memory_bytes = (config.max_memory_pages as usize) * 65536;
258
259 pool_config.total_core_instances(20);
262 pool_config.max_memory_size(max_memory_bytes);
263 pool_config.max_component_instance_size(max_memory_bytes);
264 pool_config.max_tables_per_component(5);
265
266 let mut wasmtime_config = Config::new();
267 wasmtime_config.allocation_strategy(InstanceAllocationStrategy::Pooling(pool_config));
268
269 wasmtime_config.async_stack_size(config.max_stack_size);
270 wasmtime_config.async_support(true);
271 wasmtime_config.wasm_component_model(true); wasmtime_config.epoch_interruption(true); let engine = match Engine::new(&wasmtime_config) {
275 Ok(engine) => engine,
276 Err(e) => {
277 error!(
278 target: "smg::wasm::runtime",
279 worker_id = worker_id,
280 "Failed to create engine: {}",
281 e
282 );
283 return;
284 }
285 };
286 let mut linker = Linker::<WasiState>::new(&engine);
287 if let Err(e) = wasmtime_wasi::p2::add_to_linker_async(&mut linker) {
288 error!(
289 target: "smg::wasm::runtime",
290 worker_id = worker_id,
291 "Failed to add WASI to linker: {}",
292 e
293 );
294 return;
295 }
296
297 let default_capacity = NonZeroUsize::new(10).unwrap_or(NonZeroUsize::MIN);
299 let cache_capacity =
300 NonZeroUsize::new(config.module_cache_size).unwrap_or(default_capacity);
301 let mut component_cache: LruCache<[u8; 32], Component> = LruCache::new(cache_capacity);
302
303 let engine_for_epoch = engine.clone();
308 #[expect(
309 clippy::disallowed_methods,
310 reason = "epoch interrupt handler must run as independent background task; abort on drop ensures cleanup"
311 )]
312 let epoch_handle = tokio::spawn(async move {
313 let mut interval = tokio::time::interval(Duration::from_millis(EPOCH_INTERVAL_MS));
314 loop {
315 interval.tick().await;
316 engine_for_epoch.increment_epoch();
317 }
318 });
319
320 debug!(
321 target: "smg::wasm::runtime",
322 worker_id = worker_id,
323 epoch_interval_ms = EPOCH_INTERVAL_MS,
324 "Epoch incrementer started for timeout enforcement"
325 );
326
327 loop {
328 let task = match receiver.recv().await {
329 Ok(task) => task,
330 Err(_) => {
331 debug!(
332 target: "smg::wasm::runtime",
333 worker_id = worker_id,
334 "Worker shutting down"
335 );
336 epoch_handle.abort(); break; }
339 };
340
341 match task {
342 WasmTask::ExecuteComponent {
343 sha256_hash,
344 wasm_bytes,
345 attach_point,
346 input,
347 response,
348 } => {
349 let result = Self::execute_component_in_worker(
350 &engine,
351 &linker,
352 &mut component_cache,
353 sha256_hash,
354 &wasm_bytes,
355 attach_point,
356 input,
357 &config,
358 )
359 .await;
360
361 let _ = response.send(result);
362 }
363 }
364 }
365 }
366
367 #[expect(clippy::too_many_arguments)]
368 async fn execute_component_in_worker(
369 engine: &Engine,
370 linker: &Linker<WasiState>,
371 cache: &mut LruCache<[u8; 32], Component>,
372 sha256_hash: [u8; 32],
373 wasm_bytes: &[u8],
374 attach_point: WasmModuleAttachPoint,
375 input: WasmComponentInput,
376 config: &WasmRuntimeConfig,
377 ) -> Result<WasmComponentOutput> {
378 let component = if let Some(comp) = cache.get(&sha256_hash) {
382 comp.clone() } else {
384 let comp = Component::new(engine, wasm_bytes).map_err(|e| {
386 WasmRuntimeError::CompileFailed(format!(
387 "failed to parse WebAssembly component: {e}. \
388 Hint: The WASM file must be in component format. \
389 If you're using wit-bindgen, use 'wasm-tools component new' to wrap the WASM module into a component."
390 ))
391 })?;
392
393 cache.push(sha256_hash, comp.clone());
394 comp
395 };
396
397 let mut builder = WasiCtx::builder();
398
399 let memory_limit_bytes =
402 usize::try_from(config.get_total_memory_bytes()).map_err(|_| {
403 WasmError::from(WasmRuntimeError::CallFailed(
404 "Configured WASM memory limit exceeds addressable space on this platform."
405 .to_string(),
406 ))
407 })?;
408 let limits = StoreLimitsBuilder::new()
409 .memory_size(memory_limit_bytes)
410 .trap_on_grow_failure(true) .build();
412
413 let mut store = Store::new(
414 engine,
415 WasiState {
416 ctx: builder.build(),
417 table: ResourceTable::new(),
418 limits,
419 },
420 );
421
422 store.limiter(|state| &mut state.limits);
425
426 let deadline_epochs = (config.max_execution_time_ms / EPOCH_INTERVAL_MS).max(1);
430 store.set_epoch_deadline(deadline_epochs);
431
432 store.epoch_deadline_callback(|_store| {
434 Err(wasmtime::Error::msg("execution time limit exceeded"))
435 });
436
437 let output = match attach_point {
438 WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnRequest) => {
439 let request = match input {
440 WasmComponentInput::MiddlewareRequest(req) => req,
441 WasmComponentInput::MiddlewareResponse(_) => {
442 return Err(WasmError::from(WasmRuntimeError::CallFailed(
443 "Expected MiddlewareRequest input for OnRequest attach point"
444 .to_string(),
445 )));
446 }
447 };
448
449 let bindings = Smg::instantiate_async(&mut store, &component, linker)
451 .await
452 .map_err(|e| {
453 WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string()))
454 })?;
455
456 let action_result = bindings
458 .smg_gateway_middleware_on_request()
459 .call_on_request(&mut store, &request)
460 .await
461 .map_err(|e| map_wasm_error(e, config.max_execution_time_ms))?;
462
463 WasmComponentOutput::MiddlewareAction(action_result)
464 }
465 WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnResponse) => {
466 let response = match input {
468 WasmComponentInput::MiddlewareResponse(resp) => resp,
469 WasmComponentInput::MiddlewareRequest(_) => {
470 return Err(WasmError::from(WasmRuntimeError::CallFailed(
471 "Expected MiddlewareResponse input for OnResponse attach point"
472 .to_string(),
473 )));
474 }
475 };
476
477 let bindings = Smg::instantiate_async(&mut store, &component, linker)
479 .await
480 .map_err(|e| {
481 WasmError::from(WasmRuntimeError::InstanceCreateFailed(e.to_string()))
482 })?;
483
484 let action_result = bindings
486 .smg_gateway_middleware_on_response()
487 .call_on_response(&mut store, &response)
488 .await
489 .map_err(|e| map_wasm_error(e, config.max_execution_time_ms))?;
490
491 WasmComponentOutput::MiddlewareAction(action_result)
492 }
493 WasmModuleAttachPoint::Middleware(MiddlewareAttachPoint::OnError) => {
494 return Err(WasmError::from(WasmRuntimeError::CallFailed(
495 "OnError attach point not yet implemented".to_string(),
496 )));
497 }
498 };
499
500 Ok(output)
501 }
502}
503
504impl Drop for WasmThreadPool {
505 fn drop(&mut self) {
506 self.sender.close();
508 self.receiver.close();
509
510 for worker in self.workers.drain(..) {
512 let _ = worker.join();
513 }
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use std::{num::NonZeroUsize, time::Instant};
520
521 use lru::LruCache;
522
523 use super::*;
524 use crate::config::WasmRuntimeConfig;
525
526 #[test]
527 fn test_get_cpu_info() {
528 let (cpu_count, max_recommended) = WasmRuntime::get_cpu_info();
529 assert!(cpu_count > 0);
530 assert!(max_recommended > 0);
531 assert!(max_recommended >= cpu_count);
532 }
533
534 #[test]
535 fn test_config_default_values() {
536 let config = WasmRuntimeConfig::default();
537
538 assert_eq!(config.max_memory_pages, 1024);
539 assert_eq!(config.max_execution_time_ms, 1000);
540 assert_eq!(config.max_stack_size, 1024 * 1024);
541 assert!(config.thread_pool_size > 0);
542 assert_eq!(config.module_cache_size, 10);
543 }
544
545 #[test]
546 fn test_config_clone() {
547 let config = WasmRuntimeConfig::default();
548 let cloned_config = config.clone();
549
550 assert_eq!(config.max_memory_pages, cloned_config.max_memory_pages);
551 assert_eq!(
552 config.max_execution_time_ms,
553 cloned_config.max_execution_time_ms
554 );
555 assert_eq!(config.max_stack_size, cloned_config.max_stack_size);
556 assert_eq!(config.thread_pool_size, cloned_config.thread_pool_size);
557 assert_eq!(config.module_cache_size, cloned_config.module_cache_size);
558 }
559 #[test]
560 fn test_wasm_instantiation_performance_threshold() {
561 const WASM_WAT: &str = r#"
563 (module
564 (memory (export "memory") 1)
565 (func (export "run") (param i32 i32) (result i32)
566 local.get 0
567 local.get 1
568 i32.add)
569 )
570 "#;
571
572 let iterations = 1000;
573
574 let engine_standard = Engine::default();
576 let start_standard = Instant::now();
577 for _ in 0..iterations {
578 let module = wasmtime::Module::new(&engine_standard, WASM_WAT).unwrap();
580 let mut store = Store::new(&engine_standard, ());
581 let instance = wasmtime::Instance::new(&mut store, &module, &[]).unwrap();
582 let run_func = instance
583 .get_typed_func::<(i32, i32), i32>(&mut store, "run")
584 .unwrap();
585 let _ = run_func.call(&mut store, (10, 20)).unwrap();
586 }
587 let duration_standard = start_standard.elapsed();
588
589 let mut pool_config = PoolingAllocationConfig::default();
591
592 pool_config.total_core_instances(100);
593
594 let mut config = Config::new();
595 config.allocation_strategy(InstanceAllocationStrategy::Pooling(pool_config));
596
597 let engine_pooled = Engine::new(&config).unwrap();
598
599 let cache_capacity = NonZeroUsize::new(100).unwrap();
601 let mut cache: LruCache<Vec<u8>, wasmtime::Module> = LruCache::new(cache_capacity);
602
603 let key = WASM_WAT.as_bytes().to_vec();
605 let module_compiled = wasmtime::Module::new(&engine_pooled, WASM_WAT).unwrap();
606 cache.push(key.clone(), module_compiled);
607
608 let start_pooled = Instant::now();
609 for _ in 0..iterations {
610 let module = cache.get(&key).unwrap().clone();
611 let mut store = Store::new(&engine_pooled, ());
612 let instance = wasmtime::Instance::new(&mut store, &module, &[]).unwrap();
613 let run_func = instance
614 .get_typed_func::<(i32, i32), i32>(&mut store, "run")
615 .unwrap();
616 let _ = run_func.call(&mut store, (10, 20)).unwrap();
617 }
618 let duration_pooled = start_pooled.elapsed();
619
620 let standard_secs = duration_standard.as_secs_f64();
622 let pooled_secs = duration_pooled.as_secs_f64();
623
624 if pooled_secs > 0.0 {
625 let speedup = standard_secs / pooled_secs;
626
627 assert!(
628 speedup > 5.0,
629 "Optimization regression: Pooling+Caching was only {speedup:.2}x faster",
630 );
631 }
632 }
633}