thrift_pool/
lib.rs

1//! This library provides a simple way implement [`bb8`] and/or [`r2d2`] Connection Pools
2//! for any [`TThriftClient`](thrift::TThriftClient)
3//!
4//! <br>
5//!
6//! # Usage
7//!
8//! There are 2 possible use cases
9//!
10//! ## As a library
11//!
12//! If you're implementing a library that provides a (possibly generated) thrift client,
13//! you should implement the [`ThriftConnection`] and [`FromProtocol`] traits
14//! for that client
15//!
16//! ```
17//! # use thrift::protocol::{TInputProtocol, TOutputProtocol};
18//! # use thrift_pool::{FromProtocol, ThriftConnection};
19//! #
20//! // A typical generated client looks like this
21//! struct MyThriftClient<Ip: TInputProtocol, Op: TOutputProtocol> {
22//!     i_prot: Ip,
23//!     o_prot: Op,
24//! }
25//!
26//! impl<Ip: TInputProtocol, Op: TOutputProtocol> FromProtocol for MyThriftClient<Ip, Op> {
27//!     type InputProtocol = Ip;
28//!
29//!     type OutputProtocol = Op;
30//!
31//!     fn from_protocol(
32//!         input_protocol: Self::InputProtocol,
33//!         output_protocol: Self::OutputProtocol,
34//!     ) -> Self {
35//!         MyThriftClient {
36//!             i_prot: input_protocol,
37//!             o_prot: output_protocol,
38//!         }
39//!     }
40//! }
41//!
42//! impl<Ip: TInputProtocol, Op: TOutputProtocol> ThriftConnection for MyThriftClient<Ip, Op> {
43//!     type Error = thrift::Error;
44//!     fn is_valid(&mut self) -> Result<(), Self::Error> {
45//!        Ok(())
46//!     }
47//!     fn has_broken(&mut self) -> bool {
48//!        false
49//!    }
50//! }
51//!
52//! ```
53//!
54//! ## As an application
55//!
56//! If you're implementing an application that uses a (possibly generated) thrift client that
57//! implements [`FromProtocol`] and [`ThriftConnection`] (see previous section), you can use
58//! [`r2d2`] or [`bb8`] (make sure to read their documentations) along with
59//! [`ThriftConnectionManager`] to create Connection Pools for the client
60//!
61//! ```
62//! # use thrift::protocol::{TInputProtocol, TOutputProtocol};
63//! # use thrift_pool::{FromProtocol, ThriftConnection};
64//! #
65//! # struct MyThriftClient<Ip: TInputProtocol, Op: TOutputProtocol> {
66//! #     i_prot: Ip,
67//! #     o_prot: Op,
68//! # }
69//! #
70//! # impl<Ip: TInputProtocol, Op: TOutputProtocol> FromProtocol for MyThriftClient<Ip, Op> {
71//! #     type InputProtocol = Ip;
72//! #
73//! #     type OutputProtocol = Op;
74//! #
75//! #     fn from_protocol(
76//! #         input_protocol: Self::InputProtocol,
77//! #         output_protocol: Self::OutputProtocol,
78//! #     ) -> Self {
79//! #         MyThriftClient {
80//! #             i_prot: input_protocol,
81//! #             o_prot: output_protocol,
82//! #         }
83//! #     }
84//! # }
85//! #
86//! # impl<Ip: TInputProtocol, Op: TOutputProtocol> ThriftConnection for MyThriftClient<Ip, Op> {
87//! #     type Error = thrift::Error;
88//! #     fn is_valid(&mut self) -> Result<(), Self::Error> {
89//! #        Ok(())
90//! #     }
91//! #     fn has_broken(&mut self) -> bool {
92//! #        false
93//! #    }
94//! # }
95//! # use thrift_pool::{MakeThriftConnectionFromAddrs};
96//! # use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol};
97//! # use thrift::transport::{
98//! #    ReadHalf, TFramedReadTransport, TFramedWriteTransport, TTcpChannel, WriteHalf,
99//! # };
100//! #
101//! type Client = MyThriftClient<
102//!     TCompactInputProtocol<TFramedReadTransport<ReadHalf<TTcpChannel>>>,
103//!     TCompactOutputProtocol<TFramedWriteTransport<WriteHalf<TTcpChannel>>>,
104//! >;
105//! # use tokio::net::TcpListener;
106//! # #[tokio::main]
107//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//! # let listener = TcpListener::bind("127.0.0.1:9090").await?;
109//! # tokio::spawn(async move {
110//! #      loop {
111//! #         listener.accept().await.unwrap();
112//! #      }
113//! #  });
114//!   // create a connection manager
115//!   let manager = MakeThriftConnectionFromAddrs::<Client, _>::new("localhost:9090")
116//!                 .into_connection_manager();
117//!   
118//!   // we're able to create bb8 and r2d2 Connection Pools
119//!   let bb8 = bb8::Pool::builder().build(manager.clone()).await?;
120//!   let r2d2 = r2d2::Pool::builder().build(manager)?;
121//!
122//!   // get a connection
123//!   let conn1 = bb8.get().await?;
124//!   let conn2 = r2d2.get()?;
125//! #  Ok(())
126//! # }
127//! ```
128//!
129//! <br>
130//!
131//! # Examples
132//!
133//! - [hbase-thrift](https://github.com/midnightexigent/hbase-thrift-rs) -- the project from which this
134//! library was extracted. implements Connection Pools for the client generated from the
135//! [`HBase` Thrift Spec](https://github.com/apache/hbase/tree/master/hbase-thrift/src/main/resources/org/apache/hadoop/hbase/thrift)
136//! - [thrift-pool-tutorial](https://github.com/midnightexigent/thrift-pool-tutorial-rs) -- implements
137//! Connection Pools for the client used in the official
138//! [thrift tutorial](https://github.com/apache/thrift/tree/master/tutorial)
139
140use std::{
141    io::{self, Read, Write},
142    marker::PhantomData,
143    net::ToSocketAddrs,
144};
145
146use thrift::{
147    protocol::{
148        TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol,
149        TInputProtocol, TOutputProtocol,
150    },
151    transport::{
152        ReadHalf, TBufferedReadTransport, TBufferedWriteTransport, TFramedReadTransport,
153        TFramedWriteTransport, TIoChannel, TReadTransport, TTcpChannel, TWriteTransport, WriteHalf,
154    },
155};
156
157/// Create self from a [`Read`]
158pub trait FromRead: TReadTransport {
159    type Read: io::Read;
160    fn from_read(read: Self::Read) -> Self;
161}
162
163impl<R: Read> FromRead for TBufferedReadTransport<R> {
164    type Read = R;
165    fn from_read(read: R) -> Self {
166        Self::new(read)
167    }
168}
169impl<R: Read> FromRead for TFramedReadTransport<R> {
170    type Read = R;
171    fn from_read(read: R) -> Self {
172        Self::new(read)
173    }
174}
175
176/// Create self from a [`Write`]
177pub trait FromWrite: TWriteTransport {
178    type Write: io::Write;
179    fn from_write(write: Self::Write) -> Self;
180}
181
182impl<W: Write> FromWrite for TBufferedWriteTransport<W> {
183    type Write = W;
184    fn from_write(write: W) -> Self {
185        Self::new(write)
186    }
187}
188
189impl<W: Write> FromWrite for TFramedWriteTransport<W> {
190    type Write = W;
191
192    fn from_write(write: Self::Write) -> Self {
193        Self::new(write)
194    }
195}
196
197/// Create self from a [`TReadTransport`]
198pub trait FromReadTransport: TInputProtocol {
199    type ReadTransport: TReadTransport;
200    fn from_read_transport(r_tran: Self::ReadTransport) -> Self;
201}
202
203impl<RT: TReadTransport> FromReadTransport for TBinaryInputProtocol<RT> {
204    type ReadTransport = RT;
205
206    fn from_read_transport(r_tran: RT) -> Self {
207        Self::new(r_tran, true)
208    }
209}
210
211impl<RT: TReadTransport> FromReadTransport for TCompactInputProtocol<RT> {
212    type ReadTransport = RT;
213
214    fn from_read_transport(r_tran: RT) -> Self {
215        Self::new(r_tran)
216    }
217}
218
219/// Create self from a [`TWriteTransport`]
220pub trait FromWriteTransport: TOutputProtocol {
221    type WriteTransport: TWriteTransport;
222    fn from_write_transport(w_tran: Self::WriteTransport) -> Self;
223}
224
225impl<WT: TWriteTransport> FromWriteTransport for TBinaryOutputProtocol<WT> {
226    type WriteTransport = WT;
227    fn from_write_transport(w_tran: WT) -> Self {
228        Self::new(w_tran, true)
229    }
230}
231
232impl<WT: TWriteTransport> FromWriteTransport for TCompactOutputProtocol<WT> {
233    type WriteTransport = WT;
234    fn from_write_transport(w_tran: WT) -> Self {
235        Self::new(w_tran)
236    }
237}
238
239/// Create self from a [`TInputProtocol`] and a [`TOutputProtocol`]
240pub trait FromProtocol {
241    type InputProtocol: TInputProtocol;
242    type OutputProtocol: TOutputProtocol;
243
244    fn from_protocol(
245        input_protocol: Self::InputProtocol,
246        output_protocol: Self::OutputProtocol,
247    ) -> Self;
248}
249
250/// Checks the validity of the connection
251///
252/// Used by [`ThriftConnectionManager`] to implement parts of
253/// [`bb8::ManageConnection`] and/or [`r2d2::ManageConnection`]
254pub trait ThriftConnection {
255    /// See [`r2d2::ManageConnection::Error`] and/or [`bb8::ManageConnection::Error`]
256    type Error;
257
258    /// See [`r2d2::ManageConnection::is_valid`] and/or [`bb8::ManageConnection::is_valid`]
259    ///
260    /// # Errors
261    ///
262    /// Should return `Err` if the connection is invalid
263    fn is_valid(&mut self) -> Result<(), Self::Error>;
264
265    /// See [`r2d2::ManageConnection::has_broken`] and/or [`bb8::ManageConnection::has_broken`]
266    fn has_broken(&mut self) -> bool {
267        false
268    }
269}
270
271/// A trait that creates new [`ThriftConnection`]s
272///
273/// Used by [`ThriftConnectionManager`] to implement
274/// [`bb8::ManageConnection::connect`]
275/// and/or [`r2d2::ManageConnection::connect`]
276pub trait MakeThriftConnection {
277    /// The error type returned when a connection creation fails
278    type Error;
279    /// The connection type the we are trying to create
280    type Output;
281
282    /// Attempt to create a new connection
283    ///
284    /// # Errors
285    ///
286    /// Should return `Err` if (for any reason)
287    /// unable to create a new connection
288    fn make_thrift_connection(&self) -> Result<Self::Output, Self::Error>;
289}
290
291/// A [`MakeThriftConnection`] that attempts to create new connections
292/// from a [`ToSocketAddrs`] and a [`FromProtocol`]
293///
294/// The connection is created in accordance with the
295/// [thrift rust tutorial](https://github.com/apache/thrift/tree/master/tutorial):
296///
297/// * Open a [`TTcpChannel`] and split it
298/// * Use the created `[ReadHalf]` and `[WriteHalf]` to create [`TReadTransport`] and [`TWriteTransport`]
299/// * Use those to create [`TInputProtocol`] and [`TOutputProtocol`]
300/// * Create a new client with `i_prot` and `o_prot` -- It needs to implement [`FromProtocol`]
301///
302/// For that to happen, `T` needs to be able
303/// to create the `Read`/`Write` `Transport`s
304/// and `Input`/`Output` `Protocol`s from
305/// the `ReadHalf` and `WriteHalf` of the `TTcpChannel`.
306/// Those contraints should be fairly easily satisfied
307/// by implementing the relevant traits in the library
308///
309/// ```
310///
311/// use thrift_pool::{MakeThriftConnectionFromAddrs, FromProtocol};
312///
313/// use thrift::{
314///     protocol::{TCompactInputProtocol, TCompactOutputProtocol, TInputProtocol, TOutputProtocol},
315///     transport::{
316///         ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, TReadTransport,
317///         TTcpChannel, TWriteTransport, WriteHalf,
318///     },
319/// };
320///
321/// // A typical generated client looks like this
322/// struct MyThriftClient<Ip: TInputProtocol, Op: TOutputProtocol> {
323///     i_prot: Ip,
324///     o_prot: Op,
325/// }
326/// impl<Ip: TInputProtocol, Op: TOutputProtocol> FromProtocol for MyThriftClient<Ip, Op> {
327///     type InputProtocol = Ip;
328///
329///     type OutputProtocol = Op;
330///
331///     fn from_protocol(
332///         input_protocol: Self::InputProtocol,
333///         output_protocol: Self::OutputProtocol,
334///     ) -> Self {
335///         MyThriftClient {
336///             i_prot: input_protocol,
337///             o_prot: output_protocol,
338///         }
339///     }
340/// }
341/// type Client = MyThriftClient<
342///     TCompactInputProtocol<TFramedReadTransport<ReadHalf<TTcpChannel>>>,
343///     TCompactOutputProtocol<TFramedWriteTransport<WriteHalf<TTcpChannel>>>,
344/// >;
345///
346/// // The Protocols/Transports used in this client implement the necessary traits so we can do this
347/// let manager =
348///     MakeThriftConnectionFromAddrs::<Client, _>::new("localhost:9090").into_connection_manager();
349///
350/// ```
351pub struct MakeThriftConnectionFromAddrs<T, S> {
352    addrs: S,
353    conn: PhantomData<T>,
354}
355
356impl<T, S: std::fmt::Debug> std::fmt::Debug for MakeThriftConnectionFromAddrs<T, S> {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        f.debug_struct("MakeThriftConnectionFromAddrs")
359            .field("addrs", &self.addrs)
360            .field("conn", &self.conn)
361            .finish()
362    }
363}
364impl<T, S: Clone> Clone for MakeThriftConnectionFromAddrs<T, S> {
365    fn clone(&self) -> Self {
366        Self {
367            addrs: self.addrs.clone(),
368            conn: PhantomData,
369        }
370    }
371}
372
373impl<T, S> MakeThriftConnectionFromAddrs<T, S> {
374    pub fn new(addrs: S) -> Self {
375        Self {
376            addrs,
377            conn: PhantomData,
378        }
379    }
380}
381
382impl<
383        S: ToSocketAddrs + Clone,
384        RT: FromRead<Read = ReadHalf<TTcpChannel>>,
385        IP: FromReadTransport<ReadTransport = RT>,
386        WT: FromWrite<Write = WriteHalf<TTcpChannel>>,
387        OP: FromWriteTransport<WriteTransport = WT>,
388        T: FromProtocol<InputProtocol = IP, OutputProtocol = OP>,
389    > MakeThriftConnectionFromAddrs<T, S>
390{
391    pub fn into_connection_manager(self) -> ThriftConnectionManager<Self> {
392        ThriftConnectionManager::new(self)
393    }
394}
395
396impl<
397        S: ToSocketAddrs + Clone,
398        RT: FromRead<Read = ReadHalf<TTcpChannel>>,
399        IP: FromReadTransport<ReadTransport = RT>,
400        WT: FromWrite<Write = WriteHalf<TTcpChannel>>,
401        OP: FromWriteTransport<WriteTransport = WT>,
402        T: FromProtocol<InputProtocol = IP, OutputProtocol = OP>,
403    > MakeThriftConnection for MakeThriftConnectionFromAddrs<T, S>
404{
405    type Error = thrift::Error;
406
407    type Output = T;
408
409    fn make_thrift_connection(&self) -> Result<Self::Output, Self::Error> {
410        let mut channel = TTcpChannel::new();
411        channel.open(self.addrs.clone())?;
412        let (read, write) = channel.split()?;
413
414        let read_transport = RT::from_read(read);
415        let input_protocol = IP::from_read_transport(read_transport);
416
417        let write_transport = WT::from_write(write);
418        let output_protocol = OP::from_write_transport(write_transport);
419
420        Ok(T::from_protocol(input_protocol, output_protocol))
421    }
422}
423
424/// An implementor of [`bb8::ManageConnection`] and/or [`r2d2::ManageConnection`].
425/// `T` should a [`MakeThriftConnection`] and `T::Output` should be a [`ThriftConnection`]
426pub struct ThriftConnectionManager<T>(T);
427
428impl<T: Clone> Clone for ThriftConnectionManager<T> {
429    fn clone(&self) -> Self {
430        Self(self.0.clone())
431    }
432}
433impl<T: std::fmt::Debug> std::fmt::Debug for ThriftConnectionManager<T> {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_tuple("ThriftConnectionManager")
436            .field(&self.0)
437            .finish()
438    }
439}
440
441impl<T> ThriftConnectionManager<T> {
442    pub fn new(make_thrift_connection: T) -> Self {
443        Self(make_thrift_connection)
444    }
445}
446
447#[cfg(feature = "impl-bb8")]
448#[async_trait::async_trait]
449impl<
450        E: Send + std::fmt::Debug + 'static,
451        C: ThriftConnection<Error = E> + Send + 'static,
452        T: MakeThriftConnection<Output = C, Error = E> + Send + Sync + 'static,
453    > bb8::ManageConnection for ThriftConnectionManager<T>
454{
455    type Connection = C;
456
457    type Error = E;
458
459    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
460        self.0.make_thrift_connection()
461    }
462
463    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
464        conn.has_broken()
465    }
466
467    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
468        conn.is_valid()
469    }
470}
471
472#[cfg(feature = "impl-r2d2")]
473impl<
474        E: std::error::Error + 'static,
475        C: ThriftConnection<Error = E> + Send + 'static,
476        T: MakeThriftConnection<Output = C, Error = E> + Send + Sync + 'static,
477    > r2d2::ManageConnection for ThriftConnectionManager<T>
478{
479    type Connection = C;
480
481    type Error = E;
482
483    fn connect(&self) -> Result<Self::Connection, Self::Error> {
484        self.0.make_thrift_connection()
485    }
486
487    fn has_broken(&self, conn: &mut Self::Connection) -> bool {
488        conn.has_broken()
489    }
490
491    fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
492        conn.is_valid()
493    }
494}