zlink_codegen/
codegen.rs

1//! Code generation implementation.
2
3use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8/// Code generator for Varlink interfaces.
9pub struct CodeGenerator {
10    output: String,
11    indent_level: usize,
12}
13
14impl CodeGenerator {
15    /// Create a new code generator.
16    pub fn new() -> Self {
17        Self {
18            output: String::new(),
19            indent_level: 0,
20        }
21    }
22
23    /// Get the generated output.
24    pub fn output(self) -> String {
25        self.output
26    }
27
28    /// Write module-level header for multiple interfaces.
29    pub fn write_module_header(&mut self) -> Result<()> {
30        writeln!(
31            &mut self.output,
32            "// Generated code from Varlink IDL files."
33        )?;
34        writeln!(&mut self.output)?;
35        writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36        writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37        writeln!(&mut self.output)?;
38        Ok(())
39    }
40
41    /// Generate code for an interface.
42    pub fn generate_interface(
43        &mut self,
44        interface: &Interface<'_>,
45        skip_module_header: bool,
46    ) -> Result<()> {
47        if skip_module_header {
48            self.write_interface_comment(interface)?;
49        } else {
50            self.write_header(interface)?;
51            self.writeln("use serde::{Deserialize, Serialize};")?;
52            // Always import ReplyError since we generate a stub error type when there are no errors
53            self.writeln("use zlink::{proxy, ReplyError};")?;
54            self.writeln("")?;
55        }
56
57        // Generate proxy trait using the proxy macro.
58        self.generate_proxy_trait(interface)?;
59        self.writeln("")?;
60
61        // Generate output structs for methods.
62        self.generate_output_structs(interface)?;
63
64        // Generate custom types.
65        for custom_type in interface.custom_types() {
66            self.generate_custom_type(custom_type)?;
67            self.writeln("")?;
68        }
69
70        // Generate errors.
71        if interface.errors().count() > 0 {
72            self.generate_errors(interface)?;
73            self.writeln("")?;
74        }
75
76        Ok(())
77    }
78
79    fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80        writeln!(
81            &mut self.output,
82            "// Generated code for Varlink interface `{}`.",
83            interface.name()
84        )?;
85        writeln!(&mut self.output)?;
86        Ok(())
87    }
88
89    fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90        writeln!(
91            &mut self.output,
92            "//! Generated code for Varlink interface `{}`.",
93            interface.name()
94        )?;
95        writeln!(&mut self.output, "//!",)?;
96        writeln!(
97            &mut self.output,
98            "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99        )?;
100        writeln!(
101            &mut self.output,
102            "//! You may prefer to adapt it, instead of using it verbatim.",
103        )?;
104        writeln!(&mut self.output)?;
105
106        // Add interface comments if any.
107        for comment in interface.comments() {
108            writeln!(&mut self.output, "//! {}", comment.text())?;
109        }
110        writeln!(&mut self.output)?;
111
112        Ok(())
113    }
114
115    fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116        match custom_type {
117            CustomType::Object(obj) => self.generate_custom_object(obj),
118            CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119        }
120    }
121
122    fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123        // Add comments.
124        for comment in obj.comments() {
125            self.writeln(&format!("/// {}", comment.text()))?;
126        }
127
128        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129        self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130        self.indent();
131
132        for field in obj.fields() {
133            self.generate_field(field)?;
134        }
135
136        self.dedent();
137        self.writeln("}")?;
138
139        Ok(())
140    }
141
142    fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143        // Add comments.
144        for comment in enum_type.comments() {
145            self.writeln(&format!("/// {}", comment.text()))?;
146        }
147
148        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149        self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150        self.writeln(&format!(
151            "pub enum {} {{",
152            enum_type.name().to_pascal_case()
153        ))?;
154        self.indent();
155
156        for variant in enum_type.variants() {
157            // Add variant comments.
158            for comment in variant.comments() {
159                self.writeln(&format!("/// {}", comment.text()))?;
160            }
161
162            // Varlink enum variants don't have explicit values, just names.
163            self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164        }
165
166        self.dedent();
167        self.writeln("}")?;
168
169        Ok(())
170    }
171
172    fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173        // Add field comments.
174        for comment in field.comments() {
175            self.writeln(&format!("/// {}", comment.text()))?;
176        }
177
178        let field_name = field.name().to_snake_case();
179        let rust_type = self.type_to_rust(field.ty())?;
180
181        // Check if the field type is optional.
182        let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183            // The type_to_rust will already wrap in Option
184            rust_type
185        } else {
186            rust_type
187        };
188
189        // Handle field name if it's a Rust keyword.
190        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191            format!("#[serde(rename = \"{}\")]", field.name())
192        } else {
193            String::new()
194        };
195
196        if !field_name_attr.is_empty() {
197            self.writeln(&field_name_attr)?;
198        }
199
200        let safe_field_name = if is_rust_keyword(&field_name) {
201            format!("r#{}", field_name)
202        } else {
203            field_name
204        };
205
206        self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208        Ok(())
209    }
210
211    fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212        self.writeln("/// Errors that can occur in this interface.")?;
213        self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214        self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215        self.writeln(&format!(
216            "pub enum {}Error {{",
217            interface_name_to_rust(interface.name())
218        ))?;
219        self.indent();
220
221        for error in interface.errors() {
222            // Add error comments.
223            for comment in error.comments() {
224                self.writeln(&format!("/// {}", comment.text()))?;
225            }
226
227            let variant_name = error.name().to_pascal_case();
228            if error.fields().count() == 0 {
229                self.writeln(&format!("{},", variant_name))?;
230            } else {
231                self.writeln(&format!("{} {{", variant_name))?;
232                self.indent();
233                for field in error.fields() {
234                    self.generate_error_field(field)?;
235                }
236                self.dedent();
237                self.writeln("},")?;
238            }
239        }
240
241        self.dedent();
242        self.writeln("}")?;
243
244        Ok(())
245    }
246
247    /// Generate output structs for all methods in the `interface`.
248    fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249        for method in interface.methods() {
250            // Generate output struct for any method with at least one output parameter.
251            // Varlink output parameters are always named, so we need a struct even for single
252            // outputs.
253            if method.outputs().count() > 0 {
254                let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256                // Add method comments if available
257                self.writeln(&format!(
258                    "/// Output parameters for the {} method.",
259                    method.name()
260                ))?;
261
262                // Add lifetime parameter for output structs that need it
263                let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265                self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266                if needs_lifetime {
267                    self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268                } else {
269                    self.writeln(&format!("pub struct {} {{", struct_name))?;
270                }
271                self.indent();
272
273                for output in method.outputs() {
274                    let field_name = output.name().to_snake_case();
275                    // Use reference types for output parameters where appropriate
276                    let rust_type = if needs_lifetime {
277                        self.type_to_rust_output(output.ty())?
278                    } else {
279                        self.type_to_rust(output.ty())?
280                    };
281
282                    // Add #[serde(borrow)] for fields that need it
283                    if needs_lifetime && type_needs_borrow(output.ty()) {
284                        self.writeln("#[serde(borrow)]")?;
285                    }
286
287                    if field_name != output.name() {
288                        self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289                    }
290
291                    let safe_field_name = if is_rust_keyword(&field_name) {
292                        format!("r#{}", field_name)
293                    } else {
294                        field_name
295                    };
296
297                    self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298                }
299
300                self.dedent();
301                self.writeln("}")?;
302                self.writeln("")?;
303            }
304        }
305
306        Ok(())
307    }
308
309    fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310        let trait_name = interface_name_to_rust(interface.name());
311
312        // Generate a stub error type if there are no errors in the interface
313        let error_type = if interface.errors().count() > 0 {
314            format!("{}Error", interface_name_to_rust(interface.name()))
315        } else {
316            // Generate a stub error type for interfaces without errors
317            let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319            // Generate the stub error type before the proxy trait
320            self.writeln("/// Stub error type for interface without errors.")?;
321            self.writeln("///")?;
322            self.writeln("/// This is an empty enum that can never be instantiated.")?;
323            self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324            self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325            self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326            self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327            self.writeln("")?;
328
329            stub_error_name
330        };
331
332        self.writeln("/// Proxy trait for calling methods on the interface.")?;
333        self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334        self.writeln(&format!("pub trait {} {{", trait_name))?;
335        self.indent();
336
337        for method in interface.methods() {
338            self.generate_proxy_method_signature(method, &error_type)?;
339        }
340
341        self.dedent();
342        self.writeln("}")?;
343
344        Ok(())
345    }
346
347    fn generate_proxy_method_signature(
348        &mut self,
349        method: &Method<'_>,
350        error_type: &str,
351    ) -> Result<()> {
352        // Add method comments.
353        for comment in method.comments() {
354            self.writeln(&format!("/// {}", comment.text()))?;
355        }
356
357        let method_name = method.name().to_snake_case();
358        let safe_method_name = if is_rust_keyword(&method_name) {
359            format!("r#{}", method_name)
360        } else {
361            method_name
362        };
363
364        // Generate method signature.
365        let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367        // Add input parameters.
368        for param in method.inputs() {
369            let param_name = param.name().to_snake_case();
370            let safe_param_name = if is_rust_keyword(&param_name) {
371                format!("r#{}", param_name)
372            } else {
373                param_name
374            };
375            // Use references for parameters that can be borrowed
376            let rust_type = self.type_to_rust_param(param.ty())?;
377            write!(&mut signature, ", {}: {}", safe_param_name, rust_type)?;
378        }
379
380        signature.push_str(") -> zlink::Result<Result<");
381
382        // Handle output parameters.
383        let output_count = method.outputs().count();
384        if output_count == 0 {
385            signature.push_str("()");
386        } else {
387            // Always use the generated output struct for any outputs.
388            // Varlink output parameters are always named, so we need a struct even for single
389            // outputs.
390            let struct_name = format!("{}Output", method.name().to_pascal_case());
391            // Add lifetime parameter if the struct needs one
392            let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
393            if needs_lifetime {
394                signature.push_str(&format!("{}<'_>", struct_name));
395            } else {
396                signature.push_str(&struct_name);
397            }
398        }
399
400        write!(&mut signature, ", {}>>", error_type)?;
401        signature.push(';');
402
403        self.writeln(&signature)?;
404
405        Ok(())
406    }
407
408    fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
409        // Add field comments.
410        for comment in field.comments() {
411            self.writeln(&format!("/// {}", comment.text()))?;
412        }
413
414        let field_name = field.name().to_snake_case();
415        let rust_type = self.type_to_rust(field.ty())?;
416
417        // Handle field name if it's a Rust keyword.
418        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
419            format!("#[serde(rename = \"{}\")]", field.name())
420        } else {
421            String::new()
422        };
423
424        if !field_name_attr.is_empty() {
425            self.writeln(&field_name_attr)?;
426        }
427
428        let safe_field_name = if is_rust_keyword(&field_name) {
429            format!("r#{}", field_name)
430        } else {
431            field_name
432        };
433
434        self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
435
436        Ok(())
437    }
438
439    fn type_to_rust(&self, ty: &Type) -> Result<String> {
440        type_to_rust(ty)
441    }
442
443    fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
444        type_to_rust_param(ty)
445    }
446
447    fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
448        type_to_rust_output(ty)
449    }
450
451    fn writeln(&mut self, s: &str) -> Result<()> {
452        self.write(s)?;
453        writeln!(&mut self.output)?;
454        Ok(())
455    }
456
457    fn write(&mut self, s: &str) -> Result<()> {
458        for _ in 0..self.indent_level {
459            write!(&mut self.output, "    ")?;
460        }
461        write!(&mut self.output, "{}", s)?;
462        Ok(())
463    }
464
465    fn indent(&mut self) {
466        self.indent_level += 1;
467    }
468
469    fn dedent(&mut self) {
470        if self.indent_level > 0 {
471            self.indent_level -= 1;
472        }
473    }
474}
475
476impl Default for CodeGenerator {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482fn type_to_rust(ty: &Type) -> Result<String> {
483    Ok(match ty {
484        Type::Bool => "bool".to_string(),
485        Type::Int => "i64".to_string(),
486        Type::Float => "f64".to_string(),
487        Type::String => "String".to_string(),
488        Type::Object(_fields) => {
489            // Anonymous struct - generate inline.
490            // For now, use serde_json::Value for anonymous objects.
491            // In the future, we could generate anonymous structs.
492            "serde_json::Value".to_string()
493        }
494        Type::Enum(_variants) => {
495            // Anonymous enum - use String for now.
496            "String".to_string()
497        }
498        Type::Array(elem_type) => {
499            let elem_rust = type_to_rust(elem_type.inner())?;
500            format!("Vec<{}>", elem_rust)
501        }
502        Type::Map(value_type) => {
503            let value_rust = type_to_rust(value_type.inner())?;
504            format!("std::collections::HashMap<String, {}>", value_rust)
505        }
506        Type::ForeignObject => "serde_json::Value".to_string(),
507        Type::Optional(inner_type) => {
508            let inner_rust = type_to_rust(inner_type.inner())?;
509            format!("Option<{}>", inner_rust)
510        }
511        Type::Custom(name) => name.to_pascal_case(),
512    })
513}
514
515fn type_to_rust_param(ty: &Type) -> Result<String> {
516    Ok(match ty {
517        Type::Bool => "bool".to_string(),
518        Type::Int => "i64".to_string(),
519        Type::Float => "f64".to_string(),
520        Type::String => "&str".to_string(),
521        Type::Object(_fields) => {
522            // For parameters, use reference to avoid clone
523            "&serde_json::Value".to_string()
524        }
525        Type::Enum(_variants) => {
526            // Anonymous enum - use &str for parameters
527            "&str".to_string()
528        }
529        Type::Array(elem_type) => {
530            // Use slice for array parameters with proper string handling
531            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
532            format!("&[{}]", elem_rust)
533        }
534        Type::Map(value_type) => {
535            // Use reference for map parameters with proper string handling
536            let value_rust = type_to_rust_param_elem(value_type.inner())?;
537            format!("&std::collections::HashMap<&str, {}>", value_rust)
538        }
539        Type::ForeignObject => "&serde_json::Value".to_string(),
540        Type::Optional(inner_type) => {
541            let inner_rust = type_to_rust_param(inner_type.inner())?;
542            // For optional parameters, always wrap in Option
543            format!("Option<{}>", inner_rust)
544        }
545        Type::Custom(name) => format!("&{}", name.to_pascal_case()),
546    })
547}
548
549// Helper function to get the proper type for collection elements in parameters.
550// Ensures strings always use &str instead of String.
551fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
552    Ok(match ty {
553        Type::Bool => "bool".to_string(),
554        Type::Int => "i64".to_string(),
555        Type::Float => "f64".to_string(),
556        Type::String => "&str".to_string(),
557        Type::Object(_fields) => "serde_json::Value".to_string(),
558        Type::Enum(_variants) => "&str".to_string(),
559        Type::Array(elem_type) => {
560            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
561            format!("Vec<{}>", elem_rust)
562        }
563        Type::Map(value_type) => {
564            let value_rust = type_to_rust_param_elem(value_type.inner())?;
565            format!("std::collections::HashMap<&str, {}>", value_rust)
566        }
567        Type::ForeignObject => "serde_json::Value".to_string(),
568        Type::Optional(inner_type) => {
569            let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
570            format!("Option<{}>", inner_rust)
571        }
572        Type::Custom(name) => name.to_pascal_case(),
573    })
574}
575
576fn type_to_rust_output(ty: &Type) -> Result<String> {
577    Ok(match ty {
578        Type::Bool => "bool".to_string(),
579        Type::Int => "i64".to_string(),
580        Type::Float => "f64".to_string(),
581        Type::String => "&'a str".to_string(),
582        Type::Object(_fields) => {
583            // Use owned type for objects - serde can't deserialize to &Value
584            "serde_json::Value".to_string()
585        }
586        Type::Enum(_variants) => {
587            // Anonymous enum - use &str for outputs
588            "&'a str".to_string()
589        }
590        Type::Array(elem_type) => {
591            // Use Vec for array outputs with owned inner types (except strings stay as &'a str)
592            let elem_rust = match elem_type.inner() {
593                Type::String => "&'a str".to_string(),
594                Type::Enum(_) => "&'a str".to_string(),
595                _ => type_to_rust(elem_type.inner())?,
596            };
597            format!("Vec<{}>", elem_rust)
598        }
599        Type::Map(value_type) => {
600            // Use HashMap for map outputs with borrowed types for efficiency
601            let value_rust = match value_type.inner() {
602                Type::String => "&'a str".to_string(),
603                Type::Enum(_) => "&'a str".to_string(),
604                _ => type_to_rust(value_type.inner())?,
605            };
606            format!("std::collections::HashMap<&'a str, {}>", value_rust)
607        }
608        Type::ForeignObject => "serde_json::Value".to_string(),
609        Type::Optional(inner_type) => {
610            // For optional outputs, recursively apply type_to_rust_output to maintain
611            // correct reference types for strings within collections
612            let inner_rust = type_to_rust_output(inner_type.inner())?;
613            format!("Option<{}>", inner_rust)
614        }
615        Type::Custom(name) => name.to_pascal_case(),
616    })
617}
618
619fn interface_name_to_rust(name: &str) -> String {
620    // Convert interface name like "org.example.Interface" to "Interface".
621    name.split('.').next_back().unwrap_or(name).to_pascal_case()
622}
623
624fn type_needs_lifetime(ty: &Type) -> bool {
625    match ty {
626        Type::String => true,
627        Type::Enum(_) => true, // Anonymous enums use &'a str
628        Type::Array(inner) => type_needs_lifetime(inner.inner()),
629        Type::Map(_) => {
630            // Maps always need lifetime because keys are &'a str
631            true
632        }
633        Type::Optional(inner) => type_needs_lifetime(inner.inner()),
634        _ => false,
635    }
636}
637
638fn type_needs_borrow(ty: &Type) -> bool {
639    match ty {
640        Type::String => true,
641        Type::Enum(_) => true, // Anonymous enums use &'a str
642        Type::Array(inner) => type_needs_borrow(inner.inner()),
643        Type::Map(_) => {
644            // Maps always need borrow because keys are &'a str
645            true
646        }
647        Type::Optional(inner) => type_needs_borrow(inner.inner()),
648        _ => false,
649    }
650}
651
652fn is_rust_keyword(s: &str) -> bool {
653    [
654        "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
655        "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
656        "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
657        "true", "type", "unsafe", "use", "where", "while",
658    ]
659    .contains(&s)
660}