1use proc_macro::TokenStream;
2use quote::{ToTokens, format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::{
5 Data, DeriveInput, Expr, Fields, Generics, Ident, ImplItem, ImplItemFn, ImplItemType, ItemImpl,
6 ItemStruct, LitBool, LitStr, Token, Type, parse_macro_input,
7};
8
9#[cfg(feature = "test-utils")]
67#[proc_macro_attribute]
68pub fn yrs_common_test(_: TokenStream, item: TokenStream) -> TokenStream {
69 let input_fn = parse_macro_input!(item as syn::ItemFn);
70 let fn_name = &input_fn.sig.ident;
71 let original_fn_def = input_fn.to_token_stream();
72
73 quote! {
74 macro_rules! delegate_poll_method {
80 ($self:expr, $cx:expr, $method:ident $(($($arg:expr),*))?) => {{
81 use ::std::task::Poll;
82 use ::std::task::ready;
83 let sink = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
84 let result = ready!(sink.$method($($($arg),*)?));
85 match result {
86 Ok(_) => Poll::Ready(Ok(())),
87 Err(e) => Poll::Ready(Err(::yrs::sync::Error::Other(Box::new(e)))),
88 }
89 }};
90 }
91
92 macro_rules! delegate_result_method {
94 ($self:expr, $method:ident $(($($arg:expr),*))?) => {{
95 let sink = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
96 let result = sink.$method($($($arg),*)?);
97 match result {
98 Ok(_) => Ok(()),
99 Err(e) => Err(::yrs::sync::Error::Other(Box::new(e))),
100 }
101 }};
102 }
103
104 macro_rules! delegate_stream_poll_next {
106 ($self:expr, $cx:expr) => {{
107 use ::std::task::Poll;
108 use ::std::task::ready;
109 let stream = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
110 let result = ready!(stream.poll_next($cx));
111 match result {
112 None => Poll::Ready(None),
113 Some(Ok(msg)) => Poll::Ready(Some(Ok(msg.into_data().into()))),
114 Some(Err(e)) => Poll::Ready(Some(Err(::yrs::sync::Error::Other(Box::new(e))))),
115 }
116 }};
117 }
118
119
120 struct TungsteniteSink(::futures_util::stream::SplitSink<::tokio_tungstenite::WebSocketStream<::tokio_tungstenite::MaybeTlsStream<::tokio::net::TcpStream>>, ::tokio_tungstenite::tungstenite::Message>);
122
123 impl ::futures_util::Sink<Vec<u8>> for TungsteniteSink {
124 type Error = ::yrs::sync::Error;
125
126 fn poll_ready(
127 mut self: ::std::pin::Pin<&mut Self>,
128 cx: &mut ::std::task::Context<'_>,
129 ) -> ::std::task::Poll<Result<(), Self::Error>> {
130 delegate_poll_method!(self, cx, poll_ready(cx))
131 }
132
133 fn start_send(mut self: ::std::pin::Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
134 delegate_result_method!(self, start_send(::tokio_tungstenite::tungstenite::Message::binary(item)))
135 }
136
137 fn poll_flush(
138 mut self: ::std::pin::Pin<&mut Self>,
139 cx: &mut ::std::task::Context<'_>,
140 ) -> ::std::task::Poll<Result<(), Self::Error>> {
141 delegate_poll_method!(self, cx, poll_flush(cx))
142 }
143
144 fn poll_close(
145 mut self: ::std::pin::Pin<&mut Self>,
146 cx: &mut ::std::task::Context<'_>,
147 ) -> ::std::task::Poll<Result<(), Self::Error>> {
148 delegate_poll_method!(self, cx, poll_close(cx))
149 }
150 }
151
152 struct TungsteniteStream(::futures_util::stream::SplitStream<::tokio_tungstenite::WebSocketStream<::tokio_tungstenite::MaybeTlsStream<::tokio::net::TcpStream>>>);
154 impl ::futures_util::Stream for TungsteniteStream {
155 type Item = Result<Vec<u8>, ::yrs::sync::Error>;
156
157 fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut ::std::task::Context<'_>) -> ::std::task::Poll<Option<Self::Item>> {
158 delegate_stream_poll_next!(self, cx)
159 }
160 }
161
162 async fn client(
164 addr: &str,
165 doc: ::yrs::Doc,
166 ) -> Result<::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>, Box<dyn std::error::Error>> {
167 let (stream, _) = ::tokio_tungstenite::connect_async(addr).await?;
168 let (sink, stream) = ::futures_util::stream::StreamExt::split(stream);
169 let sink = TungsteniteSink(sink);
170 let stream = TungsteniteStream(stream);
171 Ok(::yrs_tokio::connection::Connection::new(
172 ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc))),
173 sink,
174 stream,
175 ))
176 }
177
178 fn create_notifier(doc: &::yrs::Doc) -> (::std::sync::Arc<::tokio::sync::Notify>, ::yrs::Subscription) {
180 let n = ::std::sync::Arc::new(::tokio::sync::Notify::new());
181 let sub = {
182 let n = n.clone();
183 doc.observe_update_v1(move |_, _| n.notify_waiters())
184 .unwrap()
185 };
186 (n, sub)
187 }
188
189 const TIMEOUT: ::std::time::Duration = ::std::time::Duration::from_secs(5);
191
192 async fn setup_client_update_propagation( conn: &::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>
195 ) -> ::yrs::Subscription { let sink = conn.sink(); let awareness_arc = conn.awareness().clone(); let sub_handle: ::tokio::task::JoinHandle<::yrs::Subscription> = ::tokio::task::spawn_blocking(move || { let awareness_guard = awareness_arc.blocking_read(); let doc = awareness_guard.doc();
204 let inner_sink = sink.clone(); doc.observe_update_v1(move |_, e| {
208 let update = e.update.to_owned();
211 if let Some(sink) = inner_sink.upgrade() {
212 ::tokio::task::spawn(async move {
214 let msg = ::yrs::sync::Message::Sync(::yrs::sync::SyncMessage::Update(update))
215 .encode_v1();
216 let mut sink_guard = sink.lock().await;
217 sink_guard.send(msg).await.unwrap();
218 });
219 }
220 })
221 .unwrap() });
223
224 sub_handle.await.unwrap() }
227
228 async fn check_client_text_content(
230 conn: &::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>,
231 expected_str: &str,
232 ) {
233 let awareness = conn.awareness().read().await;
234 let doc = awareness.doc();
235 let text = doc.get_or_insert_text("test");
236 let str = text.get_string(&doc.transact());
237 ::core::assert_eq!(str, expected_str.to_string());
238 }
239
240 async fn apply_text_change(
242 awareness: &::std::sync::Arc<::tokio::sync::RwLock<::yrs::sync::Awareness>>,
243 change_str: &str,
244 ) {
245 let mut lock = awareness.write().await;
246 let doc = lock.doc();
247 let text = doc.get_or_insert_text("test");
248 text.push(&mut doc.transact_mut(), change_str);
249 }
250
251
252 #original_fn_def
256
257 #[::tokio::test]
262 async fn change_introduced_by_server_reaches_subscribed_clients() {
263 let doc = ::yrs::Doc::with_client_id(1);
264 let text = doc.get_or_insert_text("test");
265 let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
266 let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
267 let _server = #fn_name("0.0.0.0:6600", ::std::sync::Arc::new(bcast)).await.unwrap();
268
269 let doc = ::yrs::Doc::new();
270 let (n, _sub) = create_notifier(&doc);
271 let c1 = client("ws://localhost:6600/my-room", doc).await.unwrap();
272
273 apply_text_change(&awareness, "abc").await;
274
275 ::tokio::time::timeout(TIMEOUT, n.notified()).await.unwrap();
276
277 check_client_text_content(&c1, "abc").await;
278 }
279
280 #[::tokio::test]
281 async fn subscribed_client_fetches_initial_state() {
282 let doc = ::yrs::Doc::with_client_id(1);
283 let text = doc.get_or_insert_text("test");
284
285 text.push(&mut doc.transact_mut(), "abc");
286
287 let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
288 let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
289 let _server = #fn_name("0.0.0.0:6601", ::std::sync::Arc::new(bcast)).await.unwrap();
290
291 let doc = ::yrs::Doc::new();
292 let (n, _sub) = create_notifier(&doc);
293 let c1 = client("ws://localhost:6601/my-room", doc).await.unwrap();
294
295 ::tokio::time::timeout(TIMEOUT, n.notified()).await.unwrap();
296
297 check_client_text_content(&c1, "abc").await;
298 }
299
300 #[::tokio::test]
301 async fn changes_from_one_client_reach_others() {
302 let doc = ::yrs::Doc::with_client_id(1);
303 let _ = doc.get_or_insert_text("test");
304
305 let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
306 let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
307 let _server = #fn_name("0.0.0.0:6602", ::std::sync::Arc::new(bcast)).await.unwrap();
308
309 let d1 = ::yrs::Doc::with_client_id(2);
310 let c1 = client("ws://localhost:6602/my-room", d1).await.unwrap();
311 let _sub11 = setup_client_update_propagation(&c1).await;
313
314
315 let d2 = ::yrs::Doc::with_client_id(3);
316 let (n2, _sub2) = create_notifier(&d2);
317 let c2 = client("ws://localhost:6602/my-room", d2).await.unwrap();
318
319 apply_text_change(&c1.awareness(), "def").await;
320
321 ::tokio::time::timeout(TIMEOUT, n2.notified()).await.unwrap();
322
323 check_client_text_content(&c2, "def").await;
324 }
325
326 #[::tokio::test]
327 async fn client_failure_doesnt_affect_others() {
328 let doc = ::yrs::Doc::with_client_id(1);
329 let _text = doc.get_or_insert_text("test");
330
331 let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
332 let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
333 let _server = #fn_name("0.0.0.0:6603", ::std::sync::Arc::new(bcast)).await.unwrap();
334
335 let d1 = ::yrs::Doc::with_client_id(2);
336 let c1 = client("ws://localhost:6603/my-room", d1).await.unwrap();
337 let _sub11 = setup_client_update_propagation(&c1).await;
339
340 let d2 = ::yrs::Doc::with_client_id(3);
341 let (n2, sub2) = create_notifier(&d2);
342 let c2 = client("ws://localhost:6603/my-room", d2).await.unwrap();
343
344 let d3 = ::yrs::Doc::with_client_id(4);
345 let (n3, sub3) = create_notifier(&d3);
346 let c3 = client("ws://localhost:6603/my-room", d3).await.unwrap();
347
348 apply_text_change(&c1.awareness(), "abc").await;
349
350 ::tokio::time::sleep(TIMEOUT).await;
351
352 check_client_text_content(&c2, "abc").await;
353 check_client_text_content(&c3, "abc").await;
354
355 drop(c3);
356 drop(n3);
357 drop(sub3);
358 drop(n2);
359 drop(sub2);
360
361 let (n2, _sub2) = {
363 let a = c2.awareness().read().await;
364 let doc = a.doc();
365 create_notifier(doc)
366 };
367
368 apply_text_change(&c1.awareness(), "def").await;
369
370 ::tokio::time::timeout(TIMEOUT, n2.notified()).await.unwrap();
371
372 check_client_text_content(&c2, "abcdef").await;
373 }
374 }.into()
375}
376#[proc_macro_derive(YrsExchange)]
378pub fn derive_yrs_exchange(input: TokenStream) -> TokenStream {
379 derive_impl(input, "YrsExchange", |name, field_type, generics| {
381 TokenStream::from(quote_from_into(name, field_type, generics))
383 })
384}
385
386#[proc_macro_attribute]
401pub fn yrs_stream(attr: TokenStream, item: TokenStream) -> TokenStream {
402 let args = parse_macro_input!(attr as MethodCallAttributeArgs);
403 let call_target = args.into_target;
404 let gen_exchange = args.exchange;
405
406 let item_for_parsing = item.clone();
407 let item_for_codegen = item;
408
409 let input_struct = parse_macro_input!(item_for_parsing as ItemStruct);
410
411 let generics = input_struct.generics.clone();
413
414 let original_struct_def = input_struct.to_token_stream();
415
416 yrs_stream_code_gen(
417 call_target,
418 item_for_codegen,
419 gen_exchange,
420 Some(original_struct_def),
421 generics,
422 )
423}
424
425#[proc_macro_derive(YrsStream)]
437pub fn derive_yrs_stream(input: TokenStream) -> TokenStream {
438 let default_target = CallTarget::SingleMethod(format_ident!("into"));
439 let gen_exchange = true;
440
441 let input_struct = parse_macro_input!(input as DeriveInput);
442 let generics = input_struct.generics.clone();
443 let input_for_derive_impl = input_struct.into_token_stream().into(); yrs_stream_code_gen(
446 default_target,
447 input_for_derive_impl,
448 gen_exchange,
449 None,
450 generics,
451 )
452}
453
454#[proc_macro_derive(YrsStreamOnly)]
466pub fn derive_yrs_stream_only(input: TokenStream) -> TokenStream {
467 let default_target = CallTarget::SingleMethod(format_ident!("into"));
468 let gen_exchange = false;
469
470 let input_struct = parse_macro_input!(input as DeriveInput);
471 let generics = input_struct.generics.clone();
472 let input_for_derive_impl = input_struct.into_token_stream().into();
473
474 yrs_stream_code_gen(
475 default_target,
476 input_for_derive_impl,
477 gen_exchange,
478 None,
479 generics,
480 )
481}
482
483#[derive(Clone)]
484enum CallTarget {
485 SingleMethod(Ident), MethodChain(Expr), }
488struct MethodCallAttributeArgs {
489 into_target: CallTarget,
490 exchange: bool,
491}
492
493impl Parse for MethodCallAttributeArgs {
494 fn parse(input: ParseStream) -> syn::Result<Self> {
495 let mut into_target: Option<CallTarget> = None;
496 let mut exchange: bool = true;
497
498 while !input.is_empty() {
499 let lookahead = input.lookahead1();
500
501 if lookahead.peek(Ident) && input.peek2(Token![=]) {
502 let key: Ident = input.parse()?;
503 let _eq_token: Token![=] = input.parse()?;
504
505 if key == "into" {
506 let target_expr: Expr = input.parse()?;
507 let parsed_target = match target_expr {
508 Expr::Path(expr_path) => {
509 if expr_path.qself.is_none() && expr_path.path.segments.len() == 1 {
510 let segment =
511 expr_path.clone().path.segments.into_iter().next().unwrap();
512 if segment.arguments.is_empty() {
513 CallTarget::SingleMethod(segment.ident)
514 } else {
515 CallTarget::MethodChain(Expr::Path(expr_path))
516 }
517 } else {
518 CallTarget::MethodChain(Expr::Path(expr_path))
519 }
520 }
521 _ => CallTarget::MethodChain(target_expr),
522 };
523 into_target = Some(parsed_target);
524 } else if key == "exchange" {
525 let lit_bool: LitBool = input.parse()?;
526 exchange = lit_bool.value();
527 } else {
528 return Err(input.error(format!("未知属性参数: `{}`", key)));
529 }
530 } else {
531 return Err(input.error("期望 `key = value` 形式的属性参数"));
532 }
533
534 if !input.is_empty() {
535 let lookahead = input.lookahead1();
536 if lookahead.peek(Token![,]) {
537 let _: Token![,] = input.parse()?;
538 } else {
539 return Err(input.error("属性参数之间期望用 `,` 分隔"));
540 }
541 }
542 }
543
544 let into_target = into_target.ok_or_else(|| input.error("属性中期望关键字 `into`"))?;
545
546 Ok(MethodCallAttributeArgs {
547 into_target,
548 exchange,
549 })
550 }
551}
552
553fn yrs_stream_code_gen(
554 call_target: CallTarget,
555 input: TokenStream,
556 gen_exchange: bool,
557 preppend: Option<proc_macro2::TokenStream>,
558 generics: Generics,
559) -> TokenStream {
560 derive_impl(input, "YrsStream", move |name, field_type, _| {
561 let call_target = call_target.clone();
562 let generics = generics.clone();
563
564 let item_call_code = match &call_target {
565 CallTarget::SingleMethod(method_ident) => {
566 quote! { item.#method_ident() }
567 }
568 CallTarget::MethodChain(chain_expr) => {
569 quote! { item.#chain_expr }
570 }
571 };
572
573 let from_into: Option<proc_macro2::TokenStream> = if gen_exchange {
574 Some(quote_from_into(name, field_type, &generics)) } else {
576 None
577 };
578
579 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
580
581 quote! {
582 #preppend
583
584 #from_into
585
586 impl #impl_generics ::futures_core::Stream for #name #ty_generics #where_clause {
587 type Item = Result<Vec<u8>, ::yrs::sync::Error>;
588
589 fn poll_next(
590 mut self: ::core::pin::Pin<&mut Self>,
591 cx: &mut ::core::task::Context<'_>
592 ) -> ::core::task::Poll<Option<Self::Item>> {
593 match ::core::pin::Pin::new(&mut self.0).poll_next(cx) {
594 ::core::task::Poll::Pending => ::core::task::Poll::Pending,
595 ::core::task::Poll::Ready(None) => ::core::task::Poll::Ready(None),
596 ::core::task::Poll::Ready(Some(res)) => match res {
597 Ok(item) => ::core::task::Poll::Ready(Some(Ok(#item_call_code))),
598 Err(e) => ::core::task::Poll::Ready(Some(Err(::yrs::sync::Error::Other(e.into())))),
599 },
600 }
601 }
602 }
603 }.into()
604 })
605}
606
607#[proc_macro_derive(YrsSink)]
623pub fn derive_yrs_sink(input: TokenStream) -> TokenStream {
624 derive_yrs_sink_gen(input, true)
625}
626
627#[proc_macro_derive(YrsSinkOnly)]
643pub fn derive_yrs_sink_only(input: TokenStream) -> TokenStream {
644 derive_yrs_sink_gen(input, false)
645}
646
647fn generate_poll_method_item(method_name: &str) -> ImplItemFn {
648 let ident = format_ident!("{}", method_name);
649 let inner_field_access = quote::quote!(self.0);
650
651 syn::parse_quote! {
652 fn #ident(
653 mut self: ::core::pin::Pin<&mut Self>,
654 cx: &mut ::core::task::Context<'_>,
655 ) -> ::core::task::Poll<Result<(), Self::Error>> {
656 match ::core::pin::Pin::new(&mut #inner_field_access).#ident(cx) {
657 ::core::task::Poll::Pending => ::core::task::Poll::Pending,
658 ::core::task::Poll::Ready(Err(e)) => ::core::task::Poll::Ready(Err(::yrs::sync::Error::Other(e.into()))),
659 ::core::task::Poll::Ready(_) => ::core::task::Poll::Ready(Ok(())),
660 }
661 }
662 }
663}
664
665fn generate_start_send_method_item() -> ImplItemFn {
666 let inner_field_access = quote::quote!(self.0);
667
668 syn::parse_quote! {
669 fn start_send(
670 mut self: ::core::pin::Pin<&mut Self>,
671 item: Vec<u8>,
672 ) -> Result<(), Self::Error> {
673 ::core::pin::Pin::new(&mut #inner_field_access).start_send(item.into())
674 .map_err(|e| ::yrs::sync::Error::Other(e.into()))
675 }
676 }
677}
678
679fn derive_yrs_sink_gen(input: TokenStream, gen_exchange: bool) -> TokenStream {
680 derive_impl(input, "YrsSink", move |name, field_type, generics| {
681 let from_into: Option<proc_macro2::TokenStream> = if gen_exchange {
682 Some(quote_from_into(name, field_type, generics))
683 } else {
684 None
685 };
686
687 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
688
689 let poll_ready_fn = generate_poll_method_item("poll_ready");
690 let poll_flush_fn = generate_poll_method_item("poll_flush");
691 let poll_close_fn = generate_poll_method_item("poll_close");
692 let start_send_fn = generate_start_send_method_item();
693
694 quote! {
695 #from_into
696
697 impl #impl_generics ::futures_util::Sink<Vec<u8>> for #name #ty_generics #where_clause {
698 type Error = ::yrs::sync::Error;
699
700 #poll_ready_fn
701 #start_send_fn
702 #poll_flush_fn
703 #poll_close_fn
704 }
705 }
706 .into()
707 })
708}
709
710struct CommonSinkArgs {
711 inner_field: LitStr,
712}
713impl Parse for CommonSinkArgs {
714 fn parse(input: ParseStream) -> syn::Result<Self> {
715 let mut inner_field: Option<LitStr> = None;
716
717 while !input.is_empty() {
718 let lookahead = input.lookahead1();
719
720 if !lookahead.peek(Ident) || !input.peek2(Token![=]) {
721 return Err(input.error("Expected attribute argument `key = value`"));
722 }
723
724 let key: Ident = input.parse()?;
725 let _eq_token: Token![=] = input.parse()?;
726
727 if key == "inner" {
728 if inner_field.is_some() {
729 return Err(input.error("Duplicate `inner` argument"));
730 }
731 inner_field = Some(input.parse()?);
732 } else {
733 return Err(input.error(format!("Unknown attribute argument: `{}`", key)));
734 }
735
736 if !input.is_empty() {
737 let lookahead = input.lookahead1();
738 if lookahead.peek(Token![,]) {
739 let _: Token![,] = input.parse()?;
740 } else {
741 return Err(input.error("Attribute arguments must be separated by commas"));
742 }
743 }
744 }
745
746 let inner_field =
747 inner_field.unwrap_or_else(|| LitStr::new("self.0", proc_macro2::Span::call_site()));
748
749 Ok(CommonSinkArgs { inner_field })
750 }
751}
752
753fn get_sink_item_type(input: &ItemImpl) -> Result<Type, TokenStream> {
755 let trait_option = input.trait_.as_ref();
757 let trait_ref =
758 match trait_option {
759 Some(tr) => tr,
760 None => return Err(syn::Error::new_spanned(
761 input,
762 "yrs_common_sink must be applied to a Sink impl (e.g., `impl Sink<Item> for Type`)",
763 )
764 .to_compile_error()
765 .into()),
766 };
767 let trait_path = trait_ref.1.clone();
768
769 let last_segment_option = trait_path.segments.last();
770 let last_segment = match last_segment_option {
771 Some(seg) => seg,
772 None => {
773 return Err(syn::Error::new_spanned(&trait_path, "Invalid trait path")
774 .to_compile_error()
775 .into());
776 }
777 };
778
779 if last_segment.ident != "Sink" {
780 return Err(syn::Error::new_spanned(
781 &last_segment,
782 "yrs_common_sink must be applied to a Sink impl",
783 )
784 .to_compile_error()
785 .into());
786 }
787
788 let args = match &last_segment.arguments {
789 syn::PathArguments::AngleBracketed(args) => args,
790 _ => {
791 return Err(syn::Error::new_spanned(
792 &last_segment.arguments,
793 "Sink trait must have angle bracketed arguments",
794 )
795 .to_compile_error()
796 .into());
797 }
798 };
799
800 if args.args.len() != 1 {
801 return Err(syn::Error::new_spanned(
802 &args.args,
803 "Sink trait must have exactly one generic argument (Item)",
804 )
805 .to_compile_error()
806 .into());
807 }
808
809 let generic_arg = args.args.first().unwrap();
810 match generic_arg {
811 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
812 arg => Err(
813 syn::Error::new_spanned(arg, "Sink trait argument must be a type")
814 .to_compile_error()
815 .into(),
816 ),
817 }
818}
819
820#[proc_macro_attribute]
869pub fn yrs_common_sink(attr: TokenStream, item: TokenStream) -> TokenStream {
870 let mut impl_block = parse_macro_input!(item as ItemImpl);
871
872 let args = parse_macro_input!(attr as CommonSinkArgs);
873 let inner_expr_str = args.inner_field.value();
874
875 let inner_expr: proc_macro2::TokenStream = match inner_expr_str.parse() {
876 Ok(expr) => expr,
877 Err(_) => {
878 return syn::Error::new_spanned(args.inner_field, "Cannot parse `inner` expression")
879 .to_compile_error()
880 .into();
881 }
882 };
883
884 let outer_sink_item_type = match get_sink_item_type(&impl_block) {
885 Ok(ty) => ty.clone(),
886 Err(e) => return e,
887 };
888
889 let mut user_start_send: Option<ImplItemFn> = None;
890 let mut user_error_type_item: Option<ImplItemType> = None;
891 let mut other_user_items: Vec<ImplItem> = Vec::new();
892
893 let original_items = std::mem::take(&mut impl_block.items);
894 for item in original_items {
895 match item {
896 ImplItem::Fn(f) if f.sig.ident == "start_send" => {
897 if user_start_send.is_some() {
898 return syn::Error::new_spanned(&f.sig.ident, "Duplicate `start_send` method")
899 .to_compile_error()
900 .into();
901 }
902 user_start_send = Some(f);
903 }
904 ImplItem::Type(t) if t.ident == "Error" => {
905 if user_error_type_item.is_some() {
906 return syn::Error::new_spanned(&t.ident, "Duplicate `Error` type")
907 .to_compile_error()
908 .into();
909 }
910 user_error_type_item = Some(t.clone());
911 }
912 _ => other_user_items.push(item),
913 }
914 }
915
916 let final_error_item: ImplItemType = user_error_type_item.unwrap_or_else(|| {
917 syn::parse_quote! {
918 type Error = ::yrs::sync::Error;
919 }
920 });
921
922 let start_send_method_item: ImplItemFn = match user_start_send {
923 Some(method) => method,
924 None => {
925 let target_simple_signaling: Type = syn::parse_quote!(SignalingMessage);
926 let target_fully_qualified_signaling: Type =
927 syn::parse_quote!(::yrs_tokio::signaling::Message);
928 let target_qualified_signaling: Type = syn::parse_quote!(yrs_tokio::signaling::Message);
929
930 let outer_type_str = quote::quote!(#outer_sink_item_type).to_string();
932 let simple_type_str = quote::quote!(#target_simple_signaling).to_string();
933 let fully_qualified_type_str =
934 quote::quote!(#target_fully_qualified_signaling).to_string();
935 let qualified_type_str = quote::quote!(#target_qualified_signaling).to_string();
936
937 let is_targeted_signaling_message_type = outer_type_str == simple_type_str
938 || outer_type_str == fully_qualified_type_str
939 || outer_type_str == qualified_type_str;
940
941 let default_body = if is_targeted_signaling_message_type {
942 quote! {
943 let msg = match item {
944 ::yrs_tokio::signaling::Message::Text(txt) => Message::text(txt),
945 ::yrs_tokio::signaling::Message::Binary(bytes) => Message::binary(bytes),
946 ::yrs_tokio::signaling::Message::Ping => Message::Ping(Vec::default().into()),
947 ::yrs_tokio::signaling::Message::Pong => Message::Pong(Vec::default().into()),
948 ::yrs_tokio::signaling::Message::Close => Message::Close(None.into()),
949 };
950 if let Err(e) = ::core::pin::Pin::new(&mut #inner_expr).start_send(msg) {
951 Err(::yrs::sync::Error::Other(e.into()))
952 } else {
953 Ok(())
954 }
955 }
956 } else {
957 quote! {
959 use ::core::pin::Pin;
960 use ::futures_util::Sink;
961 Pin::new(&mut #inner_expr).start_send(item.into())
962 .map_err(|e| ::yrs::sync::Error::Other(e.into()))
963 }
964 };
965
966 syn::parse_quote! {
967 fn start_send(mut self: ::core::pin::Pin<&mut Self>, item: #outer_sink_item_type) -> Result<(), Self::Error> {
968 #default_body
969 }
970 }
971 }
972 };
973
974 let generated_poll_ready = generate_poll_fn("poll_ready", &inner_expr);
975
976 let generated_poll_flush = generate_poll_fn("poll_flush", &inner_expr);
977
978 let generated_poll_close = generate_poll_fn("poll_close", &inner_expr);
979
980 let mut final_items: Vec<ImplItem> = Vec::new();
981
982 final_items.push(ImplItem::Type(final_error_item)); final_items.push(ImplItem::Fn(generated_poll_ready)); final_items.push(ImplItem::Fn(start_send_method_item)); final_items.push(ImplItem::Fn(generated_poll_flush)); final_items.push(ImplItem::Fn(generated_poll_close)); final_items.extend(other_user_items);
989
990 impl_block.items = final_items;
991
992 quote!(#impl_block).into()
993}
994
995fn generate_poll_fn(method: &str, inner_expr: &proc_macro2::TokenStream) -> ImplItemFn {
996 let ident = format_ident!("{}", method);
997
998 syn::parse_quote! {
999 fn #ident(
1000 mut self: ::core::pin::Pin<&mut Self>,
1001 cx: &mut std::task::Context<'_>,
1002 ) -> std::task::Poll<Result<(), Self::Error>> {
1003 match ::core::pin::Pin::new(&mut #inner_expr).#ident(cx) {
1004 std::task::Poll::Pending => std::task::Poll::Pending,
1005 std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(::yrs::sync::Error::Other(e.into()))),
1006 std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())),
1007 }
1008 }
1009 }
1010}
1011
1012fn quote_from_into(
1013 name: &Ident,
1014 field_type: &Type,
1015 generics: &Generics,
1016) -> proc_macro2::TokenStream {
1017 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
1018
1019 quote! {
1020 impl #impl_generics ::core::convert::From<#field_type> for #name #ty_generics #where_clause {
1021 fn from(stream: #field_type) -> Self {
1022 #name(stream)
1023 }
1024 }
1025
1026 impl #impl_generics ::core::convert::Into<#field_type> for #name #ty_generics #where_clause {
1027 fn into(self) -> #field_type {
1028 self.0
1029 }
1030 }
1031 }
1032}
1033
1034fn derive_impl<F>(input: TokenStream, _macro_name: &str, f: F) -> TokenStream
1036where
1037 F: FnOnce(&Ident, &Type, &Generics) -> TokenStream, {
1039 let input = parse_macro_input!(input as DeriveInput);
1040 let name = &input.ident;
1041 let generics = &input.generics; let field_type = match get_field_type(&input) {
1044 Ok(ty) => ty,
1045 Err(err) => return err,
1046 };
1047
1048 let expanded = f(name, field_type, generics);
1050 TokenStream::from(expanded)
1051}
1052
1053fn get_field_type(input: &DeriveInput) -> Result<&Type, TokenStream> {
1055 Ok(match &input.data {
1056 Data::Struct(data_struct) => {
1057 if let Fields::Unnamed(fields_unnamed) = &data_struct.fields {
1058 if fields_unnamed.unnamed.len() == 1 {
1059 &fields_unnamed.unnamed.first().unwrap().ty
1060 } else {
1061 return Err(syn::Error::new_spanned(
1062 &input.ident,
1063 "can only be derived for tuple structs with one field",
1064 )
1065 .to_compile_error()
1066 .into());
1067 }
1068 } else {
1069 return Err(syn::Error::new_spanned(
1070 &input.ident,
1071 "can only be derived for tuple structs with unnamed fields",
1072 )
1073 .to_compile_error()
1074 .into());
1075 }
1076 }
1077 _ => {
1078 return Err(
1079 syn::Error::new_spanned(&input.ident, "can only be derived for structs")
1080 .to_compile_error()
1081 .into(),
1082 );
1083 }
1084 })
1085}