waynest_gen/
parser.rs

1use std::{fmt::Display, fs, io, path::Path};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quick_xml::DeError;
6use quote::{ToTokens, quote};
7use serde::{Deserialize, Serialize};
8use tracing::debug;
9
10use crate::utils::make_ident;
11
12#[derive(Debug, thiserror::Error)]
13pub enum Error {
14    #[error("{0}")]
15    IoError(#[from] io::Error),
16    #[error("{0}")]
17    Decode(#[from] DeError),
18}
19
20#[derive(Debug, Deserialize, Serialize, Clone)]
21pub struct Pair {
22    pub protocol: Protocol,
23    pub module: String,
24}
25
26#[derive(Debug, Deserialize, Serialize, Clone)]
27#[serde(deny_unknown_fields)]
28pub struct Protocol {
29    #[serde(rename(deserialize = "@name"))]
30    pub name: String,
31    pub copyright: Option<String>,
32    pub description: Option<String>,
33    #[serde(default, rename(deserialize = "interface"))]
34    pub interfaces: Vec<Interface>,
35}
36
37#[derive(Debug, Deserialize, Serialize, Clone)]
38#[serde(deny_unknown_fields)]
39pub struct Interface {
40    #[serde(rename(deserialize = "@name"))]
41    pub name: String,
42    #[serde(rename(deserialize = "@version"))]
43    pub version: u32,
44    pub description: Option<String>,
45    #[serde(default, rename(deserialize = "request"))]
46    pub requests: Vec<Message>,
47    #[serde(default, rename(deserialize = "event"))]
48    pub events: Vec<Message>,
49    #[serde(default, rename(deserialize = "enum"))]
50    pub enums: Vec<Enum>,
51}
52
53#[derive(Debug, Deserialize, Serialize, Clone)]
54#[serde(deny_unknown_fields)]
55pub struct Message {
56    #[serde(rename(deserialize = "@name"))]
57    pub name: String,
58    #[serde(rename(deserialize = "@version"))]
59    pub version: Option<u32>,
60    #[serde(rename(deserialize = "@type"))]
61    pub ty: Option<MessageType>,
62    #[serde(rename(deserialize = "@since"))]
63    pub since: Option<usize>,
64    #[serde(rename(deserialize = "@deprecated-since"))]
65    pub deprecated_since: Option<usize>,
66    pub description: Option<String>,
67    #[serde(default, rename(deserialize = "arg"))]
68    pub args: Vec<Arg>,
69}
70
71#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
72pub enum MessageType {
73    #[serde(rename(deserialize = "destructor"))]
74    Destructor,
75}
76
77#[derive(Debug, Deserialize, Serialize, Clone)]
78#[serde(deny_unknown_fields)]
79pub struct Arg {
80    #[serde(rename(deserialize = "@name"))]
81    pub name: String,
82    #[serde(rename(deserialize = "@type"))]
83    pub ty: ArgType,
84    #[serde(rename(deserialize = "@interface"))]
85    pub interface: Option<String>,
86    #[serde(rename(deserialize = "@enum"))]
87    pub r#enum: Option<String>,
88    #[serde(default, rename(deserialize = "@allow-null"))]
89    pub allow_null: bool,
90    #[serde(rename(deserialize = "@summary"))]
91    pub summary: Option<String>,
92    pub description: Option<String>,
93}
94
95#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
96pub enum ArgType {
97    #[serde(rename(deserialize = "int"))]
98    Int,
99    #[serde(rename(deserialize = "uint"))]
100    Uint,
101    #[serde(rename(deserialize = "fixed"))]
102    Fixed,
103    #[serde(rename(deserialize = "string"))]
104    String,
105    #[serde(rename(deserialize = "object"))]
106    Object,
107    #[serde(rename(deserialize = "new_id"))]
108    NewId,
109    #[serde(rename(deserialize = "array"))]
110    Array,
111    #[serde(rename(deserialize = "fd"))]
112    Fd,
113}
114
115impl ArgType {
116    pub const fn is_fd(&self) -> bool {
117        matches!(self, Self::Fd)
118    }
119
120    pub const fn is_array(&self) -> bool {
121        matches!(self, Self::Array)
122    }
123}
124
125#[derive(Debug, Deserialize, Serialize, Clone)]
126#[serde(deny_unknown_fields)]
127pub struct Enum {
128    #[serde(rename(deserialize = "@name"))]
129    pub name: String,
130    #[serde(default, rename(deserialize = "@bitfield"))]
131    pub bitfield: bool,
132    #[serde(rename(deserialize = "@since"))]
133    pub since: Option<usize>,
134    #[serde(rename(deserialize = "@deprecated-since"))]
135    pub deprecated_since: Option<usize>,
136    pub description: Option<String>,
137    #[serde(rename(deserialize = "entry"))]
138    pub entries: Vec<Entry>,
139}
140
141#[derive(Debug, Deserialize, Serialize, Clone)]
142#[serde(deny_unknown_fields)]
143pub struct Entry {
144    #[serde(rename(deserialize = "@name"))]
145    pub name: String,
146    #[serde(rename(deserialize = "@value"))]
147    pub value: String,
148    #[serde(rename(deserialize = "@summary"))]
149    pub summary: Option<String>,
150    #[serde(rename(deserialize = "@since"))]
151    pub since: Option<usize>,
152    #[serde(rename(deserialize = "@deprecated-since"))]
153    pub deprecated_since: Option<usize>,
154    pub description: Option<String>,
155}
156
157impl Pair {
158    pub fn from_path<D: Display, P: AsRef<Path>>(module: D, path: P) -> Result<Self, Error> {
159        debug!("Parsing protocol {}", path.as_ref().display());
160        Ok(Self {
161            protocol: quick_xml::de::from_str(&fs::read_to_string(path)?)?,
162            module: module.to_string(),
163        })
164    }
165}
166
167impl Arg {
168    pub fn to_enum_name(&self) -> Option<(Option<String>, String)> {
169        if let Some(e) = &self.r#enum {
170            if let Some((interface, name)) = e.split_once('.') {
171                return Some((Some(interface.to_string()), name.to_string()));
172            } else {
173                return Some((None, e.to_string()));
174            }
175        }
176
177        None
178    }
179
180    pub fn find_protocol(&self, pairs: &[Pair]) -> Option<Pair> {
181        if let Some((enum_interface, _name)) = self.to_enum_name()
182            && let Some(enum_interface) = enum_interface
183        {
184            return pairs
185                .iter()
186                .find(|pair| {
187                    pair.protocol
188                        .interfaces
189                        .iter()
190                        .any(|e| e.name == enum_interface)
191                })
192                .cloned();
193        }
194
195        None
196    }
197
198    pub fn to_rust_type_token(&self, pair: &Pair) -> TokenStream {
199        if let Some(e) = &self.r#enum {
200            if let Some((module, name)) = e.split_once('.') {
201                // Check if the referenced interface actually exists in the current pair
202                let interface_exists = pair
203                    .protocol
204                    .interfaces
205                    .iter()
206                    .any(|iface| iface.name == module);
207                if interface_exists {
208                    let protocol_name = make_ident(&pair.protocol.name);
209                    let name = make_ident(name.to_upper_camel_case());
210                    let module = make_ident(module);
211                    let protocol_module = make_ident(&pair.module);
212
213                    return quote! {super::super::super::#protocol_module::#protocol_name::#module::#name};
214                } else {
215                    // Invalid cross-protocol reference, fall back to the underlying type
216                    return self.to_underlying_type_token();
217                }
218            } else {
219                return make_ident(e.to_upper_camel_case()).to_token_stream();
220            }
221        }
222
223        self.to_underlying_type_token()
224    }
225
226    pub fn to_underlying_type_token(&self) -> TokenStream {
227        match self.ty {
228            ArgType::Int => quote! { i32 },
229            ArgType::Uint => quote! { u32 },
230            ArgType::Fixed => quote! { crate::wire::Fixed },
231            ArgType::String => quote! { String },
232            ArgType::Object => quote! { crate::wire::ObjectId },
233            ArgType::NewId => {
234                if self.interface.is_some() {
235                    quote! { crate::wire::ObjectId }
236                } else {
237                    quote! { crate::wire::NewId }
238                }
239            }
240            ArgType::Array => quote! { Vec<u8> },
241            ArgType::Fd => quote! { rustix::fd::OwnedFd },
242        }
243    }
244
245    pub fn is_return_option(&self) -> bool {
246        match self.ty {
247            ArgType::String | ArgType::Object => true,
248            ArgType::NewId => self.interface.is_some(),
249            _ => false,
250        }
251    }
252
253    pub fn to_caller(&self) -> &str {
254        if self.r#enum.is_some() {
255            return "uint";
256        }
257
258        match self.ty {
259            ArgType::Int => "int",
260            ArgType::Uint => "uint",
261            ArgType::Fixed => "fixed",
262            ArgType::String => "string",
263            ArgType::Object => "object",
264            ArgType::NewId => {
265                if self.interface.is_some() {
266                    "object"
267                } else {
268                    "new_id"
269                }
270            }
271            ArgType::Array => "array",
272            ArgType::Fd => "fd",
273        }
274    }
275}