1use heck::{ToSnekCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use tracing::debug;
5
6use crate::{
7 common::write_dispatchers,
8 parser::{Interface, Pair},
9 utils::{description_to_docs, find_enum, make_ident, write_enums},
10};
11
12pub fn generate_client_code(current: &[Pair], pairs: &[Pair]) -> TokenStream {
13 let mut modules = Vec::new();
14
15 for pair in current {
16 let protocol = &pair.protocol;
17 debug!("Generating client code for \"{}\"", &protocol.name);
18
19 let mut inner_modules = Vec::new();
20
21 for interface in &protocol.interfaces {
22 let docs = description_to_docs(interface.description.as_ref());
23 let module_name = make_ident(&interface.name);
24 let trait_name = make_ident(interface.name.to_upper_camel_case());
25 let trait_docs = format!(
26 "Trait to implement the {} interface. See the module level documentation for more info",
27 interface.name
28 );
29
30 let name = &interface.name;
31 let version = &interface.version;
32
33 let dispatchers = write_dispatchers(interface, interface.events.clone().into_iter());
34 let enums = write_enums(interface);
35 let requests = write_requests(pairs, pair, interface);
36 let events = write_events(pairs, pair, interface);
37
38 let imports = if requests.is_empty() {
39 quote! {}
40 } else {
41 quote! {use futures_util::SinkExt;}
42 };
43
44 let handler_args = if dispatchers.is_empty() {
45 quote! {
46 _client: &mut crate::server::Client,
47 _sender_id: crate::wire::ObjectId,
48 }
49 } else {
50 quote! {
51 client: &mut crate::server::Client,
52 sender_id: crate::wire::ObjectId,
53 }
54 };
55
56 inner_modules.push(quote! {
57 #(#docs)*
58 #[allow(clippy::too_many_arguments)]
59 pub mod #module_name {
60 #[allow(unused)]
61 use std::os::fd::AsRawFd;
62 #imports
63
64 #(#enums)*
65
66 #[doc = #trait_docs]
67 pub trait #trait_name {
68 const INTERFACE: &'static str = #name;
69 const VERSION: u32 = #version;
70
71 async fn handle_event(
72 &self,
73 #handler_args
74 message: &mut crate::wire::Message,
75 ) -> crate::client::Result<()> {
76 #[allow(clippy::match_single_binding)]
77 match message.opcode() {
78 #(#dispatchers),*
79 _ => Err(crate::client::Error::UnknownOpcode),
80 }
81 }
82
83 #(#requests)*
84 #(#events)*
85 }
86 }
87 })
88 }
89
90 let docs = description_to_docs(protocol.description.as_ref());
91 let module_name = make_ident(&protocol.name);
92
93 modules.push(quote! {
94 #(#docs)*
95 #[allow(clippy::module_inception)]
96 pub mod #module_name {
97 #(#inner_modules)*
98 }
99 })
100 }
101
102 quote! {
103 #![allow(async_fn_in_trait)]
104 #(#modules)*
105 }
106}
107
108fn write_requests(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
109 let mut requests = Vec::new();
110
111 for (opcode, request) in interface.requests.iter().enumerate() {
112 let opcode = opcode as u16;
113
114 let docs = description_to_docs(request.description.as_ref());
115 let name = make_ident(request.name.to_snek_case());
116 let tracing_inner = format!(
117 "-> {}#{{}}.{}()",
118 interface.name,
119 request.name.to_snek_case()
120 );
121
122 let mut args = vec![
123 quote! { &self },
124 quote! { client: &mut crate::server::Client },
125 quote! { sender_id: crate::wire::ObjectId },
126 ];
127
128 for arg in &request.args {
129 let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
130
131 if arg.allow_null {
132 ty = quote! {Option<#ty>};
133 }
134
135 let name = make_ident(arg.name.to_snek_case());
136
137 args.push(quote! {#name: #ty})
138 }
139
140 let mut build_args = Vec::new();
141
142 for arg in &request.args {
143 let build_ty = arg.to_caller();
144 let build_ty = format_ident!("put_{build_ty}");
145
146 let mut build_convert = quote! {};
147
148 if let Some((enum_interface, name)) = arg.to_enum_name() {
149 let e = if let Some(enum_interface) = enum_interface {
150 pairs.iter().find_map(|pair| {
151 pair.protocol
152 .interfaces
153 .iter()
154 .find(|e| e.name == enum_interface)
155 .and_then(|interface| interface.enums.iter().find(|e| e.name == name))
156 })
157 } else {
158 find_enum(&pair.protocol, &name)
159 };
160
161 if let Some(e) = e {
162 if e.bitfield {
163 build_convert = quote! { .bits() };
164 } else {
165 build_convert = quote! { as u32 };
166 }
167 }
168 }
169
170 let build_name = make_ident(arg.name.to_snek_case());
171 let mut build_name = quote! { #build_name };
172
173 if arg.is_return_option() && !arg.allow_null {
174 build_name = quote! { Some(#build_name) }
175 }
176
177 build_args.push(quote! { .#build_ty(#build_name #build_convert) })
178 }
179
180 requests.push(quote! {
181 #(#docs)*
182 async fn #name(#(#args),*) -> crate::client::Result<()> {
183 tracing::debug!(#tracing_inner, sender_id);
184
185 let (payload,fds) = crate::wire::PayloadBuilder::new()
186 #(#build_args)*
187 .build();
188
189 client
190 .send_message(crate::wire::Message::new(sender_id, #opcode, payload, fds))
191 .await
192 .map_err(crate::client::Error::IoError)
193 }
194 });
195 }
196
197 requests
198}
199
200fn write_events(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
201 let mut requests = Vec::new();
202
203 for request in &interface.events {
204 let docs = description_to_docs(request.description.as_ref());
205 let name = make_ident(request.name.to_snek_case());
206 let mut args = vec![
207 quote! {&self },
208 quote! { client: &mut crate::server::Client },
209 quote! { sender_id: crate::wire::ObjectId },
210 ];
211
212 for arg in &request.args {
213 let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
214
215 if arg.allow_null {
216 ty = quote! {Option<#ty>};
217 }
218
219 let name = make_ident(arg.name.to_snek_case());
220
221 args.push(quote! {#name: #ty})
222 }
223
224 requests.push(quote! {
225 #(#docs)*
226 async fn #name(#(#args),*) -> crate::client::Result<()>;
227 });
228 }
229
230 requests
231}