wiggle_generate/
config.rs

1use {
2    proc_macro2::{Span, TokenStream},
3    std::{collections::HashMap, path::PathBuf},
4    syn::{
5        braced, bracketed,
6        parse::{Parse, ParseStream},
7        punctuated::Punctuated,
8        Error, Ident, LitStr, Result, Token,
9    },
10};
11
12#[derive(Debug, Clone)]
13pub struct Config {
14    pub witx: WitxConf,
15    pub errors: ErrorConf,
16    pub async_: AsyncConf,
17    pub wasmtime: bool,
18    pub tracing: TracingConf,
19    pub mutable: bool,
20}
21
22mod kw {
23    syn::custom_keyword!(witx);
24    syn::custom_keyword!(witx_literal);
25    syn::custom_keyword!(block_on);
26    syn::custom_keyword!(errors);
27    syn::custom_keyword!(target);
28    syn::custom_keyword!(wasmtime);
29    syn::custom_keyword!(mutable);
30    syn::custom_keyword!(tracing);
31    syn::custom_keyword!(disable_for);
32    syn::custom_keyword!(trappable);
33}
34
35#[derive(Debug, Clone)]
36pub enum ConfigField {
37    Witx(WitxConf),
38    Error(ErrorConf),
39    Async(AsyncConf),
40    Wasmtime(bool),
41    Tracing(TracingConf),
42    Mutable(bool),
43}
44
45impl Parse for ConfigField {
46    fn parse(input: ParseStream) -> Result<Self> {
47        let lookahead = input.lookahead1();
48        if lookahead.peek(kw::witx) {
49            input.parse::<kw::witx>()?;
50            input.parse::<Token![:]>()?;
51            Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
52        } else if lookahead.peek(kw::witx_literal) {
53            input.parse::<kw::witx_literal>()?;
54            input.parse::<Token![:]>()?;
55            Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
56        } else if lookahead.peek(kw::errors) {
57            input.parse::<kw::errors>()?;
58            input.parse::<Token![:]>()?;
59            Ok(ConfigField::Error(input.parse()?))
60        } else if lookahead.peek(Token![async]) {
61            input.parse::<Token![async]>()?;
62            input.parse::<Token![:]>()?;
63            Ok(ConfigField::Async(AsyncConf {
64                block_with: None,
65                functions: input.parse()?,
66            }))
67        } else if lookahead.peek(kw::block_on) {
68            input.parse::<kw::block_on>()?;
69            let block_with = if input.peek(syn::token::Bracket) {
70                let content;
71                let _ = bracketed!(content in input);
72                content.parse()?
73            } else {
74                quote::quote!(wiggle::run_in_dummy_executor)
75            };
76            input.parse::<Token![:]>()?;
77            Ok(ConfigField::Async(AsyncConf {
78                block_with: Some(block_with),
79                functions: input.parse()?,
80            }))
81        } else if lookahead.peek(kw::wasmtime) {
82            input.parse::<kw::wasmtime>()?;
83            input.parse::<Token![:]>()?;
84            Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
85        } else if lookahead.peek(kw::tracing) {
86            input.parse::<kw::tracing>()?;
87            input.parse::<Token![:]>()?;
88            Ok(ConfigField::Tracing(input.parse()?))
89        } else if lookahead.peek(kw::mutable) {
90            input.parse::<kw::mutable>()?;
91            input.parse::<Token![:]>()?;
92            Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
93        } else {
94            Err(lookahead.error())
95        }
96    }
97}
98
99impl Config {
100    pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
101        let mut witx = None;
102        let mut errors = None;
103        let mut async_ = None;
104        let mut wasmtime = None;
105        let mut tracing = None;
106        let mut mutable = None;
107        for f in fields {
108            match f {
109                ConfigField::Witx(c) => {
110                    if witx.is_some() {
111                        return Err(Error::new(err_loc, "duplicate `witx` field"));
112                    }
113                    witx = Some(c);
114                }
115                ConfigField::Error(c) => {
116                    if errors.is_some() {
117                        return Err(Error::new(err_loc, "duplicate `errors` field"));
118                    }
119                    errors = Some(c);
120                }
121                ConfigField::Async(c) => {
122                    if async_.is_some() {
123                        return Err(Error::new(err_loc, "duplicate `async` field"));
124                    }
125                    async_ = Some(c);
126                }
127                ConfigField::Wasmtime(c) => {
128                    if wasmtime.is_some() {
129                        return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
130                    }
131                    wasmtime = Some(c);
132                }
133                ConfigField::Tracing(c) => {
134                    if tracing.is_some() {
135                        return Err(Error::new(err_loc, "duplicate `tracing` field"));
136                    }
137                    tracing = Some(c);
138                }
139                ConfigField::Mutable(c) => {
140                    if mutable.is_some() {
141                        return Err(Error::new(err_loc, "duplicate `mutable` field"));
142                    }
143                    mutable = Some(c);
144                }
145            }
146        }
147        Ok(Config {
148            witx: witx
149                .take()
150                .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
151            errors: errors.take().unwrap_or_default(),
152            async_: async_.take().unwrap_or_default(),
153            wasmtime: wasmtime.unwrap_or(true),
154            tracing: tracing.unwrap_or_default(),
155            mutable: mutable.unwrap_or(true),
156        })
157    }
158
159    /// Load the `witx` document for the configuration.
160    ///
161    /// # Panics
162    ///
163    /// This method will panic if the paths given in the `witx` field were not valid documents.
164    pub fn load_document(&self) -> witx::Document {
165        self.witx.load_document()
166    }
167}
168
169impl Parse for Config {
170    fn parse(input: ParseStream) -> Result<Self> {
171        let contents;
172        let _lbrace = braced!(contents in input);
173        let fields: Punctuated<ConfigField, Token![,]> =
174            contents.parse_terminated(ConfigField::parse, Token![,])?;
175        Ok(Config::build(fields.into_iter(), input.span())?)
176    }
177}
178
179/// The witx document(s) that will be loaded from a [`Config`](struct.Config.html).
180///
181/// A witx interface definition can be provided either as a collection of relative paths to
182/// documents, or as a single inlined string literal. Note that `(use ...)` directives are not
183/// permitted when providing a string literal.
184#[derive(Debug, Clone)]
185pub enum WitxConf {
186    /// A collection of paths pointing to witx files.
187    Paths(Paths),
188    /// A single witx document, provided as a string literal.
189    Literal(Literal),
190}
191
192impl WitxConf {
193    /// Load the `witx` document.
194    ///
195    /// # Panics
196    ///
197    /// This method will panic if the paths given in the `witx` field were not valid documents, or
198    /// if any of the given documents were not syntactically valid.
199    pub fn load_document(&self) -> witx::Document {
200        match self {
201            Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
202            Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
203        }
204    }
205}
206
207/// A collection of paths, pointing to witx documents.
208#[derive(Debug, Clone)]
209pub struct Paths(Vec<PathBuf>);
210
211impl Paths {
212    /// Create a new, empty collection of paths.
213    pub fn new() -> Self {
214        Default::default()
215    }
216}
217
218impl Default for Paths {
219    fn default() -> Self {
220        Self(Default::default())
221    }
222}
223
224impl AsRef<[PathBuf]> for Paths {
225    fn as_ref(&self) -> &[PathBuf] {
226        self.0.as_ref()
227    }
228}
229
230impl AsMut<[PathBuf]> for Paths {
231    fn as_mut(&mut self) -> &mut [PathBuf] {
232        self.0.as_mut()
233    }
234}
235
236impl FromIterator<PathBuf> for Paths {
237    fn from_iter<I>(iter: I) -> Self
238    where
239        I: IntoIterator<Item = PathBuf>,
240    {
241        Self(iter.into_iter().collect())
242    }
243}
244
245impl Parse for Paths {
246    fn parse(input: ParseStream) -> Result<Self> {
247        let content;
248        let _ = bracketed!(content in input);
249        let path_lits: Punctuated<LitStr, Token![,]> =
250            content.parse_terminated(Parse::parse, Token![,])?;
251
252        let expanded_paths = path_lits
253            .iter()
254            .map(|lit| {
255                PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()).join(lit.value())
256            })
257            .collect::<Vec<PathBuf>>();
258
259        Ok(Paths(expanded_paths))
260    }
261}
262
263/// A single witx document, provided as a string literal.
264#[derive(Debug, Clone)]
265pub struct Literal(String);
266
267impl AsRef<str> for Literal {
268    fn as_ref(&self) -> &str {
269        self.0.as_ref()
270    }
271}
272
273impl Parse for Literal {
274    fn parse(input: ParseStream) -> Result<Self> {
275        Ok(Self(input.parse::<syn::LitStr>()?.value()))
276    }
277}
278
279#[derive(Clone, Default, Debug)]
280/// Map from abi error type to rich error type
281pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
282
283impl ErrorConf {
284    pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
285        self.0.iter()
286    }
287}
288
289impl Parse for ErrorConf {
290    fn parse(input: ParseStream) -> Result<Self> {
291        let content;
292        let _ = braced!(content in input);
293        let items: Punctuated<ErrorConfField, Token![,]> =
294            content.parse_terminated(Parse::parse, Token![,])?;
295        let mut m = HashMap::new();
296        for i in items {
297            match m.insert(i.abi_error().clone(), i.clone()) {
298                None => {}
299                Some(prev_def) => {
300                    return Err(Error::new(
301                        *i.err_loc(),
302                        format!(
303                        "duplicate definition of rich error type for {:?}: previously defined at {:?}",
304                        i.abi_error(), prev_def.err_loc(),
305                    ),
306                    ))
307                }
308            }
309        }
310        Ok(ErrorConf(m))
311    }
312}
313
314#[derive(Debug, Clone)]
315pub enum ErrorConfField {
316    Trappable(TrappableErrorConfField),
317    User(UserErrorConfField),
318}
319impl ErrorConfField {
320    pub fn abi_error(&self) -> &Ident {
321        match self {
322            Self::Trappable(t) => &t.abi_error,
323            Self::User(u) => &u.abi_error,
324        }
325    }
326    pub fn err_loc(&self) -> &Span {
327        match self {
328            Self::Trappable(t) => &t.err_loc,
329            Self::User(u) => &u.err_loc,
330        }
331    }
332}
333
334impl Parse for ErrorConfField {
335    fn parse(input: ParseStream) -> Result<Self> {
336        let err_loc = input.span();
337        let abi_error = input.parse::<Ident>()?;
338        let _arrow: Token![=>] = input.parse()?;
339
340        let lookahead = input.lookahead1();
341        if lookahead.peek(kw::trappable) {
342            let _ = input.parse::<kw::trappable>()?;
343            let rich_error = input.parse()?;
344            Ok(ErrorConfField::Trappable(TrappableErrorConfField {
345                abi_error,
346                rich_error,
347                err_loc,
348            }))
349        } else {
350            let rich_error = input.parse::<syn::Path>()?;
351            Ok(ErrorConfField::User(UserErrorConfField {
352                abi_error,
353                rich_error,
354                err_loc,
355            }))
356        }
357    }
358}
359
360#[derive(Clone, Debug)]
361pub struct TrappableErrorConfField {
362    pub abi_error: Ident,
363    pub rich_error: Ident,
364    pub err_loc: Span,
365}
366
367#[derive(Clone)]
368pub struct UserErrorConfField {
369    pub abi_error: Ident,
370    pub rich_error: syn::Path,
371    pub err_loc: Span,
372}
373
374impl std::fmt::Debug for UserErrorConfField {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        f.debug_struct("ErrorConfField")
377            .field("abi_error", &self.abi_error)
378            .field("rich_error", &"(...)")
379            .field("err_loc", &self.err_loc)
380            .finish()
381    }
382}
383
384#[derive(Clone, Default, Debug)]
385/// Modules and funcs that have async signatures
386pub struct AsyncConf {
387    block_with: Option<TokenStream>,
388    functions: AsyncFunctions,
389}
390
391#[derive(Clone, Debug)]
392pub enum Asyncness {
393    /// Wiggle function is synchronous, wasmtime Func is synchronous
394    Sync,
395    /// Wiggle function is asynchronous, but wasmtime Func is synchronous
396    Blocking { block_with: TokenStream },
397    /// Wiggle function and wasmtime Func are asynchronous.
398    Async,
399}
400
401impl Asyncness {
402    pub fn is_async(&self) -> bool {
403        match self {
404            Self::Async => true,
405            _ => false,
406        }
407    }
408    pub fn blocking(&self) -> Option<&TokenStream> {
409        match self {
410            Self::Blocking { block_with } => Some(block_with),
411            _ => None,
412        }
413    }
414    pub fn is_sync(&self) -> bool {
415        match self {
416            Self::Sync => true,
417            _ => false,
418        }
419    }
420}
421
422#[derive(Clone, Debug)]
423pub enum AsyncFunctions {
424    Some(HashMap<String, Vec<String>>),
425    All,
426}
427impl Default for AsyncFunctions {
428    fn default() -> Self {
429        AsyncFunctions::Some(HashMap::default())
430    }
431}
432
433impl AsyncConf {
434    pub fn get(&self, module: &str, function: &str) -> Asyncness {
435        let a = match &self.block_with {
436            Some(block_with) => Asyncness::Blocking {
437                block_with: block_with.clone(),
438            },
439            None => Asyncness::Async,
440        };
441        match &self.functions {
442            AsyncFunctions::Some(fs) => {
443                if fs
444                    .get(module)
445                    .and_then(|fs| fs.iter().find(|f| *f == function))
446                    .is_some()
447                {
448                    a
449                } else {
450                    Asyncness::Sync
451                }
452            }
453            AsyncFunctions::All => a,
454        }
455    }
456
457    pub fn contains_async(&self, module: &witx::Module) -> bool {
458        for f in module.funcs() {
459            if self.get(module.name.as_str(), f.name.as_str()).is_async() {
460                return true;
461            }
462        }
463        false
464    }
465}
466
467impl Parse for AsyncFunctions {
468    fn parse(input: ParseStream) -> Result<Self> {
469        let content;
470        let lookahead = input.lookahead1();
471        if lookahead.peek(syn::token::Brace) {
472            let _ = braced!(content in input);
473            let items: Punctuated<FunctionField, Token![,]> =
474                content.parse_terminated(Parse::parse, Token![,])?;
475            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
476            use std::collections::hash_map::Entry;
477            for i in items {
478                let function_names = i
479                    .function_names
480                    .iter()
481                    .map(|i| i.to_string())
482                    .collect::<Vec<String>>();
483                match functions.entry(i.module_name.to_string()) {
484                    Entry::Occupied(o) => o.into_mut().extend(function_names),
485                    Entry::Vacant(v) => {
486                        v.insert(function_names);
487                    }
488                }
489            }
490            Ok(AsyncFunctions::Some(functions))
491        } else if lookahead.peek(Token![*]) {
492            let _: Token![*] = input.parse().unwrap();
493            Ok(AsyncFunctions::All)
494        } else {
495            Err(lookahead.error())
496        }
497    }
498}
499
500#[derive(Clone)]
501pub struct FunctionField {
502    pub module_name: Ident,
503    pub function_names: Vec<Ident>,
504    pub err_loc: Span,
505}
506
507impl Parse for FunctionField {
508    fn parse(input: ParseStream) -> Result<Self> {
509        let err_loc = input.span();
510        let module_name = input.parse::<Ident>()?;
511        let _doublecolon: Token![::] = input.parse()?;
512        let lookahead = input.lookahead1();
513        if lookahead.peek(syn::token::Brace) {
514            let content;
515            let _ = braced!(content in input);
516            let function_names: Punctuated<Ident, Token![,]> =
517                content.parse_terminated(Parse::parse, Token![,])?;
518            Ok(FunctionField {
519                module_name,
520                function_names: function_names.iter().cloned().collect(),
521                err_loc,
522            })
523        } else if lookahead.peek(Ident) {
524            let name = input.parse()?;
525            Ok(FunctionField {
526                module_name,
527                function_names: vec![name],
528                err_loc,
529            })
530        } else {
531            Err(lookahead.error())
532        }
533    }
534}
535
536#[derive(Clone)]
537pub struct WasmtimeConfig {
538    pub c: Config,
539    pub target: syn::Path,
540}
541
542#[derive(Clone)]
543pub enum WasmtimeConfigField {
544    Core(ConfigField),
545    Target(syn::Path),
546}
547impl WasmtimeConfig {
548    pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
549        let mut target = None;
550        let mut cs = Vec::new();
551        for f in fields {
552            match f {
553                WasmtimeConfigField::Target(c) => {
554                    if target.is_some() {
555                        return Err(Error::new(err_loc, "duplicate `target` field"));
556                    }
557                    target = Some(c);
558                }
559                WasmtimeConfigField::Core(c) => cs.push(c),
560            }
561        }
562        let c = Config::build(cs.into_iter(), err_loc)?;
563        Ok(WasmtimeConfig {
564            c,
565            target: target
566                .take()
567                .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
568        })
569    }
570}
571
572impl Parse for WasmtimeConfig {
573    fn parse(input: ParseStream) -> Result<Self> {
574        let contents;
575        let _lbrace = braced!(contents in input);
576        let fields: Punctuated<WasmtimeConfigField, Token![,]> =
577            contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
578        Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
579    }
580}
581
582impl Parse for WasmtimeConfigField {
583    fn parse(input: ParseStream) -> Result<Self> {
584        if input.peek(kw::target) {
585            input.parse::<kw::target>()?;
586            input.parse::<Token![:]>()?;
587            Ok(WasmtimeConfigField::Target(input.parse()?))
588        } else {
589            Ok(WasmtimeConfigField::Core(input.parse()?))
590        }
591    }
592}
593
594#[derive(Clone, Debug)]
595pub struct TracingConf {
596    enabled: bool,
597    excluded_functions: HashMap<String, Vec<String>>,
598}
599
600impl TracingConf {
601    pub fn enabled_for(&self, module: &str, function: &str) -> bool {
602        if !self.enabled {
603            return false;
604        }
605        self.excluded_functions
606            .get(module)
607            .and_then(|fs| fs.iter().find(|f| *f == function))
608            .is_none()
609    }
610}
611
612impl Default for TracingConf {
613    fn default() -> Self {
614        Self {
615            enabled: true,
616            excluded_functions: HashMap::new(),
617        }
618    }
619}
620
621impl Parse for TracingConf {
622    fn parse(input: ParseStream) -> Result<Self> {
623        let enabled = input.parse::<syn::LitBool>()?.value;
624
625        let lookahead = input.lookahead1();
626        if lookahead.peek(kw::disable_for) {
627            input.parse::<kw::disable_for>()?;
628            let content;
629            let _ = braced!(content in input);
630            let items: Punctuated<FunctionField, Token![,]> =
631                content.parse_terminated(Parse::parse, Token![,])?;
632            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
633            use std::collections::hash_map::Entry;
634            for i in items {
635                let function_names = i
636                    .function_names
637                    .iter()
638                    .map(|i| i.to_string())
639                    .collect::<Vec<String>>();
640                match functions.entry(i.module_name.to_string()) {
641                    Entry::Occupied(o) => o.into_mut().extend(function_names),
642                    Entry::Vacant(v) => {
643                        v.insert(function_names);
644                    }
645                }
646            }
647
648            Ok(TracingConf {
649                enabled,
650                excluded_functions: functions,
651            })
652        } else {
653            Ok(TracingConf {
654                enabled,
655                excluded_functions: HashMap::new(),
656            })
657        }
658    }
659}