Skip to main content

wit_dylib/
async_.rs

1use anyhow::{Result, bail};
2use std::collections::HashSet;
3use std::fmt;
4use wit_parser::{Function, FunctionKind, Resolve, WorldKey};
5
6/// Structure used to parse the command line argument `--async` consistently
7/// across guest generators.
8#[cfg_attr(feature = "clap", derive(clap::Parser))]
9#[derive(Clone, Default, Debug)]
10pub struct AsyncFilterSet {
11    /// Determines which functions to lift or lower `async`, if any.
12    ///
13    /// This option can be passed multiple times and additionally accepts
14    /// comma-separated values for each option passed. Each individual argument
15    /// passed here can be one of:
16    ///
17    /// - `all` - all imports and exports will be async
18    /// - `-all` - force all imports and exports to be sync
19    /// - `foo:bar/baz#method` - force this method to be async
20    /// - `import:foo:bar/baz#method` - force this method to be async, but only
21    ///   as an import
22    /// - `-export:foo:bar/baz#method` - force this export to be sync
23    ///
24    /// If a method is not listed in this option then the WIT's default bindings
25    /// mode will be used. If the WIT function is defined as `async` then async
26    /// bindings will be generated, otherwise sync bindings will be generated.
27    ///
28    /// Options are processed in the order they are passed here, so if a method
29    /// matches two directives passed the least-specific one should be last.
30    #[cfg_attr(
31        feature = "clap",
32        arg(
33            long = "async",
34            value_parser = parse_async,
35            value_delimiter =',',
36            value_name = "FILTER",
37        ),
38    )]
39    async_: Vec<Async>,
40
41    #[cfg_attr(feature = "clap", arg(skip))]
42    used_options: HashSet<usize>,
43}
44
45#[cfg(feature = "clap")]
46fn parse_async(s: &str) -> Result<Async, String> {
47    Ok(Async::parse(s))
48}
49
50impl AsyncFilterSet {
51    /// Returns a set where all functions should be async or not depending on
52    /// `async_` provided.
53    pub fn all(async_: bool) -> AsyncFilterSet {
54        AsyncFilterSet {
55            async_: vec![Async {
56                enabled: async_,
57                filter: AsyncFilter::All,
58            }],
59            used_options: HashSet::new(),
60        }
61    }
62
63    /// Returns whether the `func` provided is to be bound `async` or not.
64    pub fn is_async(
65        &mut self,
66        resolve: &Resolve,
67        interface: Option<&WorldKey>,
68        func: &Function,
69        is_import: bool,
70    ) -> bool {
71        let name_to_test = match interface {
72            Some(key) => format!("{}#{}", resolve.name_world_key(key), func.name),
73            None => func.name.clone(),
74        };
75        for (i, opt) in self.async_.iter().enumerate() {
76            let name = match &opt.filter {
77                AsyncFilter::All => {
78                    self.used_options.insert(i);
79                    return opt.enabled;
80                }
81                AsyncFilter::Function(s) => s,
82                AsyncFilter::Import(s) => {
83                    if !is_import {
84                        continue;
85                    }
86                    s
87                }
88                AsyncFilter::Export(s) => {
89                    if is_import {
90                        continue;
91                    }
92                    s
93                }
94            };
95            if *name == name_to_test {
96                self.used_options.insert(i);
97                return opt.enabled;
98            }
99        }
100
101        match &func.kind {
102            FunctionKind::Freestanding
103            | FunctionKind::Method(_)
104            | FunctionKind::Static(_)
105            | FunctionKind::Constructor(_) => false,
106            FunctionKind::AsyncFreestanding
107            | FunctionKind::AsyncMethod(_)
108            | FunctionKind::AsyncStatic(_) => true,
109        }
110    }
111
112    /// Intended to be used in the header comment of generated code to help
113    /// indicate what options were specified.
114    pub fn debug_opts(&self) -> impl Iterator<Item = String> + '_ {
115        self.async_.iter().map(|opt| opt.to_string())
116    }
117
118    /// Tests whether all `--async` options were used throughout bindings
119    /// generation, returning an error if any were unused.
120    pub fn ensure_all_used(&self) -> Result<()> {
121        for (i, opt) in self.async_.iter().enumerate() {
122            if self.used_options.contains(&i) {
123                continue;
124            }
125            if !matches!(opt.filter, AsyncFilter::All) {
126                bail!("unused async option: {opt}");
127            }
128        }
129        Ok(())
130    }
131
132    /// Returns whether any option explicitly requests that async is enabled.
133    pub fn any_enabled(&self) -> bool {
134        self.async_.iter().any(|o| o.enabled)
135    }
136
137    /// Pushes a new option into this set.
138    pub fn push(&mut self, directive: &str) {
139        self.async_.push(Async::parse(directive));
140    }
141}
142
143#[derive(Debug, Clone)]
144struct Async {
145    enabled: bool,
146    filter: AsyncFilter,
147}
148
149impl Async {
150    fn parse(s: &str) -> Async {
151        let (s, enabled) = match s.strip_prefix('-') {
152            Some(s) => (s, false),
153            None => (s, true),
154        };
155        let filter = match s {
156            "all" => AsyncFilter::All,
157            other => match other.strip_prefix("import:") {
158                Some(s) => AsyncFilter::Import(s.to_string()),
159                None => match other.strip_prefix("export:") {
160                    Some(s) => AsyncFilter::Export(s.to_string()),
161                    None => AsyncFilter::Function(s.to_string()),
162                },
163            },
164        };
165        Async { enabled, filter }
166    }
167}
168
169impl fmt::Display for Async {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        if !self.enabled {
172            write!(f, "-")?;
173        }
174        self.filter.fmt(f)
175    }
176}
177
178#[derive(Debug, Clone)]
179enum AsyncFilter {
180    All,
181    Function(String),
182    Import(String),
183    Export(String),
184}
185
186impl fmt::Display for AsyncFilter {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            AsyncFilter::All => write!(f, "all"),
190            AsyncFilter::Function(s) => write!(f, "{s}"),
191            AsyncFilter::Import(s) => write!(f, "import:{s}"),
192            AsyncFilter::Export(s) => write!(f, "export:{s}"),
193        }
194    }
195}