sfo_split/
splittable.rs

1use std::io::Error;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7pub struct Splittable<R, W> {
8    r: R,
9    w: W,
10}
11
12impl<R, W> Splittable<R, W> {
13    pub fn new(r: R, w: W) -> Self {
14        Self {
15            r,
16            w,
17        }
18    }
19
20    pub fn split(self) -> (RHalf<R, W>, WHalf<R, W>) {
21        let key = Arc::new(0);
22        (
23            RHalf::new(self.r, key.clone()),
24            WHalf::new(self.w, key),
25        )
26    }
27
28    pub fn get_r(&self) -> &R {
29        &self.r
30    }
31
32    pub fn get_w(&self) -> &W {
33        &self.w
34    }
35
36    pub fn get_r_mut(&mut self) -> &mut R {
37        &mut self.r
38    }
39
40    pub fn get_w_mut(&mut self) -> &mut W {
41        &mut self.w
42    }
43}
44
45impl <R, W> Deref for Splittable<R, W> {
46    type Target = R;
47    fn deref(&self) -> &Self::Target {
48        &self.r
49    }
50}
51
52impl <R, W> DerefMut for Splittable<R, W> {
53    fn deref_mut(&mut self) -> &mut Self::Target {
54        &mut self.r
55    }
56}
57
58pub struct RHalf<R, W> {
59    k: Arc<u8>,
60    r: R,
61    _p: std::marker::PhantomData<W>,
62}
63
64impl<R, W> RHalf<R, W> {
65    pub(crate) fn new(r: R, k: Arc<u8>) -> Self {
66        Self {
67            k,
68            r,
69            _p: Default::default(),
70        }
71    }
72
73    pub fn is_pair_of(&self, w: &WHalf<R, W>) -> bool {
74        Arc::ptr_eq(&self.k, &w.k)
75    }
76
77    pub fn unsplit(self, w: WHalf<R, W>) -> Splittable<R, W> {
78        if !self.is_pair_of(&w) {
79            panic!("not a pair");
80        }
81
82        Splittable::new(self.r, w.w)
83    }
84}
85
86impl<R, W> Deref for RHalf<R, W> {
87    type Target = R;
88    fn deref(&self) -> &Self::Target {
89        &self.r
90    }
91}
92
93impl<R, W> DerefMut for RHalf<R, W> {
94    fn deref_mut(&mut self) -> &mut Self::Target {
95        &mut self.r
96    }
97}
98
99pub struct WHalf<R, W> {
100    k: Arc<u8>,
101    w: W,
102    _p: std::marker::PhantomData<R>,
103}
104
105impl<R, W> WHalf<R, W> {
106    pub(crate) fn new(w: W, k: Arc<u8>) -> Self {
107        Self {
108            k,
109            w,
110            _p: Default::default(),
111        }
112    }
113
114    pub fn is_pair_of(&self, r: &RHalf<R, W>) -> bool {
115        r.is_pair_of(self)
116    }
117
118    pub fn unsplit(self, r: RHalf<R, W>) -> Splittable<R, W> {
119        r.unsplit(self)
120    }
121}
122
123impl<R, W> DerefMut for WHalf<R, W> {
124    fn deref_mut(&mut self) -> &mut Self::Target {
125        &mut self.w
126    }
127}
128
129impl<R, W> Deref for WHalf<R, W> {
130    type Target = W;
131    fn deref(&self) -> &Self::Target {
132        &self.w
133    }
134}
135
136#[cfg(feature = "io")]
137impl<R: tokio::io::AsyncRead + Unpin, W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for Splittable<R, W> {
138    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
139        Pin::new(self.get_w_mut()).poll_write(cx, buf)
140    }
141
142    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
143        Pin::new(self.get_w_mut()).poll_flush(cx)
144    }
145
146    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
147        Pin::new(self.get_w_mut()).poll_shutdown(cx)
148    }
149}
150
151#[cfg(feature = "io")]
152impl<R, W> tokio::io::AsyncRead for Splittable<R, W> where R: tokio::io::AsyncRead + Unpin, W: tokio::io::AsyncWrite + Unpin {
153    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<Result<(), Error>> {
154        Pin::new(self.get_r_mut()).poll_read(cx, buf)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    #[test]
161    fn test() {
162        pub struct TestRead {
163
164        }
165
166        pub struct TestWrite {
167
168        }
169
170        let s1 = super::Splittable::new(TestRead{}, TestWrite{});
171        let (r1, w1) = s1.split();
172        let s2 = super::Splittable::new(TestRead{}, TestWrite{});
173        let (r2, w2) = s2.split();
174        assert!(r1.is_pair_of(&w1));
175        assert!(w1.is_pair_of(&r1));
176        assert!(r2.is_pair_of(&w2));
177        assert!(w2.is_pair_of(&r2));
178        assert!(!w1.is_pair_of(&r2));
179        assert!(!w2.is_pair_of(&r1));
180
181        let _s1 = r1.unsplit(w1);
182        let _s2 = r2.unsplit(w2);
183    }
184}