waynest_gen/
client.rs

1use heck::{ToSnekCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use tracing::debug;
5
6use crate::{
7    common::write_dispatchers,
8    parser::{Interface, Pair},
9    utils::{description_to_docs, find_enum, make_ident, write_enums},
10};
11
12pub fn generate_client_code(current: &[Pair], pairs: &[Pair]) -> TokenStream {
13    let mut modules = Vec::new();
14
15    for pair in current {
16        let protocol = &pair.protocol;
17        debug!("Generating client code for \"{}\"", &protocol.name);
18
19        let mut inner_modules = Vec::new();
20
21        for interface in &protocol.interfaces {
22            let docs = description_to_docs(interface.description.as_ref());
23            let module_name = make_ident(&interface.name);
24            let trait_name = make_ident(interface.name.to_upper_camel_case());
25            let trait_docs = format!(
26                "Trait to implement the {} interface. See the module level documentation for more info",
27                interface.name
28            );
29
30            let name = &interface.name;
31            let version = &interface.version;
32
33            let dispatchers = write_dispatchers(interface, interface.events.clone().into_iter());
34            let enums = write_enums(interface);
35            let requests = write_requests(pairs, pair, interface);
36            let events = write_events(pairs, pair, interface);
37
38            let imports = if requests.is_empty() {
39                quote! {}
40            } else {
41                quote! {use futures_util::SinkExt;}
42            };
43
44            let handler_args = if dispatchers.is_empty() {
45                quote! {
46                    _client: &mut crate::server::Client,
47                    _sender_id: crate::wire::ObjectId,
48                }
49            } else {
50                quote! {
51                    client: &mut crate::server::Client,
52                    sender_id: crate::wire::ObjectId,
53                }
54            };
55
56            inner_modules.push(quote! {
57                #(#docs)*
58                #[allow(clippy::too_many_arguments)]
59                pub mod #module_name {
60                    #[allow(unused)]
61                    use std::os::fd::AsRawFd;
62                    #imports
63
64                    #(#enums)*
65
66                    #[doc = #trait_docs]
67                    pub trait #trait_name {
68                        const INTERFACE: &'static str = #name;
69                        const VERSION: u32 = #version;
70
71                        async fn handle_event(
72                            &self,
73                            #handler_args
74                            message: &mut crate::wire::Message,
75                        ) -> crate::client::Result<()> {
76                            #[allow(clippy::match_single_binding)]
77                            match message.opcode() {
78                                #(#dispatchers),*
79                                _ => Err(crate::client::Error::UnknownOpcode),
80                            }
81                        }
82
83                        #(#requests)*
84                        #(#events)*
85                    }
86                }
87            })
88        }
89
90        let docs = description_to_docs(protocol.description.as_ref());
91        let module_name = make_ident(&protocol.name);
92
93        modules.push(quote! {
94            #(#docs)*
95            #[allow(clippy::module_inception)]
96            pub mod #module_name {
97                #(#inner_modules)*
98            }
99        })
100    }
101
102    quote! {
103        #![allow(async_fn_in_trait)]
104        #(#modules)*
105    }
106}
107
108fn write_requests(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
109    let mut requests = Vec::new();
110
111    for (opcode, request) in interface.requests.iter().enumerate() {
112        let opcode = opcode as u16;
113
114        let docs = description_to_docs(request.description.as_ref());
115        let name = make_ident(request.name.to_snek_case());
116        let tracing_inner = format!(
117            "-> {}#{{}}.{}()",
118            interface.name,
119            request.name.to_snek_case()
120        );
121
122        let mut args = vec![
123            quote! { &self },
124            quote! { client: &mut crate::server::Client },
125            quote! { sender_id: crate::wire::ObjectId },
126        ];
127
128        for arg in &request.args {
129            let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
130
131            if arg.allow_null {
132                ty = quote! {Option<#ty>};
133            }
134
135            let name = make_ident(arg.name.to_snek_case());
136
137            args.push(quote! {#name: #ty})
138        }
139
140        let mut build_args = Vec::new();
141
142        for arg in &request.args {
143            let build_ty = arg.to_caller();
144            let build_ty = format_ident!("put_{build_ty}");
145
146            let mut build_convert = quote! {};
147
148            if let Some((enum_interface, name)) = arg.to_enum_name() {
149                let e = if let Some(enum_interface) = enum_interface {
150                    pairs.iter().find_map(|pair| {
151                        pair.protocol
152                            .interfaces
153                            .iter()
154                            .find(|e| e.name == enum_interface)
155                            .and_then(|interface| interface.enums.iter().find(|e| e.name == name))
156                    })
157                } else {
158                    find_enum(&pair.protocol, &name)
159                };
160
161                if let Some(e) = e {
162                    if e.bitfield {
163                        build_convert = quote! { .bits() };
164                    } else {
165                        build_convert = quote! {  as u32 };
166                    }
167                }
168            }
169
170            let build_name = make_ident(arg.name.to_snek_case());
171            let mut build_name = quote! { #build_name };
172
173            if arg.is_return_option() && !arg.allow_null {
174                build_name = quote! { Some(#build_name) }
175            }
176
177            build_args.push(quote! { .#build_ty(#build_name #build_convert) })
178        }
179
180        requests.push(quote! {
181            #(#docs)*
182            async fn #name(#(#args),*) -> crate::client::Result<()> {
183                tracing::debug!(#tracing_inner, sender_id);
184
185                let (payload,fds) = crate::wire::PayloadBuilder::new()
186                    #(#build_args)*
187                    .build();
188
189                client
190                    .send_message(crate::wire::Message::new(sender_id, #opcode, payload, fds))
191                    .await
192                    .map_err(crate::client::Error::IoError)
193            }
194        });
195    }
196
197    requests
198}
199
200fn write_events(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
201    let mut requests = Vec::new();
202
203    for request in &interface.events {
204        let docs = description_to_docs(request.description.as_ref());
205        let name = make_ident(request.name.to_snek_case());
206        let mut args = vec![
207            quote! {&self },
208            quote! { client: &mut crate::server::Client },
209            quote! { sender_id: crate::wire::ObjectId },
210        ];
211
212        for arg in &request.args {
213            let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
214
215            if arg.allow_null {
216                ty = quote! {Option<#ty>};
217            }
218
219            let name = make_ident(arg.name.to_snek_case());
220
221            args.push(quote! {#name: #ty})
222        }
223
224        requests.push(quote! {
225            #(#docs)*
226            async fn #name(#(#args),*) -> crate::client::Result<()>;
227        });
228    }
229
230    requests
231}