wasmtime_cli/commands/
serve.rs

1use crate::common::{Profile, RunCommon, RunTarget};
2use anyhow::{Context as _, Result, bail};
3use bytes::Bytes;
4use clap::Parser;
5use futures::future::FutureExt;
6use http::{Response, StatusCode};
7use http_body_util::BodyExt as _;
8use http_body_util::combinators::UnsyncBoxBody;
9use std::convert::Infallible;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::{
14    path::PathBuf,
15    sync::{
16        Arc, Mutex,
17        atomic::{AtomicBool, Ordering},
18    },
19    time::Duration,
20};
21use tokio::io::{self, AsyncWrite};
22use tokio::sync::Notify;
23use wasmtime::component::{Component, Linker, ResourceTable};
24use wasmtime::{Engine, Store, StoreContextMut, StoreLimits, UpdateDeadline};
25use wasmtime_cli_flags::opt::WasmtimeOptionValue;
26use wasmtime_wasi::p2::{StreamError, StreamResult};
27use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
28#[cfg(feature = "component-model-async")]
29use wasmtime_wasi_http::handler::p2::bindings as p2;
30use wasmtime_wasi_http::handler::{HandlerState, Proxy, ProxyHandler, ProxyPre, StoreBundle};
31use wasmtime_wasi_http::io::TokioIo;
32use wasmtime_wasi_http::{
33    DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS, DEFAULT_OUTGOING_BODY_CHUNK_SIZE, WasiHttpCtx,
34    WasiHttpView,
35};
36
37#[cfg(feature = "wasi-config")]
38use wasmtime_wasi_config::{WasiConfig, WasiConfigVariables};
39#[cfg(feature = "wasi-keyvalue")]
40use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder};
41#[cfg(feature = "wasi-nn")]
42use wasmtime_wasi_nn::wit::WasiNnCtx;
43
44const DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT: usize = 128;
45const DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT: usize = 1;
46const DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT: usize = 16;
47
48struct Host {
49    table: wasmtime::component::ResourceTable,
50    ctx: WasiCtx,
51    http: WasiHttpCtx,
52    http_outgoing_body_buffer_chunks: Option<usize>,
53    http_outgoing_body_chunk_size: Option<usize>,
54
55    #[cfg(feature = "component-model-async")]
56    p3_http: crate::common::DefaultP3Ctx,
57
58    limits: StoreLimits,
59
60    #[cfg(feature = "wasi-nn")]
61    nn: Option<WasiNnCtx>,
62
63    #[cfg(feature = "wasi-config")]
64    wasi_config: Option<WasiConfigVariables>,
65
66    #[cfg(feature = "wasi-keyvalue")]
67    wasi_keyvalue: Option<WasiKeyValueCtx>,
68
69    #[cfg(feature = "profiling")]
70    guest_profiler: Option<Arc<wasmtime::GuestProfiler>>,
71}
72
73impl WasiView for Host {
74    fn ctx(&mut self) -> WasiCtxView<'_> {
75        WasiCtxView {
76            ctx: &mut self.ctx,
77            table: &mut self.table,
78        }
79    }
80}
81
82impl WasiHttpView for Host {
83    fn ctx(&mut self) -> &mut WasiHttpCtx {
84        &mut self.http
85    }
86    fn table(&mut self) -> &mut ResourceTable {
87        &mut self.table
88    }
89
90    fn outgoing_body_buffer_chunks(&mut self) -> usize {
91        self.http_outgoing_body_buffer_chunks
92            .unwrap_or_else(|| DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS)
93    }
94
95    fn outgoing_body_chunk_size(&mut self) -> usize {
96        self.http_outgoing_body_chunk_size
97            .unwrap_or_else(|| DEFAULT_OUTGOING_BODY_CHUNK_SIZE)
98    }
99}
100
101#[cfg(feature = "component-model-async")]
102impl wasmtime_wasi_http::p3::WasiHttpView for Host {
103    fn http(&mut self) -> wasmtime_wasi_http::p3::WasiHttpCtxView<'_> {
104        wasmtime_wasi_http::p3::WasiHttpCtxView {
105            table: &mut self.table,
106            ctx: &mut self.p3_http,
107        }
108    }
109}
110
111const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new(
112    std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)),
113    8080,
114);
115
116fn parse_duration(s: &str) -> Result<Duration, String> {
117    Duration::parse(Some(s)).map_err(|e| e.to_string())
118}
119
120/// Runs a WebAssembly module
121#[derive(Parser)]
122pub struct ServeCommand {
123    #[command(flatten)]
124    run: RunCommon,
125
126    /// Socket address for the web server to bind to.
127    #[arg(long , value_name = "SOCKADDR", default_value_t = DEFAULT_ADDR)]
128    addr: SocketAddr,
129
130    /// Socket address where, when connected to, will initiate a graceful
131    /// shutdown.
132    ///
133    /// Note that graceful shutdown is also supported on ctrl-c.
134    #[arg(long, value_name = "SOCKADDR")]
135    shutdown_addr: Option<SocketAddr>,
136
137    /// Disable log prefixes of wasi-http handlers.
138    /// if unspecified, logs will be prefixed with 'stdout|stderr [{req_id}] :: '
139    #[arg(long)]
140    no_logging_prefix: bool,
141
142    /// The WebAssembly component to run.
143    #[arg(value_name = "WASM", required = true)]
144    component: PathBuf,
145
146    /// Maximum number of requests to send to a single component instance before
147    /// dropping it.
148    ///
149    /// This defaults to 1 for WASIp2 components and 128 for WASIp3 components.
150    #[arg(long)]
151    max_instance_reuse_count: Option<usize>,
152
153    /// Maximum number of concurrent requests to send to a single component
154    /// instance.
155    ///
156    /// This defaults to 1 for WASIp2 components and 16 for WASIp3 components.
157    /// Note that setting it to more than 1 will have no effect for WASIp2
158    /// components since they cannot be called concurrently.
159    #[arg(long)]
160    max_instance_concurrent_reuse_count: Option<usize>,
161
162    /// Time to hold an idle component instance for possible reuse before
163    /// dropping it.
164    ///
165    /// A number with no suffix or with an `s` suffix is interpreted as seconds;
166    /// other accepted suffixes include `ms` (milliseconds), `us` or `μs`
167    /// (microseconds), and `ns` (nanoseconds).
168    #[arg(long, default_value = "1s", value_parser = parse_duration)]
169    idle_instance_timeout: Duration,
170}
171
172impl ServeCommand {
173    /// Start a server to run the given wasi-http proxy component
174    pub fn execute(mut self) -> Result<()> {
175        self.run.common.init_logging()?;
176
177        // We force cli errors before starting to listen for connections so then
178        // we don't accidentally delay them to the first request.
179
180        if self.run.common.wasi.nn == Some(true) {
181            #[cfg(not(feature = "wasi-nn"))]
182            {
183                bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
184            }
185        }
186
187        if self.run.common.wasi.threads == Some(true) {
188            bail!("wasi-threads does not support components yet")
189        }
190
191        // The serve command requires both wasi-http and the component model, so
192        // we enable those by default here.
193        if self.run.common.wasi.http.replace(true) == Some(false) {
194            bail!("wasi-http is required for the serve command, and must not be disabled");
195        }
196        if self.run.common.wasm.component_model.replace(true) == Some(false) {
197            bail!("components are required for the serve command, and must not be disabled");
198        }
199
200        let runtime = tokio::runtime::Builder::new_multi_thread()
201            .enable_time()
202            .enable_io()
203            .build()?;
204
205        runtime.block_on(self.serve())?;
206
207        Ok(())
208    }
209
210    fn new_store(&self, engine: &Engine, req_id: Option<u64>) -> Result<Store<Host>> {
211        let mut builder = WasiCtxBuilder::new();
212        self.run.configure_wasip2(&mut builder)?;
213
214        if let Some(req_id) = req_id {
215            builder.env("REQUEST_ID", req_id.to_string());
216        }
217
218        let stdout_prefix: String;
219        let stderr_prefix: String;
220        match req_id {
221            Some(req_id) if !self.no_logging_prefix => {
222                stdout_prefix = format!("stdout [{req_id}] :: ");
223                stderr_prefix = format!("stderr [{req_id}] :: ");
224            }
225            _ => {
226                stdout_prefix = "".to_string();
227                stderr_prefix = "".to_string();
228            }
229        }
230        builder.stdout(LogStream::new(stdout_prefix, Output::Stdout));
231        builder.stderr(LogStream::new(stderr_prefix, Output::Stderr));
232
233        let mut host = Host {
234            table: wasmtime::component::ResourceTable::new(),
235            ctx: builder.build(),
236            http: WasiHttpCtx::new(),
237            http_outgoing_body_buffer_chunks: self.run.common.wasi.http_outgoing_body_buffer_chunks,
238            http_outgoing_body_chunk_size: self.run.common.wasi.http_outgoing_body_chunk_size,
239
240            limits: StoreLimits::default(),
241
242            #[cfg(feature = "wasi-nn")]
243            nn: None,
244            #[cfg(feature = "wasi-config")]
245            wasi_config: None,
246            #[cfg(feature = "wasi-keyvalue")]
247            wasi_keyvalue: None,
248            #[cfg(feature = "profiling")]
249            guest_profiler: None,
250            #[cfg(feature = "component-model-async")]
251            p3_http: crate::common::DefaultP3Ctx,
252        };
253
254        if self.run.common.wasi.nn == Some(true) {
255            #[cfg(feature = "wasi-nn")]
256            {
257                let graphs = self
258                    .run
259                    .common
260                    .wasi
261                    .nn_graph
262                    .iter()
263                    .map(|g| (g.format.clone(), g.dir.clone()))
264                    .collect::<Vec<_>>();
265                let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?;
266                host.nn.replace(WasiNnCtx::new(backends, registry));
267            }
268        }
269
270        if self.run.common.wasi.config == Some(true) {
271            #[cfg(feature = "wasi-config")]
272            {
273                let vars = WasiConfigVariables::from_iter(
274                    self.run
275                        .common
276                        .wasi
277                        .config_var
278                        .iter()
279                        .map(|v| (v.key.clone(), v.value.clone())),
280                );
281                host.wasi_config.replace(vars);
282            }
283        }
284
285        if self.run.common.wasi.keyvalue == Some(true) {
286            #[cfg(feature = "wasi-keyvalue")]
287            {
288                let ctx = WasiKeyValueCtxBuilder::new()
289                    .in_memory_data(
290                        self.run
291                            .common
292                            .wasi
293                            .keyvalue_in_memory_data
294                            .iter()
295                            .map(|v| (v.key.clone(), v.value.clone())),
296                    )
297                    .build();
298                host.wasi_keyvalue.replace(ctx);
299            }
300        }
301
302        let mut store = Store::new(engine, host);
303
304        store.data_mut().limits = self.run.store_limits();
305        store.limiter(|t| &mut t.limits);
306
307        // If fuel has been configured, we want to add the configured
308        // fuel amount to this store.
309        if let Some(fuel) = self.run.common.wasm.fuel {
310            store.set_fuel(fuel)?;
311        }
312
313        Ok(store)
314    }
315
316    fn add_to_linker(&self, linker: &mut Linker<Host>) -> Result<()> {
317        self.run.validate_p3_option()?;
318        let cli = self.run.validate_cli_enabled()?;
319
320        // Repurpose the `-Scli` flag of `wasmtime run` for `wasmtime serve`
321        // to serve as a signal to enable all WASI interfaces instead of just
322        // those in the `proxy` world. If `-Scli` is present then add all
323        // `command` APIs and then additionally add in the required HTTP APIs.
324        //
325        // If `-Scli` isn't passed then use the `add_to_linker_async`
326        // bindings which adds just those interfaces that the proxy interface
327        // uses.
328        if cli == Some(true) {
329            self.run.add_wasmtime_wasi_to_linker(linker)?;
330            wasmtime_wasi_http::add_only_http_to_linker_async(linker)?;
331            #[cfg(feature = "component-model-async")]
332            if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
333                wasmtime_wasi_http::p3::add_to_linker(linker)?;
334            }
335        } else {
336            wasmtime_wasi_http::add_to_linker_async(linker)?;
337            #[cfg(feature = "component-model-async")]
338            if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) {
339                wasmtime_wasi_http::p3::add_to_linker(linker)?;
340                wasmtime_wasi::p3::clocks::add_to_linker(linker)?;
341                wasmtime_wasi::p3::random::add_to_linker(linker)?;
342                wasmtime_wasi::p3::cli::add_to_linker(linker)?;
343            }
344        }
345
346        if self.run.common.wasi.nn == Some(true) {
347            #[cfg(not(feature = "wasi-nn"))]
348            {
349                bail!("support for wasi-nn was disabled at compile time");
350            }
351            #[cfg(feature = "wasi-nn")]
352            {
353                wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
354                    let ctx = h.nn.as_mut().unwrap();
355                    wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx)
356                })?;
357            }
358        }
359
360        if self.run.common.wasi.config == Some(true) {
361            #[cfg(not(feature = "wasi-config"))]
362            {
363                bail!("support for wasi-config was disabled at compile time");
364            }
365            #[cfg(feature = "wasi-config")]
366            {
367                wasmtime_wasi_config::add_to_linker(linker, |h| {
368                    WasiConfig::from(h.wasi_config.as_ref().unwrap())
369                })?;
370            }
371        }
372
373        if self.run.common.wasi.keyvalue == Some(true) {
374            #[cfg(not(feature = "wasi-keyvalue"))]
375            {
376                bail!("support for wasi-keyvalue was disabled at compile time");
377            }
378            #[cfg(feature = "wasi-keyvalue")]
379            {
380                wasmtime_wasi_keyvalue::add_to_linker(linker, |h: &mut Host| {
381                    WasiKeyValue::new(h.wasi_keyvalue.as_ref().unwrap(), &mut h.table)
382                })?;
383            }
384        }
385
386        if self.run.common.wasi.threads == Some(true) {
387            bail!("support for wasi-threads is not available with components");
388        }
389
390        if self.run.common.wasi.http == Some(false) {
391            bail!("support for wasi-http must be enabled for `serve` subcommand");
392        }
393
394        Ok(())
395    }
396
397    async fn serve(mut self) -> Result<()> {
398        use hyper::server::conn::http1;
399
400        let mut config = self
401            .run
402            .common
403            .config(use_pooling_allocator_by_default().unwrap_or(None))?;
404        config.wasm_component_model(true);
405        config.async_support(true);
406
407        if self.run.common.wasm.timeout.is_some() {
408            config.epoch_interruption(true);
409        }
410
411        match self.run.profile {
412            Some(Profile::Native(s)) => {
413                config.profiler(s);
414            }
415            Some(Profile::Guest { .. }) => {
416                config.epoch_interruption(true);
417            }
418            None => {}
419        }
420
421        let engine = Engine::new(&config)?;
422        let mut linker = Linker::new(&engine);
423
424        self.add_to_linker(&mut linker)?;
425
426        let component = match self.run.load_module(&engine, &self.component)? {
427            RunTarget::Core(_) => bail!("The serve command currently requires a component"),
428            RunTarget::Component(c) => c,
429        };
430
431        let instance = linker.instantiate_pre(&component)?;
432        #[cfg(feature = "component-model-async")]
433        let instance = match wasmtime_wasi_http::p3::bindings::ProxyPre::new(instance.clone()) {
434            Ok(pre) => ProxyPre::P3(pre),
435            Err(_) => ProxyPre::P2(p2::ProxyPre::new(instance)?),
436        };
437        #[cfg(not(feature = "component-model-async"))]
438        let instance = ProxyPre::P2(p2::ProxyPre::new(instance)?);
439
440        // Spawn background task(s) waiting for graceful shutdown signals. This
441        // always listens for ctrl-c but additionally can listen for a TCP
442        // connection to the specified address.
443        let shutdown = Arc::new(GracefulShutdown::default());
444        tokio::task::spawn({
445            let shutdown = shutdown.clone();
446            async move {
447                tokio::signal::ctrl_c().await.unwrap();
448                shutdown.requested.notify_one();
449            }
450        });
451        if let Some(addr) = self.shutdown_addr {
452            let listener = tokio::net::TcpListener::bind(addr).await?;
453            eprintln!(
454                "Listening for shutdown on tcp://{}/",
455                listener.local_addr()?
456            );
457            let shutdown = shutdown.clone();
458            tokio::task::spawn(async move {
459                let _ = listener.accept().await;
460                shutdown.requested.notify_one();
461            });
462        }
463
464        let socket = match &self.addr {
465            SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
466            SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
467        };
468        // Conditionally enable `SO_REUSEADDR` depending on the current
469        // platform. On Unix we want this to be able to rebind an address in
470        // the `TIME_WAIT` state which can happen then a server is killed with
471        // active TCP connections and then restarted. On Windows though if
472        // `SO_REUSEADDR` is specified then it enables multiple applications to
473        // bind the port at the same time which is not something we want. Hence
474        // this is conditionally set based on the platform (and deviates from
475        // Tokio's default from always-on).
476        socket.set_reuseaddr(!cfg!(windows))?;
477        socket.bind(self.addr)?;
478        let listener = socket.listen(100)?;
479
480        eprintln!("Serving HTTP on http://{}/", listener.local_addr()?);
481
482        log::info!("Listening on {}", self.addr);
483
484        let epoch_interval = if let Some(Profile::Guest { interval, .. }) = self.run.profile {
485            Some(interval)
486        } else if let Some(t) = self.run.common.wasm.timeout {
487            Some(EPOCH_INTERRUPT_PERIOD.min(t))
488        } else {
489            None
490        };
491        let _epoch_thread = epoch_interval.map(|t| EpochThread::spawn(t, engine.clone()));
492
493        let max_instance_reuse_count = self.max_instance_reuse_count.unwrap_or_else(|| {
494            if let ProxyPre::P3(_) = &instance {
495                DEFAULT_WASIP3_MAX_INSTANCE_REUSE_COUNT
496            } else {
497                DEFAULT_WASIP2_MAX_INSTANCE_REUSE_COUNT
498            }
499        });
500
501        let max_instance_concurrent_reuse_count = if let ProxyPre::P3(_) = &instance {
502            self.max_instance_concurrent_reuse_count
503                .unwrap_or(DEFAULT_WASIP3_MAX_INSTANCE_CONCURRENT_REUSE_COUNT)
504        } else {
505            1
506        };
507
508        let handler = ProxyHandler::new(
509            HostHandlerState {
510                cmd: self,
511                engine,
512                component,
513                max_instance_reuse_count,
514                max_instance_concurrent_reuse_count,
515            },
516            instance,
517        );
518
519        loop {
520            // Wait for a socket, but also "race" against shutdown to break out
521            // of this loop. Once the graceful shutdown signal is received then
522            // this loop exits immediately.
523            let (stream, _) = tokio::select! {
524                _ = shutdown.requested.notified() => break,
525                v = listener.accept() => v?,
526            };
527
528            // The Nagle algorithm can impose a significant latency penalty
529            // (e.g. 40ms on Linux) on guests which write small, intermittent
530            // response body chunks (e.g. SSE streams).  Here we disable that
531            // algorithm and rely on the guest to buffer if appropriate to avoid
532            // TCP fragmentation.
533            stream.set_nodelay(true)?;
534
535            let stream = TokioIo::new(stream);
536            let h = handler.clone();
537            let shutdown_guard = shutdown.clone().increment();
538            tokio::task::spawn(async move {
539                if let Err(e) = http1::Builder::new()
540                    .keep_alive(true)
541                    .serve_connection(
542                        stream,
543                        hyper::service::service_fn(move |req| {
544                            let h = h.clone();
545                            async move {
546                                use http_body_util::{BodyExt, Full};
547                                match handle_request(h, req).await {
548                                    Ok(r) => Ok::<_, Infallible>(r),
549                                    Err(e) => {
550                                        eprintln!("error: {e:?}");
551                                        let error_html = "\
552<!doctype html>
553<html>
554<head>
555    <title>500 Internal Server Error</title>
556</head>
557<body>
558    <center>
559        <h1>500 Internal Server Error</h1>
560        <hr>
561        wasmtime
562    </center>
563</body>
564</html>";
565                                        Ok(Response::builder()
566                                            .status(StatusCode::INTERNAL_SERVER_ERROR)
567                                            .header("Content-Type", "text/html; charset=UTF-8")
568                                            .body(
569                                                Full::new(bytes::Bytes::from(error_html))
570                                                    .map_err(|_| unreachable!())
571                                                    .boxed_unsync(),
572                                            )
573                                            .unwrap())
574                                    }
575                                }
576                            }
577                        }),
578                    )
579                    .await
580                {
581                    eprintln!("error: {e:?}");
582                }
583                drop(shutdown_guard);
584            });
585        }
586
587        // Upon exiting the loop we'll no longer process any more incoming
588        // connections but there may still be outstanding connections
589        // processing in child tasks. If there are wait for those to complete
590        // before shutting down completely. Also enable short-circuiting this
591        // wait with a second ctrl-c signal.
592        if shutdown.close() {
593            return Ok(());
594        }
595        eprintln!("Waiting for child tasks to exit, ctrl-c again to quit sooner...");
596        tokio::select! {
597            _ = tokio::signal::ctrl_c() => {}
598            _ = shutdown.complete.notified() => {}
599        }
600
601        Ok(())
602    }
603}
604
605struct HostHandlerState {
606    cmd: ServeCommand,
607    engine: Engine,
608    component: Component,
609    max_instance_reuse_count: usize,
610    max_instance_concurrent_reuse_count: usize,
611}
612
613impl HandlerState for HostHandlerState {
614    type StoreData = Host;
615
616    fn new_store(&self, req_id: Option<u64>) -> Result<StoreBundle<Host>> {
617        let mut store = self.cmd.new_store(&self.engine, req_id)?;
618        let write_profile = setup_epoch_handler(&self.cmd, &mut store, self.component.clone())?;
619
620        Ok(StoreBundle {
621            store,
622            write_profile,
623        })
624    }
625
626    fn request_timeout(&self) -> Duration {
627        self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX)
628    }
629
630    fn idle_instance_timeout(&self) -> Duration {
631        self.cmd.idle_instance_timeout
632    }
633
634    fn max_instance_reuse_count(&self) -> usize {
635        self.max_instance_reuse_count
636    }
637
638    fn max_instance_concurrent_reuse_count(&self) -> usize {
639        self.max_instance_concurrent_reuse_count
640    }
641
642    fn handle_worker_error(&self, error: anyhow::Error) {
643        eprintln!("worker error: {error}");
644    }
645}
646
647/// Helper structure to manage graceful shutdown int he accept loop above.
648#[derive(Default)]
649struct GracefulShutdown {
650    /// Async notification that shutdown has been requested.
651    requested: Notify,
652    /// Async notification that shutdown has completed, signaled when
653    /// `notify_when_done` is `true` and `active_tasks` reaches 0.
654    complete: Notify,
655    /// Internal state related to what's in progress when shutdown is requested.
656    state: Mutex<GracefulShutdownState>,
657}
658
659#[derive(Default)]
660struct GracefulShutdownState {
661    active_tasks: u32,
662    notify_when_done: bool,
663}
664
665impl GracefulShutdown {
666    /// Increments the number of active tasks and returns a guard indicating
667    fn increment(self: Arc<Self>) -> impl Drop {
668        struct Guard(Arc<GracefulShutdown>);
669
670        let mut state = self.state.lock().unwrap();
671        assert!(!state.notify_when_done);
672        state.active_tasks += 1;
673        drop(state);
674
675        return Guard(self);
676
677        impl Drop for Guard {
678            fn drop(&mut self) {
679                let mut state = self.0.state.lock().unwrap();
680                state.active_tasks -= 1;
681                if state.notify_when_done && state.active_tasks == 0 {
682                    self.0.complete.notify_one();
683                }
684            }
685        }
686    }
687
688    /// Flags this state as done spawning tasks and returns whether there are no
689    /// more child tasks remaining.
690    fn close(&self) -> bool {
691        let mut state = self.state.lock().unwrap();
692        state.notify_when_done = true;
693        state.active_tasks == 0
694    }
695}
696
697/// When executing with a timeout enabled, this is how frequently epoch
698/// interrupts will be executed to check for timeouts. If guest profiling
699/// is enabled, the guest epoch period will be used.
700const EPOCH_INTERRUPT_PERIOD: Duration = Duration::from_millis(50);
701
702struct EpochThread {
703    shutdown: Arc<AtomicBool>,
704    handle: Option<std::thread::JoinHandle<()>>,
705}
706
707impl EpochThread {
708    fn spawn(interval: std::time::Duration, engine: Engine) -> Self {
709        let shutdown = Arc::new(AtomicBool::new(false));
710        let handle = {
711            let shutdown = Arc::clone(&shutdown);
712            let handle = std::thread::spawn(move || {
713                while !shutdown.load(Ordering::Relaxed) {
714                    std::thread::sleep(interval);
715                    engine.increment_epoch();
716                }
717            });
718            Some(handle)
719        };
720
721        EpochThread { shutdown, handle }
722    }
723}
724
725impl Drop for EpochThread {
726    fn drop(&mut self) {
727        if let Some(handle) = self.handle.take() {
728            self.shutdown.store(true, Ordering::Relaxed);
729            handle.join().unwrap();
730        }
731    }
732}
733
734type WriteProfile = Box<dyn FnOnce(StoreContextMut<Host>) + Send>;
735
736fn setup_epoch_handler(
737    cmd: &ServeCommand,
738    store: &mut Store<Host>,
739    component: Component,
740) -> Result<WriteProfile> {
741    // Profiling Enabled
742    if let Some(Profile::Guest { interval, path }) = &cmd.run.profile {
743        #[cfg(feature = "profiling")]
744        return setup_guest_profiler(store, path.clone(), *interval, component.clone());
745        #[cfg(not(feature = "profiling"))]
746        {
747            let _ = (path, interval);
748            bail!("support for profiling disabled at compile time!");
749        }
750    }
751
752    // Profiling disabled but there's a global request timeout
753    if cmd.run.common.wasm.timeout.is_some() {
754        store.epoch_deadline_async_yield_and_update(1);
755    }
756
757    Ok(Box::new(|_store| {}))
758}
759
760#[cfg(feature = "profiling")]
761fn setup_guest_profiler(
762    store: &mut Store<Host>,
763    path: String,
764    interval: Duration,
765    component: Component,
766) -> Result<WriteProfile> {
767    use wasmtime::{AsContext, GuestProfiler, StoreContext, StoreContextMut};
768
769    let module_name = "<main>";
770
771    store.data_mut().guest_profiler = Some(Arc::new(GuestProfiler::new_component(
772        module_name,
773        interval,
774        component,
775        std::iter::empty(),
776    )));
777
778    fn sample(
779        mut store: StoreContextMut<Host>,
780        f: impl FnOnce(&mut GuestProfiler, StoreContext<Host>),
781    ) {
782        let mut profiler = store.data_mut().guest_profiler.take().unwrap();
783        f(
784            Arc::get_mut(&mut profiler).expect("profiling doesn't support threads yet"),
785            store.as_context(),
786        );
787        store.data_mut().guest_profiler = Some(profiler);
788    }
789
790    // Hostcall entry/exit, etc.
791    store.call_hook(|store, kind| {
792        sample(store, |profiler, store| profiler.call_hook(store, kind));
793        Ok(())
794    });
795
796    store.epoch_deadline_callback(move |store| {
797        sample(store, |profiler, store| {
798            profiler.sample(store, std::time::Duration::ZERO)
799        });
800
801        Ok(UpdateDeadline::Continue(1))
802    });
803
804    store.set_epoch_deadline(1);
805
806    let write_profile = Box::new(move |mut store: StoreContextMut<Host>| {
807        let profiler = Arc::try_unwrap(store.data_mut().guest_profiler.take().unwrap())
808            .expect("profiling doesn't support threads yet");
809        if let Err(e) = std::fs::File::create(&path)
810            .map_err(anyhow::Error::new)
811            .and_then(|output| profiler.finish(std::io::BufWriter::new(output)))
812        {
813            eprintln!("failed writing profile at {path}: {e:#}");
814        } else {
815            eprintln!();
816            eprintln!("Profile written to: {path}");
817            eprintln!("View this profile at https://profiler.firefox.com/.");
818        }
819    });
820
821    Ok(write_profile)
822}
823
824type Request = hyper::Request<hyper::body::Incoming>;
825
826async fn handle_request(
827    handler: ProxyHandler<HostHandlerState>,
828    req: Request,
829) -> Result<hyper::Response<UnsyncBoxBody<Bytes, anyhow::Error>>> {
830    use tokio::sync::oneshot;
831
832    let req_id = handler.next_req_id();
833
834    log::info!(
835        "Request {req_id} handling {} to {}",
836        req.method(),
837        req.uri()
838    );
839
840    // Here we must declare different channel types for p2 and p3 since p2's
841    // `WasiHttpView::new_response_outparam` expects a specific kind of sender
842    // that uses `p2::http::types::ErrorCode`, and we don't want to have to
843    // convert from the p3 `ErrorCode` to the p2 one, only to convert again to
844    // `anyhow::Error`.
845
846    type P2Response = Result<
847        hyper::Response<wasmtime_wasi_http::body::HyperOutgoingBody>,
848        p2::http::types::ErrorCode,
849    >;
850    type P3Response = hyper::Response<UnsyncBoxBody<Bytes, anyhow::Error>>;
851
852    enum Sender {
853        P2(oneshot::Sender<P2Response>),
854        P3(oneshot::Sender<P3Response>),
855    }
856
857    enum Receiver {
858        P2(oneshot::Receiver<P2Response>),
859        P3(oneshot::Receiver<P3Response>),
860    }
861
862    let (tx, rx) = match handler.instance_pre() {
863        ProxyPre::P2(_) => {
864            let (tx, rx) = oneshot::channel();
865            (Sender::P2(tx), Receiver::P2(rx))
866        }
867        ProxyPre::P3(_) => {
868            let (tx, rx) = oneshot::channel();
869            (Sender::P3(tx), Receiver::P3(rx))
870        }
871    };
872
873    handler.spawn(
874        if handler.state().max_instance_reuse_count() == 1 {
875            Some(req_id)
876        } else {
877            None
878        },
879        Box::new(move |store, proxy| {
880            Box::pin(
881                async move {
882                    match proxy {
883                        Proxy::P2(proxy) => {
884                            let Sender::P2(tx) = tx else { unreachable!() };
885                            let (req, out) = store.with(move |mut store| {
886                                let req = store
887                                    .data_mut()
888                                    .new_incoming_request(p2::http::types::Scheme::Http, req)?;
889                                let out = store.data_mut().new_response_outparam(tx)?;
890                                anyhow::Ok((req, out))
891                            })?;
892
893                            proxy
894                                .wasi_http_incoming_handler()
895                                .call_handle(store, req, out)
896                                .await
897                        }
898                        Proxy::P3(proxy) => {
899                            use wasmtime_wasi_http::p3::bindings::http::types::{
900                                ErrorCode, Request,
901                            };
902
903                            let Sender::P3(tx) = tx else { unreachable!() };
904                            let (req, body) = req.into_parts();
905                            let body = body.map_err(ErrorCode::from_hyper_request_error);
906                            let req = http::Request::from_parts(req, body);
907                            let (request, request_io_result) = Request::from_http(req);
908                            let (res, task) = proxy.handle(store, request).await??;
909                            let res = store
910                                .with(|mut store| res.into_http(&mut store, request_io_result))?;
911                            _ = tx.send(res.map(|body| body.map_err(|e| e.into()).boxed_unsync()));
912
913                            // Wait for the task to finish.
914                            task.block(store).await;
915                            Ok(())
916                        }
917                    }
918                }
919                .map(move |result| {
920                    if let Err(error) = result {
921                        eprintln!("[{req_id}] :: {error:?}");
922                    }
923                }),
924            )
925        }),
926    );
927
928    Ok(match rx {
929        Receiver::P2(rx) => rx
930            .await
931            .context("guest never invoked `response-outparam::set` method")?
932            .map_err(|e| anyhow::Error::from(e))?
933            .map(|body| body.map_err(|e| e.into()).boxed_unsync()),
934        Receiver::P3(rx) => rx.await?,
935    })
936}
937
938#[derive(Clone)]
939enum Output {
940    Stdout,
941    Stderr,
942}
943
944impl Output {
945    fn write_all(&self, buf: &[u8]) -> io::Result<()> {
946        use std::io::Write;
947
948        match self {
949            Output::Stdout => std::io::stdout().write_all(buf),
950            Output::Stderr => std::io::stderr().write_all(buf),
951        }
952    }
953}
954
955#[derive(Clone)]
956struct LogStream {
957    output: Output,
958    state: Arc<LogStreamState>,
959}
960
961struct LogStreamState {
962    prefix: String,
963    needs_prefix_on_next_write: AtomicBool,
964}
965
966impl LogStream {
967    fn new(prefix: String, output: Output) -> LogStream {
968        LogStream {
969            output,
970            state: Arc::new(LogStreamState {
971                prefix,
972                needs_prefix_on_next_write: AtomicBool::new(true),
973            }),
974        }
975    }
976
977    fn write_all(&mut self, mut bytes: &[u8]) -> io::Result<()> {
978        while !bytes.is_empty() {
979            if self
980                .state
981                .needs_prefix_on_next_write
982                .load(Ordering::Relaxed)
983            {
984                self.output.write_all(self.state.prefix.as_bytes())?;
985                self.state
986                    .needs_prefix_on_next_write
987                    .store(false, Ordering::Relaxed);
988            }
989            match bytes.iter().position(|b| *b == b'\n') {
990                Some(i) => {
991                    let (a, b) = bytes.split_at(i + 1);
992                    bytes = b;
993                    self.output.write_all(a)?;
994                    self.state
995                        .needs_prefix_on_next_write
996                        .store(true, Ordering::Relaxed);
997                }
998                None => {
999                    self.output.write_all(bytes)?;
1000                    break;
1001                }
1002            }
1003        }
1004
1005        Ok(())
1006    }
1007}
1008
1009impl wasmtime_wasi::cli::StdoutStream for LogStream {
1010    fn p2_stream(&self) -> Box<dyn wasmtime_wasi::p2::OutputStream> {
1011        Box::new(self.clone())
1012    }
1013    fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
1014        Box::new(self.clone())
1015    }
1016}
1017
1018impl wasmtime_wasi::cli::IsTerminal for LogStream {
1019    fn is_terminal(&self) -> bool {
1020        match &self.output {
1021            Output::Stdout => std::io::stdout().is_terminal(),
1022            Output::Stderr => std::io::stderr().is_terminal(),
1023        }
1024    }
1025}
1026
1027impl wasmtime_wasi::p2::OutputStream for LogStream {
1028    fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
1029        self.write_all(&bytes)
1030            .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
1031        Ok(())
1032    }
1033
1034    fn flush(&mut self) -> StreamResult<()> {
1035        Ok(())
1036    }
1037
1038    fn check_write(&mut self) -> StreamResult<usize> {
1039        Ok(1024 * 1024)
1040    }
1041}
1042
1043#[async_trait::async_trait]
1044impl wasmtime_wasi::p2::Pollable for LogStream {
1045    async fn ready(&mut self) {}
1046}
1047
1048impl AsyncWrite for LogStream {
1049    fn poll_write(
1050        mut self: Pin<&mut Self>,
1051        _cx: &mut Context<'_>,
1052        buf: &[u8],
1053    ) -> Poll<io::Result<usize>> {
1054        Poll::Ready(self.write_all(buf).map(|_| buf.len()))
1055    }
1056    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1057        Poll::Ready(Ok(()))
1058    }
1059    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1060        Poll::Ready(Ok(()))
1061    }
1062}
1063
1064/// The pooling allocator is tailor made for the `wasmtime serve` use case, so
1065/// try to use it when we can. The main cost of the pooling allocator, however,
1066/// is the virtual memory required to run it. Not all systems support the same
1067/// amount of virtual memory, for example some aarch64 and riscv64 configuration
1068/// only support 39 bits of virtual address space.
1069///
1070/// The pooling allocator, by default, will request 1000 linear memories each
1071/// sized at 6G per linear memory. This is 6T of virtual memory which ends up
1072/// being about 42 bits of the address space. This exceeds the 39 bit limit of
1073/// some systems, so there the pooling allocator will fail by default.
1074///
1075/// This function attempts to dynamically determine the hint for the pooling
1076/// allocator. This returns `Some(true)` if the pooling allocator should be used
1077/// by default, or `None` or an error otherwise.
1078///
1079/// The method for testing this is to allocate a 0-sized 64-bit linear memory
1080/// with a maximum size that's N bits large where we force all memories to be
1081/// static. This should attempt to acquire N bits of the virtual address space.
1082/// If successful that should mean that the pooling allocator is OK to use, but
1083/// if it fails then the pooling allocator is not used and the normal mmap-based
1084/// implementation is used instead.
1085fn use_pooling_allocator_by_default() -> Result<Option<bool>> {
1086    use wasmtime::{Config, Memory, MemoryType};
1087    const BITS_TO_TEST: u32 = 42;
1088    let mut config = Config::new();
1089    config.wasm_memory64(true);
1090    config.memory_reservation(1 << BITS_TO_TEST);
1091    let engine = Engine::new(&config)?;
1092    let mut store = Store::new(&engine, ());
1093    // NB: the maximum size is in wasm pages to take out the 16-bits of wasm
1094    // page size here from the maximum size.
1095    let ty = MemoryType::new64(0, Some(1 << (BITS_TO_TEST - 16)));
1096    if Memory::new(&mut store, ty).is_ok() {
1097        Ok(Some(true))
1098    } else {
1099        Ok(None)
1100    }
1101}