1use actix::prelude::*;
2use futures::{
3 channel::{mpsc, oneshot},
4 prelude::*,
5 stream::SplitSink,
6};
7use semver::Version;
8use std::{
9 collections::{HashMap, VecDeque},
10 convert::TryInto,
11 pin::Pin,
12};
13
14use ya_sb_proto::codec::{GsbMessage, ProtocolError};
15use ya_sb_proto::{
16 BroadcastReplyCode, BroadcastRequest, CallReply, CallReplyCode, CallReplyType, CallRequest,
17 RegisterReplyCode, RegisterRequest, SubscribeReplyCode, SubscribeRequest, UnregisterReplyCode,
18 UnregisterRequest, UnsubscribeReplyCode, UnsubscribeRequest,
19};
20use ya_sb_util::writer::*;
21
22use crate::local_router::router;
23use crate::Error;
24use crate::{ResponseChunk, RpcRawCall, RpcRawStreamCall};
25
26fn gen_id() -> u64 {
27 use rand::Rng;
28
29 let mut rng = rand::thread_rng();
30
31 rng.gen::<u64>() & 0x001f_ffff_ffff_ffffu64
32}
33
34#[derive(Default, Clone)]
35#[non_exhaustive]
36pub struct ClientInfo {
37 pub name: String,
38 pub version: Option<Version>,
39 pub instance_id: Vec<u8>,
40}
41
42impl ClientInfo {
43 pub fn new(name: impl ToString) -> Self {
44 ClientInfo {
45 name: name.to_string(),
46 version: Some(Version::parse(env!("CARGO_PKG_VERSION")).unwrap()),
47 instance_id: uuid::Uuid::new_v4().as_bytes().to_vec(),
48 }
49 }
50}
51
52pub trait CallRequestHandler {
53 type Reply: Stream<Item = Result<ResponseChunk, Error>> + Unpin;
54
55 fn do_call(
56 &mut self,
57 request_id: String,
58 caller: String,
59 address: String,
60 data: Vec<u8>,
61 no_reply: bool,
62 ) -> Self::Reply;
63
64 fn handle_event(&mut self, caller: String, topic: String, data: Vec<u8>) {
65 log::warn!("unhandled gsb event from: {}, to: {}", caller, topic,);
66 log::trace!(
67 "unhandled gsb event data: {:?}",
68 String::from_utf8_lossy(data.as_ref())
69 )
70 }
71
72 fn on_disconnect(&mut self) {}
73}
74
75impl ResponseChunk {
76 #[inline]
77 fn reply_type(&self) -> CallReplyType {
78 match self {
79 ResponseChunk::Full(_) => CallReplyType::Full,
80 ResponseChunk::Part(_) => CallReplyType::Partial,
81 }
82 }
83
84 #[inline]
85 fn into_vec(self) -> Vec<u8> {
86 match self {
87 ResponseChunk::Full(v) => v,
88 ResponseChunk::Part(v) => v,
89 }
90 }
91}
92
93#[derive(Default)]
94pub struct LocalRouterHandler {
95 disconnect_h: Option<Box<dyn FnOnce()>>,
96}
97
98impl LocalRouterHandler {
99 pub fn new<F: FnOnce() + 'static>(disconnect_fn: F) -> Self {
100 Self {
101 disconnect_h: Some(Box::new(disconnect_fn)),
102 }
103 }
104}
105
106impl CallRequestHandler for LocalRouterHandler {
107 type Reply = Pin<Box<dyn futures::Stream<Item = Result<ResponseChunk, Error>>>>;
108
109 fn do_call(
110 &mut self,
111 _request_id: String,
112 caller: String,
113 address: String,
114 data: Vec<u8>,
115 no_reply: bool,
116 ) -> Self::Reply {
117 router()
118 .lock()
119 .unwrap()
120 .forward_bytes_local(&address, &caller, data.as_ref(), no_reply)
121 .boxed_local()
122 }
123
124 fn on_disconnect(&mut self) {
125 if let Some(f) = self.disconnect_h.take() {
126 f()
127 };
128 }
129}
130
131impl<
132 R: futures::Stream<Item = Result<ResponseChunk, Error>> + Unpin,
133 F: FnMut(String, String, String, Vec<u8>) -> R,
134 > CallRequestHandler for F
135{
136 type Reply = R;
137
138 fn do_call(
139 &mut self,
140 request_id: String,
141 caller: String,
142 address: String,
143 data: Vec<u8>,
144 _no_reply: bool,
145 ) -> Self::Reply {
146 self(request_id, caller, address, data)
147 }
148}
149
150impl<
151 R: futures::Stream<Item = Result<ResponseChunk, Error>> + Unpin,
152 F1: FnMut(String, String, String, Vec<u8>) -> R,
153 F2: FnMut(String, String, Vec<u8>),
154 > CallRequestHandler for (F1, F2)
155{
156 type Reply = R;
157
158 fn do_call(
159 &mut self,
160 request_id: String,
161 caller: String,
162 address: String,
163 data: Vec<u8>,
164 _no_reply: bool,
165 ) -> Self::Reply {
166 (self.0)(request_id, caller, address, data)
167 }
168
169 fn handle_event(&mut self, caller: String, topic: String, data: Vec<u8>) {
170 (self.1)(caller, topic, data)
171 }
172}
173
174type TransportWriter<W> = SinkWrite<GsbMessage, W>;
175type ReplyQueue = VecDeque<oneshot::Sender<Result<(), Error>>>;
176
177struct Connection<W, H>
178where
179 W: Sink<GsbMessage, Error = ProtocolError> + Unpin,
180 H: CallRequestHandler,
181{
182 writer: TransportWriter<W>,
183 register_reply: ReplyQueue,
184 unregister_reply: ReplyQueue,
185 subscribe_reply: ReplyQueue,
186 unsubscribe_reply: ReplyQueue,
187 call_reply: HashMap<String, mpsc::Sender<Result<ResponseChunk, Error>>>,
188 broadcast_reply: ReplyQueue,
189 handler: H,
190 client_info: ClientInfo,
191 server_info: Option<ya_sb_proto::Hello>,
192}
193
194impl<W, H> Unpin for Connection<W, H>
195where
196 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
197 H: CallRequestHandler + 'static,
198{
199}
200
201fn handle_reply<Ctx: ActorContext, F: FnOnce() -> Result<(), Error>>(
202 cmd_type: &str,
203 queue: &mut ReplyQueue,
204 ctx: &mut Ctx,
205 reply_msg: F,
206) {
207 if let Some(r) = queue.pop_front() {
208 let _ = r.send(reply_msg());
209 } else {
210 log::error!("unmatched {} reply", cmd_type);
211 ctx.stop()
212 }
213}
214
215impl<W, H> EmptyBufferHandler for Connection<W, H>
216where
217 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
218 H: CallRequestHandler + 'static,
219{
220}
221
222impl<W, H> Connection<W, H>
223where
224 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
225 H: CallRequestHandler + 'static,
226{
227 fn new(client_info: ClientInfo, w: W, handler: H, ctx: &mut <Self as Actor>::Context) -> Self {
228 Connection {
229 writer: SinkWrite::new(w, ctx),
230 register_reply: Default::default(),
231 unregister_reply: Default::default(),
232 subscribe_reply: Default::default(),
233 unsubscribe_reply: Default::default(),
234 call_reply: Default::default(),
235 broadcast_reply: Default::default(),
236 handler,
237 client_info,
238 server_info: Default::default(),
239 }
240 }
241
242 fn handle_unregister_reply(
243 &mut self,
244 code: UnregisterReplyCode,
245 ctx: &mut <Self as Actor>::Context,
246 ) {
247 handle_reply(
248 "unregister",
249 &mut self.unregister_reply,
250 ctx,
251 || match code {
252 UnregisterReplyCode::UnregisteredOk => Ok(()),
253 UnregisterReplyCode::NotRegistered => {
254 Err(Error::GsbBadRequest("unregister".to_string()))
255 }
256 },
257 )
258 }
259
260 fn handle_broadcast_reply(
261 &mut self,
262 code: BroadcastReplyCode,
263 msg: String,
264 ctx: &mut <Self as Actor>::Context,
265 ) {
266 handle_reply("broadcast", &mut self.broadcast_reply, ctx, || match code {
267 BroadcastReplyCode::BroadcastOk => Ok(()),
268 BroadcastReplyCode::BroadcastBadRequest => Err(Error::GsbBadRequest(msg)),
269 })
270 }
271
272 fn handle_register_reply(
273 &mut self,
274 code: RegisterReplyCode,
275 msg: String,
276 ctx: &mut <Self as Actor>::Context,
277 ) {
278 handle_reply("register", &mut self.register_reply, ctx, || match code {
279 RegisterReplyCode::RegisteredOk => Ok(()),
280 RegisterReplyCode::RegisterBadRequest => {
281 log::warn!("bad request: {}", msg);
282 Err(Error::GsbBadRequest(msg))
283 }
284 RegisterReplyCode::RegisterConflict => {
285 log::warn!("already registered: {}", msg);
286 Err(Error::GsbAlreadyRegistered(msg))
287 }
288 })
289 }
290
291 fn handle_subscribe_reply(
292 &mut self,
293 code: SubscribeReplyCode,
294 msg: String,
295 ctx: &mut <Self as Actor>::Context,
296 ) {
297 handle_reply("subscribe", &mut self.subscribe_reply, ctx, || match code {
298 SubscribeReplyCode::SubscribedOk => Ok(()),
299 SubscribeReplyCode::SubscribeBadRequest => {
300 log::warn!("bad request: {}", msg);
301 Err(Error::GsbBadRequest(msg))
302 }
303 })
304 }
305
306 fn handle_unsubscribe_reply(
307 &mut self,
308 code: UnsubscribeReplyCode,
309 ctx: &mut <Self as Actor>::Context,
310 ) {
311 handle_reply(
312 "unsubscribe",
313 &mut self.unsubscribe_reply,
314 ctx,
315 || match code {
316 UnsubscribeReplyCode::UnsubscribedOk => Ok(()),
317 UnsubscribeReplyCode::NotSubscribed => {
318 Err(Error::GsbBadRequest("unsubscribed".to_string()))
319 }
320 },
321 )
322 }
323
324 fn handle_call_request(
325 &mut self,
326 request_id: String,
327 caller: String,
328 address: String,
329 data: Vec<u8>,
330 ctx: &mut <Self as Actor>::Context,
331 ) {
332 log::trace!(
333 "handling rpc call from = {}, to = {}, request_id={}, ",
334 caller,
335 address,
336 request_id
337 );
338 let eos_request_id = request_id.clone();
339 let do_call = self
340 .handler
341 .do_call(request_id.clone(), caller, address, data, false)
342 .into_actor(self)
343 .fold(false, move |_got_eos, r, act: &mut Self, _ctx| {
344 let request_id = request_id.clone();
345 let (got_eos, reply) = match r {
346 Ok(data) => {
347 let code = CallReplyCode::CallReplyOk as i32;
348 let reply_type = data.reply_type() as i32;
349 (
350 reply_type == 0,
351 CallReply {
352 request_id,
353 code,
354 reply_type,
355 data: data.into_vec(),
356 },
357 )
358 }
359 Err(e) => {
360 let code = CallReplyCode::ServiceFailure as i32;
361 let reply_type = Default::default();
362 let data = format!("{}", e).into_bytes();
363 (
364 true,
365 CallReply {
366 request_id,
367 code,
368 reply_type,
369 data,
370 },
371 )
372 }
373 };
374 let _ = act.writer.write(GsbMessage::CallReply(reply));
376 fut::ready(got_eos)
377 })
378 .then(|got_eos, act, _ctx| {
379 if !got_eos {
380 let _ = act.writer.write(GsbMessage::CallReply(CallReply {
381 request_id: eos_request_id,
382 code: 0,
383 reply_type: 0,
384 data: Default::default(),
385 }));
386 }
387 fut::ready(())
388 });
389 ctx.spawn(do_call);
391 }
392
393 fn handle_push_request(
394 &mut self,
395 request_id: String,
396 caller: String,
397 address: String,
398 data: Vec<u8>,
399 ctx: &mut <Self as Actor>::Context,
400 ) {
401 log::trace!(
402 "handling push call from = {}, to = {}, request_id={}, ",
403 caller,
404 address,
405 request_id
406 );
407
408 self.handler
409 .do_call(request_id, caller, address, data, true)
410 .into_actor(self)
411 .fold((), move |_, _, _, _| fut::ready(()))
412 .spawn(ctx);
413 }
414
415 fn handle_reply(
416 &mut self,
417 request_id: String,
418 code: i32,
419 reply_type: i32,
420 data: Vec<u8>,
421 ctx: &mut <Self as Actor>::Context,
422 ) -> Result<(), Box<dyn std::error::Error>> {
423 log::trace!(
424 "handling reply for request_id={}, code={}, reply_type={}",
425 request_id,
426 code,
427 reply_type
428 );
429
430 let chunk = if reply_type == CallReplyType::Partial as i32 {
431 ResponseChunk::Part(data)
432 } else {
433 ResponseChunk::Full(data)
434 };
435
436 let is_full = chunk.is_full();
437
438 if let Some(r) = self.call_reply.get_mut(&request_id) {
439 let mut r = (*r).clone();
441 let code: CallReplyCode = code.try_into()?;
442 let item = match code {
443 CallReplyCode::CallReplyOk => Ok(chunk),
444 CallReplyCode::CallReplyBadRequest => {
445 Err(Error::GsbBadRequest(String::from_utf8(chunk.into_bytes())?))
446 }
447 CallReplyCode::ServiceFailure => {
448 Err(Error::GsbFailure(String::from_utf8(chunk.into_bytes())?))
449 }
450 };
451 ctx.wait(
452 async move {
453 let s = r.send(item);
454 s.await
455 .unwrap_or_else(|e| log::warn!("undelivered reply: {}", e))
456 }
457 .into_actor(self),
458 );
459 } else {
460 log::debug!("unmatched call reply");
461 }
462
463 if is_full {
464 let _ = self.call_reply.remove(&request_id);
465 }
466
467 Ok(())
468 }
469}
470
471impl<W, H> Actor for Connection<W, H>
472where
473 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
474 H: CallRequestHandler + 'static,
475{
476 type Context = Context<Self>;
477
478 fn started(&mut self, _ctx: &mut Self::Context) {
479 log::info!("started connection to gsb");
480 let hello: ya_sb_proto::Hello = ya_sb_proto::Hello {
481 name: self.client_info.name.clone(),
482 version: self
483 .client_info
484 .version
485 .as_ref()
486 .map(|v| v.to_string())
487 .unwrap_or_default(),
488 instance_id: self.client_info.instance_id.clone(),
489 };
490
491 let _ = self.writer.write(GsbMessage::Hello(hello));
492 }
493
494 fn stopped(&mut self, _ctx: &mut Self::Context) {
495 log::info!("stopped connection to gsb");
496 self.handler.on_disconnect();
497 }
498}
499
500fn register_reply_code(code: i32) -> Option<RegisterReplyCode> {
501 Some(match code {
502 0 => RegisterReplyCode::RegisteredOk,
503 400 => RegisterReplyCode::RegisterBadRequest,
504 409 => RegisterReplyCode::RegisterConflict,
505 _ => return None,
506 })
507}
508
509fn unregister_reply_code(code: i32) -> Option<UnregisterReplyCode> {
510 Some(match code {
511 0 => UnregisterReplyCode::UnregisteredOk,
512 404 => UnregisterReplyCode::NotRegistered,
513 _ => return None,
514 })
515}
516
517fn subscribe_reply_code(code: i32) -> Option<SubscribeReplyCode> {
518 Some(match code {
519 0 => SubscribeReplyCode::SubscribedOk,
520 400 => SubscribeReplyCode::SubscribeBadRequest,
521 _ => return None,
522 })
523}
524
525fn unsubscribe_reply_code(code: i32) -> Option<UnsubscribeReplyCode> {
526 Some(match code {
527 0 => UnsubscribeReplyCode::UnsubscribedOk,
528 404 => UnsubscribeReplyCode::NotSubscribed,
529 _ => return None,
530 })
531}
532
533fn broadcast_reply_code(code: i32) -> Option<BroadcastReplyCode> {
534 Some(match code {
535 0 => BroadcastReplyCode::BroadcastOk,
536 400 => BroadcastReplyCode::BroadcastBadRequest,
537 _ => return None,
538 })
539}
540
541impl<W, H> StreamHandler<Result<GsbMessage, ProtocolError>> for Connection<W, H>
542where
543 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
544 H: CallRequestHandler + 'static,
545{
546 fn handle(&mut self, item: Result<GsbMessage, ProtocolError>, ctx: &mut Self::Context) {
547 if let Err(e) = item.as_ref() {
548 log::error!("protocol error {}", e);
549 ctx.stop();
550 return;
551 }
552
553 match item.unwrap() {
554 GsbMessage::RegisterReply(r) => {
555 if let Some(code) = register_reply_code(r.code) {
556 self.handle_register_reply(code, r.message, ctx)
557 } else {
558 log::error!("invalid reply code {}", r.code);
559 ctx.stop();
560 }
561 }
562 GsbMessage::UnregisterReply(r) => {
563 if let Some(code) = unregister_reply_code(r.code) {
564 self.handle_unregister_reply(code, ctx)
565 } else {
566 log::error!("invalid unregister reply code {}", r.code);
567 ctx.stop();
568 }
569 }
570 GsbMessage::SubscribeReply(r) => {
571 if let Some(code) = subscribe_reply_code(r.code) {
572 self.handle_subscribe_reply(code, r.message, ctx)
573 } else {
574 log::error!("invalid reply code {}", r.code);
575 ctx.stop();
576 }
577 }
578 GsbMessage::UnsubscribeReply(r) => {
579 if let Some(code) = unsubscribe_reply_code(r.code) {
580 self.handle_unsubscribe_reply(code, ctx)
581 } else {
582 log::error!("invalid unsubscribe reply code {}", r.code);
583 ctx.stop();
584 }
585 }
586 GsbMessage::BroadcastReply(r) => {
587 if let Some(code) = broadcast_reply_code(r.code) {
588 self.handle_broadcast_reply(code, r.message, ctx)
589 } else {
590 log::error!("invalid broadcast reply code {}", r.code);
591 ctx.stop();
592 }
593 }
594 GsbMessage::CallRequest(r) => {
595 if r.no_reply {
596 self.handle_push_request(r.request_id, r.caller, r.address, r.data, ctx)
597 } else {
598 self.handle_call_request(r.request_id, r.caller, r.address, r.data, ctx)
599 }
600 }
601 GsbMessage::CallReply(r) => {
602 if let Err(e) = self.handle_reply(r.request_id, r.code, r.reply_type, r.data, ctx) {
603 log::error!("error on call reply processing: {}", e);
604 ctx.stop();
605 }
606 }
607 GsbMessage::BroadcastRequest(r) => {
608 self.handler.handle_event(r.caller, r.topic, r.data);
609 }
610 GsbMessage::Ping(_) => {
611 if self.writer.write(GsbMessage::pong()).is_some() {
612 log::error!("error sending pong");
613 ctx.stop();
614 }
615 }
616 GsbMessage::Hello(h) => {
617 log::debug!("connected with server: {}/{}", h.name, h.version);
618 if self.server_info.is_some() {
619 log::error!("invalid packet: {:?}", h);
620 ctx.stop();
621 } else {
622 self.server_info = Some(h);
623 }
624 }
625 m => {
626 log::error!("unexpected gsb message: {:?}", m);
627 ctx.stop();
628 }
629 }
630 }
631}
632
633impl<W, H> io::WriteHandler<ProtocolError> for Connection<W, H>
634where
635 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
636 H: CallRequestHandler + 'static,
637{
638 fn error(&mut self, err: ProtocolError, _ctx: &mut Self::Context) -> Running {
639 log::error!("protocol error: {}", err);
640 Running::Stop
641 }
642}
643
644impl<W, H> Handler<RpcRawCall> for Connection<W, H>
645where
646 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
647 H: CallRequestHandler + 'static,
648{
649 type Result = ActorResponse<Self, Result<Vec<u8>, Error>>;
650
651 fn handle(&mut self, msg: RpcRawCall, _ctx: &mut Self::Context) -> Self::Result {
652 let request_id = format!("{}", gen_id());
653 let caller = msg.caller;
654 let address = msg.addr;
655 let data = msg.body;
656 let no_reply = msg.no_reply;
657
658 let rx = if no_reply {
659 None
660 } else {
661 let (tx, rx) = mpsc::channel(1);
662 let _ = self.call_reply.insert(request_id.clone(), tx);
663 Some(rx)
664 };
665
666 log::trace!("handling caller (rpc): {}, addr:{}", caller, address);
667 let _r = self.writer.write(GsbMessage::CallRequest(CallRequest {
668 request_id,
669 caller,
670 address,
671 data,
672 no_reply,
673 }));
674
675 match rx {
676 Some(mut rx) => {
677 let fetch_response = async move {
678 match futures::StreamExt::next(&mut rx).await {
679 Some(Ok(ResponseChunk::Full(data))) => Ok(data),
680 Some(Err(e)) => Err(e),
681 Some(Ok(ResponseChunk::Part(_))) => {
682 Err(Error::GsbFailure("streaming response".to_string()))
683 }
684 None => Err(Error::GsbFailure("unexpected EOS".to_string())),
685 }
686 };
687 ActorResponse::r#async(fetch_response.into_actor(self))
688 }
689 None => ActorResponse::reply(Ok(Vec::new())),
690 }
691 }
692}
693
694impl<W, H> Handler<RpcRawStreamCall> for Connection<W, H>
695where
696 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
697 H: CallRequestHandler + 'static,
698{
699 type Result = ActorResponse<Self, Result<(), Error>>;
700
701 fn handle(&mut self, msg: RpcRawStreamCall, _ctx: &mut Self::Context) -> Self::Result {
702 let request_id = format!("{}", gen_id());
703 let rx = msg.reply;
704 let _ = self.call_reply.insert(request_id.clone(), rx);
705 let caller = msg.caller;
706 let address = msg.addr;
707 let data = msg.body;
708 log::trace!("handling caller (stream): {}, addr:{}", caller, address);
709 let _r = self.writer.write(GsbMessage::CallRequest(CallRequest {
710 request_id,
711 caller,
712 address,
713 data,
714 no_reply: false,
715 }));
716 ActorResponse::reply(Ok(()))
717 }
718}
719
720fn send_cmd_async<A: Actor, W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static>(
721 writer: &mut TransportWriter<W>,
722 queue: &mut VecDeque<oneshot::Sender<Result<(), Error>>>,
723 msg: GsbMessage,
724) -> ActorResponse<A, Result<(), Error>> {
725 let (tx, rx) = oneshot::channel();
726 queue.push_back(tx);
727
728 if writer.write(msg).is_some() {
729 ActorResponse::reply(Err(Error::GsbFailure("no connection".into())))
730 } else {
731 ActorResponse::r#async(fut::wrap_future(async move {
732 rx.await.map_err(|_| Error::Cancelled)??;
733 Ok(())
734 }))
735 }
736}
737
738struct Bind {
739 addr: String,
740}
741
742impl Message for Bind {
743 type Result = Result<(), Error>;
744}
745
746impl<W, H> Handler<Bind> for Connection<W, H>
747where
748 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
749 H: CallRequestHandler + 'static,
750{
751 type Result = ActorResponse<Self, Result<(), Error>>;
752
753 fn handle(&mut self, msg: Bind, _ctx: &mut Self::Context) -> Self::Result {
754 let service_id = msg.addr;
755 send_cmd_async(
756 &mut self.writer,
757 &mut self.register_reply,
758 GsbMessage::RegisterRequest(RegisterRequest { service_id }),
759 )
760 }
761}
762
763struct Unbind {
764 addr: String,
765}
766
767impl Message for Unbind {
768 type Result = Result<(), Error>;
769}
770
771impl<W, H> Handler<Unbind> for Connection<W, H>
772where
773 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
774 H: CallRequestHandler + 'static,
775{
776 type Result = ActorResponse<Self, Result<(), Error>>;
777
778 fn handle(&mut self, msg: Unbind, _ctx: &mut Self::Context) -> Self::Result {
779 let service_id = msg.addr;
780 send_cmd_async(
781 &mut self.writer,
782 &mut self.unregister_reply,
783 GsbMessage::UnregisterRequest(UnregisterRequest { service_id }),
784 )
785 }
786}
787
788struct Subscribe {
789 topic: String,
790}
791
792impl Message for Subscribe {
793 type Result = Result<(), Error>;
794}
795
796impl<W, H> Handler<Subscribe> for Connection<W, H>
797where
798 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
799 H: CallRequestHandler + 'static,
800{
801 type Result = ActorResponse<Self, Result<(), Error>>;
802
803 fn handle(&mut self, msg: Subscribe, _ctx: &mut Self::Context) -> Self::Result {
804 let topic = msg.topic;
805 send_cmd_async(
806 &mut self.writer,
807 &mut self.subscribe_reply,
808 GsbMessage::SubscribeRequest(SubscribeRequest { topic }),
809 )
810 }
811}
812
813struct Unsubscribe {
814 topic: String,
815}
816
817impl Message for Unsubscribe {
818 type Result = Result<(), Error>;
819}
820
821impl<W, H> Handler<Unsubscribe> for Connection<W, H>
822where
823 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
824 H: CallRequestHandler + 'static,
825{
826 type Result = ActorResponse<Self, Result<(), Error>>;
827
828 fn handle(&mut self, msg: Unsubscribe, _ctx: &mut Self::Context) -> Self::Result {
829 let topic = msg.topic;
830 send_cmd_async(
831 &mut self.writer,
832 &mut self.unsubscribe_reply,
833 GsbMessage::UnsubscribeRequest(UnsubscribeRequest { topic }),
834 )
835 }
836}
837
838pub struct BcastCall {
839 pub caller: String,
840 pub topic: String,
841 pub body: Vec<u8>,
842}
843
844impl Message for BcastCall {
845 type Result = Result<(), Error>;
846}
847
848impl<W, H> Handler<BcastCall> for Connection<W, H>
849where
850 W: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
851 H: CallRequestHandler + 'static,
852{
853 type Result = ActorResponse<Self, Result<(), Error>>;
854
855 fn handle(&mut self, msg: BcastCall, _ctx: &mut Self::Context) -> Self::Result {
856 let caller = msg.caller;
857 let topic = msg.topic;
858 let data = msg.body;
859 send_cmd_async(
860 &mut self.writer,
861 &mut self.broadcast_reply,
862 GsbMessage::BroadcastRequest(BroadcastRequest {
863 caller,
864 topic,
865 data,
866 }),
867 )
868 }
869}
870
871pub struct ConnectionRef<
872 Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
873 H: CallRequestHandler + 'static,
874>(Addr<Connection<SplitSink<Transport, GsbMessage>, H>>);
875
876impl<
877 Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
878 H: CallRequestHandler + 'static,
879 > Unpin for ConnectionRef<Transport, H>
880{
881}
882
883impl<
884 Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
885 H: CallRequestHandler + 'static,
886 > Clone for ConnectionRef<Transport, H>
887{
888 fn clone(&self) -> Self {
889 ConnectionRef(self.0.clone())
890 }
891}
892
893impl<
894 Transport: Sink<GsbMessage, Error = ProtocolError> + Unpin + 'static,
895 H: CallRequestHandler + Unpin + 'static,
896 > ConnectionRef<Transport, H>
897{
898 pub fn bind(
899 &self,
900 addr: impl Into<String>,
901 ) -> impl Future<Output = Result<(), Error>> + 'static {
902 let addr = addr.into();
903 log::trace!("Binding remote service '{}'", addr);
904 self.0.send(Bind { addr: addr.clone() }).then(|v| async {
905 log::trace!("send bind result: {:?}", v);
906 v.map_err(|e| Error::from_addr(addr, e))?
907 })
908 }
909
910 pub fn unbind(
911 &self,
912 addr: impl Into<String>,
913 ) -> impl Future<Output = Result<(), Error>> + 'static {
914 let addr = addr.into();
915 self.0.send(Unbind { addr: addr.clone() }).then(|v| async {
916 log::trace!("send unbind result: {:?}", v);
917 v.map_err(|e| Error::from_addr(addr, e))?
918 })
919 }
920
921 pub fn subscribe(
922 &self,
923 topic: impl Into<String>,
924 ) -> impl Future<Output = Result<(), Error>> + 'static {
925 let topic = topic.into();
926 let fut = self.0.send(Subscribe {
927 topic: topic.clone(),
928 });
929 async move {
930 fut.await
931 .map_err(|e| Error::from_addr(format!("subscribing {}", topic), e))?
932 }
933 }
934
935 pub fn unsubscribe(
936 &self,
937 topic: impl Into<String>,
938 ) -> impl Future<Output = Result<(), Error>> + 'static {
939 let topic = topic.into();
940 let fut = self.0.send(Unsubscribe {
941 topic: topic.clone(),
942 });
943 async move {
944 fut.await
945 .map_err(|e| Error::from_addr(format!("unsubscribing {}", topic), e))?
946 }
947 }
948
949 pub fn broadcast(
950 &self,
951 caller: impl Into<String>,
952 topic: impl Into<String>,
953 body: Vec<u8>,
954 ) -> impl Future<Output = Result<(), Error>> + 'static {
955 let topic = topic.into();
956 let fut = self.0.send(BcastCall {
957 caller: caller.into(),
958 topic: topic.clone(),
959 body,
960 });
961 async move {
962 fut.await
963 .map_err(|e| Error::from_addr(format!("broadcasting {}", topic), e))?
964 }
965 }
966
967 pub fn call(
968 &self,
969 caller: impl Into<String>,
970 addr: impl Into<String>,
971 body: impl Into<Vec<u8>>,
972 no_reply: bool,
973 ) -> impl Future<Output = Result<Vec<u8>, Error>> {
974 let addr = addr.into();
975 self.0
976 .send(RpcRawCall {
977 caller: caller.into(),
978 addr: addr.clone(),
979 body: body.into(),
980 no_reply,
981 })
982 .then(|v| async { v.map_err(|e| Error::from_addr(addr, e))? })
983 }
984
985 pub fn call_streaming(
986 &self,
987 caller: impl Into<String>,
988 addr: impl Into<String>,
989 body: impl Into<Vec<u8>>,
990 ) -> impl Stream<Item = Result<ResponseChunk, Error>> {
991 let addr = addr.into();
992 let (tx, rx) = futures::channel::mpsc::channel(16);
993
994 let args = RpcRawStreamCall {
995 caller: caller.into(),
996 addr: addr.clone(),
997 body: body.into(),
998 reply: tx.clone(),
999 };
1000 let connection = self.0.clone();
1001 let _ = Arbiter::current().spawn(async move {
1002 let mut tx = tx;
1003 match connection.send(args).await {
1004 Ok(Ok(())) => (),
1005 Ok(Err(e)) => {
1006 tx.send(Err(e))
1007 .await
1008 .unwrap_or_else(|e| log::error!("fail: {}", e));
1009 }
1010 Err(e) => {
1011 tx.send(Err(Error::from_addr(addr, e)))
1012 .await
1013 .unwrap_or_else(|e| log::error!("fail: {}", e));
1014 }
1015 }
1016 });
1017 rx
1018 }
1019
1020 pub fn connected(&self) -> bool {
1021 self.0.connected()
1022 }
1023}
1024
1025pub fn connect<Transport, H>(
1026 client_info: ClientInfo,
1027 transport: Transport,
1028) -> ConnectionRef<Transport, H>
1029where
1030 Transport: Sink<GsbMessage, Error = ProtocolError>
1031 + Stream<Item = Result<GsbMessage, ProtocolError>>
1032 + Unpin
1033 + 'static,
1034 H: CallRequestHandler + 'static + Default + Unpin,
1035{
1036 connect_with_handler(client_info, transport, Default::default())
1037}
1038
1039pub fn connect_with_handler<Transport, H>(
1040 client_info: ClientInfo,
1041 transport: Transport,
1042 handler: H,
1043) -> ConnectionRef<Transport, H>
1044where
1045 Transport: Sink<GsbMessage, Error = ProtocolError>
1046 + Stream<Item = Result<GsbMessage, ProtocolError>>
1047 + Unpin
1048 + 'static,
1049 H: CallRequestHandler + 'static,
1050{
1051 let (split_sink, split_stream) = transport.split();
1052 ConnectionRef(Connection::create(move |ctx| {
1053 let _h = Connection::add_stream(split_stream, ctx);
1054 Connection::new(client_info, split_sink, handler, ctx)
1055 }))
1056}
1057
1058pub type TcpTransport =
1059 tokio_util::codec::Framed<tokio::net::TcpStream, ya_sb_proto::codec::GsbMessageCodec>;
1060
1061pub async fn tcp(addr: impl tokio::net::ToSocketAddrs) -> Result<TcpTransport, std::io::Error> {
1062 let s = tokio::net::TcpStream::connect(addr).await?;
1063 Ok(tokio_util::codec::Framed::new(
1064 s,
1065 ya_sb_proto::codec::GsbMessageCodec::default(),
1066 ))
1067}
1068
1069#[cfg(feature = "tls")]
1070mod tls {
1071 use rustls::pki_types::ServerName;
1072 use std::sync::Arc;
1073 use tokio::net::TcpStream;
1074 use tokio_rustls::client::TlsStream;
1075 use ya_sb_proto::codec::GsbMessageCodec;
1076 pub use ya_sb_util::tls::CertHash;
1077 use ya_sb_util::tls::HashVerifier;
1078
1079 pub type TlsTransport = tokio_util::codec::Framed<TlsStream<TcpStream>, GsbMessageCodec>;
1080
1081 pub async fn tls(
1082 addr: impl tokio::net::ToSocketAddrs,
1083 cert_hash: CertHash,
1084 ) -> Result<TlsTransport, std::io::Error> {
1085 let v = Arc::new(HashVerifier::new(cert_hash));
1086
1087 let connector = tokio_rustls::TlsConnector::from(Arc::new(
1088 rustls::ClientConfig::builder()
1089 .dangerous()
1090 .with_custom_certificate_verifier(v)
1091 .with_no_client_auth(),
1092 ));
1093 let sock = TcpStream::connect(&addr).await?;
1094 let io = connector
1095 .connect(ServerName::IpAddress(sock.peer_addr()?.ip().into()), sock)
1096 .await?;
1097 let framed = tokio_util::codec::Framed::new(io, GsbMessageCodec::default());
1098
1099 Ok(framed)
1100 }
1101}
1102
1103#[cfg(feature = "tls")]
1104pub use tls::*;
1105
1106#[cfg(unix)]
1107mod unix {
1108
1109 use super::*;
1110 use std::task::Poll;
1111
1112 pub type UnixTransport =
1113 tokio_util::codec::Framed<tokio::net::UnixStream, ya_sb_proto::codec::GsbMessageCodec>;
1114
1115 pub async fn unix<P>(path: P) -> Result<UnixTransport, std::io::Error>
1116 where
1117 P: AsRef<std::path::Path>,
1118 {
1119 let s = tokio::net::UnixStream::connect(path).await?;
1120 Ok(tokio_util::codec::Framed::new(
1121 s,
1122 ya_sb_proto::codec::GsbMessageCodec::default(),
1123 ))
1124 }
1125
1126 trait ITransport:
1128 Sink<GsbMessage, Error = ProtocolError>
1129 + Stream<Item = Result<GsbMessage, ProtocolError>>
1130 + Unpin
1131 + 'static
1132 {
1133 }
1134
1135 impl ITransport for TcpTransport {}
1136 impl ITransport for UnixTransport {}
1137
1138 pub enum Transport {
1139 Tcp(TcpTransport),
1140 Unix(UnixTransport),
1141 }
1142
1143 impl Transport {
1144 fn inner(self: Pin<&mut Self>) -> Pin<&mut (dyn ITransport)> {
1145 match self.get_mut() {
1146 Transport::Tcp(tcp_transport) => Pin::new(tcp_transport),
1147 Transport::Unix(unix_transport) => Pin::new(unix_transport),
1148 }
1149 }
1150 }
1151
1152 impl Sink<GsbMessage> for Transport {
1153 type Error = ProtocolError;
1154
1155 fn poll_ready(
1156 self: Pin<&mut Self>,
1157 cx: &mut core::task::Context<'_>,
1158 ) -> Poll<Result<(), Self::Error>> {
1159 self.inner().poll_ready(cx)
1160 }
1161
1162 fn start_send(self: Pin<&mut Self>, item: GsbMessage) -> Result<(), Self::Error> {
1163 self.inner().start_send(item)
1164 }
1165
1166 fn poll_flush(
1167 self: Pin<&mut Self>,
1168 cx: &mut core::task::Context<'_>,
1169 ) -> Poll<Result<(), Self::Error>> {
1170 self.inner().poll_flush(cx)
1171 }
1172
1173 fn poll_close(
1174 self: Pin<&mut Self>,
1175 cx: &mut core::task::Context<'_>,
1176 ) -> Poll<Result<(), Self::Error>> {
1177 self.inner().poll_close(cx)
1178 }
1179 }
1180
1181 impl Stream for Transport {
1182 type Item = Result<GsbMessage, ProtocolError>;
1183
1184 fn poll_next(
1185 self: Pin<&mut Self>,
1186 cx: &mut core::task::Context<'_>,
1187 ) -> Poll<Option<Self::Item>> {
1188 self.inner().poll_next(cx)
1189 }
1190 }
1191
1192 impl Unpin for Transport {}
1193
1194 pub async fn transport(addr: ya_sb_proto::GsbAddr) -> Result<Transport, std::io::Error> {
1195 match addr {
1196 ya_sb_proto::GsbAddr::Tcp(addr) => Ok(Transport::Tcp(tcp(addr).await?)),
1197 ya_sb_proto::GsbAddr::Unix(path) => Ok(Transport::Unix(unix(path).await?)),
1198 }
1199 }
1200}
1201
1202#[cfg(unix)]
1203pub use unix::*;
1204
1205#[cfg(not(unix))]
1206pub type Transport = TcpTransport;
1207
1208#[cfg(not(unix))]
1209pub async fn transport(addr: ya_sb_proto::GsbAddr) -> Result<TcpTransport, std::io::Error> {
1210 match addr {
1211 ya_sb_proto::GsbAddr::Tcp(addr) => Ok(tcp(addr).await?),
1212 ya_sb_proto::GsbAddr::Unix(_) => panic!("Unix sockets not supported on this OS"),
1213 }
1214}