1use convert_case::{Case, Casing};
2use proc_macro::{self, TokenStream};
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{punctuated::Punctuated, token::Comma, Token};
6use uuid::Uuid;
7
8#[cfg(test)]
9mod tests;
10
11fn skip_self(
12    arguments: &Punctuated<syn::FnArg, syn::token::Comma>,
13) -> Punctuated<syn::FnArg, syn::token::Comma> {
14    let mut output = Punctuated::new();
15    for arg in arguments {
16        if let syn::FnArg::Typed(pat_type) = &arg {
17            if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*pat_type.pat {
18                if ident != "self" {
19                    output.push(arg.clone());
20                }
21            }
22        }
23    }
24
25    output
26}
27
28fn request(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
29    let items = item.items.iter().filter_map(|item| match item {
30        syn::TraitItem::Fn(method) => {
31            let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
32            let args = skip_self(&method.sig.inputs);
33            let item = quote! {
34                #ident { #args },
35            };
36            Some(item)
37        }
38        _ => None,
39    });
40
41    let output = quote! {
42        #[derive(Debug, serde::Serialize, serde::Deserialize)]
43        pub enum Request {
44           #(#items)*
45        }
46    };
47    output
48}
49
50fn extract_return_type(return_type: &syn::ReturnType) -> proc_macro2::TokenStream {
51    match return_type {
52        syn::ReturnType::Default => {
53            quote! {
54                ()
55            }
56        }
57        syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
58            syn::Type::ImplTrait(impl_trait) => {
59                let return_type = extract_stream_item_type(&impl_trait);
60                quote! {
61                    #return_type
62                }
63            }
64            _ => {
65                quote! {
66                    #return_type
67                }
68            }
69        },
70    }
71}
72
73fn is_stream(return_type: &syn::ReturnType) -> bool {
74    match return_type {
75        syn::ReturnType::Default => false,
76        syn::ReturnType::Type(_, return_type) => {
77            matches!(*return_type.to_owned(), syn::Type::ImplTrait(_))
78        }
79    }
80}
81
82fn response(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
83    let items = item.items.iter().filter_map(|item| match item {
84        syn::TraitItem::Fn(method) => {
85            let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
86            let return_type = extract_return_type(&method.sig.output);
87            let item = if is_stream(&method.sig.output) {
88                quote! { #ident(zzrpc::producer::StreamResponse<#return_type>), }
89            } else {
90                quote! { #ident(#return_type), }
91            };
92            Some(item)
93        }
94        _ => None,
95    });
96
97    let output = quote! {
98        #[derive(Debug, serde::Serialize, serde::Deserialize)]
99        pub enum Response {
100            #(#items)*
101        }
102    };
103    output
104}
105
106fn consumer_senders(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
107    let items = item.items.iter().filter_map(|item| match item {
108        syn::TraitItem::Fn(method) => {
109            let ident = format_ident!(
110                "{}__sender",
111                method.sig.ident.to_string().to_case(Case::Snake)
112            );
113            let return_type = extract_return_type(&method.sig.output);
114            let item = quote! {
115                #ident: std::sync::Arc<
116                    zzrpc::futures::channel::mpsc::UnboundedSender<(
117                        zzrpc::consumer::Message<Request>,
118                        zzrpc::consumer::ResultSender<#return_type, Error>,
119                    )>,
120                >,
121            };
122            Some(item)
123        }
124        _ => None,
125    });
126
127    let output = quote! {
128        #[derive(Debug)]
129        struct Senders<Error> {
130            #(#items)*
131        }
132    };
133    output
134}
135
136fn impl_consumer_state(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
137    let mut channels = vec![];
138    let mut senders = vec![];
139    let mut items = vec![];
140    let mut handlers = vec![];
141    let mut drainers = vec![];
142
143    for method in item.items.iter().filter_map(|item| match item {
144        syn::TraitItem::Fn(method) => Some(method),
145        _ => None,
146    }) {
147        let ident = format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
148        let ident_requests = format_ident!(
149            "{}__requests",
150            method.sig.ident.to_string().to_case(Case::Snake)
151        );
152        let ident_sender = format_ident!(
153            "{}__sender",
154            method.sig.ident.to_string().to_case(Case::Snake)
155        );
156        let ident_receiver = format_ident!(
157            "{}__receiver",
158            method.sig.ident.to_string().to_case(Case::Snake)
159        );
160        let return_type = extract_return_type(&method.sig.output);
161
162        channels.push(quote! {
163            let (#ident_sender, #ident_receiver) = zzrpc::futures::channel::mpsc::unbounded::<(
164                zzrpc::consumer::Message<Request>,
165                zzrpc::consumer::ResultSender<#return_type, Error>,
166            )>();
167        });
168
169        senders.push(quote! {
170            #ident_sender: std::sync::Arc::new(#ident_sender),
171        });
172
173        items.push(quote! {
174            #ident_requests: std::collections::HashMap::new(),
175            #ident_receiver,
176        });
177
178        handlers.push(if is_stream(&method.sig.output) {
179            quote! {
180                Response::#ident(result) => {
181                    match result {
182                        zzrpc::producer::StreamResponse::Open => {
183                            if let Some((sender, _)) = self.#ident_requests.get_mut(&id) {
184                                if let Some(sender) = sender.take() {
185                                    let _ = sender.send(Ok(()));
186                                    self.pending -= 1;
187                                }
188                            }
189                        },
190                        zzrpc::producer::StreamResponse::Item(item) => {
191                            if let Some((_, sender)) = self.#ident_requests.get_mut(&id) {
192                                let _ = sender.unbounded_send(item);
193                            }
194                        },
195                        zzrpc::producer::StreamResponse::Closed => {
196                            self.#ident_requests.remove(&id);
197                        },
198                    }
199                }
200            }
201        } else {
202            quote! {
203                Response::#ident(result) => {
204                    if let Some(sender) = self.#ident_requests.remove(&id) {
205                        let _ = sender.send(Ok(result));
206                        self.pending -= 1;
207                    }
208                }
209            }
210        });
211
212        if is_stream(&method.sig.output) {
213            drainers.push(quote! {
214                for (_id, (mut sender, _)) in self.#ident_requests.drain() {
215                    if let Some(sender) = sender.take() {
216                        let _ = sender.send(Err(shutdown_type.into()));
217                    }
218                }
219            });
220        } else {
221            drainers.push(quote! {
222                for (_id, sender) in self.#ident_requests.drain() {
223                    let _ = sender.send(Err(shutdown_type.into()));
224                }
225            });
226        }
227    }
228
229    let output = quote! {
230        impl<Error> ConsumerState<Error> {
231            fn new() -> (Senders<Error>, Self) {
232                #(#channels)*
233
234                let senders = Senders {
235                    #(#senders)*
236                };
237
238                let state = ConsumerState {
239                    pending: 0,
240                    #(#items)*
241                };
242
243                (senders, state)
244            }
245
246            fn handle_message(
247                &mut self,
248                message: zzrpc::producer::Message<Response>,
249            ) -> Option<zzrpc::ShutdownType> {
250                match message {
251                    zzrpc::producer::Message::Response { id, response } => {
252                        match response {
253                            #(#handlers)*
254                        }
255                        None
256                    }
257                    zzrpc::producer::Message::Aborted => Some(zzrpc::ShutdownType::Aborted),
258                    zzrpc::producer::Message::Shutdown => Some(zzrpc::ShutdownType::Shutdown),
259                }
260            }
261
262            fn idle(&self) -> bool {
263                self.pending == 0
264            }
265
266            fn shutdown(&mut self, shutdown_type: zzrpc::ShutdownType) {
267                #(#drainers)*
268            }
269        }
270    };
271    output
272}
273
274fn consumer_state(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
275    let items = item.items.iter().filter_map(|item| match item {
276        syn::TraitItem::Fn(method) => {
277            let ident_requests = format_ident!("{}__requests", method.sig.ident.to_string().to_case(Case::Snake));
278            let ident_receiver = format_ident!("{}__receiver", method.sig.ident.to_string().to_case(Case::Snake));
279            let return_type = extract_return_type(&method.sig.output);
280
281            let sender = if is_stream(&method.sig.output) {
282                quote! {
283                    (Option<zzrpc::futures::channel::oneshot::Sender<zzrpc::consumer::Result<(), Error>>>,
284                    zzrpc::futures::channel::mpsc::UnboundedSender<#return_type>)
285                }
286            } else {
287                quote! {
288                    zzrpc::futures::channel::oneshot::Sender<zzrpc::consumer::Result<#return_type, Error>>
289                }
290            };
291
292            let item = quote! {
293                #ident_requests: std::collections::HashMap<usize, #sender>,
294                #ident_receiver: zzrpc::futures::channel::mpsc::UnboundedReceiver<(
295                    zzrpc::consumer::Message<Request>,
296                    zzrpc::consumer::ResultSender<#return_type, Error>,
297                )>,
298            };
299            Some(item)
300        }
301        _ => None,
302    });
303
304    let impl_consumer_state = impl_consumer_state(item);
305
306    let output = quote! {
307        #[derive(Debug)]
308        struct ConsumerState<Error> {
309            pending: usize,
310            #(#items)*
311        }
312
313        #impl_consumer_state
314    };
315    output
316}
317
318fn impl_consume(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
319    let items = item.items.iter().filter_map(|item| match item {
320        syn::TraitItem::Fn(method) => {
321            let ident_option = format_ident!(
322                "{}__option",
323                method.sig.ident.to_string().to_case(Case::Snake)
324            );
325            let ident_receiver = format_ident!(
326                "{}__receiver",
327                method.sig.ident.to_string().to_case(Case::Snake)
328            );
329            let ident_requests = format_ident!(
330                "{}__requests",
331                method.sig.ident.to_string().to_case(Case::Snake)
332            );
333
334            let patterns = if is_stream(&method.sig.output) {
335                quote! {
336                    zzrpc::consumer::ResultSender::Stream { result_sender, values_sender } => {
337                        if let Err(error) = result {
338                            let _ = result_sender.send(Err(error));
339                        } else {
340                            state.#ident_requests.insert(id, (Some(result_sender), values_sender));
341                            state.pending += 1;
342                        }
343                    },
344                    zzrpc::consumer::ResultSender::Abort => {
345                        if let Some((sender, _)) = state.#ident_requests.remove(&id) {
346                            if sender.is_some() {
347                                state.pending -= 1;
348                            }
349                        }
350                    },
351                    _ => unreachable!("value sender got when stream sender expected"),
352                }
353            } else {
354                quote! {
355                    zzrpc::consumer::ResultSender::Value(sender) => {
356                        if let Err(error) = result {
357                            let _ = sender.send(Err(error));
358                        } else {
359                            state.#ident_requests.insert(id, sender);
360                            state.pending += 1;
361                        }
362                    },
363                    zzrpc::consumer::ResultSender::Abort => {
364                        if state.#ident_requests.remove(&id).is_some() {
365                            state.pending -= 1;
366                        }
367                    },
368                    _ => unreachable!("stream sender got when value sender expected"),
369                }
370            };
371
372            let item = quote! {
373                #ident_option = zzrpc::futures::StreamExt::next(&mut state.#ident_receiver) => {
374                    if let Some((message, result_sender)) = #ident_option {
375                        if timeout.is_some() {
376                            timeout_future.reset();
377                        }
378
379                        let id = message.id;
380                        let result = sender.send(message).await;
381                        let result = result.map_err(|error| {
382                            match error {
383                                mezzenger::Error::Closed => zzrpc::Error::Closed,
384                                mezzenger::Error::Other(error) => zzrpc::Error::Transport(error),
385                            }
386                        });
387                        match result_sender {
388                            #patterns
389                        }
390                    }
391                },
392            };
393            Some(item)
394        }
395        _ => None,
396    });
397
398    let implementation = quote! {
399        use zzrpc::futures::{FutureExt, SinkExt, StreamExt};
400
401        let zzrpc::consumer::Configuration {
402            shutdown,
403            mut receive_error_callback,
404            timeout,
405            ..
406        } = configuration;
407
408        let (drop_sender, mut drop_receiver) = zzrpc::futures::channel::oneshot::channel::<()>();
409        let drop_sender = Some(drop_sender);
410
411        let (senders, mut state) = ConsumerState::new();
412
413        zzrpc::spawn(async move {
414            let (mut sender, receiver) = transport.split();
415            let receiver = zzrpc::futures::StreamExt::fuse(receiver);
416
417            let timeout_future = if let Some(duration) = timeout {
418                zzrpc::Timeout::new(duration)
419            } else {
420                zzrpc::Timeout::never()
421            };
422            let shutdown = zzrpc::futures::FutureExt::fuse(shutdown);
423            zzrpc::futures::pin_mut!(receiver, timeout_future, shutdown);
424
425            loop {
426                zzrpc::futures::select! {
427                    receive_option = zzrpc::futures::StreamExt::next(&mut receiver) => {
428                        if let Some(receive_result) = receive_option {
429                            if timeout.is_some() {
430                                timeout_future.reset();
431                            }
432
433                            match receive_result {
434                                Ok(message) =>  {
435                                    if let Some(shutdown_type) = state.handle_message(message) {
436                                        state.shutdown(shutdown_type);
437                                        break;
438                                    }
439                                },
440                                Err(error) => {
441                                    if let zzrpc::HandlingStrategy::Stop(shutdown_type) = receive_error_callback.on_receive_error(error) {
442                                        state.shutdown(shutdown_type);
443                                        break;
444                                    }
445                                },
446                            }
447                        } else {
448                            state.shutdown(zzrpc::ShutdownType::Closed);
449                            break;
450                        }
451                    },
452                    #(#items)*
453                    _ = &mut timeout_future => {
454                        if timeout.is_some() && !state.idle() {
455                            state.shutdown(zzrpc::ShutdownType::Timeout);
456                            break;
457                        }
458                    },
459                    shutdown_type = &mut shutdown => {
460                        state.shutdown(shutdown_type);
461                        break;
462                    }
463                    _ = &mut drop_receiver => { break; }
464                }
465            }
466        });
467
468        Consumer {
469            id_counter: zzrpc::atomic_counter::ConsistentCounter::new(0),
470            senders,
471            drop_sender,
472        }
473    };
474
475    let output = quote! {
476        impl<Error> zzrpc::consumer::Consume<Consumer<Error>, Error> for Consumer<Error> {
477            type Request = Request;
478            type Response = Response;
479
480            #[cfg(not(target_arch = "wasm32"))]
481            fn consume_unreliable<Transport, Shutdown, ReceiveErrorCallback>(
482                transport: Transport,
483                configuration: zzrpc::consumer::Configuration<Shutdown, Error, ReceiveErrorCallback>,
484            ) -> Consumer<Error>
485            where
486                Transport: mezzenger::Transport<
487                        zzrpc::producer::Message<Self::Response>,
488                        zzrpc::consumer::Message<Self::Request>,
489                        Error,
490                    > + mezzenger::Reliable
491                    + mezzenger::Order
492                    + Send
493                    + 'static,
494                Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + Send + 'static,
495                ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + Send + 'static,
496                Error: Send + 'static, {
497                #implementation
498            }
499
500            #[cfg(target_arch = "wasm32")]
501            fn consume_unreliable<Transport, Shutdown, ReceiveErrorCallback>(
502                transport: Transport,
503                configuration: zzrpc::consumer::Configuration<Shutdown, Error, ReceiveErrorCallback>,
504            ) -> Consumer<Error>
505            where
506                Transport: mezzenger::Transport<
507                        zzrpc::producer::Message<Self::Response>,
508                        zzrpc::consumer::Message<Self::Request>,
509                        Error,
510                    > + mezzenger::Reliable
511                    + mezzenger::Order
512                    + 'static,
513                Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + 'static,
514                ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + 'static,
515                Error: 'static, {
516                #implementation
517            }
518        }
519    };
520    output
521}
522
523fn pattern_arguments(
524    arguments: &Punctuated<syn::FnArg, syn::token::Comma>,
525) -> Punctuated<syn::Ident, Comma> {
526    let mut output = Punctuated::new();
527    for arg in arguments {
528        if let syn::FnArg::Typed(pat_type) = &arg {
529            if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*pat_type.pat {
530                if ident != "self" {
531                    output.push(ident.clone());
532                }
533            }
534        }
535    }
536
537    output
538}
539
540fn impl_api(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
541    let ident = &item.ident;
542
543    let items = item.items.iter().filter_map(|item| match item {
544        syn::TraitItem::Fn(method) => {
545            let ident_request =
546                format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
547            let ident_sender = format_ident!(
548                "{}__sender",
549                method.sig.ident.to_string().to_case(Case::Snake)
550            );
551
552            let mut signature = method.sig.clone();
553            signature.asyncness = None;
554            match &signature.output {
555                syn::ReturnType::Default => {
556                    signature.output = syn::parse2::<syn::ReturnType>(
557                        quote!(-> zzrpc::ValueRequest<(), Request, Self::Error>),
558                    )
559                    .unwrap();
560                }
561                syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
562                    syn::Type::ImplTrait(impl_trait) => {
563                        let return_type = extract_stream_item_type(&impl_trait);
564                        signature.output = syn::parse2::<syn::ReturnType>(
565                            quote!(-> zzrpc::StreamRequest<#return_type, Request, Self::Error>),
566                        )
567                        .unwrap();
568                    }
569                    _ => {
570                        signature.output = syn::parse2::<syn::ReturnType>(
571                            quote!(-> zzrpc::ValueRequest<#return_type, Request, Self::Error>),
572                        )
573                        .unwrap();
574                    }
575                },
576            };
577
578            let ident_request_future = if is_stream(&method.sig.output) {
579                quote!(zzrpc::StreamRequest)
580            } else {
581                quote!(zzrpc::ValueRequest)
582            };
583
584            let arguments = pattern_arguments(&method.sig.inputs);
585
586            let item = quote! {
587                #signature {
588                    use zzrpc::atomic_counter::AtomicCounter;
589                    let request = Request::#ident_request { #arguments };
590                    #ident_request_future::new(
591                        self.senders.#ident_sender.clone(),
592                        self.id_counter.inc(),
593                        request,
594                    )
595                }
596            };
597            Some(item)
598        }
599        _ => None,
600    });
601
602    let output = quote! {
603        impl<Error> #ident for Consumer<Error> {
604            #(#items)*
605
606            type Request = Request;
607            type Error = Error;
608        }
609    };
610    output
611}
612
613fn consumer(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
614    let senders = consumer_senders(item);
615    let state = consumer_state(item);
616
617    let impl_consume = impl_consume(item);
618    let impl_api = impl_api(item);
619
620    let output = quote! {
621        #senders
622
623        #state
624
625        #[derive(Debug)]
626        pub struct Consumer<Error> {
627            id_counter: zzrpc::atomic_counter::ConsistentCounter,
628            senders: Senders<Error>,
629            drop_sender: Option<zzrpc::futures::channel::oneshot::Sender<()>>,
630        }
631
632        #impl_consume
633
634        #impl_api
635
636        impl<Error> Drop for Consumer<Error> {
637            fn drop(&mut self) {
638                if let Some(sender) = self.drop_sender.take() {
639                    let _ = sender.send(());
640                }
641            }
642        }
643    };
644    output
645}
646
647fn impl_produce(item: &syn::ItemTrait) -> proc_macro2::TokenStream {
648    let items = item.items.iter().filter_map(|item| match item {
649        syn::TraitItem::Fn(method) => {
650            let ident = method.sig.ident.clone();
651            let ident_request =
652                format_ident!("{}", method.sig.ident.to_string().to_case(Case::Pascal));
653
654            let arguments = pattern_arguments(&method.sig.inputs);
655
656            let item = if is_stream(&method.sig.output) {
657                quote! {
658                    Request::#ident_request { #arguments } => {
659                        let me = me.clone();
660                        let reply_sender = reply_sender.clone();
661                        let remove_aborter_sender = remove_aborter_sender.clone();
662                        spawn(async move {
663                            let mut stream = zzrpc::futures::StreamExt::fuse(me.#ident(#arguments).await);
664
665                            let response = zzrpc::producer::StreamResponse::Open;
666                            let response = Response::#ident_request(response);
667                            let message = zzrpc::producer::Message::Response { id, response };
668                            let _ = reply_sender.unbounded_send(message);
669                            
670                            loop {
671                                select! {
672                                    result = zzrpc::futures::StreamExt::next(&mut stream) => {
673                                        if let Some(result) = result {
674                                            let response = zzrpc::producer::StreamResponse::Item(result);
675                                            let response = Response::#ident_request(response);
676                                            let message = zzrpc::producer::Message::Response { id, response };
677                                            let _ = reply_sender.unbounded_send(message);
678                                        } else {
679                                            let response = zzrpc::producer::StreamResponse::Closed;
680                                            let response = Response::#ident_request(response);
681                                            let message = zzrpc::producer::Message::Response { id, response };
682                                            let _ = reply_sender.unbounded_send(message);
683
684                                            let _ = remove_aborter_sender.unbounded_send(id);
685                                            break;
686                                        }
687                                    },
688                                    _ = abort_receiver => {
689                                        break;
690                                    },
691                                };
692                            }
693                        });
694                    },
695                }
696            } else {
697                quote! {
698                    Request::#ident_request { #arguments } => {
699                        let me = me.clone();
700                        let reply_sender = reply_sender.clone();
701                        let remove_aborter_sender = remove_aborter_sender.clone();
702                        spawn(async move {
703                            let task = zzrpc::futures::FutureExt::fuse(me.#ident(#arguments));
704                            pin_mut!(task);
705                            select! {
706                                result = task => {
707                                    let response = Response::#ident_request(result);
708                                    let message = zzrpc::producer::Message::Response { id, response };
709                                    let _ = reply_sender.unbounded_send(message);
710                                    let _ = remove_aborter_sender.unbounded_send(id);
711                                },
712                                _ = abort_receiver => (),
713                            };
714                        });
715                    },
716                }
717            };
718
719            Some(item)
720        }
721        _ => None,
722    });
723
724    let uuid = Uuid::new_v4();
725    let ident = format_ident!("__impl_produce_{}", uuid.simple().to_string());
726
727    let output = quote! {
728        #[macro_export]
729        macro_rules! #ident {
730            ($self:ident, $transport:ident, $configuration:ident) => {
731                zzrpc::spawn(async move {
732                    use std::sync::Arc;
733                    use std::collections::HashMap;
734                    use futures::{
735                        channel::{
736                            mpsc::unbounded,
737                            oneshot,
738                        },
739                        pin_mut, select, SinkExt, StreamExt,
740                    };
741                    use zzrpc::{ShutdownType, Timeout};
742
743                    use zzrpc::spawn;
744
745                    let me = Arc::new($self);
746
747                    let (mut sender, receiver) = $transport.split();
748                    let mut receiver = zzrpc::futures::StreamExt::fuse(receiver);
749
750                    let (reply_sender, mut reply_receiver) = unbounded::<zzrpc::producer::Message<Response>>();
751
752                    let mut aborters: HashMap<usize, oneshot::Sender<()>> = HashMap::new();
753                    let (remove_aborter_sender, mut remove_aborter_receiver) = unbounded::<usize>();
754
755                    let (stop_sender, stop_receiver) = oneshot::channel::<ShutdownType>();
756                    let mut stop_sender = Some(stop_sender);
757
758                    let zzrpc::producer::Configuration {
759                        shutdown,
760                        mut send_error_callback,
761                        mut receive_error_callback,
762                        timeout,
763                        ..
764                    } = $configuration;
765
766                    spawn(async move {
767                        while let Some(message) = zzrpc::futures::StreamExt::next(&mut reply_receiver).await {
768                            match message {
769                                zzrpc::producer::Message::Response { id, .. } => {
770                                    let result = sender.send(message).await;
771                                    if let Err(error) = result {
772                                        if let zzrpc::HandlingStrategy::Stop(shutdown_type) =
773                                            send_error_callback.on_send_error(id, error)
774                                        {
775                                            if let Some(stop_sender) = stop_sender.take() {
776                                                let _ = stop_sender.send(shutdown_type);
777                                            }
778                                        }
779                                    }
780                                }
781                                _ => {
782                                    let _ = sender.send(message).await;
783                                }
784                            }
785                        }
786                    });
787
788                    let handle_message = |aborters: &mut HashMap<usize, oneshot::Sender<()>>, message: zzrpc::consumer::Message<Request>| {
789                        let zzrpc::consumer::Message { id, payload } = message;
790                        match payload {
791                            zzrpc::consumer::Payload::Request(request) => {
792                                let (abort_sender, mut abort_receiver) = oneshot::channel::<()>();
793                                aborters.insert(id, abort_sender);
794                                match request {
795                                     #(#items)*
796                                }
797                            }
798                            zzrpc::consumer::Payload::Abort => {
799                                if let Some(abort_sender) = aborters.remove(&id) {
800                                    let _ = abort_sender.send(());
801                                }
802                            }
803                        }
804                    };
805
806                    let mut return_value = ShutdownType::Closed;
807                    let mut handle_shutdown = |shutdown_type: ShutdownType| {
808                        return_value = shutdown_type;
809                        let message = match shutdown_type {
810                            ShutdownType::Shutdown => zzrpc::producer::Message::Shutdown,
811                            ShutdownType::Aborted => zzrpc::producer::Message::Aborted,
812                            _ => {
813                                return;
814                            }
815                        };
816                        let _ = reply_sender.unbounded_send(message);
817                    };
818
819                    let timeout_future = if let Some(duration) = timeout {
820                        Timeout::new(duration)
821                    } else {
822                        Timeout::never()
823                    };
824                    let shutdown = zzrpc::futures::FutureExt::fuse(shutdown);
825                    pin_mut!(shutdown, timeout_future, stop_receiver);
826                    loop {
827                        select! {
828                            receive_option = zzrpc::futures::StreamExt::next(&mut receiver) => {
829                                if timeout.is_some() {
830                                    timeout_future.reset();
831                                }
832
833                                if let Some(receive_result) = receive_option {
834                                    match receive_result {
835                                        Ok(message) => handle_message(&mut aborters, message),
836                                        Err(error) => {
837                                            if let zzrpc::HandlingStrategy::Stop(shutdown_type) = receive_error_callback.on_receive_error(error) {
838                                                handle_shutdown(shutdown_type);
839                                                break;
840                                            }
841                                        },
842                                    }
843                                } else {
844                                    handle_shutdown(ShutdownType::Closed);
845                                    break;
846                                }
847                            },
848                            id_option = zzrpc::futures::StreamExt::next(&mut remove_aborter_receiver) => {
849                                if let Some(id) = id_option {
850                                    aborters.remove(&id);
851                                }
852                            },
853                            _ = &mut timeout_future => {
854                                if timeout.is_some() {
855                                    if aborters.is_empty() {
856                                        handle_shutdown(ShutdownType::Timeout);
857                                        break;
858                                    } else {
859                                        timeout_future.reset();
860                                    }
861                                }
862                            },
863                            shutdown_type = &mut stop_receiver => {
864                                if let Ok(shutdown_type) = shutdown_type {
865                                    handle_shutdown(shutdown_type);
866                                    break;
867                                }
868                            },
869                            shutdown_type = &mut shutdown => {
870                                handle_shutdown(shutdown_type);
871                                break;
872                            },
873                        }
874                    }
875
876                    for (_, aborter) in aborters.drain() {
877                        let _ = aborter.send(());
878                    }
879
880                    return_value
881                })
882            }
883        }
884
885        pub use #ident as impl_produce;
886    };
887    output
888}
889
890fn extract_stream_item_type(impl_trait: &syn::TypeImplTrait) -> syn::Type {
891    if impl_trait.bounds.len() == 1 {
892        if let syn::TypeParamBound::Trait(bound) = &impl_trait.bounds[0] {
893            if bound.path.segments.len() == 1 {
894                let stream = &bound.path.segments[0];
895                if stream.ident == "Stream" {
896                    if let syn::PathArguments::AngleBracketed(arguments) = &stream.arguments {
897                        if arguments.args.len() == 1 {
898                            let argument = &arguments.args[0];
899                            if let syn::GenericArgument::AssocType(binding) = argument {
900                                if binding.ident == "Item" {
901                                    return binding.ty.clone();
902                                }
903                            }
904                        }
905                    }
906                }
907            }
908        }
909    }
910    panic!("invalid stream request method return type");
911}
912
913fn modify_trait(mut item: syn::ItemTrait) -> syn::ItemTrait {
914    if item.generics.lt_token.is_some() {
915        panic!("generic traits are not supported");
916    }
917
918    item.items.push(
919        syn::parse2::<syn::TraitItem>(quote! {
920            type Request;
922        })
923        .unwrap(),
924    );
925
926    item.items.push(
927        syn::parse2::<syn::TraitItem>(quote! {
928            type Error;
930        })
931        .unwrap(),
932    );
933
934    let must_use = format_ident!("must_use");
935    let must_use = syn::Attribute {
936        pound_token: Token!(#)(Span::call_site()),
937        style: syn::AttrStyle::Outer,
938        bracket_token: syn::token::Bracket(Span::call_site()),
939        meta: syn::Meta::Path(syn::Path {
940            leading_colon: None,
941            segments: syn::punctuated::Punctuated::from_iter([syn::PathSegment::from(must_use)]),
942        }),
943    };
944
945    for method in item.items.iter_mut().filter_map(|item| {
946        if let syn::TraitItem::Fn(func) = item {
947            Some(func)
948        } else {
949            None
950        }
951    }) {
952        if method.sig.asyncness.is_none() {
953            panic!("all api methods should be marked as \"async\"")
954        }
955
956        if method.sig.generics.lt_token.is_some() {
957            panic!("generic methods are not supported")
958        }
959
960        method.sig.asyncness = None;
961        method.attrs.push(must_use.clone());
962
963        match &method.sig.output {
964            syn::ReturnType::Default => {
965                method.sig.output = syn::parse2::<syn::ReturnType>(
966                    quote!(-> zzrpc::ValueRequest<(), Request, Self::Error>),
967                )
968                .unwrap();
969            }
970            syn::ReturnType::Type(_, return_type) => match *return_type.to_owned() {
971                syn::Type::ImplTrait(impl_trait) => {
972                    let return_type = extract_stream_item_type(&impl_trait);
973                    method.sig.output = syn::parse2::<syn::ReturnType>(
974                        quote!(-> zzrpc::StreamRequest<#return_type, Request, Self::Error>),
975                    )
976                    .unwrap();
977                }
978                _ => {
979                    method.sig.output = syn::parse2::<syn::ReturnType>(
980                        quote!(-> zzrpc::ValueRequest<#return_type, Request, Self::Error>),
981                    )
982                    .unwrap();
983                }
984            },
985        };
986    }
987
988    item
989}
990
991#[proc_macro_attribute]
992pub fn api(_attr: TokenStream, item: TokenStream) -> TokenStream {
993    if let Ok(item) = syn::parse2::<syn::ItemTrait>(item.into()) {
994        let item_modified = modify_trait(item.clone());
995
996        let request = request(&item);
997        let response = response(&item);
998        let consumer = consumer(&item);
999        let impl_produce = impl_produce(&item);
1000
1001        let output = quote! {
1002            #item_modified
1003
1004            #request
1005
1006            #response
1007
1008            #consumer
1009
1010            #impl_produce
1011        };
1012
1013        output.into()
1014    } else {
1015        panic!("expected a trait")
1016    }
1017}
1018
1019#[proc_macro_derive(Produce)]
1020pub fn produce(input: TokenStream) -> TokenStream {
1021    let syn::DeriveInput { ident, .. } = syn::parse_macro_input!(input);
1022    let output = quote! {
1023        impl zzrpc::producer::Produce for #ident {
1024            type Request = Request;
1025            type Response = Response;
1026
1027            #[cfg(not(target_arch = "wasm32"))]
1028            fn produce_unreliable<Transport, Error, Shutdown, SendErrorCallback, ReceiveErrorCallback>(
1029                self,
1030                transport: Transport,
1031                configuration: zzrpc::producer::Configuration<
1032                    Shutdown,
1033                    Error,
1034                    SendErrorCallback,
1035                    ReceiveErrorCallback,
1036                >,
1037            ) -> zzrpc::JoinHandle<zzrpc::ShutdownType> where
1038                Transport: mezzenger::Transport<
1039                        zzrpc::consumer::Message<Self::Request>,
1040                        zzrpc::producer::Message<Self::Response>,
1041                        Error,
1042                    > + mezzenger::Reliable
1043                    + mezzenger::Order
1044                    + Send
1045                    + 'static,
1046                Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + Send + 'static,
1047                SendErrorCallback: zzrpc::SendErrorCallback<Error> + Send + 'static,
1048                ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + Send + 'static {
1049                impl_produce!(self, transport, configuration)
1050            }
1051
1052            #[cfg(target_arch = "wasm32")]
1053            fn produce_unreliable<Transport, Error, Shutdown, SendErrorCallback, ReceiveErrorCallback>(
1054                self,
1055                transport: Transport,
1056                configuration: zzrpc::producer::Configuration<
1057                    Shutdown,
1058                    Error,
1059                    SendErrorCallback,
1060                    ReceiveErrorCallback,
1061                >,
1062            ) -> zzrpc::JoinHandle<zzrpc::ShutdownType> where
1063                Transport: mezzenger::Transport<
1064                        zzrpc::consumer::Message<Self::Request>,
1065                        zzrpc::producer::Message<Self::Response>,
1066                        Error,
1067                    > + mezzenger::Reliable
1068                    + mezzenger::Order
1069                    + 'static,
1070                Shutdown: zzrpc::futures::Future<Output = zzrpc::ShutdownType> + 'static,
1071                SendErrorCallback: zzrpc::SendErrorCallback<Error> + 'static,
1072                ReceiveErrorCallback: zzrpc::ReceiveErrorCallback<Error> + 'static {
1073                impl_produce!(self, transport, configuration)
1074            }
1075        }
1076    };
1077    output.into()
1078}