1use {
2 proc_macro2::{Span, TokenStream},
3 std::{collections::HashMap, path::PathBuf},
4 syn::{
5 Error, Ident, LitStr, Result, Token, braced, bracketed,
6 parse::{Parse, ParseStream},
7 punctuated::Punctuated,
8 },
9};
10
11#[derive(Debug, Clone)]
12pub struct Config {
13 pub witx: WitxConf,
14 pub errors: ErrorConf,
15 pub async_: AsyncConf,
16 pub wasmtime: bool,
17 pub tracing: TracingConf,
18 pub mutable: bool,
19}
20
21mod kw {
22 syn::custom_keyword!(witx);
23 syn::custom_keyword!(witx_literal);
24 syn::custom_keyword!(block_on);
25 syn::custom_keyword!(errors);
26 syn::custom_keyword!(target);
27 syn::custom_keyword!(wasmtime);
28 syn::custom_keyword!(mutable);
29 syn::custom_keyword!(tracing);
30 syn::custom_keyword!(disable_for);
31 syn::custom_keyword!(trappable);
32}
33
34#[derive(Debug, Clone)]
35pub enum ConfigField {
36 Witx(WitxConf),
37 Error(ErrorConf),
38 Async(AsyncConf),
39 Wasmtime(bool),
40 Tracing(TracingConf),
41 Mutable(bool),
42}
43
44impl Parse for ConfigField {
45 fn parse(input: ParseStream) -> Result<Self> {
46 let lookahead = input.lookahead1();
47 if lookahead.peek(kw::witx) {
48 input.parse::<kw::witx>()?;
49 input.parse::<Token![:]>()?;
50 Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
51 } else if lookahead.peek(kw::witx_literal) {
52 input.parse::<kw::witx_literal>()?;
53 input.parse::<Token![:]>()?;
54 Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
55 } else if lookahead.peek(kw::errors) {
56 input.parse::<kw::errors>()?;
57 input.parse::<Token![:]>()?;
58 Ok(ConfigField::Error(input.parse()?))
59 } else if lookahead.peek(Token![async]) {
60 input.parse::<Token![async]>()?;
61 input.parse::<Token![:]>()?;
62 Ok(ConfigField::Async(AsyncConf {
63 block_with: None,
64 functions: input.parse()?,
65 }))
66 } else if lookahead.peek(kw::block_on) {
67 input.parse::<kw::block_on>()?;
68 let block_with = if input.peek(syn::token::Bracket) {
69 let content;
70 let _ = bracketed!(content in input);
71 content.parse()?
72 } else {
73 quote::quote!(wiggle::run_in_dummy_executor)
74 };
75 input.parse::<Token![:]>()?;
76 Ok(ConfigField::Async(AsyncConf {
77 block_with: Some(block_with),
78 functions: input.parse()?,
79 }))
80 } else if lookahead.peek(kw::wasmtime) {
81 input.parse::<kw::wasmtime>()?;
82 input.parse::<Token![:]>()?;
83 Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
84 } else if lookahead.peek(kw::tracing) {
85 input.parse::<kw::tracing>()?;
86 input.parse::<Token![:]>()?;
87 Ok(ConfigField::Tracing(input.parse()?))
88 } else if lookahead.peek(kw::mutable) {
89 input.parse::<kw::mutable>()?;
90 input.parse::<Token![:]>()?;
91 Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
92 } else {
93 Err(lookahead.error())
94 }
95 }
96}
97
98impl Config {
99 pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
100 let mut witx = None;
101 let mut errors = None;
102 let mut async_ = None;
103 let mut wasmtime = None;
104 let mut tracing = None;
105 let mut mutable = None;
106 for f in fields {
107 match f {
108 ConfigField::Witx(c) => {
109 if witx.is_some() {
110 return Err(Error::new(err_loc, "duplicate `witx` field"));
111 }
112 witx = Some(c);
113 }
114 ConfigField::Error(c) => {
115 if errors.is_some() {
116 return Err(Error::new(err_loc, "duplicate `errors` field"));
117 }
118 errors = Some(c);
119 }
120 ConfigField::Async(c) => {
121 if async_.is_some() {
122 return Err(Error::new(err_loc, "duplicate `async` field"));
123 }
124 async_ = Some(c);
125 }
126 ConfigField::Wasmtime(c) => {
127 if wasmtime.is_some() {
128 return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
129 }
130 wasmtime = Some(c);
131 }
132 ConfigField::Tracing(c) => {
133 if tracing.is_some() {
134 return Err(Error::new(err_loc, "duplicate `tracing` field"));
135 }
136 tracing = Some(c);
137 }
138 ConfigField::Mutable(c) => {
139 if mutable.is_some() {
140 return Err(Error::new(err_loc, "duplicate `mutable` field"));
141 }
142 mutable = Some(c);
143 }
144 }
145 }
146 Ok(Config {
147 witx: witx
148 .take()
149 .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
150 errors: errors.take().unwrap_or_default(),
151 async_: async_.take().unwrap_or_default(),
152 wasmtime: wasmtime.unwrap_or(true),
153 tracing: tracing.unwrap_or_default(),
154 mutable: mutable.unwrap_or(true),
155 })
156 }
157
158 pub fn load_document(&self) -> witx::Document {
164 self.witx.load_document()
165 }
166}
167
168impl Parse for Config {
169 fn parse(input: ParseStream) -> Result<Self> {
170 let contents;
171 let _lbrace = braced!(contents in input);
172 let fields: Punctuated<ConfigField, Token![,]> =
173 contents.parse_terminated(ConfigField::parse, Token![,])?;
174 Ok(Config::build(fields.into_iter(), input.span())?)
175 }
176}
177
178#[derive(Debug, Clone)]
184pub enum WitxConf {
185 Paths(Paths),
187 Literal(Literal),
189}
190
191impl WitxConf {
192 pub fn load_document(&self) -> witx::Document {
199 match self {
200 Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
201 Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
202 }
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct Paths(Vec<PathBuf>);
209
210impl Paths {
211 pub fn new() -> Self {
213 Default::default()
214 }
215}
216
217impl Default for Paths {
218 fn default() -> Self {
219 Self(Default::default())
220 }
221}
222
223impl AsRef<[PathBuf]> for Paths {
224 fn as_ref(&self) -> &[PathBuf] {
225 self.0.as_ref()
226 }
227}
228
229impl AsMut<[PathBuf]> for Paths {
230 fn as_mut(&mut self) -> &mut [PathBuf] {
231 self.0.as_mut()
232 }
233}
234
235impl FromIterator<PathBuf> for Paths {
236 fn from_iter<I>(iter: I) -> Self
237 where
238 I: IntoIterator<Item = PathBuf>,
239 {
240 Self(iter.into_iter().collect())
241 }
242}
243
244impl Parse for Paths {
245 fn parse(input: ParseStream) -> Result<Self> {
246 let content;
247 let _ = bracketed!(content in input);
248 let path_lits: Punctuated<LitStr, Token![,]> =
249 content.parse_terminated(Parse::parse, Token![,])?;
250
251 let expanded_paths = path_lits
252 .iter()
253 .map(|lit| {
254 PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()).join(lit.value())
255 })
256 .collect::<Vec<PathBuf>>();
257
258 Ok(Paths(expanded_paths))
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct Literal(String);
265
266impl AsRef<str> for Literal {
267 fn as_ref(&self) -> &str {
268 self.0.as_ref()
269 }
270}
271
272impl Parse for Literal {
273 fn parse(input: ParseStream) -> Result<Self> {
274 Ok(Self(input.parse::<syn::LitStr>()?.value()))
275 }
276}
277
278#[derive(Clone, Default, Debug)]
279pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
281
282impl ErrorConf {
283 pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
284 self.0.iter()
285 }
286}
287
288impl Parse for ErrorConf {
289 fn parse(input: ParseStream) -> Result<Self> {
290 let content;
291 let _ = braced!(content in input);
292 let items: Punctuated<ErrorConfField, Token![,]> =
293 content.parse_terminated(Parse::parse, Token![,])?;
294 let mut m = HashMap::new();
295 for i in items {
296 match m.insert(i.abi_error().clone(), i.clone()) {
297 None => {}
298 Some(prev_def) => {
299 return Err(Error::new(
300 *i.err_loc(),
301 format!(
302 "duplicate definition of rich error type for {:?}: previously defined at {:?}",
303 i.abi_error(),
304 prev_def.err_loc(),
305 ),
306 ));
307 }
308 }
309 }
310 Ok(ErrorConf(m))
311 }
312}
313
314#[derive(Debug, Clone)]
315pub enum ErrorConfField {
316 Trappable(TrappableErrorConfField),
317 User(UserErrorConfField),
318}
319impl ErrorConfField {
320 pub fn abi_error(&self) -> &Ident {
321 match self {
322 Self::Trappable(t) => &t.abi_error,
323 Self::User(u) => &u.abi_error,
324 }
325 }
326 pub fn err_loc(&self) -> &Span {
327 match self {
328 Self::Trappable(t) => &t.err_loc,
329 Self::User(u) => &u.err_loc,
330 }
331 }
332}
333
334impl Parse for ErrorConfField {
335 fn parse(input: ParseStream) -> Result<Self> {
336 let err_loc = input.span();
337 let abi_error = input.parse::<Ident>()?;
338 let _arrow: Token![=>] = input.parse()?;
339
340 let lookahead = input.lookahead1();
341 if lookahead.peek(kw::trappable) {
342 let _ = input.parse::<kw::trappable>()?;
343 let rich_error = input.parse()?;
344 Ok(ErrorConfField::Trappable(TrappableErrorConfField {
345 abi_error,
346 rich_error,
347 err_loc,
348 }))
349 } else {
350 let rich_error = input.parse::<syn::Path>()?;
351 Ok(ErrorConfField::User(UserErrorConfField {
352 abi_error,
353 rich_error,
354 err_loc,
355 }))
356 }
357 }
358}
359
360#[derive(Clone, Debug)]
361pub struct TrappableErrorConfField {
362 pub abi_error: Ident,
363 pub rich_error: Ident,
364 pub err_loc: Span,
365}
366
367#[derive(Clone)]
368pub struct UserErrorConfField {
369 pub abi_error: Ident,
370 pub rich_error: syn::Path,
371 pub err_loc: Span,
372}
373
374impl std::fmt::Debug for UserErrorConfField {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 f.debug_struct("ErrorConfField")
377 .field("abi_error", &self.abi_error)
378 .field("rich_error", &"(...)")
379 .field("err_loc", &self.err_loc)
380 .finish()
381 }
382}
383
384#[derive(Clone, Default, Debug)]
385pub struct AsyncConf {
387 block_with: Option<TokenStream>,
388 functions: AsyncFunctions,
389}
390
391#[derive(Clone, Debug)]
392pub enum Asyncness {
393 Sync,
395 Blocking { block_with: TokenStream },
397 Async,
399}
400
401impl Asyncness {
402 pub fn is_async(&self) -> bool {
403 match self {
404 Self::Async => true,
405 _ => false,
406 }
407 }
408 pub fn blocking(&self) -> Option<&TokenStream> {
409 match self {
410 Self::Blocking { block_with } => Some(block_with),
411 _ => None,
412 }
413 }
414 pub fn is_sync(&self) -> bool {
415 match self {
416 Self::Sync => true,
417 _ => false,
418 }
419 }
420}
421
422#[derive(Clone, Debug)]
423pub enum AsyncFunctions {
424 Some(HashMap<String, Vec<String>>),
425 All,
426}
427impl Default for AsyncFunctions {
428 fn default() -> Self {
429 AsyncFunctions::Some(HashMap::default())
430 }
431}
432
433impl AsyncConf {
434 pub fn get(&self, module: &str, function: &str) -> Asyncness {
435 let a = match &self.block_with {
436 Some(block_with) => Asyncness::Blocking {
437 block_with: block_with.clone(),
438 },
439 None => Asyncness::Async,
440 };
441 match &self.functions {
442 AsyncFunctions::Some(fs) => {
443 if fs
444 .get(module)
445 .and_then(|fs| fs.iter().find(|f| *f == function))
446 .is_some()
447 {
448 a
449 } else {
450 Asyncness::Sync
451 }
452 }
453 AsyncFunctions::All => a,
454 }
455 }
456
457 pub fn contains_async(&self, module: &witx::Module) -> bool {
458 for f in module.funcs() {
459 if self.get(module.name.as_str(), f.name.as_str()).is_async() {
460 return true;
461 }
462 }
463 false
464 }
465}
466
467impl Parse for AsyncFunctions {
468 fn parse(input: ParseStream) -> Result<Self> {
469 let content;
470 let lookahead = input.lookahead1();
471 if lookahead.peek(syn::token::Brace) {
472 let _ = braced!(content in input);
473 let items: Punctuated<FunctionField, Token![,]> =
474 content.parse_terminated(Parse::parse, Token![,])?;
475 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
476 use std::collections::hash_map::Entry;
477 for i in items {
478 let function_names = i
479 .function_names
480 .iter()
481 .map(|i| i.to_string())
482 .collect::<Vec<String>>();
483 match functions.entry(i.module_name.to_string()) {
484 Entry::Occupied(o) => o.into_mut().extend(function_names),
485 Entry::Vacant(v) => {
486 v.insert(function_names);
487 }
488 }
489 }
490 Ok(AsyncFunctions::Some(functions))
491 } else if lookahead.peek(Token![*]) {
492 let _: Token![*] = input.parse().unwrap();
493 Ok(AsyncFunctions::All)
494 } else {
495 Err(lookahead.error())
496 }
497 }
498}
499
500#[derive(Clone)]
501pub struct FunctionField {
502 pub module_name: Ident,
503 pub function_names: Vec<Ident>,
504 pub err_loc: Span,
505}
506
507impl Parse for FunctionField {
508 fn parse(input: ParseStream) -> Result<Self> {
509 let err_loc = input.span();
510 let module_name = input.parse::<Ident>()?;
511 let _doublecolon: Token![::] = input.parse()?;
512 let lookahead = input.lookahead1();
513 if lookahead.peek(syn::token::Brace) {
514 let content;
515 let _ = braced!(content in input);
516 let function_names: Punctuated<Ident, Token![,]> =
517 content.parse_terminated(Parse::parse, Token![,])?;
518 Ok(FunctionField {
519 module_name,
520 function_names: function_names.iter().cloned().collect(),
521 err_loc,
522 })
523 } else if lookahead.peek(Ident) {
524 let name = input.parse()?;
525 Ok(FunctionField {
526 module_name,
527 function_names: vec![name],
528 err_loc,
529 })
530 } else {
531 Err(lookahead.error())
532 }
533 }
534}
535
536#[derive(Clone)]
537pub struct WasmtimeConfig {
538 pub c: Config,
539 pub target: syn::Path,
540}
541
542#[derive(Clone)]
543pub enum WasmtimeConfigField {
544 Core(ConfigField),
545 Target(syn::Path),
546}
547impl WasmtimeConfig {
548 pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
549 let mut target = None;
550 let mut cs = Vec::new();
551 for f in fields {
552 match f {
553 WasmtimeConfigField::Target(c) => {
554 if target.is_some() {
555 return Err(Error::new(err_loc, "duplicate `target` field"));
556 }
557 target = Some(c);
558 }
559 WasmtimeConfigField::Core(c) => cs.push(c),
560 }
561 }
562 let c = Config::build(cs.into_iter(), err_loc)?;
563 Ok(WasmtimeConfig {
564 c,
565 target: target
566 .take()
567 .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
568 })
569 }
570}
571
572impl Parse for WasmtimeConfig {
573 fn parse(input: ParseStream) -> Result<Self> {
574 let contents;
575 let _lbrace = braced!(contents in input);
576 let fields: Punctuated<WasmtimeConfigField, Token![,]> =
577 contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
578 Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
579 }
580}
581
582impl Parse for WasmtimeConfigField {
583 fn parse(input: ParseStream) -> Result<Self> {
584 if input.peek(kw::target) {
585 input.parse::<kw::target>()?;
586 input.parse::<Token![:]>()?;
587 Ok(WasmtimeConfigField::Target(input.parse()?))
588 } else {
589 Ok(WasmtimeConfigField::Core(input.parse()?))
590 }
591 }
592}
593
594#[derive(Clone, Debug)]
595pub struct TracingConf {
596 enabled: bool,
597 excluded_functions: HashMap<String, Vec<String>>,
598}
599
600impl TracingConf {
601 pub fn enabled_for(&self, module: &str, function: &str) -> bool {
602 if !self.enabled {
603 return false;
604 }
605 self.excluded_functions
606 .get(module)
607 .and_then(|fs| fs.iter().find(|f| *f == function))
608 .is_none()
609 }
610}
611
612impl Default for TracingConf {
613 fn default() -> Self {
614 Self {
615 enabled: true,
616 excluded_functions: HashMap::new(),
617 }
618 }
619}
620
621impl Parse for TracingConf {
622 fn parse(input: ParseStream) -> Result<Self> {
623 let enabled = input.parse::<syn::LitBool>()?.value;
624
625 let lookahead = input.lookahead1();
626 if lookahead.peek(kw::disable_for) {
627 input.parse::<kw::disable_for>()?;
628 let content;
629 let _ = braced!(content in input);
630 let items: Punctuated<FunctionField, Token![,]> =
631 content.parse_terminated(Parse::parse, Token![,])?;
632 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
633 use std::collections::hash_map::Entry;
634 for i in items {
635 let function_names = i
636 .function_names
637 .iter()
638 .map(|i| i.to_string())
639 .collect::<Vec<String>>();
640 match functions.entry(i.module_name.to_string()) {
641 Entry::Occupied(o) => o.into_mut().extend(function_names),
642 Entry::Vacant(v) => {
643 v.insert(function_names);
644 }
645 }
646 }
647
648 Ok(TracingConf {
649 enabled,
650 excluded_functions: functions,
651 })
652 } else {
653 Ok(TracingConf {
654 enabled,
655 excluded_functions: HashMap::new(),
656 })
657 }
658 }
659}