wit_bindgen_core/
async_.rs

1use anyhow::{bail, Result};
2use std::collections::HashSet;
3use std::fmt;
4use wit_parser::{Function, FunctionKind, Resolve, WorldKey};
5
6/// Structure used to parse the command line argument `--async` consistently
7/// across guest generators.
8#[cfg_attr(feature = "clap", derive(clap::Parser))]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
10#[derive(Clone, Default, Debug)]
11pub struct AsyncFilterSet {
12    /// Determines which functions to lift or lower `async`, if any.
13    ///
14    /// This option can be passed multiple times and additionally accepts
15    /// comma-separated values for each option passed. Each individual argument
16    /// passed here can be one of:
17    ///
18    /// - `all` - all imports and exports will be async
19    /// - `-all` - force all imports and exports to be sync
20    /// - `foo:bar/baz#method` - force this method to be async
21    /// - `import:foo:bar/baz#method` - force this method to be async, but only
22    ///   as an import
23    /// - `-export:foo:bar/baz#method` - force this export to be sync
24    ///
25    /// If a method is not listed in this option then the WIT's default bindings
26    /// mode will be used. If the WIT function is defined as `async` then async
27    /// bindings will be generated, otherwise sync bindings will be generated.
28    ///
29    /// Options are processed in the order they are passed here, so if a method
30    /// matches two directives passed the least-specific one should be last.
31    #[cfg_attr(
32        feature = "clap",
33        arg(
34            long = "async",
35            value_parser = parse_async,
36            value_delimiter =',',
37            value_name = "FILTER",
38        ),
39    )]
40    #[cfg_attr(feature = "serde", serde(rename = "async"))]
41    async_: Vec<Async>,
42
43    #[cfg_attr(feature = "clap", arg(skip))]
44    #[cfg_attr(feature = "serde", serde(skip))]
45    used_options: HashSet<usize>,
46}
47
48#[cfg(feature = "clap")]
49fn parse_async(s: &str) -> Result<Async, String> {
50    Ok(Async::parse(s))
51}
52
53impl AsyncFilterSet {
54    /// Returns a set where all functions should be async or not depending on
55    /// `async_` provided.
56    pub fn all(async_: bool) -> AsyncFilterSet {
57        AsyncFilterSet {
58            async_: vec![Async {
59                enabled: async_,
60                filter: AsyncFilter::All,
61            }],
62            used_options: HashSet::new(),
63        }
64    }
65
66    /// Returns whether the `func` provided is to be bound `async` or not.
67    pub fn is_async(
68        &mut self,
69        resolve: &Resolve,
70        interface: Option<&WorldKey>,
71        func: &Function,
72        is_import: bool,
73    ) -> bool {
74        let name_to_test = match interface {
75            Some(key) => format!("{}#{}", resolve.name_world_key(key), func.name),
76            None => func.name.clone(),
77        };
78        for (i, opt) in self.async_.iter().enumerate() {
79            let name = match &opt.filter {
80                AsyncFilter::All => {
81                    self.used_options.insert(i);
82                    return opt.enabled;
83                }
84                AsyncFilter::Function(s) => s,
85                AsyncFilter::Import(s) => {
86                    if !is_import {
87                        continue;
88                    }
89                    s
90                }
91                AsyncFilter::Export(s) => {
92                    if is_import {
93                        continue;
94                    }
95                    s
96                }
97            };
98            if *name == name_to_test {
99                self.used_options.insert(i);
100                return opt.enabled;
101            }
102        }
103
104        match &func.kind {
105            FunctionKind::Freestanding
106            | FunctionKind::Method(_)
107            | FunctionKind::Static(_)
108            | FunctionKind::Constructor(_) => false,
109            FunctionKind::AsyncFreestanding
110            | FunctionKind::AsyncMethod(_)
111            | FunctionKind::AsyncStatic(_) => true,
112        }
113    }
114
115    /// Intended to be used in the header comment of generated code to help
116    /// indicate what options were specified.
117    pub fn debug_opts(&self) -> impl Iterator<Item = String> + '_ {
118        self.async_.iter().map(|opt| opt.to_string())
119    }
120
121    /// Tests whether all `--async` options were used throughout bindings
122    /// generation, returning an error if any were unused.
123    pub fn ensure_all_used(&self) -> Result<()> {
124        for (i, opt) in self.async_.iter().enumerate() {
125            if self.used_options.contains(&i) {
126                continue;
127            }
128            if !matches!(opt.filter, AsyncFilter::All) {
129                bail!("unused async option: {opt}");
130            }
131        }
132        Ok(())
133    }
134
135    /// Returns whether any option explicitly requests that async is enabled.
136    pub fn any_enabled(&self) -> bool {
137        self.async_.iter().any(|o| o.enabled)
138    }
139
140    /// Pushes a new option into this set.
141    pub fn push(&mut self, directive: &str) {
142        self.async_.push(Async::parse(directive));
143    }
144}
145
146#[derive(Debug, Clone)]
147#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
148struct Async {
149    enabled: bool,
150    filter: AsyncFilter,
151}
152
153impl Async {
154    fn parse(s: &str) -> Async {
155        let (s, enabled) = match s.strip_prefix('-') {
156            Some(s) => (s, false),
157            None => (s, true),
158        };
159        let filter = match s {
160            "all" => AsyncFilter::All,
161            other => match other.strip_prefix("import:") {
162                Some(s) => AsyncFilter::Import(s.to_string()),
163                None => match other.strip_prefix("export:") {
164                    Some(s) => AsyncFilter::Export(s.to_string()),
165                    None => AsyncFilter::Function(s.to_string()),
166                },
167            },
168        };
169        Async { enabled, filter }
170    }
171}
172
173impl fmt::Display for Async {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        if !self.enabled {
176            write!(f, "-")?;
177        }
178        self.filter.fmt(f)
179    }
180}
181
182#[derive(Debug, Clone)]
183#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
184enum AsyncFilter {
185    All,
186    Function(String),
187    Import(String),
188    Export(String),
189}
190
191impl fmt::Display for AsyncFilter {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        match self {
194            AsyncFilter::All => write!(f, "all"),
195            AsyncFilter::Function(s) => write!(f, "{s}"),
196            AsyncFilter::Import(s) => write!(f, "import:{s}"),
197            AsyncFilter::Export(s) => write!(f, "export:{s}"),
198        }
199    }
200}