1use crate::channel::TResult;
2use crate::connection::pool::ConnectionPools;
3use crate::connection::{ConnectionResult, FrameInput, FrameOutput};
4use crate::defragmentation::ResponseDefragmenter;
5use crate::errors::{CodecError, ConnectionError, HandlerError, TChannelError};
6use crate::fragmentation::RequestFragmenter;
7use crate::frames::TFrameStream;
8use crate::handler::{
9 HandlerResult, MessageArgsHandler, RequestHandler, RequestHandlerAdapter, RequestHandlerAsync,
10 RequestHandlerAsyncAdapter,
11};
12use crate::messages::args::{MessageArgs, MessageArgsResponse, ResponseCode};
13use crate::messages::Message;
14use futures::StreamExt;
15use futures::{future, TryStreamExt};
16use log::{debug, error};
17use std::collections::HashMap;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::sync::Arc;
20use tokio::sync::{Mutex, RwLock};
21
22type HandlerRef = Arc<Mutex<Box<dyn MessageArgsHandler>>>;
23
24#[derive(Debug, new)]
28pub struct SubChannel {
29 service_name: String,
30 connection_pools: Arc<ConnectionPools>,
31 #[new(default)]
32 handlers: RwLock<HashMap<String, HandlerRef>>,
33}
34
35impl SubChannel {
36 pub(super) async fn send<REQ: Message, RES: Message, ADDR: ToSocketAddrs>(
37 &self,
38 request: REQ,
39 host: ADDR,
40 ) -> HandlerResult<RES> {
41 let (frames_in, frames_out) = self.create_frame_io(host).await?;
42 let response_res = self.send_internal(request, frames_in, &frames_out).await;
43 frames_out.close().await; match response_res {
45 Ok((code, response)) => match code {
46 ResponseCode::Ok => Ok(response),
47 ResponseCode::Error => Err(HandlerError::MessageError(response)),
48 },
49 Err(err) => Err(HandlerError::InternalError(err)),
50 }
51 }
52
53 async fn create_frame_io<ADDR: ToSocketAddrs>(
54 &self,
55 host: ADDR,
56 ) -> TResult<(FrameInput, FrameOutput)> {
57 let host = first_addr(host)?;
58 self.connect(host).await
59 }
60
61 pub(super) async fn send_internal<REQ: Message, RES: Message>(
62 &self,
63 request: REQ,
64 frames_in: FrameInput,
65 frames_out: &FrameOutput,
66 ) -> TResult<(ResponseCode, RES)> {
67 let frames = self.create_frames(request).await?;
68 send_frames(frames, frames_out).await?;
69 let response = ResponseDefragmenter::new(frames_in)
70 .read_response_msg()
71 .await;
72 frames_out.close().await; response
74 }
75
76 pub async fn register<REQ, RES, HANDLER>(
78 &self,
79 endpoint: impl AsRef<str>,
80 request_handler: HANDLER,
81 ) -> TResult<()>
82 where
83 REQ: Message + 'static,
84 RES: Message + 'static,
85 HANDLER: RequestHandler<REQ = REQ, RES = RES> + 'static,
86 {
87 let handler_adapter = RequestHandlerAdapter::new(request_handler);
88 self.register_handler(endpoint, Arc::new(Mutex::new(Box::new(handler_adapter))))
89 .await
90 }
91
92 pub async fn register_async<REQ, RES, HANDLER>(
94 &self,
95 endpoint: impl AsRef<str>,
96 request_handler: HANDLER,
97 ) -> TResult<()>
98 where
99 REQ: Message + 'static,
100 RES: Message + 'static,
101 HANDLER: RequestHandlerAsync<REQ = REQ, RES = RES> + 'static,
102 {
103 let handler_adapter = RequestHandlerAsyncAdapter::new(request_handler);
104 self.register_handler(endpoint, Arc::new(Mutex::new(Box::new(handler_adapter))))
105 .await
106 }
107
108 pub async fn unregister(&mut self, endpoint: impl AsRef<str>) -> TResult<()> {
110 let mut handlers = self.handlers.write().await;
111 match handlers.remove(endpoint.as_ref()) {
112 Some(_) => Ok(()),
113 None => Err(TChannelError::Error(format!(
114 "Handler '{}' is missing.",
115 endpoint.as_ref()
116 ))),
117 }
118 }
119
120 async fn register_handler(
121 &self,
122 endpoint: impl AsRef<str>,
123 request_handler: HandlerRef,
124 ) -> TResult<()> {
125 let mut handlers = self.handlers.write().await;
126 if handlers.contains_key(endpoint.as_ref()) {
127 return Err(TChannelError::Error(format!(
128 "Handler already registered for '{}'",
129 endpoint.as_ref()
130 )));
131 }
132 handlers.insert(endpoint.as_ref().to_string(), request_handler);
133 Ok(()) }
135
136 async fn connect(&self, host: SocketAddr) -> TResult<(FrameInput, FrameOutput)> {
137 let pool = self.connection_pools.get(host).await?;
138 let connection = pool.get().await?;
139 Ok(connection.new_frames_io().await?)
140 }
141
142 async fn create_frames<REQ: Message>(&self, request: REQ) -> TResult<TFrameStream> {
143 let message_args = request.try_into()?;
144 RequestFragmenter::new(self.service_name.clone(), message_args).create_frames()
145 }
146
147 pub(crate) async fn handle(&self, request: MessageArgs) -> MessageArgsResponse {
148 let endpoint = Self::read_endpoint_name(&request)?;
149 let handler_locked = self.get_handler(endpoint).await?;
150 let mut handler = handler_locked.lock().await; handler.handle(request).await
152 }
153
154 async fn get_handler(&self, endpoint: String) -> TResult<HandlerRef> {
155 let handlers = self.handlers.read().await;
156 match handlers.get(&endpoint) {
157 Some(handler) => Ok(handler.clone()),
158 None => Err(TChannelError::Error(format!(
159 "No handler with name '{}'.",
160 endpoint
161 ))),
162 }
163 }
164
165 fn read_endpoint_name(request: &MessageArgs) -> Result<String, CodecError> {
166 match request.args.get(0) {
167 Some(arg) => Ok(String::from_utf8(arg.to_vec())?),
168 None => Err(CodecError::Error("Missing arg1/endpoint name".to_string())),
169 }
170 }
171}
172
173fn first_addr<ADDR: ToSocketAddrs>(addr: ADDR) -> ConnectionResult<SocketAddr> {
174 let mut addrs = addr.to_socket_addrs()?;
175 if let Some(addr) = addrs.next() {
176 return Ok(addr);
177 }
178 Err(ConnectionError::Error(
179 "Unable to get host addr".to_string(),
180 ))
181}
182
183async fn send_frames(frames: TFrameStream, frames_out: &FrameOutput) -> ConnectionResult<()> {
184 debug!("Sending frames");
185 frames
186 .then(|frame| frames_out.send(frame))
187 .inspect_err(|err| error!("Failed to send frame {:?}", err))
188 .try_for_each(|_res| future::ready(Ok(())))
189 .await
190}