wayland_protocol_code_generator/
lib.rs

1#![recursion_limit = "256"]
2
3extern crate heck;
4extern crate proc_macro2;
5extern crate wayland_protocol_scanner;
6#[macro_use]
7extern crate quote;
8#[macro_use]
9extern crate lazy_static;
10
11use wayland_protocol_scanner::{Protocol, ProtocolChild, Interface, InterfaceChild, EventOrRequestField};
12use proc_macro2::{Ident, Span, TokenStream};
13use heck::{CamelCase, SnakeCase};
14
15lazy_static! {
16    static ref PROTOCOL: Protocol = wayland_protocol_scanner::parse_wayland_protocol();
17}
18
19fn escape_name(name: &String) -> String {
20    if name == "move" {
21        String::from("mv")
22    } else {
23        name.clone()
24    }
25}
26
27fn construct_indent_from_string(str: &String) -> Ident {
28    Ident::new(str, Span::call_site())
29}
30
31enum Case {
32    CamelCase,
33    SnakeCase,
34}
35
36fn construct_ident_from_str_and_case(str: &String, case: Option<Case>) -> Ident {
37    match case {
38        Some(case) => match case {
39            Case::CamelCase => construct_indent_from_string(&str.to_camel_case()),
40            Case::SnakeCase => construct_indent_from_string(&str.to_snake_case()),
41        },
42        None => construct_indent_from_string(&str),
43    }
44}
45
46macro_rules! ident {
47    ($t:expr, $( $s:expr ),*; $c: expr) => (construct_ident_from_str_and_case(&format!($t, $(escape_name(&$s)),*), $c))
48}
49
50macro_rules! generate_arguments {
51    ($re:expr) => {
52        $re.items.iter().filter_map(|child| match child {
53            EventOrRequestField::Arg(arg) => {
54                let arg_name = ident!("{}", arg.name; None);
55                let arg_typ = ident!("{}", arg.typ; Some(Case::CamelCase));
56                Some(quote! {#arg_name: #arg_typ})
57            }
58            _ => None,
59        })
60    }
61}
62
63fn add_arg_size(arg: &wayland_protocol_scanner::Arg) -> Option<TokenStream> {
64    let arg_name = ident!("{}",arg.name; None);
65    let arg_typ = ident!("{}", &arg.typ; Some(Case::CamelCase));
66    match &arg.typ[..] {
67        "String" => Some(quote! {
68            raw_size += ((#arg_name.len() + 1) as f64 / 4.0).ceil() as usize * 4 + 4;
69        }),
70        "Fd" => None,
71        // TODO: Array and other types
72        _ => Some(quote! {raw_size += size_of::<#arg_typ>();}),
73    }
74}
75
76fn send_arg(arg: &wayland_protocol_scanner::Arg) -> Option<TokenStream> {
77    let arg_name = ident!("{}",arg.name; None);
78    let arg_typ = ident!("{}", &arg.typ; Some(Case::CamelCase));
79    match &arg.typ.to_camel_case()[..] {
80        "String" => Some(quote! {
81            let str_len = #arg_name.len();
82            let buf_len = ((#arg_name.len() + 1) as f64 / 4.0).ceil() as usize * 4;
83            unsafe {
84                std::ptr::copy(&buf_len as *const usize as *const u8, &mut send_buffer[written_len] as *mut u8, str_len + 1);
85                std::ptr::copy(&#arg_name.into_bytes()[0] as *const u8, &mut send_buffer[written_len + 4] as *mut u8, str_len);
86            }
87            #[allow(unused)]
88            written_len += buf_len + 4;
89        }),
90        "Fd" => Some(quote! {
91            info!("Send FD: {}", #arg_name);
92            send_fd[send_fd_num] = #arg_name;
93            send_fd_num += 1;
94        }),
95        // TODO: Array and other types
96        _ => Some(quote! {
97            unsafe {
98                std::ptr::copy(&#arg_name as *const #arg_typ, &mut send_buffer[written_len] as *mut u8 as *mut #arg_typ, 1);
99            }
100            #[allow(unused)]
101            written_len += size_of::<u32>();
102        }),
103    }
104}
105
106fn parse_args(arg: &wayland_protocol_scanner::Arg) -> Option<TokenStream> {
107    let arg_name = ident!("{}", arg.name; Some(Case::SnakeCase));
108    let arg_typ = ident!("{}", arg.typ; Some(Case::CamelCase));
109    match &arg.typ[..] {
110        "fixed" => Some(quote! {
111            let #arg_name: f32 = 0.0;
112            warn!("Fixed value has not been implemented");
113        }),
114        "string" => Some(quote! {
115            parsed_len += size_of::<u32>();
116            let start = parsed_len - size_of::<u32>();
117
118            let raw_ptr = msg_body[start..parsed_len].as_ptr() as *const u32;
119            let str_len = unsafe{
120                *raw_ptr
121            };
122            let str_len = (str_len as f64 / 4.0).ceil() as usize * 4;
123            parsed_len += str_len;
124
125            let src_ptr = msg_body[(start + size_of::<u32>())..parsed_len].as_ptr();
126            let mut tmp_ptr = Vec::with_capacity(str_len);
127            unsafe {
128                tmp_ptr.set_len(str_len);
129                std::ptr::copy(src_ptr, tmp_ptr.as_mut_ptr(), str_len);
130            };
131            let #arg_name = std::str::from_utf8(&tmp_ptr).unwrap().trim_matches('\0').to_string();
132        }),
133        "array" => Some(quote! {
134            let #arg_name: Vec<u32> = Vec::new();
135            warn!("Array value has not been implemented");
136        }),
137        _ => Some(quote! {
138            parsed_len += size_of::<#arg_typ>();
139            let start = parsed_len - size_of::<#arg_typ>();
140
141            let raw_ptr = msg_body[start..parsed_len].as_ptr() as *const #arg_typ;
142            let #arg_name = unsafe{
143                *raw_ptr
144            };
145        }),
146    }
147}
148
149fn generate_code_for_interface(interface: &Interface) -> TokenStream {
150    let struct_name = ident!("{}", interface.name; Some(Case::CamelCase));
151
152    let mut req_op_code: i32 = -1;
153    let send_req_functions = interface.items.iter().filter_map(|msg| match msg {
154        InterfaceChild::Request(req) => {
155            req_op_code += 1;
156            let args = generate_arguments!(req);
157            let function_name = ident!("{}", &req.name; None);
158
159            let add_raw_size = req.items.iter().filter_map(|child| {
160                match child {
161                    EventOrRequestField::Arg(arg) => {
162                        add_arg_size(arg)
163                    }
164                    _ => { None }
165                }
166            });
167            let send_args = req.items.iter().filter_map(|child| {
168                match child {
169                    EventOrRequestField::Arg(arg) => {
170                        send_arg(arg)
171                    }
172                    _ => { None }
173                }
174            });
175
176            Some(quote! {
177                pub fn #function_name(&self, #(#args),*) {
178                    #[allow(unused)]
179                    let mut raw_size = 8;
180                    #(#add_raw_size)*
181                    let mut send_buffer: Vec<u8> = vec![0; raw_size];
182                    let mut send_fd = vec![0; 16];
183
184                    #[allow(unused)]
185                    let mut send_fd_num = 0;
186                    unsafe {
187                        std::ptr::copy(&self.object_id as *const u32, &mut send_buffer[0] as *mut u8 as *mut u32, 1);
188                        let op_code_and_length: u32 = ((raw_size as u32) << 16) + (#req_op_code as u32);
189                        std::ptr::copy(&op_code_and_length as *const u32, &mut send_buffer[size_of::<u32>()] as *mut u8 as *mut u32, 1);
190                    }
191
192                    #[allow(unused)]
193                    let mut written_len: usize = 8;
194                    #(#send_args)*
195                    unsafe {
196                        send_fd.set_len(send_fd_num);
197                    }
198                    self.socket.send(&send_buffer, &send_fd);
199                }
200            })
201        }
202        _ => None
203    });
204
205    let mut ev_op_code: i32 = -1;
206    let parse_ev = interface.items.iter().filter_map(|msg| match msg {
207        InterfaceChild::Event(ev) => {
208            ev_op_code += 1;
209            let op_code = ev_op_code as u16;
210
211            let ev_name_str = format!(
212                "{}{}Event",
213                interface.name.to_camel_case(),
214                ev.name.to_camel_case()
215            );
216            let ev_interface_name = ident!("{}Event", interface.name; Some(Case::CamelCase));
217            let ev_name = ident!("{}{}Event", interface.name, ev.name; Some(Case::CamelCase));
218
219            let parse_args = ev.items.iter().filter_map(|field| match field {
220                EventOrRequestField::Arg(arg) => parse_args(arg),
221                _ => None,
222            });
223            let arg_names = ev.items.iter().filter_map(|field| match field {
224                EventOrRequestField::Arg(arg) => {
225                    let arg_name = ident!("{}", arg.name; Some(Case::SnakeCase));
226                    Some(quote! {#arg_name})
227                }
228                _ => None,
229            });
230            Some(quote! {
231                #op_code => {
232                    info!("Receive event {}", #ev_name_str);
233
234                    #[allow(unused)]
235                    let mut parsed_len: usize = 0;
236                    #(#parse_args)*
237                    Event::#ev_interface_name(#ev_interface_name::#ev_name(#ev_name {
238                        sender_id,
239                        #(#arg_names),*
240                    }))
241                }
242            })
243        }
244        _ => None
245    });
246    quote! {
247        #[derive(Clone)]
248        pub struct #struct_name {
249            #[allow(dead_code)]
250            pub object_id: u32,
251            #[allow(dead_code)]
252            pub socket: Arc<WaylandSocket>,
253        }
254        impl WlRawObject for #struct_name {
255            fn new(object_id: u32, socket: Arc<WaylandSocket>) -> #struct_name {
256                #struct_name { object_id, socket }
257            }
258            fn to_enum(self) -> WlObject {
259                WlObject::#struct_name(self)
260            }
261        }
262        impl #struct_name {
263            fn parse_event(sender_id: u32, op_code: u16, msg_body: Vec<u8>) -> Event {
264                match op_code {
265                    #(#parse_ev)*
266                    _ => panic!("Unknown event")
267                }
268            }
269            #(#send_req_functions)*
270        }
271    }
272}
273
274fn generate_code_for_wayland_enums() -> TokenStream {
275    let enum_interface_names = PROTOCOL.items.iter().filter_map(|item| match item {
276        ProtocolChild::Interface(interface) => {
277            let interface_name = ident!("{}", interface.name; Some(Case::CamelCase));
278            Some(quote! {#interface_name(#interface_name)})
279        }
280        _ => None,
281    });
282    let enum_event_names = PROTOCOL.items.iter().filter_map(|item| match item {
283        ProtocolChild::Interface(interface) => {
284            let event_name = ident!("{}Event", interface.name; Some(Case::CamelCase));
285            Some(quote! {#event_name(#event_name)})
286        }
287        _ => None,
288    });
289    let wl_event_enums = PROTOCOL.items.iter().filter_map(|item| match item {
290        ProtocolChild::Interface(interface) => {
291            let event_enum_name = ident!("{}Event", interface.name; Some(Case::CamelCase));
292            let event_structs = interface.items.iter().filter_map(|ev| match ev {
293                InterfaceChild::Event(ev) => {
294                    let ev_struct_enum_name = ident!("{}{}Event", interface.name, ev.name; Some(Case::CamelCase));
295                    let event_fields = generate_arguments!(ev);
296                    Some(quote! {
297                        pub struct #ev_struct_enum_name {
298                            #[allow(dead_code)]
299                            pub sender_id: u32,
300                            #(#[allow(dead_code)]pub #event_fields),*
301                        }
302                    })
303                }
304                _ => None
305            });
306            let event_struct_enum_names = interface.items.iter().filter_map(|ev| match ev {
307                InterfaceChild::Event(ev) => {
308                    let ev_struct_enum_name = ident!("{}{}Event", interface.name, ev.name; Some(Case::CamelCase));
309                    Some(quote! {
310                        #ev_struct_enum_name(#ev_struct_enum_name)
311                    })
312                }
313                _ => None
314            });
315            Some(quote! {
316                #(#event_structs)*
317                pub enum #event_enum_name {
318                    #(#event_struct_enum_names),*
319                }
320            })
321        }
322        _ => None,
323    });
324    let impl_wl_get_obj = PROTOCOL.items.iter().filter_map(|item| match item {
325        ProtocolChild::Interface(interface) => {
326            let interface_name = ident!("{}", interface.name; Some(Case::CamelCase));
327            let get_function_name = ident!("try_get_{}", interface.name; Some(Case::SnakeCase));
328            Some(quote! {
329                #[allow(dead_code)]
330                pub fn #get_function_name(&self) -> Option<#interface_name> {
331                    match self {
332                        WlObject::#interface_name(item) => Some(item.clone()),
333                        _ => None,
334                    }
335                }
336            })
337        }
338        _ => None,
339    });
340    quote! {
341        pub enum WlObject {
342            #(#enum_interface_names),*
343        }
344        pub enum Event {
345            #(#enum_event_names),*
346        }
347        #(#wl_event_enums)*
348        impl WlObject {
349            #(#impl_wl_get_obj)*
350        }
351    }
352}
353
354pub fn generate_wayland_protocol_code() -> String {
355    let codes_for_every_interface = PROTOCOL.items.iter().filter_map(|item| match item {
356        ProtocolChild::Interface(interface) => {
357            Some(generate_code_for_interface(interface))
358        }
359        _ => None
360    });
361    let code_for_wayland_enums = generate_code_for_wayland_enums();
362    let parse_event_for_interface = PROTOCOL.items.iter().filter_map(|item| match item {
363        ProtocolChild::Interface(interface) => {
364            let interface_name = ident!("{}", interface.name; Some(Case::CamelCase));
365            Some(quote! {
366                WlObject::#interface_name(_obj) => #interface_name::parse_event(sender_id, op_code, msg_body)
367            })
368        }
369        _ => None
370    });
371
372    let code = quote! {
373        use super::socket::*;
374        use crate::unix_socket::UnixSocket;
375        use std::sync::Arc;
376        use std::mem::transmute;
377        use std::mem::size_of;
378
379        type NewId=u32;
380        type Uint=u32;
381        type Int=i32;
382        type Fd=i32;
383        type Object=u32;
384        type Fixed=f32; // TODO: handle fixed value
385        type Array=Vec<u32>;
386
387
388        #(#codes_for_every_interface)*
389        #code_for_wayland_enums
390
391        #[repr(packed)]
392        struct EventHeaderPre {
393            pub sender_id: u32,
394            pub msg_size_and_op_code: u32,
395        }
396        #[repr(packed)]
397        pub struct EventHeader {
398            pub sender_id: u32,
399            pub msg_size: u16,
400            pub op_code: u16,
401        }
402        impl EventHeaderPre {
403            fn convert_to_event_header(self) -> EventHeader {
404                let msg_size = {
405                    let size = self.msg_size_and_op_code >> 16;
406                    if size % 4 == 0 {
407                        size as u16
408                    } else {
409                        (size as f64 / 4.0).ceil() as u16 * 4
410                    }
411                } - 8;
412                EventHeader {
413                    sender_id: self.sender_id,
414                    msg_size,
415                    op_code: self.msg_size_and_op_code as u16,
416                }
417            }
418        }
419        pub trait ReadEvent {
420            fn read_event(&mut self) -> Vec<(EventHeader, Vec<u8>)>;
421        }
422        impl ReadEvent for UnixSocket {
423            fn read_event(&mut self) -> Vec<(EventHeader, Vec<u8>)> {
424                let mut buffer: [u8; 1024] = [0; 1024];
425                let mut fds: [u8; 24] = [0; 24];
426                let (size, _) = self.read(&mut buffer, &mut fds);
427                if size == 1024 {
428                    warn!("Buffer is full");
429                }
430                let mut ret_value = Vec::new();
431                let mut read_size: usize = 0;
432                while read_size < size {
433                    let mut event_header: [u8; size_of::<EventHeaderPre>()] =
434                        [0; size_of::<EventHeaderPre>()];
435                    unsafe {
436                        std::ptr::copy(
437                            &buffer[read_size] as *const u8,
438                            event_header.as_mut_ptr(),
439                            size_of::<EventHeaderPre>(),
440                        );
441                    }
442                    let event_header = unsafe {
443                        transmute::<[u8; size_of::<EventHeaderPre>()], EventHeaderPre>(event_header)
444                            .convert_to_event_header()
445                    };
446                    let msg_size = event_header.msg_size as usize;
447                    let mut msg_body = vec![0; event_header.msg_size as usize];
448                    unsafe {
449                        std::ptr::copy(
450                            &buffer[read_size + size_of::<EventHeaderPre>()] as *const u8,
451                            msg_body.as_mut_ptr(),
452                            msg_size,
453                        );
454                    }
455                    ret_value.push((event_header, msg_body));
456                    read_size += size_of::<EventHeaderPre>() + msg_size;
457                }
458                return ret_value;
459            }
460        }
461        pub trait WlRawObject {
462            fn new(object_id: u32, socket: Arc<WaylandSocket>) -> Self;
463            fn to_enum(self) -> WlObject;
464        }
465        impl WlObject {
466            pub fn parse_event(&self, sender_id: u32, op_code: u16, msg_body: Vec<u8>) -> Event {
467                match self {
468                    #(#parse_event_for_interface),*
469                }
470            }
471        }
472    };
473
474    return code.to_string();
475}