witx_bindgen/
lib.rs

1use heck::*;
2use std::io::{Read, Write};
3use std::path::Path;
4use std::process::{Command, Stdio};
5use witx::*;
6
7pub fn generate<P: AsRef<Path>>(witx_paths: &[P]) -> String {
8    let doc = witx::load(witx_paths).unwrap();
9
10    let mut raw = String::new();
11    raw.push_str(
12        "\
13// This file is automatically generated, DO NOT EDIT
14//
15// To regenerate this file run the `crates/witx-bindgen` command
16
17use core::mem::MaybeUninit;
18
19pub use crate::error::Error;
20pub type Result<T, E = Error> = core::result::Result<T, E>;
21",
22    );
23    for ty in doc.typenames() {
24        ty.render(&mut raw);
25        raw.push_str("\n");
26    }
27    for m in doc.modules() {
28        m.render(&mut raw);
29        raw.push_str("\n");
30    }
31
32    let mut rustfmt = Command::new("rustfmt")
33        .stdin(Stdio::piped())
34        .stdout(Stdio::piped())
35        .spawn()
36        .unwrap();
37    rustfmt
38        .stdin
39        .take()
40        .unwrap()
41        .write_all(raw.as_bytes())
42        .unwrap();
43    let mut ret = String::new();
44    rustfmt
45        .stdout
46        .take()
47        .unwrap()
48        .read_to_string(&mut ret)
49        .unwrap();
50    let status = rustfmt.wait().unwrap();
51    assert!(status.success());
52    return ret;
53}
54
55trait Render {
56    fn render(&self, src: &mut String);
57}
58
59impl Render for NamedType {
60    fn render(&self, src: &mut String) {
61        let name = self.name.as_str();
62        match &self.tref {
63            TypeRef::Value(ty) => match &**ty {
64                Type::Enum(e) => render_enum(src, name, e),
65                Type::Flags(f) => render_flags(src, name, f),
66                Type::Int(c) => render_const(src, name, c),
67                Type::Struct(s) => render_struct(src, name, s),
68                Type::Union(u) => render_union(src, name, u),
69                Type::Handle(h) => render_handle(src, name, h),
70                Type::Array { .. }
71                | Type::Pointer { .. }
72                | Type::ConstPointer { .. }
73                | Type::Builtin { .. } => render_alias(src, name, &self.tref),
74            },
75            TypeRef::Name(_nt) => render_alias(src, name, &self.tref),
76        }
77    }
78}
79
80// TODO verify this is correct way of handling IntDatatype
81fn render_const(src: &mut String, name: &str, c: &IntDatatype) {
82    src.push_str(&format!("pub type {} = ", name.to_camel_case()));
83    c.repr.render(src);
84    src.push_str(";\n");
85    for r#const in c.consts.iter() {
86        rustdoc(&r#const.docs, src);
87        src.push_str(&format!(
88            "pub const {}_{}: {} = {};",
89            name.to_shouty_snake_case(),
90            r#const.name.as_str().to_shouty_snake_case(),
91            name.to_camel_case(),
92            r#const.value
93        ));
94    }
95}
96
97fn render_union(src: &mut String, name: &str, u: &UnionDatatype) {
98    src.push_str("#[repr(C)]\n");
99    src.push_str("#[derive(Copy, Clone)]\n");
100    src.push_str(&format!("pub union {}U {{\n", name.to_camel_case()));
101    for variant in u.variants.iter() {
102        if let Some(ref tref) = variant.tref {
103            rustdoc(&variant.docs, src);
104            src.push_str("pub ");
105            variant.name.render(src);
106            src.push_str(": ");
107            tref.render(src);
108            src.push_str(",\n");
109        }
110    }
111    src.push_str("}\n");
112    src.push_str("#[repr(C)]\n");
113    src.push_str("#[derive(Copy, Clone)]\n");
114    src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
115    src.push_str(&format!(
116        "pub tag: {},\n",
117        u.tag.name.as_str().to_camel_case()
118    ));
119    src.push_str(&format!("pub u: {}U,\n", name.to_camel_case()));
120    src.push_str("}\n");
121}
122
123fn render_struct(src: &mut String, name: &str, s: &StructDatatype) {
124    src.push_str("#[repr(C)]\n");
125    if struct_contains_union(s) {
126        // Unions can't automatically derive `Debug`.
127        src.push_str("#[derive(Copy, Clone)]\n");
128    } else {
129        src.push_str("#[derive(Copy, Clone, Debug)]\n");
130    }
131    src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
132    for member in s.members.iter() {
133        rustdoc(&member.docs, src);
134        src.push_str("pub ");
135        member.name.render(src);
136        src.push_str(": ");
137        member.tref.render(src);
138        src.push_str(",\n");
139    }
140    src.push_str("}");
141}
142
143fn render_flags(src: &mut String, name: &str, f: &FlagsDatatype) {
144    src.push_str(&format!("pub type {} = ", name.to_camel_case()));
145    f.repr.render(src);
146    src.push_str(";\n");
147    for (i, variant) in f.flags.iter().enumerate() {
148        rustdoc(&variant.docs, src);
149        src.push_str(&format!(
150            "pub const {}_{}: {} = 0x{:x};",
151            name.to_shouty_snake_case(),
152            variant.name.as_str().to_shouty_snake_case(),
153            name.to_camel_case(),
154            1 << i
155        ));
156    }
157}
158
159fn render_enum(src: &mut String, name: &str, e: &EnumDatatype) {
160    src.push_str(&format!("pub type {} = ", name.to_camel_case()));
161    e.repr.render(src);
162    src.push_str(";\n");
163    for (i, variant) in e.variants.iter().enumerate() {
164        rustdoc(&variant.docs, src);
165        src.push_str(&format!(
166            "pub const {}_{}: {} = {};",
167            name.to_shouty_snake_case(),
168            variant.name.as_str().to_shouty_snake_case(),
169            name.to_camel_case(),
170            i
171        ));
172    }
173
174    if name == "errno" {
175        src.push_str("pub(crate) fn strerror(code: u16) -> &'static str {");
176        src.push_str("match code {");
177        for variant in e.variants.iter() {
178            src.push_str(&name.to_shouty_snake_case());
179            src.push_str("_");
180            src.push_str(&variant.name.as_str().to_shouty_snake_case());
181            src.push_str(" => \"");
182            src.push_str(variant.docs.trim());
183            src.push_str("\",");
184        }
185        src.push_str("_ => \"Unknown error.\",");
186        src.push_str("}");
187        src.push_str("}");
188    }
189}
190
191impl Render for IntRepr {
192    fn render(&self, src: &mut String) {
193        match self {
194            IntRepr::U8 => src.push_str("u8"),
195            IntRepr::U16 => src.push_str("u16"),
196            IntRepr::U32 => src.push_str("u32"),
197            IntRepr::U64 => src.push_str("u64"),
198        }
199    }
200}
201
202fn render_alias(src: &mut String, name: &str, dest: &TypeRef) {
203    src.push_str(&format!("pub type {}", name.to_camel_case()));
204    if dest.type_().passed_by() == TypePassedBy::PointerLengthPair {
205        src.push_str("<'a>");
206    }
207    src.push_str(" = ");
208
209    // Give `size` special treatment to translate it to `usize` in Rust instead of `u32`. Makes
210    // things a bit nicer for client libraries. We can remove this hack once WASI moves to a
211    // snapshot that uses BuiltinType::Size.
212    if name == "size" {
213        src.push_str("usize");
214    } else {
215        dest.render(src);
216    }
217    src.push(';');
218}
219
220impl Render for TypeRef {
221    fn render(&self, src: &mut String) {
222        match self {
223            TypeRef::Name(t) => {
224                src.push_str(&t.name.as_str().to_camel_case());
225                if t.type_().passed_by() == TypePassedBy::PointerLengthPair {
226                    src.push_str("<'_>");
227                }
228            }
229            TypeRef::Value(v) => match &**v {
230                Type::Builtin(t) => t.render(src),
231                Type::Array(t) => {
232                    src.push_str("&'a [");
233                    t.render(src);
234                    src.push_str("]");
235                }
236                Type::Pointer(t) => {
237                    src.push_str("*mut ");
238                    t.render(src);
239                }
240                Type::ConstPointer(t) => {
241                    src.push_str("*const ");
242                    t.render(src);
243                }
244                t => panic!("reference to anonymous {} not possible!", t.kind()),
245            },
246        }
247    }
248}
249
250impl Render for BuiltinType {
251    fn render(&self, src: &mut String) {
252        match self {
253            BuiltinType::String => src.push_str("&str"),
254            BuiltinType::U8 => src.push_str("u8"),
255            BuiltinType::U16 => src.push_str("u16"),
256            BuiltinType::U32 => src.push_str("u32"),
257            BuiltinType::U64 => src.push_str("u64"),
258            BuiltinType::S8 => src.push_str("i8"),
259            BuiltinType::S16 => src.push_str("i16"),
260            BuiltinType::S32 => src.push_str("i32"),
261            BuiltinType::S64 => src.push_str("i64"),
262            BuiltinType::F32 => src.push_str("f32"),
263            BuiltinType::F64 => src.push_str("f64"),
264            BuiltinType::USize => src.push_str("usize"),
265            BuiltinType::Char8 => {
266                // Char8 represents a UTF8 code *unit* (`u8` in Rust, `char8_t` in C++20)
267                // rather than a code *point* (`char` in Rust which is multi-byte)
268                src.push_str("u8")
269            }
270        }
271    }
272}
273
274impl Render for Module {
275    fn render(&self, src: &mut String) {
276        let rust_name = self.name.as_str().to_snake_case();
277        // wrapper functions
278        for f in self.funcs() {
279            render_highlevel(&f, &rust_name, src);
280            src.push_str("\n\n");
281        }
282
283        // raw module
284        src.push_str("pub mod ");
285        src.push_str(&rust_name);
286        src.push_str("{\nuse super::*;");
287        src.push_str("#[link(wasm_import_module =\"");
288        src.push_str(self.name.as_str());
289        src.push_str("\")]\n");
290        src.push_str("extern \"C\" {\n");
291        for f in self.funcs() {
292            f.render(src);
293            src.push_str("\n");
294        }
295        src.push_str("}");
296        src.push_str("}");
297    }
298}
299
300fn render_highlevel(func: &InterfaceFunc, module: &str, src: &mut String) {
301    let mut rust_name = String::new();
302    func.name.render(&mut rust_name);
303    let rust_name = rust_name.to_snake_case();
304    rustdoc(&func.docs, src);
305    rustdoc_params(&func.params, "Parameters", src);
306    rustdoc_params(&func.results, "Return", src);
307
308    // Render the function and its arguments, and note that the arguments here
309    // are the exact type name arguments as opposed to the pointer/length pair
310    // ones. These functions are unsafe because they work with integer file
311    // descriptors, which are effectively forgeable and danglable raw pointers
312    // into the file descriptor address space.
313    src.push_str("pub unsafe fn ");
314
315    // TODO workout how to handle wasi-ephemeral which introduces multiple
316    // WASI modules into the picture. For now, feature-gate it, and if we're
317    // compiling ephmeral bindings, prefix wrapper syscall with module name.
318    cfg_if::cfg_if! {
319        if #[cfg(feature = "multi-module")] {
320            src.push_str(&[module, &rust_name].join("_"));
321        } else {
322            src.push_str(&rust_name);
323        }
324    }
325
326    src.push_str("(");
327    for param in func.params.iter() {
328        param.name.render(src);
329        src.push_str(": ");
330        param.tref.render(src);
331        src.push_str(",");
332    }
333    src.push_str(")");
334
335    // Render the result type of this function, if there is one.
336    if let Some(first) = func.results.get(0) {
337        // only know how to generate bindings for arguments where the first
338        // results is an errno, so assert this here and if it ever changes we'll
339        // need to update codegen below.
340        assert_eq!(first.name.as_str(), "error");
341        src.push_str(" -> Result<");
342        // 1 == `Result<()>`, 2 == `Result<T>`, 3+ == `Result<(...)>`
343        if func.results.len() != 2 {
344            src.push_str("(");
345        }
346        for result in func.results.iter().skip(1) {
347            result.tref.render(src);
348            src.push_str(",");
349        }
350        if func.results.len() != 2 {
351            src.push_str(")");
352        }
353        src.push_str(">");
354    }
355
356    src.push_str("{");
357    for result in func.results.iter().skip(1) {
358        src.push_str("let mut ");
359        result.name.render(src);
360        src.push_str(" = MaybeUninit::uninit();");
361    }
362    if func.results.len() > 0 {
363        src.push_str("let rc = ");
364    }
365    src.push_str(module);
366    src.push_str("::");
367    src.push_str(&rust_name);
368    src.push_str("(");
369
370    // Forward all parameters, fetching the pointer/length as appropriate
371    for param in func.params.iter() {
372        match param.tref.type_().passed_by() {
373            TypePassedBy::Value(_) => param.name.render(src),
374            TypePassedBy::Pointer => unreachable!(
375                "unable to translate parameter `{}` of type `{}` in function `{}`",
376                param.name.as_str(),
377                param.tref.type_name(),
378                func.name.as_str()
379            ),
380            TypePassedBy::PointerLengthPair => {
381                param.name.render(src);
382                src.push_str(".as_ptr(), ");
383                param.name.render(src);
384                src.push_str(".len()");
385            }
386        }
387        src.push_str(",");
388    }
389
390    // Forward all out-pointers as trailing arguments
391    for result in func.results.iter().skip(1) {
392        result.name.render(src);
393        src.push_str(".as_mut_ptr(),");
394    }
395    src.push_str(");");
396
397    // Check the return value, and if successful load all of the out pointers
398    // assuming they were initialized (part of the wasi contract).
399    if func.results.len() > 0 {
400        src.push_str("if let Some(err) = Error::from_raw_error(rc) { ");
401        src.push_str("Err(err)");
402        src.push_str("} else {");
403        src.push_str("Ok(");
404        if func.results.len() != 2 {
405            src.push_str("(");
406        }
407        for result in func.results.iter().skip(1) {
408            result.name.render(src);
409            src.push_str(".assume_init(),");
410        }
411        if func.results.len() != 2 {
412            src.push_str(")");
413        }
414        src.push_str(") }");
415    }
416    src.push_str("}");
417}
418
419impl Render for InterfaceFunc {
420    fn render(&self, src: &mut String) {
421        rustdoc(&self.docs, src);
422        if self.name.as_str() != self.name.as_str().to_snake_case() {
423            src.push_str("#[link_name = \"");
424            src.push_str(self.name.as_str());
425            src.push_str("\"]\n");
426        }
427        src.push_str("pub fn ");
428        let mut name = String::new();
429        self.name.render(&mut name);
430        src.push_str(&name.to_snake_case());
431        src.push_str("(");
432        for param in self.params.iter() {
433            param.render(src);
434            src.push_str(",");
435        }
436        for result in self.results.iter().skip(1) {
437            result.name.render(src);
438            src.push_str(": *mut ");
439            result.tref.render(src);
440            src.push_str(",");
441        }
442        src.push_str(")");
443        if let Some(result) = self.results.get(0) {
444            src.push_str(" -> ");
445            result.render(src);
446        // special-case the `proc_exit` function for now to be "noreturn", and
447        // eventually we'll have an attribute in `*.witx` to specify this as
448        // well.
449        } else if self.name.as_str() == "proc_exit" {
450            src.push_str(" -> !");
451        }
452        src.push_str(";");
453    }
454}
455
456impl Render for InterfaceFuncParam {
457    fn render(&self, src: &mut String) {
458        let is_param = match self.position {
459            InterfaceFuncParamPosition::Param(_) => true,
460            _ => false,
461        };
462        match self.tref.type_().passed_by() {
463            // By-value arguments are passed as-is
464            TypePassedBy::Value(_) => {
465                if is_param {
466                    self.name.render(src);
467                    src.push_str(": ");
468                }
469                self.tref.render(src);
470            }
471            // Pointer arguments are passed with a `*mut` out in front
472            TypePassedBy::Pointer => {
473                if is_param {
474                    self.name.render(src);
475                    src.push_str(": ");
476                }
477                src.push_str("*mut ");
478                self.tref.render(src);
479            }
480            // ... and pointer/length arguments are passed with first their
481            // pointer and then their length, as the name would otherwise imply
482            TypePassedBy::PointerLengthPair => {
483                assert!(is_param);
484                src.push_str(self.name.as_str());
485                src.push_str("_ptr");
486                src.push_str(": ");
487                src.push_str("*const ");
488                match &*self.tref.type_() {
489                    Type::Array(x) => x.render(src),
490                    Type::Builtin(BuiltinType::String) => src.push_str("u8"),
491                    x => panic!("unexpected pointer length pair type {:?}", x),
492                }
493                src.push_str(", ");
494                src.push_str(self.name.as_str());
495                src.push_str("_len");
496                src.push_str(": ");
497                src.push_str("usize");
498            }
499        }
500    }
501}
502
503impl Render for Id {
504    fn render(&self, src: &mut String) {
505        match self.as_str() {
506            "in" => src.push_str("r#in"),
507            "type" => src.push_str("r#type"),
508            "yield" => src.push_str("r#yield"),
509            s => src.push_str(s),
510        }
511    }
512}
513
514fn render_handle(src: &mut String, name: &str, _h: &HandleDatatype) {
515    src.push_str(&format!("pub type {} = u32;", name.to_camel_case()));
516}
517
518fn rustdoc(docs: &str, dst: &mut String) {
519    if docs.trim().is_empty() {
520        return;
521    }
522    for line in docs.lines() {
523        dst.push_str("/// ");
524        dst.push_str(line);
525        dst.push_str("\n");
526    }
527}
528
529fn rustdoc_params(docs: &[InterfaceFuncParam], header: &str, dst: &mut String) {
530    let docs = docs
531        .iter()
532        .filter(|param| param.docs.trim().len() > 0)
533        .collect::<Vec<_>>();
534    if docs.len() == 0 {
535        return;
536    }
537
538    dst.push_str("///\n");
539    dst.push_str("/// ## ");
540    dst.push_str(header);
541    dst.push_str("\n");
542    dst.push_str("///\n");
543
544    for param in docs {
545        for (i, line) in param.docs.lines().enumerate() {
546            dst.push_str("/// ");
547            if i == 0 {
548                dst.push_str("* `");
549                param.name.render(dst);
550                dst.push_str("` - ");
551            } else {
552                dst.push_str("  ");
553            }
554            dst.push_str(line);
555            dst.push_str("\n");
556        }
557    }
558}
559
560fn struct_contains_union(s: &StructDatatype) -> bool {
561    s.members
562        .iter()
563        .any(|member| type_contains_union(&member.tref.type_()))
564}
565
566fn type_contains_union(ty: &Type) -> bool {
567    match ty {
568        Type::Union(_) => true,
569        Type::Array(tref) => type_contains_union(&tref.type_()),
570        Type::Struct(st) => struct_contains_union(st),
571        _ => false,
572    }
573}