segment_rs/
connection.rs

1use crate::frame::{
2    self, Frame, ParseFrameError, ARRAY_IDENT, BOOLEAN_IDENT, DOUBLE_IDENT, ERROR_IDENT,
3    INTEGER_IDENT, MAP_IDENT, STRING_IDENT,
4};
5use bytes::{Buf, BytesMut};
6use std::io::{self, Cursor};
7use thiserror::Error;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10
11/// Represents connection option
12#[derive(Debug)]
13pub struct ConnectionOptions {
14    host: String,
15    port: u16,
16}
17
18#[derive(Debug)]
19/// Represents a Segment connection
20pub struct Connection {
21    stream: TcpStream,
22    buf: BytesMut,
23}
24
25#[derive(Debug, Error)]
26/// Represents a connection error
27pub enum ConnectionError {
28    /// Represents a TCP connection error
29    #[error(transparent)]
30    TCPError(#[from] io::Error),
31
32    /// Occurs when the connection is prematurely closed by the server
33    #[error("server did not send any response")]
34    Eof,
35
36    /// Occurs when there is an error in parsing the frame
37    #[error(transparent)]
38    FrameError(#[from] ParseFrameError),
39}
40
41impl Connection {
42    /// Creates a new connection from a TcpStream
43    pub async fn connect(options: &ConnectionOptions) -> Result<Self, ConnectionError> {
44        let stream = TcpStream::connect(format!("{}:{}", options.host(), options.port())).await?;
45        Ok(Connection {
46            stream,
47            buf: BytesMut::with_capacity(4096),
48        })
49    }
50
51    /// Reads a frame from the connection and parses it
52    pub async fn read_frame(&mut self) -> Result<Frame, ConnectionError> {
53        loop {
54            if let Some(frame) = self.parse_frame()? {
55                return Ok(frame);
56            }
57
58            if self.stream.read_buf(&mut self.buf).await? == 0 {
59                return Err(ConnectionError::Eof);
60            }
61        }
62    }
63
64    fn parse_frame(&mut self) -> Result<Option<Frame>, ConnectionError> {
65        let mut cursor = Cursor::new(&self.buf[..]);
66        match frame::parse(&mut cursor) {
67            Ok(frame) => {
68                self.buf.advance(cursor.position() as usize);
69                Ok(Some(frame))
70            }
71            Err(ParseFrameError::Incomplete) => Ok(None),
72            Err(e) => Err(e.into()),
73        }
74    }
75
76    /// Writes a frame to the connection
77    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ConnectionError> {
78        match frame {
79            Frame::Array(array) => {
80                self.stream.write_u8(ARRAY_IDENT).await?;
81                self.stream
82                    .write_all(format!("{}\r\n", array.len()).as_bytes())
83                    .await?;
84                for value in array {
85                    self.write_value(value).await?;
86                }
87            }
88            Frame::Map(map) => {
89                self.stream.write_u8(MAP_IDENT).await?;
90                self.stream
91                    .write_all(format!("{}\r\n", map.len() / 2).as_bytes())
92                    .await?;
93                for value in map {
94                    self.write_value(value).await?;
95                }
96            }
97            _ => self.write_value(frame).await?,
98        }
99
100        self.stream.flush().await?;
101        Ok(())
102    }
103
104    async fn write_value(&mut self, frame: &Frame) -> Result<(), ConnectionError> {
105        match frame {
106            Frame::String(data) => {
107                let len = data.len();
108                self.stream.write_u8(STRING_IDENT).await?;
109                self.stream
110                    .write_all(format!("{}\r\n", len).as_bytes())
111                    .await?;
112                self.stream.write_all(data).await?;
113                self.stream.write_all(b"\r\n").await?;
114            }
115            Frame::Integer(data) => {
116                self.stream.write_u8(INTEGER_IDENT).await?;
117                self.stream
118                    .write_all(format!("{}\r\n", data).as_bytes())
119                    .await?;
120            }
121            Frame::Boolean(data) => {
122                self.stream.write_u8(BOOLEAN_IDENT).await?;
123                if *data {
124                    self.stream
125                        .write_all(format!("{}\r\n", 1).as_bytes())
126                        .await?;
127                } else {
128                    self.stream
129                        .write_all(format!("{}\r\n", 0).as_bytes())
130                        .await?;
131                }
132            }
133            Frame::Null => {
134                self.stream.write_all(b"-\r\n").await?;
135            }
136            Frame::Double(data) => {
137                self.stream.write_u8(DOUBLE_IDENT).await?;
138                self.stream
139                    .write_all(format!("{}\r\n", data).as_bytes())
140                    .await?;
141            }
142            Frame::Error(data) => {
143                let len = data.len();
144                self.stream.write_u8(ERROR_IDENT).await?;
145                self.stream
146                    .write_all(format!("{}\r\n", len).as_bytes())
147                    .await?;
148                self.stream.write_all(data).await?;
149                self.stream.write_all(b"\r\n").await?;
150            }
151            _ => unreachable!(),
152        }
153
154        Ok(())
155    }
156}
157
158impl ConnectionOptions {
159    /// Creates a new connection option
160    pub fn new(host: &str, port: u16) -> Self {
161        ConnectionOptions {
162            host: host.to_string(),
163            port,
164        }
165    }
166
167    /// Returns the connection host
168    pub fn host(&self) -> &str {
169        &self.host
170    }
171
172    /// Returns the connection port
173    pub fn port(&self) -> u16 {
174        self.port
175    }
176}