witx_codegen/rust/
function.rs1use std::io::Write;
2
3use super::*;
4
5impl RustGenerator {
6 pub fn define_func<T: Write>(
7 w: &mut PrettyWriter<T>,
8 module_name: &str,
9 func_witx: &witx::Function,
10 ) -> Result<(), Error> {
11 assert_eq!(func_witx.abi, witx::Abi::Preview1);
12 let name = func_witx.name.as_str().to_string();
13 let params_witx = &func_witx.params;
14 let mut params = vec![];
15 for param_witx in params_witx {
16 let param_name = param_witx.name.as_str();
17 let param_type = ASType::from(¶m_witx.tref);
18 params.push((param_name.to_string(), param_type));
19 }
20
21 let results_witx = &func_witx.results;
22 assert_eq!(results_witx.len(), 1);
23 let result_witx = &results_witx[0];
24 let result = ASType::from(&result_witx.tref);
25 let result = match result {
26 ASType::Result(result) => result,
27 _ => unreachable!(),
28 };
29
30 let ok_type = result.ok_type.clone();
31
32 let docs = &func_witx.docs;
33 if !docs.is_empty() {
34 Self::write_docs(w, docs)?;
35 }
36
37 let mut params_decomposed = vec![];
38
39 for param in ¶ms {
40 let mut decomposed = param.1.decompose(¶m.0, false);
41 params_decomposed.append(&mut decomposed);
42 }
43
44 let mut results = vec![];
45 if let ASType::Tuple(tuple_members) = ok_type.as_ref().leaf() {
48 for (i, tuple_member) in tuple_members.iter().enumerate() {
49 let name = format!("result{}_ptr", i);
50 results.push((name, tuple_member.type_.clone()));
51 }
52 } else {
53 let name = "result_ptr";
54 results.push((name.to_string(), ok_type));
55 }
56 let mut results_decomposed = vec![];
57 for result in &results {
58 let mut decomposed = result.1.decompose(&result.0, true);
59 results_decomposed.append(&mut decomposed);
60 }
61
62 Self::define_func_raw(
63 w,
64 module_name,
65 &name,
66 ¶ms_decomposed,
67 &results_decomposed,
68 &result,
69 )?;
70
71 let signature_witx = func_witx.wasm_signature(witx::CallMode::DefinedImport);
72 let params_count_witx = signature_witx.params.len() + signature_witx.results.len();
73 assert_eq!(
74 params_count_witx,
75 params_decomposed.len() + results_decomposed.len() + 1
76 );
77
78 Ok(())
79 }
80
81 fn define_func_raw<T: Write>(
82 w: &mut PrettyWriter<T>,
83 module_name: &str,
84 name: &str,
85 params_decomposed: &[ASTypeDecomposed],
86 results_decomposed: &[ASTypeDecomposed],
87 result: &ASResult,
88 ) -> Result<(), Error> {
89 let results_decomposed_deref = results_decomposed
90 .iter()
91 .map(|result_ptr_type| match result_ptr_type.type_.as_ref() {
92 ASType::MutPtr(result_type) => ASTypeDecomposed {
93 name: result_ptr_type.name.clone(),
94 type_: result_type.clone(),
95 },
96 _ => panic!("Result type is not a pointer"),
97 })
98 .collect::<Vec<_>>();
99 let results_set = results_decomposed_deref
100 .iter()
101 .map(|result| result.type_.as_lang())
102 .collect::<Vec<_>>();
103 let rust_fn_result_str = match results_set.len() {
104 0 => "()".to_string(),
105 1 => results_set[0].clone(),
106 _ => format!("({})", results_set.join(", ")),
107 };
108 w.indent()?.write(format!("pub fn {}(", name.as_fn()))?;
109 if !params_decomposed.is_empty() || !results_decomposed.is_empty() {
110 w.eol()?;
111 }
112 for param in params_decomposed {
113 w.write_line_continued(format!(
114 "{}: {},",
115 param.name.as_var(),
116 param.type_.as_lang(),
117 ))?;
118 }
119 w.write_line(format!(") -> Result<{}, Error> {{", rust_fn_result_str))?;
120 {
121 let mut w = w.new_block();
122
123 {
125 w.write_line(format!("#[link(wasm_import_module = \"{}\")]", module_name))?;
126 w.write_line("extern \"C\" {")?;
127 {
128 let mut w = w.new_block();
129 w.indent()?.write(format!("fn {}(", name.as_fn()))?;
130 if !params_decomposed.is_empty() {
131 w.eol()?;
132 }
133 for param in params_decomposed.iter().chain(results_decomposed.iter()) {
134 w.write_line_continued(format!(
135 "{}: {},",
136 param.name.as_var(),
137 param.type_.as_lang(),
138 ))?;
139 }
140 w.write_line(format!(") -> {};", result.error_type.as_lang()))?;
141 }
142 w.write_line("}")?;
143 }
144
145 for result in &results_decomposed_deref {
147 w.write_line(format!(
148 "let mut {} = std::mem::MaybeUninit::uninit();",
149 result.name.as_var()
150 ))?;
151 }
152
153 w.write_line(format!("let res = unsafe {{ {}(", name.as_fn()))?;
154 for param in params_decomposed {
155 w.write_line_continued(format!("{},", param.name.as_var()))?;
156 }
157 for result in results_decomposed_deref.iter() {
158 w.write_line_continued(format!("{}.as_mut_ptr(),", result.name.as_var()))?;
159 }
160 w.write_line(")};")?;
161 w.write_lines(
162 "if res != 0 {
163 return Err(Error::WasiError(res as _));
164}",
165 )?;
166 let res_str = match results_decomposed.len() {
167 0 => "()".to_string(),
168 1 => format!(
169 "unsafe {{ {}.assume_init() }}",
170 results_decomposed_deref[0].name.as_var()
171 ),
172 _ => format!(
173 "unsafe {{ ({}) }}",
174 results_decomposed_deref
175 .iter()
176 .map(|result| format!("{}.assume_init()", result.name.as_var()))
177 .collect::<Vec<_>>()
178 .join(", ")
179 ),
180 };
181 w.write_line(format!("Ok({})", res_str))?;
182 };
183 w.write_line("}")?;
184 w.eob()?;
185
186 Ok(())
187 }
188}