1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
use ::async_std::sync::Mutex;
use ::async_std::task;
use cfg_if::cfg_if;
use futures::channel::oneshot;
use futures::{AsyncRead, AsyncWrite};
/// This module implements the traits/methods that require `async-std`
/// runtime for the RPC client. The module is enabled if either
/// `feature = "async_std_runtime"` or `featue = "http_tide"` is true.
use std::sync::Arc;

use super::*;

type Codec = Arc<Mutex<Box<dyn ClientCodec>>>;
type ResponseMap = HashMap<u16, oneshot::Sender<Result<ResponseBody, ResponseBody>>>;

/// RPC Client. Unlike [`Server`](../../server/struct.Server.html), the `Client`
/// struct contains field that uses runtime dependent synchronization primitives,
/// thus there is a separate 'Client' struct defined for each of the `async-std`
/// and `tokio` runtime.
pub struct Client<Mode> {
    count: AtomicMessageId,
    inner_codec: Codec,
    pending: Arc<Mutex<ResponseMap>>,

    mode: PhantomData<Mode>,
}

cfg_if! {
    if #[cfg(any(
        all(
            feature = "serde_bincode",
            not(feature = "serde_json"),
            not(feature = "serde_cbor"),
            not(feature = "serde_rmp"),
        ),
        all(
            feature = "serde_cbor",
            not(feature = "serde_json"),
            not(feature = "serde_bincode"),
            not(feature = "serde_rmp"),
        ),
        all(
            feature = "serde_json",
            not(feature = "serde_bincode"),
            not(feature = "serde_cbor"),
            not(feature = "serde_rmp"),
        ),
        all(
            feature = "serde_rmp",
            not(feature = "serde_cbor"),
            not(feature = "serde_json"),
            not(feature = "serde_bincode"),
        )
    ))] {
        use ::async_std::net::{TcpStream, ToSocketAddrs};
        use async_tungstenite::async_std::connect_async;
        use crate::transport::ws::WebSocketConn;
        use crate::server::DEFAULT_RPC_PATH;

        /// The following impl block is controlled by feature flag. It is enabled
        /// if and only if **exactly one** of the the following feature flag is turned on
        /// - `serde_bincode`
        /// - `serde_json`
        /// - `serde_cbor`
        /// - `serde_rmp`
        impl Client<NotConnected> {
            /// Connects the an RPC server over socket at the specified network address
            ///
            /// This is enabled
            /// if and only if **exactly one** of the the following feature flag is turned on
            /// - `serde_bincode`
            /// - `serde_json`
            /// - `serde_cbor`
            /// - `serde_rmp`
            ///
            /// Example
            ///
            /// ```rust
            /// use toy_rpc::Client;
            ///
            /// #[async_std::main]
            /// async fn main() {
            ///     let addr = "127.0.0.1";
            ///     let client = Client::dial(addr).await;
            /// }
            ///
            /// ```
            pub async fn dial(addr: impl ToSocketAddrs) -> Result<Client<Connected>, Error> {
                let stream = TcpStream::connect(addr).await?;

                Ok(Self::with_stream(stream))
            }

            /// Similar to `dial`, this connects to an WebSocket RPC server at the specified network address using the defatul codec
            ///
            /// This is enabled
            /// if and only if **exactly one** of the the following feature flag is turned on
            /// - `serde_bincode`
            /// - `serde_json`
            /// - `serde_cbor`
            /// - `serde_rmp`
            ///
            /// # Example
            ///
            /// ```rust
            /// use toy_rpc::client::Client;
            ///
            /// #[async_std::main]
            /// async fn main() {
            ///     let addr = "ws://127.0.0.1:8080";
            ///     let client = Client::dial_http(addr).await.unwrap();
            /// }
            /// ```
            ///
            pub async fn dial_websocket(addr: &'static str) -> Result<Client<Connected>, Error> {
                let url = url::Url::parse(addr)?;
                Self::_dial_websocket(url).await
            }

            async fn _dial_websocket(url: url::Url) -> Result<Client<Connected>, Error> {
                let (ws_stream, _) = connect_async(&url).await?;

                let ws_stream = WebSocketConn::new(ws_stream);
                let codec = DefaultCodec::with_websocket(ws_stream);

                Ok(Self::with_codec(codec))
            }

            /// Connects to an HTTP RPC server at the specified network address using WebSocket and the defatul codec.
            ///
            /// It is recommended to use "ws://" as the url scheme as opposed to "http://"; however, internally the url scheme
            /// is changed to "ws://". Internally, `DEFAULT_RPC_PATH="_rpc"` is appended to the end of `addr`,
            /// and the rest is the same is calling `dial_websocket`.
            /// If a network path were to be supplpied, the network path must end with a slash "/".
            /// For example, a valid path could be "ws://127.0.0.1/rpc/".
            ///
            /// *Warning*: WebSocket is used as the underlying transport protocol starting from version "0.5.0-beta.0",
            /// and this will make client of versions later than "0.5.0-beta.0" incompatible with servers of versions
            /// earlier than "0.5.0-beta.0".
            ///
            /// This is enabled
            /// if and only if **only one** of the the following feature flag is turned on
            /// - `serde_bincode`
            /// - `serde_json`
            /// - `serde_cbor`
            /// - `serde_rmp`
            ///
            /// # Example
            ///
            /// ```rust
            /// use toy_rpc::Client;
            ///
            /// #[async_std::main]
            /// async fn main() {
            ///     let addr = "ws://127.0.0.1:8080/rpc/";
            ///     let client = Client::dial_http(addr).await.unwrap();
            /// }
            /// ```
            ///
            pub async fn dial_http(addr: &'static str) -> Result<Client<Connected>, Error> {
                let mut url = url::Url::parse(addr)?.join(DEFAULT_RPC_PATH)?;
                url.set_scheme("ws").expect("Failed to change scheme to ws");

                Self::_dial_websocket(url).await
            }

            /// Creates an RPC `Client` over socket with a specified `async_std::net::TcpStream` and the default codec
            ///
            /// This is enabled
            /// if and only if **exactly one** of the the following feature flag is turned on
            /// - `serde_bincode`
            /// - `serde_json`
            /// - `serde_cbor`
            /// - `serde_rmp`
            ///
            /// # Example
            /// ```
            /// use async_std::net::TcpStream;
            /// use toy_rpc::Client;
            ///
            /// #[async_std::main]
            /// async fn main() {
            ///     let stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
            ///     let client = Client::with_stream(stream);
            /// }
            /// ```
            pub fn with_stream<T>(stream: T) -> Client<Connected>
            where
                T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
            {
                let codec = DefaultCodec::new(stream);

                Self::with_codec(codec)
            }
        }

        impl Client<NotConnected> {
            /// Creates an RPC 'Client` with a specified codec. The codec must
            /// implement `ClientCodec` trait and `GracefulShutdown` trait.
            ///
            /// Example
            ///
            /// ```rust
            /// use async_std::net::TcpStream;
            /// use toy_rpc::codec::bincode::Codec;
            /// use toy_rpc::Client;
            ///
            /// #[async_std::main]
            /// async fn main() {
            ///     let addr = "127.0.0.1:8080";
            ///     let stream = TcpStream::connect(addr).await.unwrap();
            ///     let codec = Codec::new(stream);
            ///     let client = Client::with_codec(codec);
            /// }
            /// ```
            pub fn with_codec<C>(codec: C) -> Client<Connected>
            where
                C: ClientCodec + Send + Sync + 'static,
            {
                let box_codec: Box<dyn ClientCodec> = Box::new(codec);

                Client::<Connected> {
                    count: AtomicMessageId::new(0u16),
                    inner_codec: Arc::new(Mutex::new(box_codec)),
                    pending: Arc::new(Mutex::new(HashMap::new())),

                    mode: PhantomData,
                }
            }
        }

    }

}

