1use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8pub struct CodeGenerator {
10 output: String,
11 indent_level: usize,
12}
13
14impl CodeGenerator {
15 pub fn new() -> Self {
17 Self {
18 output: String::new(),
19 indent_level: 0,
20 }
21 }
22
23 pub fn output(self) -> String {
25 self.output
26 }
27
28 pub fn write_module_header(&mut self) -> Result<()> {
30 writeln!(
31 &mut self.output,
32 "// Generated code from Varlink IDL files."
33 )?;
34 writeln!(&mut self.output)?;
35 writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36 writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37 writeln!(&mut self.output)?;
38 Ok(())
39 }
40
41 pub fn generate_interface(
43 &mut self,
44 interface: &Interface<'_>,
45 skip_module_header: bool,
46 ) -> Result<()> {
47 if skip_module_header {
48 self.write_interface_comment(interface)?;
49 } else {
50 self.write_header(interface)?;
51 self.writeln("use serde::{Deserialize, Serialize};")?;
52 self.writeln("use zlink::{proxy, ReplyError};")?;
54 self.writeln("")?;
55 }
56
57 self.generate_proxy_trait(interface)?;
59 self.writeln("")?;
60
61 self.generate_output_structs(interface)?;
63
64 for custom_type in interface.custom_types() {
66 self.generate_custom_type(custom_type)?;
67 self.writeln("")?;
68 }
69
70 if interface.errors().count() > 0 {
72 self.generate_errors(interface)?;
73 self.writeln("")?;
74 }
75
76 Ok(())
77 }
78
79 fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80 writeln!(
81 &mut self.output,
82 "// Generated code for Varlink interface `{}`.",
83 interface.name()
84 )?;
85 writeln!(&mut self.output)?;
86 Ok(())
87 }
88
89 fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90 writeln!(
91 &mut self.output,
92 "//! Generated code for Varlink interface `{}`.",
93 interface.name()
94 )?;
95 writeln!(&mut self.output, "//!",)?;
96 writeln!(
97 &mut self.output,
98 "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99 )?;
100 writeln!(
101 &mut self.output,
102 "//! You may prefer to adapt it, instead of using it verbatim.",
103 )?;
104 writeln!(&mut self.output)?;
105
106 for comment in interface.comments() {
108 writeln!(&mut self.output, "//! {}", comment.text())?;
109 }
110 writeln!(&mut self.output)?;
111
112 Ok(())
113 }
114
115 fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116 match custom_type {
117 CustomType::Object(obj) => self.generate_custom_object(obj),
118 CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119 }
120 }
121
122 fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123 for comment in obj.comments() {
125 self.writeln(&format!("/// {}", comment.text()))?;
126 }
127
128 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129 self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130 self.indent();
131
132 for field in obj.fields() {
133 self.generate_field(field)?;
134 }
135
136 self.dedent();
137 self.writeln("}")?;
138
139 Ok(())
140 }
141
142 fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143 for comment in enum_type.comments() {
145 self.writeln(&format!("/// {}", comment.text()))?;
146 }
147
148 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149 self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150 self.writeln(&format!(
151 "pub enum {} {{",
152 enum_type.name().to_pascal_case()
153 ))?;
154 self.indent();
155
156 for variant in enum_type.variants() {
157 for comment in variant.comments() {
159 self.writeln(&format!("/// {}", comment.text()))?;
160 }
161
162 self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164 }
165
166 self.dedent();
167 self.writeln("}")?;
168
169 Ok(())
170 }
171
172 fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173 for comment in field.comments() {
175 self.writeln(&format!("/// {}", comment.text()))?;
176 }
177
178 let field_name = field.name().to_snake_case();
179 let rust_type = self.type_to_rust(field.ty())?;
180
181 let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183 rust_type
185 } else {
186 rust_type
187 };
188
189 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191 format!("#[serde(rename = \"{}\")]", field.name())
192 } else {
193 String::new()
194 };
195
196 if !field_name_attr.is_empty() {
197 self.writeln(&field_name_attr)?;
198 }
199
200 let safe_field_name = if is_rust_keyword(&field_name) {
201 format!("r#{}", field_name)
202 } else {
203 field_name
204 };
205
206 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208 Ok(())
209 }
210
211 fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212 self.writeln("/// Errors that can occur in this interface.")?;
213 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215 self.writeln(&format!(
216 "pub enum {}Error {{",
217 interface_name_to_rust(interface.name())
218 ))?;
219 self.indent();
220
221 for error in interface.errors() {
222 for comment in error.comments() {
224 self.writeln(&format!("/// {}", comment.text()))?;
225 }
226
227 let variant_name = error.name().to_pascal_case();
228 if error.fields().count() == 0 {
229 self.writeln(&format!("{},", variant_name))?;
230 } else {
231 self.writeln(&format!("{} {{", variant_name))?;
232 self.indent();
233 for field in error.fields() {
234 self.generate_error_field(field)?;
235 }
236 self.dedent();
237 self.writeln("},")?;
238 }
239 }
240
241 self.dedent();
242 self.writeln("}")?;
243
244 Ok(())
245 }
246
247 fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249 for method in interface.methods() {
250 if method.outputs().count() > 0 {
254 let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256 self.writeln(&format!(
258 "/// Output parameters for the {} method.",
259 method.name()
260 ))?;
261
262 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266 if needs_lifetime {
267 self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268 } else {
269 self.writeln(&format!("pub struct {} {{", struct_name))?;
270 }
271 self.indent();
272
273 for output in method.outputs() {
274 let field_name = output.name().to_snake_case();
275 let rust_type = if needs_lifetime {
277 self.type_to_rust_output(output.ty())?
278 } else {
279 self.type_to_rust(output.ty())?
280 };
281
282 if needs_lifetime && type_needs_borrow(output.ty()) {
284 self.writeln("#[serde(borrow)]")?;
285 }
286
287 if field_name != output.name() {
288 self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289 }
290
291 let safe_field_name = if is_rust_keyword(&field_name) {
292 format!("r#{}", field_name)
293 } else {
294 field_name
295 };
296
297 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298 }
299
300 self.dedent();
301 self.writeln("}")?;
302 self.writeln("")?;
303 }
304 }
305
306 Ok(())
307 }
308
309 fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310 let trait_name = interface_name_to_rust(interface.name());
311
312 let error_type = if interface.errors().count() > 0 {
314 format!("{}Error", interface_name_to_rust(interface.name()))
315 } else {
316 let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319 self.writeln("/// Stub error type for interface without errors.")?;
321 self.writeln("///")?;
322 self.writeln("/// This is an empty enum that can never be instantiated.")?;
323 self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326 self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327 self.writeln("")?;
328
329 stub_error_name
330 };
331
332 self.writeln("/// Proxy trait for calling methods on the interface.")?;
333 self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334 self.writeln(&format!("pub trait {} {{", trait_name))?;
335 self.indent();
336
337 for method in interface.methods() {
338 self.generate_proxy_method_signature(method, &error_type)?;
339 }
340
341 self.dedent();
342 self.writeln("}")?;
343
344 Ok(())
345 }
346
347 fn generate_proxy_method_signature(
348 &mut self,
349 method: &Method<'_>,
350 error_type: &str,
351 ) -> Result<()> {
352 for comment in method.comments() {
354 self.writeln(&format!("/// {}", comment.text()))?;
355 }
356
357 let method_name = method.name().to_snake_case();
358 let safe_method_name = if is_rust_keyword(&method_name) {
359 format!("r#{}", method_name)
360 } else {
361 method_name
362 };
363
364 let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367 for param in method.inputs() {
369 let param_name = param.name().to_snake_case();
370 let safe_param_name = if is_rust_keyword(¶m_name) {
371 format!("r#{}", param_name)
372 } else {
373 param_name
374 };
375 let rust_type = self.type_to_rust_param(param.ty())?;
377 write!(&mut signature, ", {}: {}", safe_param_name, rust_type)?;
378 }
379
380 signature.push_str(") -> zlink::Result<Result<");
381
382 let output_count = method.outputs().count();
384 if output_count == 0 {
385 signature.push_str("()");
386 } else {
387 let struct_name = format!("{}Output", method.name().to_pascal_case());
391 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
393 if needs_lifetime {
394 signature.push_str(&format!("{}<'_>", struct_name));
395 } else {
396 signature.push_str(&struct_name);
397 }
398 }
399
400 write!(&mut signature, ", {}>>", error_type)?;
401 signature.push(';');
402
403 self.writeln(&signature)?;
404
405 Ok(())
406 }
407
408 fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
409 for comment in field.comments() {
411 self.writeln(&format!("/// {}", comment.text()))?;
412 }
413
414 let field_name = field.name().to_snake_case();
415 let rust_type = self.type_to_rust(field.ty())?;
416
417 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
419 format!("#[serde(rename = \"{}\")]", field.name())
420 } else {
421 String::new()
422 };
423
424 if !field_name_attr.is_empty() {
425 self.writeln(&field_name_attr)?;
426 }
427
428 let safe_field_name = if is_rust_keyword(&field_name) {
429 format!("r#{}", field_name)
430 } else {
431 field_name
432 };
433
434 self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
435
436 Ok(())
437 }
438
439 fn type_to_rust(&self, ty: &Type) -> Result<String> {
440 type_to_rust(ty)
441 }
442
443 fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
444 type_to_rust_param(ty)
445 }
446
447 fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
448 type_to_rust_output(ty)
449 }
450
451 fn writeln(&mut self, s: &str) -> Result<()> {
452 self.write(s)?;
453 writeln!(&mut self.output)?;
454 Ok(())
455 }
456
457 fn write(&mut self, s: &str) -> Result<()> {
458 for _ in 0..self.indent_level {
459 write!(&mut self.output, " ")?;
460 }
461 write!(&mut self.output, "{}", s)?;
462 Ok(())
463 }
464
465 fn indent(&mut self) {
466 self.indent_level += 1;
467 }
468
469 fn dedent(&mut self) {
470 if self.indent_level > 0 {
471 self.indent_level -= 1;
472 }
473 }
474}
475
476impl Default for CodeGenerator {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482fn type_to_rust(ty: &Type) -> Result<String> {
483 Ok(match ty {
484 Type::Bool => "bool".to_string(),
485 Type::Int => "i64".to_string(),
486 Type::Float => "f64".to_string(),
487 Type::String => "String".to_string(),
488 Type::Object(_fields) => {
489 "serde_json::Value".to_string()
493 }
494 Type::Enum(_variants) => {
495 "String".to_string()
497 }
498 Type::Array(elem_type) => {
499 let elem_rust = type_to_rust(elem_type.inner())?;
500 format!("Vec<{}>", elem_rust)
501 }
502 Type::Map(value_type) => {
503 let value_rust = type_to_rust(value_type.inner())?;
504 format!("std::collections::HashMap<String, {}>", value_rust)
505 }
506 Type::ForeignObject => "serde_json::Value".to_string(),
507 Type::Optional(inner_type) => {
508 let inner_rust = type_to_rust(inner_type.inner())?;
509 format!("Option<{}>", inner_rust)
510 }
511 Type::Custom(name) => name.to_pascal_case(),
512 })
513}
514
515fn type_to_rust_param(ty: &Type) -> Result<String> {
516 Ok(match ty {
517 Type::Bool => "bool".to_string(),
518 Type::Int => "i64".to_string(),
519 Type::Float => "f64".to_string(),
520 Type::String => "&str".to_string(),
521 Type::Object(_fields) => {
522 "&serde_json::Value".to_string()
524 }
525 Type::Enum(_variants) => {
526 "&str".to_string()
528 }
529 Type::Array(elem_type) => {
530 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
532 format!("&[{}]", elem_rust)
533 }
534 Type::Map(value_type) => {
535 let value_rust = type_to_rust_param_elem(value_type.inner())?;
537 format!("&std::collections::HashMap<&str, {}>", value_rust)
538 }
539 Type::ForeignObject => "&serde_json::Value".to_string(),
540 Type::Optional(inner_type) => {
541 let inner_rust = type_to_rust_param(inner_type.inner())?;
542 format!("Option<{}>", inner_rust)
544 }
545 Type::Custom(name) => format!("&{}", name.to_pascal_case()),
546 })
547}
548
549fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
552 Ok(match ty {
553 Type::Bool => "bool".to_string(),
554 Type::Int => "i64".to_string(),
555 Type::Float => "f64".to_string(),
556 Type::String => "&str".to_string(),
557 Type::Object(_fields) => "serde_json::Value".to_string(),
558 Type::Enum(_variants) => "&str".to_string(),
559 Type::Array(elem_type) => {
560 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
561 format!("Vec<{}>", elem_rust)
562 }
563 Type::Map(value_type) => {
564 let value_rust = type_to_rust_param_elem(value_type.inner())?;
565 format!("std::collections::HashMap<&str, {}>", value_rust)
566 }
567 Type::ForeignObject => "serde_json::Value".to_string(),
568 Type::Optional(inner_type) => {
569 let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
570 format!("Option<{}>", inner_rust)
571 }
572 Type::Custom(name) => name.to_pascal_case(),
573 })
574}
575
576fn type_to_rust_output(ty: &Type) -> Result<String> {
577 Ok(match ty {
578 Type::Bool => "bool".to_string(),
579 Type::Int => "i64".to_string(),
580 Type::Float => "f64".to_string(),
581 Type::String => "&'a str".to_string(),
582 Type::Object(_fields) => {
583 "serde_json::Value".to_string()
585 }
586 Type::Enum(_variants) => {
587 "&'a str".to_string()
589 }
590 Type::Array(elem_type) => {
591 let elem_rust = match elem_type.inner() {
593 Type::String => "&'a str".to_string(),
594 Type::Enum(_) => "&'a str".to_string(),
595 _ => type_to_rust(elem_type.inner())?,
596 };
597 format!("Vec<{}>", elem_rust)
598 }
599 Type::Map(value_type) => {
600 let value_rust = match value_type.inner() {
602 Type::String => "&'a str".to_string(),
603 Type::Enum(_) => "&'a str".to_string(),
604 _ => type_to_rust(value_type.inner())?,
605 };
606 format!("std::collections::HashMap<&'a str, {}>", value_rust)
607 }
608 Type::ForeignObject => "serde_json::Value".to_string(),
609 Type::Optional(inner_type) => {
610 let inner_rust = type_to_rust_output(inner_type.inner())?;
613 format!("Option<{}>", inner_rust)
614 }
615 Type::Custom(name) => name.to_pascal_case(),
616 })
617}
618
619fn interface_name_to_rust(name: &str) -> String {
620 name.split('.').next_back().unwrap_or(name).to_pascal_case()
622}
623
624fn type_needs_lifetime(ty: &Type) -> bool {
625 match ty {
626 Type::String => true,
627 Type::Enum(_) => true, Type::Array(inner) => type_needs_lifetime(inner.inner()),
629 Type::Map(_) => {
630 true
632 }
633 Type::Optional(inner) => type_needs_lifetime(inner.inner()),
634 _ => false,
635 }
636}
637
638fn type_needs_borrow(ty: &Type) -> bool {
639 match ty {
640 Type::String => true,
641 Type::Enum(_) => true, Type::Array(inner) => type_needs_borrow(inner.inner()),
643 Type::Map(_) => {
644 true
646 }
647 Type::Optional(inner) => type_needs_borrow(inner.inner()),
648 _ => false,
649 }
650}
651
652fn is_rust_keyword(s: &str) -> bool {
653 [
654 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
655 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
656 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
657 "true", "type", "unsafe", "use", "where", "while",
658 ]
659 .contains(&s)
660}