ya_service_bus/
connection.rs

1use actix::prelude::*;
2use futures::{
3    channel::{mpsc, oneshot},
4    prelude::*,
5    stream::SplitSink,
6};
7use semver::Version;
8use std::{
9    collections::{HashMap, VecDeque},
10    convert::TryInto,
11    pin::Pin,
12};
13
14use ya_sb_proto::codec::{GsbMessage, ProtocolError};
15use ya_sb_proto::{
16    BroadcastReplyCode, BroadcastRequest, CallReply, CallReplyCode, CallReplyType, CallRequest,
17    RegisterReplyCode, RegisterRequest, SubscribeReplyCode, SubscribeRequest, UnregisterReplyCode,
18    UnregisterRequest, UnsubscribeReplyCode, UnsubscribeRequest,
19};
20use ya_sb_util::writer::*;
21
22use crate::local_router::router;
23use crate::Error;
24use crate::{ResponseChunk, RpcRawCall, RpcRawStreamCall};
25
26fn gen_id() -> u64 {
27    use rand::Rng;
28
29    let mut rng = rand::thread_rng();
30
31    rng.gen::<u64>() & 0x001f_ffff_ffff_ffffu64
32}
33
34#[derive(Default, Clone)]
35#[non_exhaustive]
36pub struct ClientInfo {
37    pub name: String,
38    pub version: Option<Version>,
39    pub instance_id: Vec<u8>,
40}
41
42impl ClientInfo {
43    pub fn new(name: impl ToString) -> Self {
44        ClientInfo {
45            name: name.to_string(),
46            version: Some(Version::parse(env!("CARGO_PKG_VERSION")).unwrap()),
47            instance_id: uuid::Uuid::new_v4().as_bytes().to_vec(),
48        }
49    }
50}
51
52pub trait CallRequestHandler {
53    type Reply: Stream<Item = Result<ResponseChunk, Error>> + Unpin;
54
55    fn do_call(
56        &mut self,
57        request_id: String,
58        caller: String,
59        address: String,
60        data: Vec<u8>,
61        no_reply: bool,
62    ) -> Self::Reply;
63
64    fn handle_event(&mut self, caller: String, topic: String, data: Vec<u8>) {
65        log::warn!("unhandled gsb event from: {}, to: {}", caller, topic,);
66        log::trace!(
67            "unhandled gsb event data: {:?}",
68            String::from_utf8_lossy(data.as_ref())
69        )
70    }
71
72    fn on_disconnect(&mut self) {}
73}
74
75impl ResponseChunk {
76    #[inline]
77    fn reply_type(&self) -> CallReplyType {
78        match self {
79            ResponseChunk::Full(_) => CallReplyType::Full,
80            ResponseChunk::Part(_) => CallReplyType::Partial,
81        }
82    }
83
84    #[inline]
85    fn into_vec(self) -> Vec<u8> {
86        match self {
87            ResponseChunk::Full(v) => v,
88            ResponseChunk::Part(v) => v,
89        }
90    }
91}
92
93#[derive(Default)]
94pub struct LocalRouterHandler {
95    disconnect_h: Option<Box<dyn FnOnce()>>,
96}
97
98impl LocalRouterHandler {
99    pub fn new<F: FnOnce() + 'static>(disconnect_fn: F) -> Self {
100        Self {
101            disconnect_h: Some(Box::new(disconnect_fn)),
102        }
103    }
104}
105
106impl CallRequestHandler for LocalRouterHandler {
107    type Reply = Pin<Box<dyn futures::Stream<Item = Result<ResponseChunk, Error>>>>;
108
109    fn do_call(
110        &mut self,
111        _request_id: String,
112        caller: String,
113        address: String,
114        data: Vec<u8>,
115        no_reply: bool,
116    ) -> Self::Reply {
117        router()
118            .lock()
119            .unwrap()
120            .forward_bytes_local(&address, &caller, data.as_ref(), no_reply)
121            .boxed_local()
122    }
123
124    fn on_disconnect(&mut self) {
125        if let Some(f) = self.disconnect_h.take() {
126            f()
127        };
128    }
129}
130
131impl<
132        R: futures::Stream<Item = Result<ResponseChunk, Error>> + Unpin,
133        F: FnMut(String, String, String, Vec<u8>) -> R,
134    > CallRequestHandler for F
135{
136    type Reply = R;
137
138    fn do_call(
139        &mut self,
140        request_id: String,
141        caller: String,
142        address: String,
143        data: Vec<u8>,
144        _no_reply: bool,
145    ) -> Self::Reply {
146        self(request_id, caller, address, data)
147    }
148}
149
150impl<
151        R: futures::Stream<Item = Result<ResponseChunk, Error>> + Unpin,
152        F1: FnMut(String, String, String, Vec<u8>) -> R,
153        F2: FnMut(String, String, Vec<u8>),
154    > CallRequestHandler for (F1, F2)
155{
156    type Reply = R;
157
158    fn do_call(
159        &mut self,
160        request_id: String,
161        caller: String,
162        address: String,
163        data: Vec<u8>,
164        _no_reply: bool,
165    ) -> Self::Reply {
166        (self.0)(request_id, caller, address, data)
167    }
168
169    fn handle_event(&mut self, caller: String, topic: String, data: Vec<u8>) {
170        (self.1)(caller, topic, data)
171    }
172}
173
174type TransportWriter<W> = SinkWrite<GsbMessage, W>;
175type ReplyQueue = VecDeque<oneshot::Sender<Result<(), Error>>>;
176
177struct Connection<W, H>
178where
179    W: Sink<GsbMessage, Error = ProtocolError> + Unpin,
180    H: CallRequestHandler,
181{
182    writer: TransportWriter<W>,
183    register_reply: ReplyQueue,
184    unregister_reply: ReplyQueue,
185    subscribe_reply: ReplyQueue,
186    unsubscribe_reply: ReplyQueue,
187    call_reply: HashMap<String, mpsc::Sender<Result<ResponseChunk, Error>>>,
188    broadcast_reply: ReplyQueue,
189    handler: H,
190    client_info: ClientInfo,
191    server_info: Option<ya_sb_proto::Hello>,
192}
193
194impl<W, H> Unpin for Connection<W, H>
195where
196    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
197    H: CallRequestHandler + 'static,
198{
199}
200
201fn handle_reply<Ctx: ActorContext, F: FnOnce() -> Result<(), Error>>(
202    cmd_type: &str,
203    queue: &mut ReplyQueue,
204    ctx: &mut Ctx,
205    reply_msg: F,
206) {
207    if let Some(r) = queue.pop_front() {
208        let _ = r.send(reply_msg());
209    } else {
210        log::error!("unmatched {} reply", cmd_type);
211        ctx.stop()
212    }
213}
214
215impl<W, H> EmptyBufferHandler for Connection<W, H>
216where
217    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
218    H: CallRequestHandler + 'static,
219{
220}
221
222impl<W, H> Connection<W, H>
223where
224    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
225    H: CallRequestHandler + 'static,
226{
227    fn new(client_info: ClientInfo, w: W, handler: H, ctx: &mut <Self as Actor>::Context) -> Self {
228        Connection {
229            writer: SinkWrite::new(w, ctx),
230            register_reply: Default::default(),
231            unregister_reply: Default::default(),
232            subscribe_reply: Default::default(),
233            unsubscribe_reply: Default::default(),
234            call_reply: Default::default(),
235            broadcast_reply: Default::default(),
236            handler,
237            client_info,
238            server_info: Default::default(),
239        }
240    }
241
242    fn handle_unregister_reply(
243        &mut self,
244        code: UnregisterReplyCode,
245        ctx: &mut <Self as Actor>::Context,
246    ) {
247        handle_reply(
248            "unregister",
249            &mut self.unregister_reply,
250            ctx,
251            || match code {
252                UnregisterReplyCode::UnregisteredOk => Ok(()),
253                UnregisterReplyCode::NotRegistered => {
254                    Err(Error::GsbBadRequest("unregister".to_string()))
255                }
256            },
257        )
258    }
259
260    fn handle_broadcast_reply(
261        &mut self,
262        code: BroadcastReplyCode,
263        msg: String,
264        ctx: &mut <Self as Actor>::Context,
265    ) {
266        handle_reply("broadcast", &mut self.broadcast_reply, ctx, || match code {
267            BroadcastReplyCode::BroadcastOk => Ok(()),
268            BroadcastReplyCode::BroadcastBadRequest => Err(Error::GsbBadRequest(msg)),
269        })
270    }
271
272    fn handle_register_reply(
273        &mut self,
274        code: RegisterReplyCode,
275        msg: String,
276        ctx: &mut <Self as Actor>::Context,
277    ) {
278        handle_reply("register", &mut self.register_reply, ctx, || match code {
279            RegisterReplyCode::RegisteredOk => Ok(()),
280            RegisterReplyCode::RegisterBadRequest => {
281                log::warn!("bad request: {}", msg);
282                Err(Error::GsbBadRequest(msg))
283            }
284            RegisterReplyCode::RegisterConflict => {
285                log::warn!("already registered: {}", msg);
286                Err(Error::GsbAlreadyRegistered(msg))
287            }
288        })
289    }
290
291    fn handle_subscribe_reply(
292        &mut self,
293        code: SubscribeReplyCode,
294        msg: String,
295        ctx: &mut <Self as Actor>::Context,
296    ) {
297        handle_reply("subscribe", &mut self.subscribe_reply, ctx, || match code {
298            SubscribeReplyCode::SubscribedOk => Ok(()),
299            SubscribeReplyCode::SubscribeBadRequest => {
300                log::warn!("bad request: {}", msg);
301                Err(Error::GsbBadRequest(msg))
302            }
303        })
304    }
305
306    fn handle_unsubscribe_reply(
307        &mut self,
308        code: UnsubscribeReplyCode,
309        ctx: &mut <Self as Actor>::Context,
310    ) {
311        handle_reply(
312            "unsubscribe",
313            &mut self.unsubscribe_reply,
314            ctx,
315            || match code {
316                UnsubscribeReplyCode::UnsubscribedOk => Ok(()),
317                UnsubscribeReplyCode::NotSubscribed => {
318                    Err(Error::GsbBadRequest("unsubscribed".to_string()))
319                }
320            },
321        )
322    }
323
324    fn handle_call_request(
325        &mut self,
326        request_id: String,
327        caller: String,
328        address: String,
329        data: Vec<u8>,
330        ctx: &mut <Self as Actor>::Context,
331    ) {
332        log::trace!(
333            "handling rpc call from = {}, to = {}, request_id={}, ",
334            caller,
335            address,
336            request_id
337        );
338        let eos_request_id = request_id.clone();
339        let do_call = self
340            .handler
341            .do_call(request_id.clone(), caller, address, data, false)
342            .into_actor(self)
343            .fold(false, move |_got_eos, r, act: &mut Self, _ctx| {
344                let request_id = request_id.clone();
345                let (got_eos, reply) = match r {
346                    Ok(data) => {
347                        let code = CallReplyCode::CallReplyOk as i32;
348                        let reply_type = data.reply_type() as i32;
349                        (
350                            reply_type == 0,
351                            CallReply {
352                                request_id,
353                                code,
354                                reply_type,
355                                data: data.into_vec(),
356                            },
357                        )
358                    }
359                    Err(e) => {
360                        let code = CallReplyCode::ServiceFailure as i32;
361                        let reply_type = Default::default();
362                        let data = format!("{}", e).into_bytes();
363                        (
364                            true,
365                            CallReply {
366                                request_id,
367                                code,
368                                reply_type,
369                                data,
370                            },
371                        )
372                    }
373                };
374                // TODO: handle write error
375                let _ = act.writer.write(GsbMessage::CallReply(reply));
376                fut::ready(got_eos)
377            })
378            .then(|got_eos, act, _ctx| {
379                if !got_eos {
380                    let _ = act.writer.write(GsbMessage::CallReply(CallReply {
381                        request_id: eos_request_id,
382                        code: 0,
383                        reply_type: 0,
384                        data: Default::default(),
385                    }));
386                }
387                fut::ready(())
388            });
389        //do_call.spawn(ctx);
390        ctx.spawn(do_call);
391    }
392
393    fn handle_push_request(
394        &mut self,
395        request_id: String,
396        caller: String,
397        address: String,
398        data: Vec<u8>,
399        ctx: &mut <Self as Actor>::Context,
400    ) {
401        log::trace!(
402            "handling push call from = {}, to = {}, request_id={}, ",
403            caller,
404            address,
405            request_id
406        );
407
408        self.handler
409            .do_call(request_id, caller, address, data, true)
410            .into_actor(self)
411            .fold((), move |_, _, _, _| fut::ready(()))
412            .spawn(ctx);
413    }
414
415    fn handle_reply(
416        &mut self,
417        request_id: String,
418        code: i32,
419        reply_type: i32,
420        data: Vec<u8>,
421        ctx: &mut <Self as Actor>::Context,
422    ) -> Result<(), Box<dyn std::error::Error>> {
423        log::trace!(
424            "handling reply for request_id={}, code={}, reply_type={}",
425            request_id,
426            code,
427            reply_type
428        );
429
430        let chunk = if reply_type == CallReplyType::Partial as i32 {
431            ResponseChunk::Part(data)
432        } else {
433            ResponseChunk::Full(data)
434        };
435
436        let is_full = chunk.is_full();
437
438        if let Some(r) = self.call_reply.get_mut(&request_id) {
439            // TODO: check error
440            let mut r = (*r).clone();
441            let code: CallReplyCode = code.try_into()?;
442            let item = match code {
443                CallReplyCode::CallReplyOk => Ok(chunk),
444                CallReplyCode::CallReplyBadRequest => {
445                    Err(Error::GsbBadRequest(String::from_utf8(chunk.into_bytes())?))
446                }
447                CallReplyCode::ServiceFailure => {
448                    Err(Error::GsbFailure(String::from_utf8(chunk.into_bytes())?))
449                }
450            };
451            ctx.wait(
452                async move {
453                    let s = r.send(item);
454                    s.await
455                        .unwrap_or_else(|e| log::warn!("undelivered reply: {}", e))
456                }
457                .into_actor(self),
458            );
459        } else {
460            log::debug!("unmatched call reply");
461        }
462
463        if is_full {
464            let _ = self.call_reply.remove(&request_id);
465        }
466
467        Ok(())
468    }
469}
470
471impl<W, H> Actor for Connection<W, H>
472where
473    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
474    H: CallRequestHandler + 'static,
475{
476    type Context = Context<Self>;
477
478    fn started(&mut self, _ctx: &mut Self::Context) {
479        log::info!("started connection to gsb");
480        let hello: ya_sb_proto::Hello = ya_sb_proto::Hello {
481            name: self.client_info.name.clone(),
482            version: self
483                .client_info
484                .version
485                .as_ref()
486                .map(|v| v.to_string())
487                .unwrap_or_default(),
488            instance_id: self.client_info.instance_id.clone(),
489        };
490
491        let _ = self.writer.write(GsbMessage::Hello(hello));
492    }
493
494    fn stopped(&mut self, _ctx: &mut Self::Context) {
495        log::info!("stopped connection to gsb");
496        self.handler.on_disconnect();
497    }
498}
499
500fn register_reply_code(code: i32) -> Option<RegisterReplyCode> {
501    Some(match code {
502        0 => RegisterReplyCode::RegisteredOk,
503        400 => RegisterReplyCode::RegisterBadRequest,
504        409 => RegisterReplyCode::RegisterConflict,
505        _ => return None,
506    })
507}
508
509fn unregister_reply_code(code: i32) -> Option<UnregisterReplyCode> {
510    Some(match code {
511        0 => UnregisterReplyCode::UnregisteredOk,
512        404 => UnregisterReplyCode::NotRegistered,
513        _ => return None,
514    })
515}
516
517fn subscribe_reply_code(code: i32) -> Option<SubscribeReplyCode> {
518    Some(match code {
519        0 => SubscribeReplyCode::SubscribedOk,
520        400 => SubscribeReplyCode::SubscribeBadRequest,
521        _ => return None,
522    })
523}
524
525fn unsubscribe_reply_code(code: i32) -> Option<UnsubscribeReplyCode> {
526    Some(match code {
527        0 => UnsubscribeReplyCode::UnsubscribedOk,
528        404 => UnsubscribeReplyCode::NotSubscribed,
529        _ => return None,
530    })
531}
532
533fn broadcast_reply_code(code: i32) -> Option<BroadcastReplyCode> {
534    Some(match code {
535        0 => BroadcastReplyCode::BroadcastOk,
536        400 => BroadcastReplyCode::BroadcastBadRequest,
537        _ => return None,
538    })
539}
540
541impl<W, H> StreamHandler<Result<GsbMessage, ProtocolError>> for Connection<W, H>
542where
543    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
544    H: CallRequestHandler + 'static,
545{
546    fn handle(&mut self, item: Result<GsbMessage, ProtocolError>, ctx: &mut Self::Context) {
547        if let Err(e) = item.as_ref() {
548            log::error!("protocol error {}", e);
549            ctx.stop();
550            return;
551        }
552
553        match item.unwrap() {
554            GsbMessage::RegisterReply(r) => {
555                if let Some(code) = register_reply_code(r.code) {
556                    self.handle_register_reply(code, r.message, ctx)
557                } else {
558                    log::error!("invalid reply code {}", r.code);
559                    ctx.stop();
560                }
561            }
562            GsbMessage::UnregisterReply(r) => {
563                if let Some(code) = unregister_reply_code(r.code) {
564                    self.handle_unregister_reply(code, ctx)
565                } else {
566                    log::error!("invalid unregister reply code {}", r.code);
567                    ctx.stop();
568                }
569            }
570            GsbMessage::SubscribeReply(r) => {
571                if let Some(code) = subscribe_reply_code(r.code) {
572                    self.handle_subscribe_reply(code, r.message, ctx)
573                } else {
574                    log::error!("invalid reply code {}", r.code);
575                    ctx.stop();
576                }
577            }
578            GsbMessage::UnsubscribeReply(r) => {
579                if let Some(code) = unsubscribe_reply_code(r.code) {
580                    self.handle_unsubscribe_reply(code, ctx)
581                } else {
582                    log::error!("invalid unsubscribe reply code {}", r.code);
583                    ctx.stop();
584                }
585            }
586            GsbMessage::BroadcastReply(r) => {
587                if let Some(code) = broadcast_reply_code(r.code) {
588                    self.handle_broadcast_reply(code, r.message, ctx)
589                } else {
590                    log::error!("invalid broadcast reply code {}", r.code);
591                    ctx.stop();
592                }
593            }
594            GsbMessage::CallRequest(r) => {
595                if r.no_reply {
596                    self.handle_push_request(r.request_id, r.caller, r.address, r.data, ctx)
597                } else {
598                    self.handle_call_request(r.request_id, r.caller, r.address, r.data, ctx)
599                }
600            }
601            GsbMessage::CallReply(r) => {
602                if let Err(e) = self.handle_reply(r.request_id, r.code, r.reply_type, r.data, ctx) {
603                    log::error!("error on call reply processing: {}", e);
604                    ctx.stop();
605                }
606            }
607            GsbMessage::BroadcastRequest(r) => {
608                self.handler.handle_event(r.caller, r.topic, r.data);
609            }
610            GsbMessage::Ping(_) => {
611                if self.writer.write(GsbMessage::pong()).is_some() {
612                    log::error!("error sending pong");
613                    ctx.stop();
614                }
615            }
616            GsbMessage::Hello(h) => {
617                log::debug!("connected with server: {}/{}", h.name, h.version);
618                if self.server_info.is_some() {
619                    log::error!("invalid packet: {:?}", h);
620                    ctx.stop();
621                } else {
622                    self.server_info = Some(h);
623                }
624            }
625            m => {
626                log::error!("unexpected gsb message: {:?}", m);
627                ctx.stop();
628            }
629        }
630    }
631}
632
633impl<W, H> io::WriteHandler<ProtocolError> for Connection<W, H>
634where
635    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
636    H: CallRequestHandler + 'static,
637{
638    fn error(&mut self, err: ProtocolError, _ctx: &mut Self::Context) -> Running {
639        log::error!("protocol error: {}", err);
640        Running::Stop
641    }
642}
643
644impl<W, H> Handler<RpcRawCall> for Connection<W, H>
645where
646    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
647    H: CallRequestHandler + 'static,
648{
649    type Result = ActorResponse<Self, Result<Vec<u8>, Error>>;
650
651    fn handle(&mut self, msg: RpcRawCall, _ctx: &mut Self::Context) -> Self::Result {
652        let request_id = format!("{}", gen_id());
653        let caller = msg.caller;
654        let address = msg.addr;
655        let data = msg.body;
656        let no_reply = msg.no_reply;
657
658        let rx = if no_reply {
659            None
660        } else {
661            let (tx, rx) = mpsc::channel(1);
662            let _ = self.call_reply.insert(request_id.clone(), tx);
663            Some(rx)
664        };
665
666        log::trace!("handling caller (rpc): {}, addr:{}", caller, address);
667        let _r = self.writer.write(GsbMessage::CallRequest(CallRequest {
668            request_id,
669            caller,
670            address,
671            data,
672            no_reply,
673        }));
674
675        match rx {
676            Some(mut rx) => {
677                let fetch_response = async move {
678                    match futures::StreamExt::next(&mut rx).await {
679                        Some(Ok(ResponseChunk::Full(data))) => Ok(data),
680                        Some(Err(e)) => Err(e),
681                        Some(Ok(ResponseChunk::Part(_))) => {
682                            Err(Error::GsbFailure("streaming response".to_string()))
683                        }
684                        None => Err(Error::GsbFailure("unexpected EOS".to_string())),
685                    }
686                };
687                ActorResponse::r#async(fetch_response.into_actor(self))
688            }
689            None => ActorResponse::reply(Ok(Vec::new())),
690        }
691    }
692}
693
694impl<W, H> Handler<RpcRawStreamCall> for Connection<W, H>
695where
696    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
697    H: CallRequestHandler + 'static,
698{
699    type Result = ActorResponse<Self, Result<(), Error>>;
700
701    fn handle(&mut self, msg: RpcRawStreamCall, _ctx: &mut Self::Context) -> Self::Result {
702        let request_id = format!("{}", gen_id());
703        let rx = msg.reply;
704        let _ = self.call_reply.insert(request_id.clone(), rx);
705        let caller = msg.caller;
706        let address = msg.addr;
707        let data = msg.body;
708        log::trace!("handling caller (stream): {}, addr:{}", caller, address);
709        let _r = self.writer.write(GsbMessage::CallRequest(CallRequest {
710            request_id,
711            caller,
712            address,
713            data,
714            no_reply: false,
715        }));
716        ActorResponse::reply(Ok(()))
717    }
718}
719
720fn send_cmd_async<A: Actor, W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static>(
721    writer: &mut TransportWriter<W>,
722    queue: &mut VecDeque<oneshot::Sender<Result<(), Error>>>,
723    msg: GsbMessage,
724) -> ActorResponse<A, Result<(), Error>> {
725    let (tx, rx) = oneshot::channel();
726    queue.push_back(tx);
727
728    if writer.write(msg).is_some() {
729        ActorResponse::reply(Err(Error::GsbFailure("no connection".into())))
730    } else {
731        ActorResponse::r#async(fut::wrap_future(async move {
732            rx.await.map_err(|_| Error::Cancelled)??;
733            Ok(())
734        }))
735    }
736}
737
738struct Bind {
739    addr: String,
740}
741
742impl Message for Bind {
743    type Result = Result<(), Error>;
744}
745
746impl<W, H> Handler<Bind> for Connection<W, H>
747where
748    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
749    H: CallRequestHandler + 'static,
750{
751    type Result = ActorResponse<Self, Result<(), Error>>;
752
753    fn handle(&mut self, msg: Bind, _ctx: &mut Self::Context) -> Self::Result {
754        let service_id = msg.addr;
755        send_cmd_async(
756            &mut self.writer,
757            &mut self.register_reply,
758            GsbMessage::RegisterRequest(RegisterRequest { service_id }),
759        )
760    }
761}
762
763struct Unbind {
764    addr: String,
765}
766
767impl Message for Unbind {
768    type Result = Result<(), Error>;
769}
770
771impl<W, H> Handler<Unbind> for Connection<W, H>
772where
773    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
774    H: CallRequestHandler + 'static,
775{
776    type Result = ActorResponse<Self, Result<(), Error>>;
777
778    fn handle(&mut self, msg: Unbind, _ctx: &mut Self::Context) -> Self::Result {
779        let service_id = msg.addr;
780        send_cmd_async(
781            &mut self.writer,
782            &mut self.unregister_reply,
783            GsbMessage::UnregisterRequest(UnregisterRequest { service_id }),
784        )
785    }
786}
787
788struct Subscribe {
789    topic: String,
790}
791
792impl Message for Subscribe {
793    type Result = Result<(), Error>;
794}
795
796impl<W, H> Handler<Subscribe> for Connection<W, H>
797where
798    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
799    H: CallRequestHandler + 'static,
800{
801    type Result = ActorResponse<Self, Result<(), Error>>;
802
803    fn handle(&mut self, msg: Subscribe, _ctx: &mut Self::Context) -> Self::Result {
804        let topic = msg.topic;
805        send_cmd_async(
806            &mut self.writer,
807            &mut self.subscribe_reply,
808            GsbMessage::SubscribeRequest(SubscribeRequest { topic }),
809        )
810    }
811}
812
813struct Unsubscribe {
814    topic: String,
815}
816
817impl Message for Unsubscribe {
818    type Result = Result<(), Error>;
819}
820
821impl<W, H> Handler<Unsubscribe> for Connection<W, H>
822where
823    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
824    H: CallRequestHandler + 'static,
825{
826    type Result = ActorResponse<Self, Result<(), Error>>;
827
828    fn handle(&mut self, msg: Unsubscribe, _ctx: &mut Self::Context) -> Self::Result {
829        let topic = msg.topic;
830        send_cmd_async(
831            &mut self.writer,
832            &mut self.unsubscribe_reply,
833            GsbMessage::UnsubscribeRequest(UnsubscribeRequest { topic }),
834        )
835    }
836}
837
838pub struct BcastCall {
839    pub caller: String,
840    pub topic: String,
841    pub body: Vec<u8>,
842}
843
844impl Message for BcastCall {
845    type Result = Result<(), Error>;
846}
847
848impl<W, H> Handler<BcastCall> for Connection<W, H>
849where
850    W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
851    H: CallRequestHandler + 'static,
852{
853    type Result = ActorResponse<Self, Result<(), Error>>;
854
855    fn handle(&mut self, msg: BcastCall, _ctx: &mut Self::Context) -> Self::Result {
856        let caller = msg.caller;
857        let topic = msg.topic;
858        let data = msg.body;
859        send_cmd_async(
860            &mut self.writer,
861            &mut self.broadcast_reply,
862            GsbMessage::BroadcastRequest(BroadcastRequest {
863                caller,
864                topic,
865                data,
866            }),
867        )
868    }
869}
870
871pub struct ConnectionRef<
872    Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
873    H: CallRequestHandler + 'static,
874>(Addr<Connection<SplitSink<Transport, GsbMessage>, H>>);
875
876impl<
877        Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
878        H: CallRequestHandler + 'static,
879    > Unpin for ConnectionRef<Transport, H>
880{
881}
882
883impl<
884        Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
885        H: CallRequestHandler + 'static,
886    > Clone for ConnectionRef<Transport, H>
887{
888    fn clone(&self) -> Self {
889        ConnectionRef(self.0.clone())
890    }
891}
892
893impl<
894        Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
895        H: CallRequestHandler + Unpin + 'static,
896    > ConnectionRef<Transport, H>
897{
898    pub fn bind(
899        &self,
900        addr: impl Into<String>,
901    ) -> impl Future<Output = Result<(), Error>> + 'static {
902        let addr = addr.into();
903        log::trace!("Binding remote service '{}'", addr);
904        self.0.send(Bind { addr: addr.clone() }).then(|v| async {
905            log::trace!("send bind result: {:?}", v);
906            v.map_err(|e| Error::from_addr(addr, e))?
907        })
908    }
909
910    pub fn unbind(
911        &self,
912        addr: impl Into<String>,
913    ) -> impl Future<Output = Result<(), Error>> + 'static {
914        let addr = addr.into();
915        self.0.send(Unbind { addr: addr.clone() }).then(|v| async {
916            log::trace!("send unbind result: {:?}", v);
917            v.map_err(|e| Error::from_addr(addr, e))?
918        })
919    }
920
921    pub fn subscribe(
922        &self,
923        topic: impl Into<String>,
924    ) -> impl Future<Output = Result<(), Error>> + 'static {
925        let topic = topic.into();
926        let fut = self.0.send(Subscribe {
927            topic: topic.clone(),
928        });
929        async move {
930            fut.await
931                .map_err(|e| Error::from_addr(format!("subscribing {}", topic), e))?
932        }
933    }
934
935    pub fn unsubscribe(
936        &self,
937        topic: impl Into<String>,
938    ) -> impl Future<Output = Result<(), Error>> + 'static {
939        let topic = topic.into();
940        let fut = self.0.send(Unsubscribe {
941            topic: topic.clone(),
942        });
943        async move {
944            fut.await
945                .map_err(|e| Error::from_addr(format!("unsubscribing {}", topic), e))?
946        }
947    }
948
949    pub fn broadcast(
950        &self,
951        caller: impl Into<String>,
952        topic: impl Into<String>,
953        body: Vec<u8>,
954    ) -> impl Future<Output = Result<(), Error>> + 'static {
955        let topic = topic.into();
956        let fut = self.0.send(BcastCall {
957            caller: caller.into(),
958            topic: topic.clone(),
959            body,
960        });
961        async move {
962            fut.await
963                .map_err(|e| Error::from_addr(format!("broadcasting {}", topic), e))?
964        }
965    }
966
967    pub fn call(
968        &self,
969        caller: impl Into<String>,
970        addr: impl Into<String>,
971        body: impl Into<Vec<u8>>,
972        no_reply: bool,
973    ) -> impl Future<Output = Result<Vec<u8>, Error>> {
974        let addr = addr.into();
975        self.0
976            .send(RpcRawCall {
977                caller: caller.into(),
978                addr: addr.clone(),
979                body: body.into(),
980                no_reply,
981            })
982            .then(|v| async { v.map_err(|e| Error::from_addr(addr, e))? })
983    }
984
985    pub fn call_streaming(
986        &self,
987        caller: impl Into<String>,
988        addr: impl Into<String>,
989        body: impl Into<Vec<u8>>,
990    ) -> impl Stream<Item = Result<ResponseChunk, Error>> {
991        let addr = addr.into();
992        let (tx, rx) = futures::channel::mpsc::channel(16);
993
994        let args = RpcRawStreamCall {
995            caller: caller.into(),
996            addr: addr.clone(),
997            body: body.into(),
998            reply: tx.clone(),
999        };
1000        let connection = self.0.clone();
1001        let _ = Arbiter::current().spawn(async move {
1002            let mut tx = tx;
1003            match connection.send(args).await {
1004                Ok(Ok(())) => (),
1005                Ok(Err(e)) => {
1006                    tx.send(Err(e))
1007                        .await
1008                        .unwrap_or_else(|e| log::error!("fail: {}", e));
1009                }
1010                Err(e) => {
1011                    tx.send(Err(Error::from_addr(addr, e)))
1012                        .await
1013                        .unwrap_or_else(|e| log::error!("fail: {}", e));
1014                }
1015            }
1016        });
1017        rx
1018    }
1019
1020    pub fn connected(&self) -> bool {
1021        self.0.connected()
1022    }
1023}
1024
1025pub fn connect<Transport, H>(
1026    client_info: ClientInfo,
1027    transport: Transport,
1028) -> ConnectionRef<Transport, H>
1029where
1030    Transport: Sink<GsbMessage, Error = ProtocolError>
1031        + Stream<Item = Result<GsbMessage, ProtocolError>>
1032        + Unpin
1033        + 'static,
1034    H: CallRequestHandler + 'static + Default + Unpin,
1035{
1036    connect_with_handler(client_info, transport, Default::default())
1037}
1038
1039pub fn connect_with_handler<Transport, H>(
1040    client_info: ClientInfo,
1041    transport: Transport,
1042    handler: H,
1043) -> ConnectionRef<Transport, H>
1044where
1045    Transport: Sink<GsbMessage, Error = ProtocolError>
1046        + Stream<Item = Result<GsbMessage, ProtocolError>>
1047        + Unpin
1048        + 'static,
1049    H: CallRequestHandler + 'static,
1050{
1051    let (split_sink, split_stream) = transport.split();
1052    ConnectionRef(Connection::create(move |ctx| {
1053        let _h = Connection::add_stream(split_stream, ctx);
1054        Connection::new(client_info, split_sink, handler, ctx)
1055    }))
1056}
1057
1058pub type TcpTransport =
1059    tokio_util::codec::Framed<tokio::net::TcpStream, ya_sb_proto::codec::GsbMessageCodec>;
1060
1061pub async fn tcp(addr: impl tokio::net::ToSocketAddrs) -> Result<TcpTransport, std::io::Error> {
1062    let s = tokio::net::TcpStream::connect(addr).await?;
1063    Ok(tokio_util::codec::Framed::new(
1064        s,
1065        ya_sb_proto::codec::GsbMessageCodec::default(),
1066    ))
1067}
1068
1069#[cfg(feature = "tls")]
1070mod tls {
1071    use rustls::pki_types::ServerName;
1072    use std::sync::Arc;
1073    use tokio::net::TcpStream;
1074    use tokio_rustls::client::TlsStream;
1075    use ya_sb_proto::codec::GsbMessageCodec;
1076    pub use ya_sb_util::tls::CertHash;
1077    use ya_sb_util::tls::HashVerifier;
1078
1079    pub type TlsTransport = tokio_util::codec::Framed<TlsStream<TcpStream>, GsbMessageCodec>;
1080
1081    pub async fn tls(
1082        addr: impl tokio::net::ToSocketAddrs,
1083        cert_hash: CertHash,
1084    ) -> Result<TlsTransport, std::io::Error> {
1085        let v = Arc::new(HashVerifier::new(cert_hash));
1086
1087        let connector = tokio_rustls::TlsConnector::from(Arc::new(
1088            rustls::ClientConfig::builder()
1089                .dangerous()
1090                .with_custom_certificate_verifier(v)
1091                .with_no_client_auth(),
1092        ));
1093        let sock = TcpStream::connect(&addr).await?;
1094        let io = connector
1095            .connect(ServerName::IpAddress(sock.peer_addr()?.ip().into()), sock)
1096            .await?;
1097        let framed = tokio_util::codec::Framed::new(io, GsbMessageCodec::default());
1098
1099        Ok(framed)
1100    }
1101}
1102
1103#[cfg(feature = "tls")]
1104pub use tls::*;
1105
1106#[cfg(unix)]
1107mod unix {
1108
1109    use super::*;
1110    use std::task::Poll;
1111
1112    pub type UnixTransport =
1113        tokio_util::codec::Framed<tokio::net::UnixStream, ya_sb_proto::codec::GsbMessageCodec>;
1114
1115    pub async fn unix<P>(path: P) -> Result<UnixTransport, std::io::Error>
1116    where
1117        P: AsRef<std::path::Path>,
1118    {
1119        let s = tokio::net::UnixStream::connect(path).await?;
1120        Ok(tokio_util::codec::Framed::new(
1121            s,
1122            ya_sb_proto::codec::GsbMessageCodec::default(),
1123        ))
1124    }
1125
1126    /// This trait exists to annotate the return type of Transport::inner()
1127    trait ITransport:
1128        Sink<GsbMessage, Error = ProtocolError>
1129        + Stream<Item = Result<GsbMessage, ProtocolError>>
1130        + Unpin
1131        + 'static
1132    {
1133    }
1134
1135    impl ITransport for TcpTransport {}
1136    impl ITransport for UnixTransport {}
1137
1138    pub enum Transport {
1139        Tcp(TcpTransport),
1140        Unix(UnixTransport),
1141    }
1142
1143    impl Transport {
1144        fn inner(self: Pin<&mut Self>) -> Pin<&mut (dyn ITransport)> {
1145            match self.get_mut() {
1146                Transport::Tcp(tcp_transport) => Pin::new(tcp_transport),
1147                Transport::Unix(unix_transport) => Pin::new(unix_transport),
1148            }
1149        }
1150    }
1151
1152    impl Sink<GsbMessage> for Transport {
1153        type Error = ProtocolError;
1154
1155        fn poll_ready(
1156            self: Pin<&mut Self>,
1157            cx: &mut core::task::Context<'_>,
1158        ) -> Poll<Result<(), Self::Error>> {
1159            self.inner().poll_ready(cx)
1160        }
1161
1162        fn start_send(self: Pin<&mut Self>, item: GsbMessage) -> Result<(), Self::Error> {
1163            self.inner().start_send(item)
1164        }
1165
1166        fn poll_flush(
1167            self: Pin<&mut Self>,
1168            cx: &mut core::task::Context<'_>,
1169        ) -> Poll<Result<(), Self::Error>> {
1170            self.inner().poll_flush(cx)
1171        }
1172
1173        fn poll_close(
1174            self: Pin<&mut Self>,
1175            cx: &mut core::task::Context<'_>,
1176        ) -> Poll<Result<(), Self::Error>> {
1177            self.inner().poll_close(cx)
1178        }
1179    }
1180
1181    impl Stream for Transport {
1182        type Item = Result<GsbMessage, ProtocolError>;
1183
1184        fn poll_next(
1185            self: Pin<&mut Self>,
1186            cx: &mut core::task::Context<'_>,
1187        ) -> Poll<Option<Self::Item>> {
1188            self.inner().poll_next(cx)
1189        }
1190    }
1191
1192    impl Unpin for Transport {}
1193
1194    pub async fn transport(addr: ya_sb_proto::GsbAddr) -> Result<Transport, std::io::Error> {
1195        match addr {
1196            ya_sb_proto::GsbAddr::Tcp(addr) => Ok(Transport::Tcp(tcp(addr).await?)),
1197            ya_sb_proto::GsbAddr::Unix(path) => Ok(Transport::Unix(unix(path).await?)),
1198        }
1199    }
1200}
1201
1202#[cfg(unix)]
1203pub use unix::*;
1204
1205#[cfg(not(unix))]
1206pub type Transport = TcpTransport;
1207
1208#[cfg(not(unix))]
1209pub async fn transport(addr: ya_sb_proto::GsbAddr) -> Result<TcpTransport, std::io::Error> {
1210    match addr {
1211        ya_sb_proto::GsbAddr::Tcp(addr) => Ok(tcp(addr).await?),
1212        ya_sb_proto::GsbAddr::Unix(_) => panic!("Unix sockets not supported on this OS"),
1213    }
1214}