yrs_tokio_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{ToTokens, format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::{
5    Data, DeriveInput, Expr, Fields, Generics, Ident, ImplItem, ImplItemFn, ImplItemType, ItemImpl,
6    ItemStruct, LitBool, LitStr, Token, Type, parse_macro_input,
7};
8
9/// Yrs tokio common test unit generator
10/// # Examples
11/// ```rust
12///use std::net::SocketAddr;
13///use std::str::FromStr;
14///use yrs_axum_ws::{YrsSink, YrsStream};
15///use axum::extract::ws::WebSocket;
16///use axum::extract::{State, WebSocketUpgrade};
17///use axum::response::Response;
18///use futures_util::{ready, SinkExt, StreamExt};
19///use std::sync::Arc;
20///use axum::Router;
21///use axum::routing::any;
22///use tokio::sync::Mutex;
23///use tokio::task;
24///use tokio::task::JoinHandle;
25///use yrs::updates::encoder::Encode;
26///use yrs::{GetString, Text, Transact};
27///use yrs_tokio::broadcast::BroadcastGroup;
28///use yrs_tokio::yrs_common_test;
29///
30///#[yrs_common_test]
31///async fn start_server(
32///    addr: &str,
33///    bcast: Arc<BroadcastGroup>,
34///) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
35///    let addr = SocketAddr::from_str(addr)?;
36///
37///    let app = Router::new()
38///        .route("/my-room", any(ws_handler))
39///        .with_state(bcast);
40///
41///    Ok(tokio::spawn(async move {
42///        let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
43///        axum::serve(listener, app).await.unwrap();
44///    }))
45///}
46///
47///async fn ws_handler(
48///    ws: WebSocketUpgrade,
49///    State(bcast): State<Arc<BroadcastGroup>>,
50///) -> Response {
51///    ws.on_upgrade(move |socket| peer(socket, bcast))
52///}
53///
54///async fn peer(ws: WebSocket, bcast: Arc<BroadcastGroup>) {
55///    let (sink, stream) = ws.split();
56///    let sink = Arc::new(Mutex::new(YrsSink::from(sink)));
57///    let stream = YrsStream::from(stream);
58///
59///    let sub = bcast.subscribe(sink, stream);
60///    match sub.completed().await {
61///        Ok(_) => println!("broadcasting for channel finished successfully"),
62///        Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
63///    }
64///}
65/// ```
66#[cfg(feature = "test-utils")]
67#[proc_macro_attribute]
68pub fn yrs_common_test(_: TokenStream, item: TokenStream) -> TokenStream {
69    let input_fn = parse_macro_input!(item as syn::ItemFn);
70    let fn_name = &input_fn.sig.ident;
71    let original_fn_def = input_fn.to_token_stream();
72
73    quote! {
74        // ===============================================================
75        // 在生成的代码块中定义辅助宏和类型,供测试函数使用
76        // ===============================================================
77
78        // 辅助宏: 用于委托 poll_* 方法调用到内部 sink,并进行错误转换
79        macro_rules! delegate_poll_method {
80            ($self:expr, $cx:expr, $method:ident $(($($arg:expr),*))?) => {{
81                use ::std::task::Poll;
82                use ::std::task::ready;
83                let sink = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
84                let result = ready!(sink.$method($($($arg),*)?));
85                match result {
86                    Ok(_) => Poll::Ready(Ok(())),
87                    Err(e) => Poll::Ready(Err(::yrs::sync::Error::Other(Box::new(e)))),
88                }
89            }};
90        }
91
92        // 辅助宏: 用于委托 Result 方法调用到内部 sink,并进行错误转换 (如 start_send)
93        macro_rules! delegate_result_method {
94            ($self:expr, $method:ident $(($($arg:expr),*))?) => {{
95                let sink = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
96                let result = sink.$method($($($arg),*)?);
97                match result {
98                    Ok(_) => Ok(()),
99                    Err(e) => Err(::yrs::sync::Error::Other(Box::new(e))),
100                }
101            }};
102        }
103
104        // 辅助宏: 用于委托 Stream 的 poll_next 方法调用到内部 stream,并进行错误转换和数据处理
105        macro_rules! delegate_stream_poll_next {
106            ($self:expr, $cx:expr) => {{
107                use ::std::task::Poll;
108                use ::std::task::ready;
109                let stream = unsafe { ::std::pin::Pin::new_unchecked(&mut $self.0) };
110                let result = ready!(stream.poll_next($cx));
111                match result {
112                    None => Poll::Ready(None),
113                    Some(Ok(msg)) => Poll::Ready(Some(Ok(msg.into_data().into()))),
114                    Some(Err(e)) => Poll::Ready(Some(Err(::yrs::sync::Error::Other(Box::new(e))))),
115                }
116            }};
117        }
118
119
120        // 定义 TungsteniteSink,并使用辅助宏简化其 Sink 实现
121        struct TungsteniteSink(::futures_util::stream::SplitSink<::tokio_tungstenite::WebSocketStream<::tokio_tungstenite::MaybeTlsStream<::tokio::net::TcpStream>>, ::tokio_tungstenite::tungstenite::Message>);
122
123        impl ::futures_util::Sink<Vec<u8>> for TungsteniteSink {
124            type Error = ::yrs::sync::Error;
125
126            fn poll_ready(
127                mut self: ::std::pin::Pin<&mut Self>,
128                cx: &mut ::std::task::Context<'_>,
129            ) -> ::std::task::Poll<Result<(), Self::Error>> {
130                delegate_poll_method!(self, cx, poll_ready(cx))
131            }
132
133            fn start_send(mut self: ::std::pin::Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
134                delegate_result_method!(self, start_send(::tokio_tungstenite::tungstenite::Message::binary(item)))
135            }
136
137            fn poll_flush(
138                mut self: ::std::pin::Pin<&mut Self>,
139                cx: &mut ::std::task::Context<'_>,
140            ) -> ::std::task::Poll<Result<(), Self::Error>> {
141                delegate_poll_method!(self, cx, poll_flush(cx))
142            }
143
144            fn poll_close(
145                mut self: ::std::pin::Pin<&mut Self>,
146                cx: &mut ::std::task::Context<'_>,
147            ) -> ::std::task::Poll<Result<(), Self::Error>> {
148                delegate_poll_method!(self, cx, poll_close(cx))
149            }
150        }
151
152        // 定义 TungsteniteStream,并使用辅助宏简化其 Stream 实现
153        struct TungsteniteStream(::futures_util::stream::SplitStream<::tokio_tungstenite::WebSocketStream<::tokio_tungstenite::MaybeTlsStream<::tokio::net::TcpStream>>>);
154        impl ::futures_util::Stream for TungsteniteStream {
155            type Item = Result<Vec<u8>, ::yrs::sync::Error>;
156
157            fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut ::std::task::Context<'_>) -> ::std::task::Poll<Option<Self::Item>> {
158                delegate_stream_poll_next!(self, cx)
159            }
160        }
161
162        // 定义 client 辅助函数
163        async fn client(
164            addr: &str,
165            doc: ::yrs::Doc,
166        ) -> Result<::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>, Box<dyn std::error::Error>> {
167            let (stream, _) = ::tokio_tungstenite::connect_async(addr).await?;
168            let (sink, stream) = ::futures_util::stream::StreamExt::split(stream);
169            let sink = TungsteniteSink(sink);
170            let stream = TungsteniteStream(stream);
171            Ok(::yrs_tokio::connection::Connection::new(
172                ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc))),
173                sink,
174                stream,
175            ))
176        }
177
178        // 定义 create_notifier 辅助函数
179        fn create_notifier(doc: &::yrs::Doc) -> (::std::sync::Arc<::tokio::sync::Notify>, ::yrs::Subscription) {
180            let n = ::std::sync::Arc::new(::tokio::sync::Notify::new());
181            let sub = {
182                let n = n.clone();
183                doc.observe_update_v1(move |_, _| n.notify_waiters())
184                    .unwrap()
185            };
186            (n, sub)
187        }
188
189        // 定义 TIMEOUT 常量
190        const TIMEOUT: ::std::time::Duration = ::std::time::Duration::from_secs(5);
191
192        // 定义辅助函数:设置客户端文档更新时自动发送消息 (修改为 async)
193        async fn setup_client_update_propagation( // <--- 添加 async
194            conn: &::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>
195        ) -> ::yrs::Subscription { // <--- 返回 Subscription
196            let sink = conn.sink(); // Weak<Mutex<TungsteniteSink>>, 'static lifetime
197            let awareness_arc = conn.awareness().clone(); // Arc<RwLock<Awareness<Doc>>>, 'static lifetime
198
199            // 将获取 Doc 引用和设置 observe_update_v1 的同步操作移到阻塞任务中
200            let sub_handle: ::tokio::task::JoinHandle<::yrs::Subscription> = ::tokio::task::spawn_blocking(move || { // <--- 使用 spawn_blocking
201                 // 这个闭包在阻塞线程池上执行,可以安全地调用 blocking_read()
202                 let awareness_guard = awareness_arc.blocking_read(); // <--- 现在安全了
203                 let doc = awareness_guard.doc();
204                 let inner_sink = sink.clone(); // 克隆 Weak 引用给 observer 闭包使用
205
206                 // 设置 observe_update_v1 监听器
207                 doc.observe_update_v1(move |_, e| {
208                    // 这个回调函数可能在任意线程上被调用,取决于 yrs 内部的实现,
209                    // 但内部生成新的异步任务是安全的方式。
210                    let update = e.update.to_owned();
211                    if let Some(sink) = inner_sink.upgrade() {
212                         // 在异步运行时上生成新任务来发送消息
213                         ::tokio::task::spawn(async move {
214                            let msg = ::yrs::sync::Message::Sync(::yrs::sync::SyncMessage::Update(update))
215                                .encode_v1();
216                            let mut sink_guard = sink.lock().await;
217                            sink_guard.send(msg).await.unwrap();
218                        });
219                    }
220                 })
221                 .unwrap() // 获取 Subscription
222            });
223
224            // 在 async 函数中等待阻塞任务的结果 (Subscription)
225            sub_handle.await.unwrap() // <--- Await the JoinHandle
226        }
227
228        // 定义辅助函数:检查客户端文档中 "test" 文本的内容
229        async fn check_client_text_content(
230            conn: &::yrs_tokio::connection::Connection<TungsteniteSink, TungsteniteStream>,
231            expected_str: &str,
232        ) {
233            let awareness = conn.awareness().read().await;
234            let doc = awareness.doc();
235            let text = doc.get_or_insert_text("test");
236            let str = text.get_string(&doc.transact());
237            ::core::assert_eq!(str, expected_str.to_string());
238        }
239
240         // 定义辅助函数:修改客户端或服务端文档中 "test" 文本的内容
241         async fn apply_text_change(
242             awareness: &::std::sync::Arc<::tokio::sync::RwLock<::yrs::sync::Awareness>>,
243             change_str: &str,
244         ) {
245             let mut lock = awareness.write().await;
246             let doc = lock.doc();
247             let text = doc.get_or_insert_text("test");
248             text.push(&mut doc.transact_mut(), change_str);
249         }
250
251
252        // ===============================================================
253        // 原始被修饰的函数定义 (例如,这是服务端的启动函数)
254        // ===============================================================
255        #original_fn_def
256
257        // ===============================================================
258        // 使用上面定义的辅助项的测试用例
259        // ===============================================================
260
261        #[::tokio::test]
262        async fn change_introduced_by_server_reaches_subscribed_clients() {
263            let doc = ::yrs::Doc::with_client_id(1);
264            let text = doc.get_or_insert_text("test");
265            let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
266            let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
267            let _server = #fn_name("0.0.0.0:6600", ::std::sync::Arc::new(bcast)).await.unwrap();
268
269            let doc = ::yrs::Doc::new();
270            let (n, _sub) = create_notifier(&doc);
271            let c1 = client("ws://localhost:6600/my-room", doc).await.unwrap();
272
273            apply_text_change(&awareness, "abc").await;
274
275            ::tokio::time::timeout(TIMEOUT, n.notified()).await.unwrap();
276
277            check_client_text_content(&c1, "abc").await;
278        }
279
280        #[::tokio::test]
281        async fn subscribed_client_fetches_initial_state() {
282            let doc = ::yrs::Doc::with_client_id(1);
283            let text = doc.get_or_insert_text("test");
284
285            text.push(&mut doc.transact_mut(), "abc");
286
287            let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
288            let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
289            let _server = #fn_name("0.0.0.0:6601", ::std::sync::Arc::new(bcast)).await.unwrap();
290
291            let doc = ::yrs::Doc::new();
292            let (n, _sub) = create_notifier(&doc);
293            let c1 = client("ws://localhost:6601/my-room", doc).await.unwrap();
294
295            ::tokio::time::timeout(TIMEOUT, n.notified()).await.unwrap();
296
297            check_client_text_content(&c1, "abc").await;
298        }
299
300        #[::tokio::test]
301        async fn changes_from_one_client_reach_others() {
302            let doc = ::yrs::Doc::with_client_id(1);
303            let _ = doc.get_or_insert_text("test");
304
305            let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
306            let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
307            let _server = #fn_name("0.0.0.0:6602", ::std::sync::Arc::new(bcast)).await.unwrap();
308
309            let d1 = ::yrs::Doc::with_client_id(2);
310            let c1 = client("ws://localhost:6602/my-room", d1).await.unwrap();
311            // 使用辅助函数设置客户端更新传播 (现在是 async,需要 .await)
312            let _sub11 = setup_client_update_propagation(&c1).await;
313
314
315            let d2 = ::yrs::Doc::with_client_id(3);
316            let (n2, _sub2) = create_notifier(&d2);
317            let c2 = client("ws://localhost:6602/my-room", d2).await.unwrap();
318
319            apply_text_change(&c1.awareness(), "def").await;
320
321            ::tokio::time::timeout(TIMEOUT, n2.notified()).await.unwrap();
322
323            check_client_text_content(&c2, "def").await;
324        }
325
326        #[::tokio::test]
327        async fn client_failure_doesnt_affect_others() {
328            let doc = ::yrs::Doc::with_client_id(1);
329            let _text = doc.get_or_insert_text("test");
330
331            let awareness = ::std::sync::Arc::new(::tokio::sync::RwLock::new(::yrs::sync::Awareness::new(doc)));
332            let bcast = ::yrs_tokio::broadcast::BroadcastGroup::new(awareness.clone(), 10).await;
333            let _server = #fn_name("0.0.0.0:6603", ::std::sync::Arc::new(bcast)).await.unwrap();
334
335            let d1 = ::yrs::Doc::with_client_id(2);
336            let c1 = client("ws://localhost:6603/my-room", d1).await.unwrap();
337            // 使用辅助函数设置客户端更新传播 (现在是 async,需要 .await)
338            let _sub11 = setup_client_update_propagation(&c1).await;
339
340            let d2 = ::yrs::Doc::with_client_id(3);
341            let (n2, sub2) = create_notifier(&d2);
342            let c2 = client("ws://localhost:6603/my-room", d2).await.unwrap();
343
344            let d3 = ::yrs::Doc::with_client_id(4);
345            let (n3, sub3) = create_notifier(&d3);
346            let c3 = client("ws://localhost:6603/my-room", d3).await.unwrap();
347
348            apply_text_change(&c1.awareness(), "abc").await;
349
350            ::tokio::time::sleep(TIMEOUT).await;
351
352            check_client_text_content(&c2, "abc").await;
353            check_client_text_content(&c3, "abc").await;
354
355            drop(c3);
356            drop(n3);
357            drop(sub3);
358            drop(n2);
359            drop(sub2);
360
361            // 重新创建 notifier
362             let (n2, _sub2) = {
363                let a = c2.awareness().read().await;
364                let doc = a.doc();
365                create_notifier(doc)
366            };
367
368            apply_text_change(&c1.awareness(), "def").await;
369
370            ::tokio::time::timeout(TIMEOUT, n2.notified()).await.unwrap();
371
372            check_client_text_content(&c2, "abcdef").await;
373        }
374    }.into()
375}
376/// Yrs tokio Into/From generator
377#[proc_macro_derive(YrsExchange)]
378pub fn derive_yrs_exchange(input: TokenStream) -> TokenStream {
379    // 传递 generics 到闭包
380    derive_impl(input, "YrsExchange", |name, field_type, generics| {
381        // 调用修改后的 quote_from_into
382        TokenStream::from(quote_from_into(name, field_type, generics))
383    })
384}
385
386/// Yrs tokio stream generator, use `into` argument defined convert method, use for not has message
387/// # Examples
388/// ```rust
389/// use yrs_tokio_macros::yrs_stream;
390/// use tokio_tungstenite::WebSocketStream;
391/// use tokio::io::{AsyncRead, AsyncWrite};
392/// use std::marker::Unpin;
393/// use futures_util::stream::SplitStream;
394///
395/// #[yrs_stream(into=into_data().into(), exchange=false)]
396/// pub struct YrsStream<S>(SplitStream<WebSocketStream<S>>)
397/// where
398///     S: AsyncRead + AsyncWrite + Unpin;
399/// ```
400#[proc_macro_attribute]
401pub fn yrs_stream(attr: TokenStream, item: TokenStream) -> TokenStream {
402    let args = parse_macro_input!(attr as MethodCallAttributeArgs);
403    let call_target = args.into_target;
404    let gen_exchange = args.exchange;
405
406    let item_for_parsing = item.clone();
407    let item_for_codegen = item;
408
409    let input_struct = parse_macro_input!(item_for_parsing as ItemStruct);
410
411    // 获取泛型信息
412    let generics = input_struct.generics.clone();
413
414    let original_struct_def = input_struct.to_token_stream();
415
416    yrs_stream_code_gen(
417        call_target,
418        item_for_codegen,
419        gen_exchange,
420        Some(original_struct_def),
421        generics,
422    )
423}
424
425/// Yrs tokio stream generator, use convert message
426/// # Examples
427/// ```rust
428/// use yrs_tokio_macros::YrsStream;
429/// use std::marker::Unpin;
430/// use axum::extract::ws::WebSocket;
431/// use futures_util::stream::SplitStream;
432///
433///#[derive(YrsStream)]
434///pub struct YrsStream(SplitStream<WebSocket>);
435/// ```
436#[proc_macro_derive(YrsStream)]
437pub fn derive_yrs_stream(input: TokenStream) -> TokenStream {
438    let default_target = CallTarget::SingleMethod(format_ident!("into"));
439    let gen_exchange = true;
440
441    let input_struct = parse_macro_input!(input as DeriveInput);
442    let generics = input_struct.generics.clone();
443    let input_for_derive_impl = input_struct.into_token_stream().into(); // 转换回 TokenStream 传递给 derive_impl
444
445    yrs_stream_code_gen(
446        default_target,
447        input_for_derive_impl,
448        gen_exchange,
449        None,
450        generics,
451    )
452}
453
454/// Yrs tokio stream generator, use convert message, without exchange
455/// # Examples
456/// ```rust
457/// use yrs_tokio_macros::YrsStreamOnly;
458/// use std::marker::Unpin;
459/// use axum::extract::ws::WebSocket;
460/// use futures_util::stream::SplitStream;
461///
462///#[derive(YrsStreamOnly)]
463///pub struct YrsStream(SplitStream<WebSocket>);
464/// ```
465#[proc_macro_derive(YrsStreamOnly)]
466pub fn derive_yrs_stream_only(input: TokenStream) -> TokenStream {
467    let default_target = CallTarget::SingleMethod(format_ident!("into"));
468    let gen_exchange = false;
469
470    let input_struct = parse_macro_input!(input as DeriveInput);
471    let generics = input_struct.generics.clone();
472    let input_for_derive_impl = input_struct.into_token_stream().into();
473
474    yrs_stream_code_gen(
475        default_target,
476        input_for_derive_impl,
477        gen_exchange,
478        None,
479        generics,
480    )
481}
482
483#[derive(Clone)]
484enum CallTarget {
485    SingleMethod(Ident), // 存储单一方法的标识符 (如 `into_data`)
486    MethodChain(Expr),   // 存储方法链的表达式 (如 `into_data().into()`)
487}
488struct MethodCallAttributeArgs {
489    into_target: CallTarget,
490    exchange: bool,
491}
492
493impl Parse for MethodCallAttributeArgs {
494    fn parse(input: ParseStream) -> syn::Result<Self> {
495        let mut into_target: Option<CallTarget> = None;
496        let mut exchange: bool = true;
497
498        while !input.is_empty() {
499            let lookahead = input.lookahead1();
500
501            if lookahead.peek(Ident) && input.peek2(Token![=]) {
502                let key: Ident = input.parse()?;
503                let _eq_token: Token![=] = input.parse()?;
504
505                if key == "into" {
506                    let target_expr: Expr = input.parse()?;
507                    let parsed_target = match target_expr {
508                        Expr::Path(expr_path) => {
509                            if expr_path.qself.is_none() && expr_path.path.segments.len() == 1 {
510                                let segment =
511                                    expr_path.clone().path.segments.into_iter().next().unwrap();
512                                if segment.arguments.is_empty() {
513                                    CallTarget::SingleMethod(segment.ident)
514                                } else {
515                                    CallTarget::MethodChain(Expr::Path(expr_path))
516                                }
517                            } else {
518                                CallTarget::MethodChain(Expr::Path(expr_path))
519                            }
520                        }
521                        _ => CallTarget::MethodChain(target_expr),
522                    };
523                    into_target = Some(parsed_target);
524                } else if key == "exchange" {
525                    let lit_bool: LitBool = input.parse()?;
526                    exchange = lit_bool.value();
527                } else {
528                    return Err(input.error(format!("未知属性参数: `{}`", key)));
529                }
530            } else {
531                return Err(input.error("期望 `key = value` 形式的属性参数"));
532            }
533
534            if !input.is_empty() {
535                let lookahead = input.lookahead1();
536                if lookahead.peek(Token![,]) {
537                    let _: Token![,] = input.parse()?;
538                } else {
539                    return Err(input.error("属性参数之间期望用 `,` 分隔"));
540                }
541            }
542        }
543
544        let into_target = into_target.ok_or_else(|| input.error("属性中期望关键字 `into`"))?;
545
546        Ok(MethodCallAttributeArgs {
547            into_target,
548            exchange,
549        })
550    }
551}
552
553fn yrs_stream_code_gen(
554    call_target: CallTarget,
555    input: TokenStream,
556    gen_exchange: bool,
557    preppend: Option<proc_macro2::TokenStream>,
558    generics: Generics,
559) -> TokenStream {
560    derive_impl(input, "YrsStream", move |name, field_type, _| {
561        let call_target = call_target.clone();
562        let generics = generics.clone();
563
564        let item_call_code = match &call_target {
565            CallTarget::SingleMethod(method_ident) => {
566                quote! { item.#method_ident() }
567            }
568            CallTarget::MethodChain(chain_expr) => {
569                quote! { item.#chain_expr }
570            }
571        };
572
573        let from_into: Option<proc_macro2::TokenStream> = if gen_exchange {
574            Some(quote_from_into(name, field_type, &generics)) // 传递 generics 引用
575        } else {
576            None
577        };
578
579        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
580
581        quote! {
582            #preppend
583
584            #from_into
585
586            impl #impl_generics ::futures_core::Stream for #name #ty_generics #where_clause {
587                type Item = Result<Vec<u8>, ::yrs::sync::Error>;
588
589                fn poll_next(
590                    mut self: ::core::pin::Pin<&mut Self>,
591                    cx: &mut ::core::task::Context<'_>
592                ) -> ::core::task::Poll<Option<Self::Item>> {
593                    match ::core::pin::Pin::new(&mut self.0).poll_next(cx) {
594                        ::core::task::Poll::Pending => ::core::task::Poll::Pending,
595                        ::core::task::Poll::Ready(None) => ::core::task::Poll::Ready(None),
596                        ::core::task::Poll::Ready(Some(res)) => match res {
597                            Ok(item) => ::core::task::Poll::Ready(Some(Ok(#item_call_code))),
598                            Err(e) => ::core::task::Poll::Ready(Some(Err(::yrs::sync::Error::Other(e.into())))),
599                        },
600                    }
601                }
602            }
603        }.into()
604    })
605}
606
607/// Yrs tokio sink generator
608/// # Examples
609/// ```rust
610/// use yrs_tokio_macros::YrsSink;
611/// use tokio_tungstenite::WebSocketStream;
612/// use tokio::io::{AsyncRead, AsyncWrite};
613/// use std::marker::Unpin;
614/// use futures_util::stream::SplitSink;
615/// use tokio_tungstenite::tungstenite::Message;
616///
617/// #[derive(YrsSink)]
618/// pub struct YrsSink<S>(SplitSink<WebSocketStream<S>, Message>)
619/// where
620///     S: AsyncRead + AsyncWrite + Unpin + Send;
621/// ```
622#[proc_macro_derive(YrsSink)]
623pub fn derive_yrs_sink(input: TokenStream) -> TokenStream {
624    derive_yrs_sink_gen(input, true)
625}
626
627/// Yrs tokio sink generator without exchange
628/// # Examples
629/// ```rust
630/// use yrs_tokio_macros::YrsSinkOnly;
631/// use tokio_tungstenite::WebSocketStream;
632/// use tokio::io::{AsyncRead, AsyncWrite};
633/// use std::marker::Unpin;
634/// use futures_util::stream::SplitSink;
635/// use tokio_tungstenite::tungstenite::Message;
636///
637/// #[derive(YrsSinkOnly)]
638/// pub struct YrsSink<S>(SplitSink<WebSocketStream<S>, Message>)
639/// where
640///     S: AsyncRead + AsyncWrite + Unpin + Send;
641/// ```
642#[proc_macro_derive(YrsSinkOnly)]
643pub fn derive_yrs_sink_only(input: TokenStream) -> TokenStream {
644    derive_yrs_sink_gen(input, false)
645}
646
647fn generate_poll_method_item(method_name: &str) -> ImplItemFn {
648    let ident = format_ident!("{}", method_name);
649    let inner_field_access = quote::quote!(self.0);
650
651    syn::parse_quote! {
652        fn #ident(
653            mut self: ::core::pin::Pin<&mut Self>,
654            cx: &mut ::core::task::Context<'_>,
655        ) -> ::core::task::Poll<Result<(), Self::Error>> {
656            match ::core::pin::Pin::new(&mut #inner_field_access).#ident(cx) {
657                ::core::task::Poll::Pending => ::core::task::Poll::Pending,
658                ::core::task::Poll::Ready(Err(e)) => ::core::task::Poll::Ready(Err(::yrs::sync::Error::Other(e.into()))),
659                ::core::task::Poll::Ready(_) => ::core::task::Poll::Ready(Ok(())),
660            }
661        }
662    }
663}
664
665fn generate_start_send_method_item() -> ImplItemFn {
666    let inner_field_access = quote::quote!(self.0);
667
668    syn::parse_quote! {
669        fn start_send(
670            mut self: ::core::pin::Pin<&mut Self>,
671            item: Vec<u8>,
672        ) -> Result<(), Self::Error> {
673            ::core::pin::Pin::new(&mut #inner_field_access).start_send(item.into())
674                .map_err(|e| ::yrs::sync::Error::Other(e.into()))
675        }
676    }
677}
678
679fn derive_yrs_sink_gen(input: TokenStream, gen_exchange: bool) -> TokenStream {
680    derive_impl(input, "YrsSink", move |name, field_type, generics| {
681        let from_into: Option<proc_macro2::TokenStream> = if gen_exchange {
682            Some(quote_from_into(name, field_type, generics))
683        } else {
684            None
685        };
686
687        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
688
689        let poll_ready_fn = generate_poll_method_item("poll_ready");
690        let poll_flush_fn = generate_poll_method_item("poll_flush");
691        let poll_close_fn = generate_poll_method_item("poll_close");
692        let start_send_fn = generate_start_send_method_item();
693
694        quote! {
695            #from_into
696
697            impl #impl_generics ::futures_util::Sink<Vec<u8>> for #name #ty_generics #where_clause {
698                type Error = ::yrs::sync::Error;
699
700                #poll_ready_fn
701                #start_send_fn
702                #poll_flush_fn
703                #poll_close_fn
704            }
705        }
706        .into()
707    })
708}
709
710struct CommonSinkArgs {
711    inner_field: LitStr,
712}
713impl Parse for CommonSinkArgs {
714    fn parse(input: ParseStream) -> syn::Result<Self> {
715        let mut inner_field: Option<LitStr> = None;
716
717        while !input.is_empty() {
718            let lookahead = input.lookahead1();
719
720            if !lookahead.peek(Ident) || !input.peek2(Token![=]) {
721                return Err(input.error("Expected attribute argument `key = value`"));
722            }
723
724            let key: Ident = input.parse()?;
725            let _eq_token: Token![=] = input.parse()?;
726
727            if key == "inner" {
728                if inner_field.is_some() {
729                    return Err(input.error("Duplicate `inner` argument"));
730                }
731                inner_field = Some(input.parse()?);
732            } else {
733                return Err(input.error(format!("Unknown attribute argument: `{}`", key)));
734            }
735
736            if !input.is_empty() {
737                let lookahead = input.lookahead1();
738                if lookahead.peek(Token![,]) {
739                    let _: Token![,] = input.parse()?;
740                } else {
741                    return Err(input.error("Attribute arguments must be separated by commas"));
742                }
743            }
744        }
745
746        let inner_field =
747            inner_field.unwrap_or_else(|| LitStr::new("self.0", proc_macro2::Span::call_site()));
748
749        Ok(CommonSinkArgs { inner_field })
750    }
751}
752
753// 辅助函数:从 impl Sink<Item> for ... 头部解析出 Item 类型
754fn get_sink_item_type(input: &ItemImpl) -> Result<Type, TokenStream> {
755    // <--- 修改返回类型为 Type
756    let trait_option = input.trait_.as_ref();
757    let trait_ref =
758        match trait_option {
759            Some(tr) => tr,
760            None => return Err(syn::Error::new_spanned(
761                input,
762                "yrs_common_sink must be applied to a Sink impl (e.g., `impl Sink<Item> for Type`)",
763            )
764            .to_compile_error()
765            .into()),
766        };
767    let trait_path = trait_ref.1.clone();
768
769    let last_segment_option = trait_path.segments.last();
770    let last_segment = match last_segment_option {
771        Some(seg) => seg,
772        None => {
773            return Err(syn::Error::new_spanned(&trait_path, "Invalid trait path")
774                .to_compile_error()
775                .into());
776        }
777    };
778
779    if last_segment.ident != "Sink" {
780        return Err(syn::Error::new_spanned(
781            &last_segment,
782            "yrs_common_sink must be applied to a Sink impl",
783        )
784        .to_compile_error()
785        .into());
786    }
787
788    let args = match &last_segment.arguments {
789        syn::PathArguments::AngleBracketed(args) => args,
790        _ => {
791            return Err(syn::Error::new_spanned(
792                &last_segment.arguments,
793                "Sink trait must have angle bracketed arguments",
794            )
795            .to_compile_error()
796            .into());
797        }
798    };
799
800    if args.args.len() != 1 {
801        return Err(syn::Error::new_spanned(
802            &args.args,
803            "Sink trait must have exactly one generic argument (Item)",
804        )
805        .to_compile_error()
806        .into());
807    }
808
809    let generic_arg = args.args.first().unwrap();
810    match generic_arg {
811        syn::GenericArgument::Type(ty) => Ok(ty.clone()),
812        arg => Err(
813            syn::Error::new_spanned(arg, "Sink trait argument must be a type")
814                .to_compile_error()
815                .into(),
816        ),
817    }
818}
819
820/// Simplifies implementing `futures_util::Sink` trait by generating boilerplate
821/// methods and managing member order.
822///
823/// This attribute macro should be applied to an `impl Sink<Item> for Type` block.
824/// It finds and reorders members within the impl block to match the
825/// `futures_util::Sink` trait definition order: `type Error`, `fn poll_ready`,
826/// `fn start_send`, `fn poll_flush`, `fn poll_close`, followed by any other
827/// user-defined members.
828///
829/// # Attributes
830///
831/// - `inner = "expr"`: Required string literal specifying the expression to
832///   access the inner Sink instance (e.g., `"self.0"` or `"self.my_field"`).
833///   The default value is `"self.0"`.
834///
835/// # Required Impl Members
836///
837/// The macro looks for these members in the `impl` block:
838///
839/// - `type Error`: The associated error type. If not manually defined in the
840///   impl block, the macro will default to `type Error = ::yrs::sync::Error;`.
841/// - `fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error>`:
842///   The method for sending an item into the sink. If not manually defined in the
843///   impl block, the macro will attempt to generate a default implementation.
844///   The default implementation depends on the `Item` type of the Sink being implemented:
845///   - If `Item` is `SignalingMessage` (checked by name token), a specific `match`
846///     based conversion body is generated, using unqualified names (`Message`, `Bytes`),
847///     relying on user-provided `use` statements.
848///   - For any other `Item` type, a generic `item.into().map_err(...)` body is generated.
849///     This requires the Item type to implement `Into<InnerSinkItem>`.
850///
851/// # Generated Members
852///
853/// The macro generates these members if they are not manually defined in the impl block
854/// (for `type Error` and `fn start_send`) or always generates them (`poll_*` methods):
855///
856/// - `type Error`: Generated if not manually defined. Defaults to `::yrs::sync::Error`.
857/// - `fn poll_ready(...)`: Always generated. Delegates to the inner sink.
858/// - `fn poll_flush(...)`: Always generated. Delegates to the inner sink.
859/// - `fn poll_close(...)`: Always generated. Delegates to the inner sink.
860/// - `fn start_send(...)`: Generated if not manually defined. Implementation depends on Item type.
861///
862/// # Member Order
863///
864/// All members in the `impl` block will be automatically reordered to match the
865/// `futures_util::Sink` trait definition order:
866/// `type Error`, `fn poll_ready`, `fn start_send`, `fn poll_flush`, `fn poll_close`,
867/// followed by any other user-defined members (helper functions, consts, etc.).
868#[proc_macro_attribute]
869pub fn yrs_common_sink(attr: TokenStream, item: TokenStream) -> TokenStream {
870    let mut impl_block = parse_macro_input!(item as ItemImpl);
871
872    let args = parse_macro_input!(attr as CommonSinkArgs);
873    let inner_expr_str = args.inner_field.value();
874
875    let inner_expr: proc_macro2::TokenStream = match inner_expr_str.parse() {
876        Ok(expr) => expr,
877        Err(_) => {
878            return syn::Error::new_spanned(args.inner_field, "Cannot parse `inner` expression")
879                .to_compile_error()
880                .into();
881        }
882    };
883
884    let outer_sink_item_type = match get_sink_item_type(&impl_block) {
885        Ok(ty) => ty.clone(),
886        Err(e) => return e,
887    };
888
889    let mut user_start_send: Option<ImplItemFn> = None;
890    let mut user_error_type_item: Option<ImplItemType> = None;
891    let mut other_user_items: Vec<ImplItem> = Vec::new();
892
893    let original_items = std::mem::take(&mut impl_block.items);
894    for item in original_items {
895        match item {
896            ImplItem::Fn(f) if f.sig.ident == "start_send" => {
897                if user_start_send.is_some() {
898                    return syn::Error::new_spanned(&f.sig.ident, "Duplicate `start_send` method")
899                        .to_compile_error()
900                        .into();
901                }
902                user_start_send = Some(f);
903            }
904            ImplItem::Type(t) if t.ident == "Error" => {
905                if user_error_type_item.is_some() {
906                    return syn::Error::new_spanned(&t.ident, "Duplicate `Error` type")
907                        .to_compile_error()
908                        .into();
909                }
910                user_error_type_item = Some(t.clone());
911            }
912            _ => other_user_items.push(item),
913        }
914    }
915
916    let final_error_item: ImplItemType = user_error_type_item.unwrap_or_else(|| {
917        syn::parse_quote! {
918            type Error = ::yrs::sync::Error;
919        }
920    });
921
922    let start_send_method_item: ImplItemFn = match user_start_send {
923        Some(method) => method,
924        None => {
925            let target_simple_signaling: Type = syn::parse_quote!(SignalingMessage);
926            let target_fully_qualified_signaling: Type =
927                syn::parse_quote!(::yrs_tokio::signaling::Message);
928            let target_qualified_signaling: Type = syn::parse_quote!(yrs_tokio::signaling::Message);
929
930            // Convert syn::Type to string for comparison since Type doesn't implement PartialEq
931            let outer_type_str = quote::quote!(#outer_sink_item_type).to_string();
932            let simple_type_str = quote::quote!(#target_simple_signaling).to_string();
933            let fully_qualified_type_str =
934                quote::quote!(#target_fully_qualified_signaling).to_string();
935            let qualified_type_str = quote::quote!(#target_qualified_signaling).to_string();
936
937            let is_targeted_signaling_message_type = outer_type_str == simple_type_str
938                || outer_type_str == fully_qualified_type_str
939                || outer_type_str == qualified_type_str;
940
941            let default_body = if is_targeted_signaling_message_type {
942                quote! {
943                    let msg = match item {
944                        ::yrs_tokio::signaling::Message::Text(txt) => Message::text(txt),
945                        ::yrs_tokio::signaling::Message::Binary(bytes) => Message::binary(bytes),
946                        ::yrs_tokio::signaling::Message::Ping => Message::Ping(Vec::default().into()),
947                        ::yrs_tokio::signaling::Message::Pong => Message::Pong(Vec::default().into()),
948                        ::yrs_tokio::signaling::Message::Close => Message::Close(None.into()),
949                    };
950                    if let Err(e) = ::core::pin::Pin::new(&mut #inner_expr).start_send(msg) {
951                         Err(::yrs::sync::Error::Other(e.into()))
952                    } else {
953                        Ok(())
954                    }
955                }
956            } else {
957                // 生成更通用的默认 body (使用全限定名,但 Item 参数用短名)
958                quote! {
959                     use ::core::pin::Pin;
960                     use ::futures_util::Sink;
961                     Pin::new(&mut #inner_expr).start_send(item.into())
962                         .map_err(|e| ::yrs::sync::Error::Other(e.into()))
963                }
964            };
965
966            syn::parse_quote! {
967                fn start_send(mut self: ::core::pin::Pin<&mut Self>, item: #outer_sink_item_type) -> Result<(), Self::Error> {
968                    #default_body
969                }
970            }
971        }
972    };
973
974    let generated_poll_ready = generate_poll_fn("poll_ready", &inner_expr);
975
976    let generated_poll_flush = generate_poll_fn("poll_flush", &inner_expr);
977
978    let generated_poll_close = generate_poll_fn("poll_close", &inner_expr);
979
980    let mut final_items: Vec<ImplItem> = Vec::new();
981
982    final_items.push(ImplItem::Type(final_error_item)); // 1. type Error
983    final_items.push(ImplItem::Fn(generated_poll_ready)); // 2. poll_ready
984    final_items.push(ImplItem::Fn(start_send_method_item)); // 3. start_send
985    final_items.push(ImplItem::Fn(generated_poll_flush)); // 4. poll_flush
986    final_items.push(ImplItem::Fn(generated_poll_close)); // 5. poll_close
987
988    final_items.extend(other_user_items);
989
990    impl_block.items = final_items;
991
992    quote!(#impl_block).into()
993}
994
995fn generate_poll_fn(method: &str, inner_expr: &proc_macro2::TokenStream) -> ImplItemFn {
996    let ident = format_ident!("{}", method);
997
998    syn::parse_quote! {
999        fn #ident(
1000            mut self: ::core::pin::Pin<&mut Self>,
1001            cx: &mut std::task::Context<'_>,
1002        ) -> std::task::Poll<Result<(), Self::Error>> {
1003            match ::core::pin::Pin::new(&mut #inner_expr).#ident(cx) {
1004                std::task::Poll::Pending => std::task::Poll::Pending,
1005                std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(::yrs::sync::Error::Other(e.into()))),
1006                std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())),
1007            }
1008        }
1009    }
1010}
1011
1012fn quote_from_into(
1013    name: &Ident,
1014    field_type: &Type,
1015    generics: &Generics,
1016) -> proc_macro2::TokenStream {
1017    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
1018
1019    quote! {
1020        impl #impl_generics ::core::convert::From<#field_type> for #name #ty_generics #where_clause {
1021            fn from(stream: #field_type) -> Self {
1022                #name(stream)
1023            }
1024        }
1025
1026        impl #impl_generics ::core::convert::Into<#field_type> for #name #ty_generics #where_clause {
1027            fn into(self) -> #field_type {
1028                self.0
1029            }
1030        }
1031    }
1032}
1033
1034/// 通用 derive 宏逻辑
1035fn derive_impl<F>(input: TokenStream, _macro_name: &str, f: F) -> TokenStream
1036where
1037    F: FnOnce(&Ident, &Type, &Generics) -> TokenStream, // 闭包签名修改
1038{
1039    let input = parse_macro_input!(input as DeriveInput);
1040    let name = &input.ident;
1041    let generics = &input.generics; // 获取泛型信息
1042
1043    let field_type = match get_field_type(&input) {
1044        Ok(ty) => ty,
1045        Err(err) => return err,
1046    };
1047
1048    // 调用闭包时传递泛型信息
1049    let expanded = f(name, field_type, generics);
1050    TokenStream::from(expanded)
1051}
1052
1053/// 获取单字段 tuple struct 的字段类型
1054fn get_field_type(input: &DeriveInput) -> Result<&Type, TokenStream> {
1055    Ok(match &input.data {
1056        Data::Struct(data_struct) => {
1057            if let Fields::Unnamed(fields_unnamed) = &data_struct.fields {
1058                if fields_unnamed.unnamed.len() == 1 {
1059                    &fields_unnamed.unnamed.first().unwrap().ty
1060                } else {
1061                    return Err(syn::Error::new_spanned(
1062                        &input.ident,
1063                        "can only be derived for tuple structs with one field",
1064                    )
1065                    .to_compile_error()
1066                    .into());
1067                }
1068            } else {
1069                return Err(syn::Error::new_spanned(
1070                    &input.ident,
1071                    "can only be derived for tuple structs with unnamed fields",
1072                )
1073                .to_compile_error()
1074                .into());
1075            }
1076        }
1077        _ => {
1078            return Err(
1079                syn::Error::new_spanned(&input.ident, "can only be derived for structs")
1080                    .to_compile_error()
1081                    .into(),
1082            );
1083        }
1084    })
1085}