ss_rs/net/
stream.rs

1//! Shadowsocks streams.
2
3use std::{
4    fmt::{self, Display, Formatter},
5    io,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use futures_core::{ready, Future};
13use pin_project_lite::pin_project;
14use tokio::{
15    io::{AsyncRead, AsyncWrite, ReadBuf},
16    time::{Instant, Sleep},
17};
18
19use crate::{
20    context::Ctx,
21    crypto::{
22        cipher::{Cipher, Method},
23        hkdf_sha1, Nonce,
24    },
25    net::{buf::OwnedReadBuf, constants::MAXIMUM_PAYLOAD_SIZE, poll_read_exact},
26};
27
28/// A shadowsocks tcp stream.
29pub struct TcpStream<T> {
30    inner_stream: T,
31
32    cipher_method: Method,
33    cipher_key: Vec<u8>,
34
35    enc_cipher: Option<Cipher>,
36    dec_cipher: Option<Cipher>,
37
38    enc_nonce: Nonce,
39    dec_nonce: Nonce,
40
41    incoming_salt: Option<Vec<u8>>, // for replay protection
42
43    read_state: ReadState,
44    write_state: WriteState,
45
46    in_payload: Vec<u8>,  // decrypted payload
47    out_payload: Vec<u8>, // encrypted payload
48
49    read_buf: OwnedReadBuf,
50
51    ctx: Arc<Ctx>,
52}
53
54impl<T> TcpStream<T> {
55    /// Creates a new shadowsocks tcp stream from a stream.
56    pub fn new(inner_stream: T, cipher_method: Method, cipher_key: &[u8], ctx: Arc<Ctx>) -> Self {
57        TcpStream {
58            inner_stream,
59            cipher_method,
60            cipher_key: cipher_key.to_owned(),
61            enc_cipher: None,
62            dec_cipher: None,
63            enc_nonce: Nonce::new(cipher_method.iv_size()),
64            dec_nonce: Nonce::new(cipher_method.iv_size()),
65            incoming_salt: None,
66            read_state: ReadState::ReadSalt,
67            write_state: WriteState::WriteSalt,
68            in_payload: Vec::new(),
69            out_payload: Vec::new(),
70            read_buf: OwnedReadBuf::new(),
71            ctx: ctx.clone(),
72        }
73    }
74}
75
76impl<T> TcpStream<T> {
77    fn encrypt(&mut self, plaintext: &[u8]) -> io::Result<Vec<u8>> {
78        match self
79            .enc_cipher
80            .as_ref()
81            .expect("no salt received")
82            .encrypt(&self.enc_nonce, plaintext)
83        {
84            Ok(data) => {
85                self.enc_nonce.increment();
86                Ok(data)
87            }
88            Err(_) => Err(io::Error::new(io::ErrorKind::Other, Error::Encryption)),
89        }
90    }
91
92    fn decrypt(&mut self, ciphertext: &[u8]) -> io::Result<Vec<u8>> {
93        match self
94            .dec_cipher
95            .as_ref()
96            .expect("no salt received")
97            .decrypt(&self.dec_nonce, ciphertext)
98        {
99            Ok(data) => {
100                self.dec_nonce.increment();
101                Ok(data)
102            }
103            Err(_) => Err(io::Error::new(io::ErrorKind::Other, Error::Decryption)),
104        }
105    }
106}
107
108impl<T> TcpStream<T>
109where
110    T: AsyncRead + Unpin,
111{
112    fn poll_read_decrypt_helper(
113        &mut self,
114        cx: &mut Context<'_>,
115        buf: &mut ReadBuf<'_>,
116    ) -> Poll<io::Result<()>> {
117        let res = ready!(self.poll_read_decrypt(cx, buf));
118
119        if let Err(e) = res {
120            if e.kind() != io::ErrorKind::UnexpectedEof {
121                return Err(e).into();
122            }
123        }
124
125        Ok(()).into()
126    }
127
128    fn poll_read_decrypt(
129        &mut self,
130        cx: &mut Context<'_>,
131        buf: &mut ReadBuf<'_>,
132    ) -> Poll<io::Result<()>> {
133        loop {
134            match self.read_state {
135                ReadState::ReadSalt => {
136                    ready!(self.poll_read_salt(cx))?;
137                    self.read_state = ReadState::ReadLength;
138                }
139                ReadState::ReadLength => {
140                    let len = ready!(self.poll_read_length(cx))?;
141                    self.read_state = ReadState::ReadPayload(len);
142                }
143                ReadState::ReadPayload(payload_len) => {
144                    self.in_payload = ready!(self.poll_read_payload(cx, payload_len))?;
145                    self.read_state = ReadState::ReadPayloadOut;
146                }
147                ReadState::ReadPayloadOut => {
148                    let buf_remaining = buf.remaining();
149                    let payload_len = self.in_payload.len();
150
151                    if buf_remaining >= payload_len {
152                        buf.put_slice(&self.in_payload);
153                        self.read_state = ReadState::ReadLength;
154                    } else {
155                        let (data, remaining) = self.in_payload.split_at(buf_remaining);
156                        buf.put_slice(data);
157                        self.in_payload = remaining.to_owned();
158                        self.read_state = ReadState::ReadPayloadOut;
159                    }
160
161                    return Ok(()).into();
162                }
163            }
164        }
165    }
166
167    fn poll_read_salt(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168        if self.dec_cipher.is_none() {
169            let mut salt = vec![0u8; self.cipher_method.salt_size()];
170            ready!(poll_read_exact(
171                &mut self.inner_stream,
172                &mut self.read_buf,
173                cx,
174                &mut salt
175            ))?;
176
177            self.incoming_salt = Some(salt.clone());
178
179            let mut subkey = vec![0u8; self.cipher_method.key_size()];
180            hkdf_sha1(&self.cipher_key, &salt, &mut subkey);
181
182            let cipher = Cipher::new(self.cipher_method, &mut subkey);
183            self.dec_cipher.replace(cipher);
184        }
185
186        Ok(()).into()
187    }
188
189    fn poll_read_length(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
190        let mut buf = vec![0u8; 2 + self.cipher_method.tag_size()];
191        ready!(poll_read_exact(
192            &mut self.inner_stream,
193            &mut self.read_buf,
194            cx,
195            &mut buf
196        ))?;
197
198        let len = self.decrypt(&buf)?;
199        let len = [len[0], len[1]];
200        let payload_len = (u16::from_be_bytes(len) as usize) & MAXIMUM_PAYLOAD_SIZE;
201
202        if let Some(salt) = self.incoming_salt.take() {
203            if !self.ctx.check_replay(&salt) {
204                return Err(io::Error::new(io::ErrorKind::Other, Error::DuplicateSalt)).into();
205            }
206        }
207
208        Ok(payload_len).into()
209    }
210
211    fn poll_read_payload(
212        &mut self,
213        cx: &mut Context<'_>,
214        payload_len: usize,
215    ) -> Poll<io::Result<Vec<u8>>> {
216        let mut buf = vec![0u8; payload_len + self.cipher_method.tag_size()];
217        ready!(poll_read_exact(
218            &mut self.inner_stream,
219            &mut self.read_buf,
220            cx,
221            &mut buf
222        ))?;
223        let payload = self.decrypt(&buf)?;
224
225        Ok(payload).into()
226    }
227}
228
229impl<T> TcpStream<T>
230where
231    T: AsyncWrite + Unpin,
232{
233    fn poll_write_encrypt(
234        &mut self,
235        cx: &mut Context<'_>,
236        payload: &[u8],
237    ) -> Poll<io::Result<usize>> {
238        loop {
239            match self.write_state {
240                WriteState::WriteSalt => {
241                    ready!(self.poll_write_salt(cx))?;
242                    self.write_state = WriteState::WriteLength;
243                }
244                WriteState::WriteLength => {
245                    ready!(self.poll_write_length(cx, payload))?;
246                    self.write_state = WriteState::WritePayload;
247                }
248                WriteState::WritePayload => {
249                    ready!(self.poll_write_payload(cx, payload))?;
250                    self.write_state = WriteState::WritePayloadOut;
251                }
252                WriteState::WritePayloadOut => {
253                    while !self.out_payload.is_empty() {
254                        let nwrite = ready!(
255                            Pin::new(&mut self.inner_stream).poll_write(cx, &self.out_payload)
256                        )?;
257
258                        self.out_payload = self.out_payload[nwrite..].to_vec();
259                    }
260
261                    self.write_state = WriteState::WriteLength;
262
263                    let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
264                    return Ok(length).into();
265                }
266            }
267        }
268    }
269
270    fn poll_write_salt(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
271        use rand::prelude::*;
272
273        if self.enc_cipher.is_none() {
274            let mut salt = vec![0u8; self.cipher_method.salt_size()];
275            let mut rng = StdRng::from_entropy();
276            rng.fill_bytes(&mut salt);
277
278            let mut subkey = vec![0u8; self.cipher_method.key_size()];
279            hkdf_sha1(&self.cipher_key, &salt, &mut subkey);
280
281            let cipher = Cipher::new(self.cipher_method, &mut subkey);
282            self.enc_cipher.replace(cipher);
283
284            self.out_payload.append(&mut salt);
285        }
286
287        Ok(()).into()
288    }
289
290    fn poll_write_length(&mut self, _cx: &mut Context<'_>, payload: &[u8]) -> Poll<io::Result<()>> {
291        let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
292        let len = (length as u16).to_be_bytes();
293
294        let mut buf = self.encrypt(&len)?;
295        self.out_payload.append(&mut buf);
296
297        Ok(()).into()
298    }
299
300    fn poll_write_payload(
301        &mut self,
302        _cx: &mut Context<'_>,
303        payload: &[u8],
304    ) -> Poll<io::Result<()>> {
305        let length = usize::min(payload.len(), MAXIMUM_PAYLOAD_SIZE);
306
307        let mut buf = self.encrypt(&payload[..length])?;
308        self.out_payload.append(&mut buf);
309
310        Ok(()).into()
311    }
312}
313
314impl<T> AsyncRead for TcpStream<T>
315where
316    T: AsyncRead + Unpin,
317{
318    fn poll_read(
319        self: Pin<&mut Self>,
320        cx: &mut Context<'_>,
321        buf: &mut ReadBuf<'_>,
322    ) -> Poll<io::Result<()>> {
323        self.get_mut().poll_read_decrypt_helper(cx, buf)
324    }
325}
326
327impl<T> AsyncWrite for TcpStream<T>
328where
329    T: AsyncWrite + Unpin,
330{
331    fn poll_write(
332        self: Pin<&mut Self>,
333        cx: &mut Context<'_>,
334        buf: &[u8],
335    ) -> Poll<io::Result<usize>> {
336        self.get_mut().poll_write_encrypt(cx, buf)
337    }
338
339    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340        let inner_stream = &mut self.get_mut().inner_stream;
341        Pin::new(inner_stream).poll_flush(cx)
342    }
343
344    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
345        let inner_stream = &mut self.get_mut().inner_stream;
346        Pin::new(inner_stream).poll_shutdown(cx)
347    }
348}
349
350pin_project! {
351    /// A stream with timeout.
352    ///
353    /// A successful read or write on the stream will reset the timeout.
354    pub struct TimeoutStream<T> {
355        #[pin]
356        inner_stream: T,
357
358        duration: Duration,
359        sleep: Option<Sleep>,
360    }
361}
362
363impl<T> TimeoutStream<T> {
364    /// Creates a new timeout stream with the given duration.
365    pub fn new(inner_stream: T, duration: Duration) -> Self {
366        TimeoutStream {
367            inner_stream,
368            duration,
369            sleep: None,
370        }
371    }
372}
373
374impl<T> TimeoutStream<T> {
375    fn check_timeout(sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> io::Result<()> {
376        match sleep.poll(cx) {
377            Poll::Ready(_) => Err(io::ErrorKind::TimedOut.into()),
378            Poll::Pending => Ok(()),
379        }
380    }
381
382    fn reset_timeout(sleep: Pin<&mut Sleep>, duration: Duration) {
383        sleep.reset(Instant::now() + duration);
384    }
385}
386
387impl<T> AsyncRead for TimeoutStream<T>
388where
389    T: AsyncRead,
390{
391    fn poll_read(
392        self: Pin<&mut Self>,
393        cx: &mut Context<'_>,
394        buf: &mut ReadBuf<'_>,
395    ) -> Poll<io::Result<()>> {
396        let this = self.project();
397        let ret = this.inner_stream.poll_read(cx, buf);
398
399        let sleep = unsafe {
400            Pin::new_unchecked(
401                this.sleep
402                    .get_or_insert(tokio::time::sleep_until(Instant::now() + *this.duration)),
403            )
404        };
405
406        match ret {
407            Poll::Ready(_) => Self::reset_timeout(sleep, *this.duration),
408            Poll::Pending => Self::check_timeout(sleep, cx)?,
409        }
410
411        ret
412    }
413}
414
415impl<T> AsyncWrite for TimeoutStream<T>
416where
417    T: AsyncWrite,
418{
419    fn poll_write(
420        self: Pin<&mut Self>,
421        cx: &mut Context<'_>,
422        buf: &[u8],
423    ) -> Poll<Result<usize, io::Error>> {
424        let this = self.project();
425        let ret = this.inner_stream.poll_write(cx, buf);
426
427        let sleep = unsafe {
428            Pin::new_unchecked(
429                this.sleep
430                    .get_or_insert(tokio::time::sleep_until(Instant::now() + *this.duration)),
431            )
432        };
433
434        match ret {
435            Poll::Ready(_) => Self::reset_timeout(sleep, *this.duration),
436            Poll::Pending => Self::check_timeout(sleep, cx)?,
437        }
438
439        ret
440    }
441
442    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
443        let this = self.project();
444        this.inner_stream.poll_flush(cx)
445    }
446
447    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
448        let this = self.project();
449        this.inner_stream.poll_shutdown(cx)
450    }
451}
452
453/// Errors during shadowsocks communication.
454#[derive(Debug)]
455pub enum Error {
456    /// Encryption error.
457    Encryption,
458
459    /// Decryption error.
460    Decryption,
461
462    /// Duplicate salt received, possible replay attack.
463    DuplicateSalt,
464}
465
466impl Display for Error {
467    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
468        match self {
469            Error::Encryption => write!(f, "encryption error"),
470            Error::Decryption => write!(f, "decryption error"),
471            Error::DuplicateSalt => write!(f, "duplicate salt received, possible replay attack"),
472        }
473    }
474}
475
476impl std::error::Error for Error {}
477
478enum ReadState {
479    ReadSalt,
480    ReadLength,
481    ReadPayload(usize),
482    ReadPayloadOut,
483}
484
485enum WriteState {
486    WriteSalt,
487    WriteLength,
488    WritePayload,
489    WritePayloadOut,
490}
491
492// #[cfg(test)]
493// mod tests {
494//     use std::{pin::Pin, time::Duration};
495
496//     use tokio::{
497//         io::{AsyncReadExt, AsyncWriteExt},
498//         net::TcpListener,
499//     };
500
501//     use super::TimeoutStream;
502
503//     #[tokio::test]
504//     async fn test_timeout() {
505//         let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
506
507//         loop {
508//             let (stream, _) = listener.accept().await.unwrap();
509//             let mut stream = TimeoutStream::new(stream, Duration::from_secs(3));
510
511//             tokio::spawn(async move {
512//                 let mut stream = unsafe { Pin::new_unchecked(&mut stream) };
513
514//                 loop {
515//                     let mut buf = [0u8; 1024];
516//                     match stream.read(&mut buf).await {
517//                         Ok(0) => return,
518//                         Ok(n) => {
519//                             print!("{}", String::from_utf8(buf[..n].to_owned()).unwrap());
520//                             stream.write(&buf[..n]).await.unwrap();
521//                         }
522//                         Err(e) => {
523//                             println!("{}", e);
524//                             return;
525//                         }
526//                     }
527//                 }
528//             });
529//         }
530//     }
531// }