witx_codegen/rust/
function.rs

1use std::io::Write;
2
3use super::*;
4
5impl RustGenerator {
6    pub fn define_func<T: Write>(
7        w: &mut PrettyWriter<T>,
8        module_name: &str,
9        func_witx: &witx::Function,
10    ) -> Result<(), Error> {
11        assert_eq!(func_witx.abi, witx::Abi::Preview1);
12        let name = func_witx.name.as_str().to_string();
13        let params_witx = &func_witx.params;
14        let mut params = vec![];
15        for param_witx in params_witx {
16            let param_name = param_witx.name.as_str();
17            let param_type = ASType::from(&param_witx.tref);
18            params.push((param_name.to_string(), param_type));
19        }
20
21        let results_witx = &func_witx.results;
22        assert_eq!(results_witx.len(), 1);
23        let result_witx = &results_witx[0];
24        let result = ASType::from(&result_witx.tref);
25        let result = match result {
26            ASType::Result(result) => result,
27            _ => unreachable!(),
28        };
29
30        let ok_type = result.ok_type.clone();
31
32        let docs = &func_witx.docs;
33        if !docs.is_empty() {
34            Self::write_docs(w, docs)?;
35        }
36
37        let mut params_decomposed = vec![];
38
39        for param in &params {
40            let mut decomposed = param.1.decompose(&param.0, false);
41            params_decomposed.append(&mut decomposed);
42        }
43
44        let mut results = vec![];
45        // A tuple in a result is expanded into additional parameters, transformed to
46        // pointers
47        if let ASType::Tuple(tuple_members) = ok_type.as_ref().leaf() {
48            for (i, tuple_member) in tuple_members.iter().enumerate() {
49                let name = format!("result{}_ptr", i);
50                results.push((name, tuple_member.type_.clone()));
51            }
52        } else {
53            let name = "result_ptr";
54            results.push((name.to_string(), ok_type));
55        }
56        let mut results_decomposed = vec![];
57        for result in &results {
58            let mut decomposed = result.1.decompose(&result.0, true);
59            results_decomposed.append(&mut decomposed);
60        }
61
62        Self::define_func_raw(
63            w,
64            module_name,
65            &name,
66            &params_decomposed,
67            &results_decomposed,
68            &result,
69        )?;
70
71        let signature_witx = func_witx.wasm_signature(witx::CallMode::DefinedImport);
72        let params_count_witx = signature_witx.params.len() + signature_witx.results.len();
73        assert_eq!(
74            params_count_witx,
75            params_decomposed.len() + results_decomposed.len() + 1
76        );
77
78        Ok(())
79    }
80
81    fn define_func_raw<T: Write>(
82        w: &mut PrettyWriter<T>,
83        module_name: &str,
84        name: &str,
85        params_decomposed: &[ASTypeDecomposed],
86        results_decomposed: &[ASTypeDecomposed],
87        result: &ASResult,
88    ) -> Result<(), Error> {
89        let results_decomposed_deref = results_decomposed
90            .iter()
91            .map(|result_ptr_type| match result_ptr_type.type_.as_ref() {
92                ASType::MutPtr(result_type) => ASTypeDecomposed {
93                    name: result_ptr_type.name.clone(),
94                    type_: result_type.clone(),
95                },
96                _ => panic!("Result type is not a pointer"),
97            })
98            .collect::<Vec<_>>();
99        let results_set = results_decomposed_deref
100            .iter()
101            .map(|result| result.type_.as_lang())
102            .collect::<Vec<_>>();
103        let rust_fn_result_str = match results_set.len() {
104            0 => "()".to_string(),
105            1 => results_set[0].clone(),
106            _ => format!("({})", results_set.join(", ")),
107        };
108        w.indent()?.write(format!("pub fn {}(", name.as_fn()))?;
109        if !params_decomposed.is_empty() || !results_decomposed.is_empty() {
110            w.eol()?;
111        }
112        for param in params_decomposed {
113            w.write_line_continued(format!(
114                "{}: {},",
115                param.name.as_var(),
116                param.type_.as_lang(),
117            ))?;
118        }
119        w.write_line(format!(") -> Result<{}, Error> {{", rust_fn_result_str))?;
120        {
121            let mut w = w.new_block();
122
123            // Inner (raw) definition
124            {
125                w.write_line(format!("#[link(wasm_import_module = \"{}\")]", module_name))?;
126                w.write_line("extern \"C\" {")?;
127                {
128                    let mut w = w.new_block();
129                    w.indent()?.write(format!("fn {}(", name.as_fn()))?;
130                    if !params_decomposed.is_empty() {
131                        w.eol()?;
132                    }
133                    for param in params_decomposed.iter().chain(results_decomposed.iter()) {
134                        w.write_line_continued(format!(
135                            "{}: {},",
136                            param.name.as_var(),
137                            param.type_.as_lang(),
138                        ))?;
139                    }
140                    w.write_line(format!(") -> {};", result.error_type.as_lang()))?;
141                }
142                w.write_line("}")?;
143            }
144
145            // Wrapper
146            for result in &results_decomposed_deref {
147                w.write_line(format!(
148                    "let mut {} = std::mem::MaybeUninit::uninit();",
149                    result.name.as_var()
150                ))?;
151            }
152
153            w.write_line(format!("let res = unsafe {{ {}(", name.as_fn()))?;
154            for param in params_decomposed {
155                w.write_line_continued(format!("{},", param.name.as_var()))?;
156            }
157            for result in results_decomposed_deref.iter() {
158                w.write_line_continued(format!("{}.as_mut_ptr(),", result.name.as_var()))?;
159            }
160            w.write_line(")};")?;
161            w.write_lines(
162                "if res != 0 {
163    return Err(Error::WasiError(res as _));
164}",
165            )?;
166            let res_str = match results_decomposed.len() {
167                0 => "()".to_string(),
168                1 => format!(
169                    "unsafe {{ {}.assume_init() }}",
170                    results_decomposed_deref[0].name.as_var()
171                ),
172                _ => format!(
173                    "unsafe {{ ({}) }}",
174                    results_decomposed_deref
175                        .iter()
176                        .map(|result| format!("{}.assume_init()", result.name.as_var()))
177                        .collect::<Vec<_>>()
178                        .join(", ")
179                ),
180            };
181            w.write_line(format!("Ok({})", res_str))?;
182        };
183        w.write_line("}")?;
184        w.eob()?;
185
186        Ok(())
187    }
188}