1use std::{fmt::Display, fs, io, path::Path};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quick_xml::DeError;
6use quote::{ToTokens, quote};
7use serde::{Deserialize, Serialize};
8use tracing::debug;
9
10use crate::utils::make_ident;
11
12#[derive(Debug, thiserror::Error)]
13pub enum Error {
14 #[error("{0}")]
15 IoError(#[from] io::Error),
16 #[error("{0}")]
17 Decode(#[from] DeError),
18}
19
20#[derive(Debug, Deserialize, Serialize, Clone)]
21pub struct Pair {
22 pub protocol: Protocol,
23 pub module: String,
24}
25
26#[derive(Debug, Deserialize, Serialize, Clone)]
27#[serde(deny_unknown_fields)]
28pub struct Protocol {
29 #[serde(rename(deserialize = "@name"))]
30 pub name: String,
31 pub copyright: Option<String>,
32 pub description: Option<String>,
33 #[serde(default, rename(deserialize = "interface"))]
34 pub interfaces: Vec<Interface>,
35}
36
37#[derive(Debug, Deserialize, Serialize, Clone)]
38#[serde(deny_unknown_fields)]
39pub struct Interface {
40 #[serde(rename(deserialize = "@name"))]
41 pub name: String,
42 #[serde(rename(deserialize = "@version"))]
43 pub version: u32,
44 pub description: Option<String>,
45 #[serde(default, rename(deserialize = "request"))]
46 pub requests: Vec<Message>,
47 #[serde(default, rename(deserialize = "event"))]
48 pub events: Vec<Message>,
49 #[serde(default, rename(deserialize = "enum"))]
50 pub enums: Vec<Enum>,
51}
52
53#[derive(Debug, Deserialize, Serialize, Clone)]
54#[serde(deny_unknown_fields)]
55pub struct Message {
56 #[serde(rename(deserialize = "@name"))]
57 pub name: String,
58 #[serde(rename(deserialize = "@version"))]
59 pub version: Option<u32>,
60 #[serde(rename(deserialize = "@type"))]
61 pub ty: Option<MessageType>,
62 #[serde(rename(deserialize = "@since"))]
63 pub since: Option<usize>,
64 #[serde(rename(deserialize = "@deprecated-since"))]
65 pub deprecated_since: Option<usize>,
66 pub description: Option<String>,
67 #[serde(default, rename(deserialize = "arg"))]
68 pub args: Vec<Arg>,
69}
70
71#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
72pub enum MessageType {
73 #[serde(rename(deserialize = "destructor"))]
74 Destructor,
75}
76
77#[derive(Debug, Deserialize, Serialize, Clone)]
78#[serde(deny_unknown_fields)]
79pub struct Arg {
80 #[serde(rename(deserialize = "@name"))]
81 pub name: String,
82 #[serde(rename(deserialize = "@type"))]
83 pub ty: ArgType,
84 #[serde(rename(deserialize = "@interface"))]
85 pub interface: Option<String>,
86 #[serde(rename(deserialize = "@enum"))]
87 pub r#enum: Option<String>,
88 #[serde(default, rename(deserialize = "@allow-null"))]
89 pub allow_null: bool,
90 #[serde(rename(deserialize = "@summary"))]
91 pub summary: Option<String>,
92 pub description: Option<String>,
93}
94
95#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
96pub enum ArgType {
97 #[serde(rename(deserialize = "int"))]
98 Int,
99 #[serde(rename(deserialize = "uint"))]
100 Uint,
101 #[serde(rename(deserialize = "fixed"))]
102 Fixed,
103 #[serde(rename(deserialize = "string"))]
104 String,
105 #[serde(rename(deserialize = "object"))]
106 Object,
107 #[serde(rename(deserialize = "new_id"))]
108 NewId,
109 #[serde(rename(deserialize = "array"))]
110 Array,
111 #[serde(rename(deserialize = "fd"))]
112 Fd,
113}
114
115impl ArgType {
116 pub const fn is_fd(&self) -> bool {
117 matches!(self, Self::Fd)
118 }
119
120 pub const fn is_array(&self) -> bool {
121 matches!(self, Self::Array)
122 }
123}
124
125#[derive(Debug, Deserialize, Serialize, Clone)]
126#[serde(deny_unknown_fields)]
127pub struct Enum {
128 #[serde(rename(deserialize = "@name"))]
129 pub name: String,
130 #[serde(default, rename(deserialize = "@bitfield"))]
131 pub bitfield: bool,
132 #[serde(rename(deserialize = "@since"))]
133 pub since: Option<usize>,
134 #[serde(rename(deserialize = "@deprecated-since"))]
135 pub deprecated_since: Option<usize>,
136 pub description: Option<String>,
137 #[serde(rename(deserialize = "entry"))]
138 pub entries: Vec<Entry>,
139}
140
141#[derive(Debug, Deserialize, Serialize, Clone)]
142#[serde(deny_unknown_fields)]
143pub struct Entry {
144 #[serde(rename(deserialize = "@name"))]
145 pub name: String,
146 #[serde(rename(deserialize = "@value"))]
147 pub value: String,
148 #[serde(rename(deserialize = "@summary"))]
149 pub summary: Option<String>,
150 #[serde(rename(deserialize = "@since"))]
151 pub since: Option<usize>,
152 #[serde(rename(deserialize = "@deprecated-since"))]
153 pub deprecated_since: Option<usize>,
154 pub description: Option<String>,
155}
156
157impl Pair {
158 pub fn from_path<D: Display, P: AsRef<Path>>(module: D, path: P) -> Result<Self, Error> {
159 debug!("Parsing protocol {}", path.as_ref().display());
160 Ok(Self {
161 protocol: quick_xml::de::from_str(&fs::read_to_string(path)?)?,
162 module: module.to_string(),
163 })
164 }
165}
166
167impl Arg {
168 pub fn to_enum_name(&self) -> Option<(Option<String>, String)> {
169 if let Some(e) = &self.r#enum {
170 if let Some((interface, name)) = e.split_once('.') {
171 return Some((Some(interface.to_string()), name.to_string()));
172 } else {
173 return Some((None, e.to_string()));
174 }
175 }
176
177 None
178 }
179
180 pub fn find_protocol(&self, pairs: &[Pair]) -> Option<Pair> {
181 if let Some((enum_interface, _name)) = self.to_enum_name()
182 && let Some(enum_interface) = enum_interface
183 {
184 return pairs
185 .iter()
186 .find(|pair| {
187 pair.protocol
188 .interfaces
189 .iter()
190 .any(|e| e.name == enum_interface)
191 })
192 .cloned();
193 }
194
195 None
196 }
197
198 pub fn to_rust_type_token(&self, pair: &Pair) -> TokenStream {
199 if let Some(e) = &self.r#enum {
200 if let Some((module, name)) = e.split_once('.') {
201 let interface_exists = pair
203 .protocol
204 .interfaces
205 .iter()
206 .any(|iface| iface.name == module);
207 if interface_exists {
208 let protocol_name = make_ident(&pair.protocol.name);
209 let name = make_ident(name.to_upper_camel_case());
210 let module = make_ident(module);
211 let protocol_module = make_ident(&pair.module);
212
213 return quote! {super::super::super::#protocol_module::#protocol_name::#module::#name};
214 } else {
215 return self.to_underlying_type_token();
217 }
218 } else {
219 return make_ident(e.to_upper_camel_case()).to_token_stream();
220 }
221 }
222
223 self.to_underlying_type_token()
224 }
225
226 pub fn to_underlying_type_token(&self) -> TokenStream {
227 match self.ty {
228 ArgType::Int => quote! { i32 },
229 ArgType::Uint => quote! { u32 },
230 ArgType::Fixed => quote! { crate::wire::Fixed },
231 ArgType::String => quote! { String },
232 ArgType::Object => quote! { crate::wire::ObjectId },
233 ArgType::NewId => {
234 if self.interface.is_some() {
235 quote! { crate::wire::ObjectId }
236 } else {
237 quote! { crate::wire::NewId }
238 }
239 }
240 ArgType::Array => quote! { Vec<u8> },
241 ArgType::Fd => quote! { rustix::fd::OwnedFd },
242 }
243 }
244
245 pub fn is_return_option(&self) -> bool {
246 match self.ty {
247 ArgType::String | ArgType::Object => true,
248 ArgType::NewId => self.interface.is_some(),
249 _ => false,
250 }
251 }
252
253 pub fn to_caller(&self) -> &str {
254 if self.r#enum.is_some() {
255 return "uint";
256 }
257
258 match self.ty {
259 ArgType::Int => "int",
260 ArgType::Uint => "uint",
261 ArgType::Fixed => "fixed",
262 ArgType::String => "string",
263 ArgType::Object => "object",
264 ArgType::NewId => {
265 if self.interface.is_some() {
266 "object"
267 } else {
268 "new_id"
269 }
270 }
271 ArgType::Array => "array",
272 ArgType::Fd => "fd",
273 }
274 }
275}