Skip to main content

rusty_tpkt/
service.rs

1use std::{collections::VecDeque, net::SocketAddr};
2
3use bytes::{Buf, BytesMut};
4use tokio::{
5    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf, split},
6    net::{TcpListener, TcpStream},
7};
8
9use crate::{
10    TpktConnection, TpktError, TpktReader, TpktRecvResult, TpktWriter,
11    parser::{TpktParser, TpktParserResult},
12    serialiser::TpktSerialiser,
13};
14
15/// A TPKT server implemented over a TCP connection.
16pub struct TcpTpktServer {
17    listener: TcpListener,
18}
19
20impl TcpTpktServer {
21    /// Start listening on the provided TCP port.
22    pub async fn listen(address: SocketAddr) -> Result<Self, TpktError> {
23        Ok(Self { listener: TcpListener::bind(address).await? })
24    }
25
26    /// Accept an incoming connection. This may be called multiple times.
27    pub async fn accept<'a>(&self) -> Result<(TcpTpktConnection, SocketAddr), TpktError> {
28        let (stream, remote_host) = self.listener.accept().await?;
29        let (reader, writer) = split(stream);
30        Ok((TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer)), remote_host))
31    }
32}
33
34/// An established TPKT connection.
35pub struct TcpTpktConnection {
36    reader: TcpTpktReader,
37    writer: TcpTpktWriter,
38}
39
40impl TcpTpktConnection {
41    /// Initiates a client TPKT connection.
42    pub async fn connect<'a>(address: SocketAddr) -> Result<TcpTpktConnection, TpktError> {
43        let stream = TcpStream::connect(address).await?;
44        let (reader, writer) = split(stream);
45        return Ok(TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer)));
46    }
47
48    fn new(reader: TcpTpktReader, writer: TcpTpktWriter) -> Self {
49        TcpTpktConnection { reader, writer }
50    }
51}
52
53impl TpktConnection for TcpTpktConnection {
54    async fn split(self) -> Result<(impl TpktReader, impl TpktWriter), TpktError> {
55        Ok((self.reader, self.writer))
56    }
57}
58
59/// The read half of a TPKT connection.
60pub struct TcpTpktReader {
61    parser: TpktParser,
62    receive_buffer: BytesMut,
63    reader: ReadHalf<TcpStream>,
64}
65
66impl TcpTpktReader {
67    fn new(reader: ReadHalf<TcpStream>) -> Self {
68        Self { reader, parser: TpktParser::new(), receive_buffer: BytesMut::new() }
69    }
70}
71
72impl TpktReader for TcpTpktReader {
73    async fn recv(&mut self) -> Result<TpktRecvResult, TpktError> {
74        loop {
75            let buffer = &mut self.receive_buffer;
76            match self.parser.parse(buffer) {
77                Ok(TpktParserResult::Data(x)) => return Ok(TpktRecvResult::Data(x)),
78                Ok(TpktParserResult::InProgress) => (),
79                Err(x) => return Err(x),
80            };
81            if self.reader.read_buf(buffer).await? == 0 {
82                return Ok(TpktRecvResult::Closed);
83            };
84        }
85    }
86}
87
88/// The write half of a TPKT connection.
89pub struct TcpTpktWriter {
90    write_buffer: BytesMut,
91    serialiser: TpktSerialiser,
92    writer: WriteHalf<TcpStream>,
93}
94
95impl TcpTpktWriter {
96    fn new(writer: WriteHalf<TcpStream>) -> Self {
97        Self { serialiser: TpktSerialiser::new(), writer, write_buffer: BytesMut::new() }
98    }
99}
100
101impl TpktWriter for TcpTpktWriter {
102    async fn send(&mut self, input: &mut VecDeque<Vec<u8>>) -> Result<(), TpktError> {
103        while let Some(packet) = input.pop_front() {
104            self.write_buffer.extend(self.serialiser.serialise(&packet)?);
105        }
106
107        while self.write_buffer.has_remaining() {
108            self.writer.write_buf(&mut self.write_buffer).await?;
109        }
110        Ok(())
111    }
112}