impl Client<Connected> {
    /// Invokes the named function and wait synchronously in a blocking manner.
    ///
    /// This function internally calls `task::block_on` to wait for the response.
    /// Do NOT use this function inside another `task::block_on`.async_std
    ///
    /// Example
    ///
    /// ```rust
    /// use toy_rpc::Client;
    ///
    /// #[async_std::main]
    /// async fn main() {
    ///     let addr = "127.0.0.1:8080";
    ///     let client = Client::dial(addr).await.unwrap();
    ///
    ///     let args = "arguments";
    ///     let reply: Result<String, Error> = client.call("echo_service.echo", &args);
    ///     println!("{:?}", reply);
    /// }
    /// ```
    pub fn call<Req, Res>(&self, service_method: impl ToString, args: Req) -> Result<Res, Error>
    where
        Req: serde::Serialize + Send + Sync,
        Res: serde::de::DeserializeOwned,
    {
        task::block_on(self.async_call(service_method, args))
    }

    /// Invokes the named function asynchronously by spawning a new task and returns the `JoinHandle`
    ///
    /// ```rust
    /// use async_std::task;
    ///
    /// use toy_rpc::client::Client;
    /// use toy_rpc::error::Error;
    ///
    /// #[async_std::main]
    /// async fn main() {
    ///     let addr = "127.0.0.1:8080";
    ///     let client = Client::dial(addr).await.unwrap();
    ///
    ///     let args = "arguments";
    ///     let handle: task::JoinHandle<Result<Res, Error>> = client.spawn_task("echo_service.echo", args);
    ///     let reply: Result<String, Error> = handle.await;
    ///     println!("{:?}", reply);
    /// }
    /// ```
    pub fn spawn_task<Req, Res>(
        &self,
        service_method: impl ToString + Send + 'static,
        args: Req,
    ) -> task::JoinHandle<Result<Res, Error>>
    where
        Req: serde::Serialize + Send + Sync + 'static,
        Res: serde::de::DeserializeOwned + Send + 'static,
    {
        let codec = self.inner_codec.clone();
        let pending = self.pending.clone();
        let id = self.count.fetch_add(1u16, Ordering::Relaxed);

        task::spawn(
            async move { Self::_async_call(service_method, &args, id, codec, pending).await },
        )
    }

    /// Invokes the named function asynchronously
    ///
    /// Example
    ///
    /// ```rust
    /// use toy_rpc::Client;
    /// use toy_rpc::error::Error;
    ///
    /// #[async_std::main]
    /// async fn main() {
    ///     let addr = "127.0.0.1:8080";
    ///     let client = Client::dial(addr).await.unwrap();
    ///
    ///     let args = "arguments";
    ///     let reply: Result<String, Error> = client.async_call("echo_service.echo", &args).await;
    ///     println!("{:?}", reply);
    /// }
    /// ```
    pub async fn async_call<Req, Res>(
        &self,
        service_method: impl ToString,
        args: Req,
    ) -> Result<Res, Error>
    where
        Req: serde::Serialize + Send + Sync,
        Res: serde::de::DeserializeOwned,
    {
        let codec = self.inner_codec.clone();
        let pending = self.pending.clone();
        let id = self.count.fetch_add(1u16, Ordering::Relaxed);

        Self::_async_call(service_method, &args, id, codec, pending).await
    }

    async fn _async_call<Req, Res>(
        service_method: impl ToString,
        args: &Req,
        id: MessageId,
        codec: Arc<Mutex<Box<dyn ClientCodec>>>,
        pending: Arc<Mutex<ResponseMap>>,
    ) -> Result<Res, Error>
    where
        Req: serde::Serialize + Send + Sync,
        Res: serde::de::DeserializeOwned,
    {
        let _codec = &mut *codec.lock().await;
        let header = RequestHeader {
            id,
            service_method: service_method.to_string(),
        };
        let req = &args as &(dyn erased::Serialize + Send + Sync);

        // send request
        _codec.write_request(header, req).await?;

        // creates channel for receiving response
        let (done_sender, done) = oneshot::channel::<Result<ResponseBody, ResponseBody>>();

        // insert sender to pending map
        {
            let mut _pending = pending.lock().await;
            _pending.insert(id, done_sender);
        }

        Client::<Connected>::_read_response(_codec.as_mut(), pending).await?;

        Client::<Connected>::_handle_response(done, &id)
    }

    /// Gracefully shutdown the connection.
    ///
    /// For a WebSocket connection, a Close message will be sent.
    /// For a raw TCP connection, the client will simply drop the connection
    pub async fn close(self) {
        let _codec = &mut self.inner_codec.lock().await;
        _codec.close().await;
    }
}

impl Client<Connected> {
    async fn _read_response(
        codec: &mut dyn ClientCodec,
        pending: Arc<Mutex<ResponseMap>>,
    ) -> Result<(), Error> {
        // wait for response
        if let Some(header) = codec.read_response_header().await {
            let ResponseHeader { id, is_error } = header?;
            let deserializer =
                codec
                    .read_response_body()
                    .await
                    .ok_or(Error::IoError(std::io::Error::new(
                        std::io::ErrorKind::UnexpectedEof,
                        "Unexpected EOF reading response body",
                    )))?;
            let deserializer = deserializer?;

            let res = match is_error {
                false => Ok(deserializer),
                true => Err(deserializer),
            };

            // send back response
            let mut _pending = pending.lock().await;
            if let Some(done_sender) = _pending.remove(&id) {
                done_sender.send(res).map_err(|_| {
                    Error::TransportError(format!(
                        "Failed to send ResponseBody over oneshot channel {}",
                        &id
                    ))
                })?;
            }
        }

        Ok(())
    }

    fn _handle_response<Res>(
        mut done: oneshot::Receiver<Result<ResponseBody, ResponseBody>>,
        id: &MessageId,
    ) -> Result<Res, Error>
    where
        Res: serde::de::DeserializeOwned,
    {
        // wait for result from oneshot channel
        let res = match done.try_recv() {
            Ok(o) => match o {
                Some(r) => r,
                None => {
                    return Err(Error::TransportError(format!(
                        "Done channel for id {} is out of date",
                        &id
                    )))
                }
            },
            _ => {
                return Err(Error::TransportError(format!(
                    "Done channel for id {} is canceled",
                    &id
                )))
            }
        };

        // deserialize Ok message and Err message
        match res {
            Ok(mut resp_body) => {
                let resp = erased::deserialize(&mut resp_body)
                    .map_err(|e| Error::ParseError(Box::new(e)))?;

                // upon successful deserializing, an Ok() must be returned
                Ok(resp)
            }
            Err(mut err_body) => {
                let err = erased::deserialize(&mut err_body)
                    .map_err(|e| Error::ParseError(Box::new(e)))?;

                // upon successful deserializing, an Err() must be returned
                Err(Error::RpcError(err))
            }
        }
    }
}