1use crate::frame::{
2 self, Frame, ParseFrameError, ARRAY_IDENT, BOOLEAN_IDENT, DOUBLE_IDENT, ERROR_IDENT,
3 INTEGER_IDENT, MAP_IDENT, STRING_IDENT,
4};
5use bytes::{Buf, BytesMut};
6use std::io::{self, Cursor};
7use thiserror::Error;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10
11#[derive(Debug)]
13pub struct ConnectionOptions {
14 host: String,
15 port: u16,
16}
17
18#[derive(Debug)]
19pub struct Connection {
21 stream: TcpStream,
22 buf: BytesMut,
23}
24
25#[derive(Debug, Error)]
26pub enum ConnectionError {
28 #[error(transparent)]
30 TCPError(#[from] io::Error),
31
32 #[error("server did not send any response")]
34 Eof,
35
36 #[error(transparent)]
38 FrameError(#[from] ParseFrameError),
39}
40
41impl Connection {
42 pub async fn connect(options: &ConnectionOptions) -> Result<Self, ConnectionError> {
44 let stream = TcpStream::connect(format!("{}:{}", options.host(), options.port())).await?;
45 Ok(Connection {
46 stream,
47 buf: BytesMut::with_capacity(4096),
48 })
49 }
50
51 pub async fn read_frame(&mut self) -> Result<Frame, ConnectionError> {
53 loop {
54 if let Some(frame) = self.parse_frame()? {
55 return Ok(frame);
56 }
57
58 if self.stream.read_buf(&mut self.buf).await? == 0 {
59 return Err(ConnectionError::Eof);
60 }
61 }
62 }
63
64 fn parse_frame(&mut self) -> Result<Option<Frame>, ConnectionError> {
65 let mut cursor = Cursor::new(&self.buf[..]);
66 match frame::parse(&mut cursor) {
67 Ok(frame) => {
68 self.buf.advance(cursor.position() as usize);
69 Ok(Some(frame))
70 }
71 Err(ParseFrameError::Incomplete) => Ok(None),
72 Err(e) => Err(e.into()),
73 }
74 }
75
76 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ConnectionError> {
78 match frame {
79 Frame::Array(array) => {
80 self.stream.write_u8(ARRAY_IDENT).await?;
81 self.stream
82 .write_all(format!("{}\r\n", array.len()).as_bytes())
83 .await?;
84 for value in array {
85 self.write_value(value).await?;
86 }
87 }
88 Frame::Map(map) => {
89 self.stream.write_u8(MAP_IDENT).await?;
90 self.stream
91 .write_all(format!("{}\r\n", map.len() / 2).as_bytes())
92 .await?;
93 for value in map {
94 self.write_value(value).await?;
95 }
96 }
97 _ => self.write_value(frame).await?,
98 }
99
100 self.stream.flush().await?;
101 Ok(())
102 }
103
104 async fn write_value(&mut self, frame: &Frame) -> Result<(), ConnectionError> {
105 match frame {
106 Frame::String(data) => {
107 let len = data.len();
108 self.stream.write_u8(STRING_IDENT).await?;
109 self.stream
110 .write_all(format!("{}\r\n", len).as_bytes())
111 .await?;
112 self.stream.write_all(data).await?;
113 self.stream.write_all(b"\r\n").await?;
114 }
115 Frame::Integer(data) => {
116 self.stream.write_u8(INTEGER_IDENT).await?;
117 self.stream
118 .write_all(format!("{}\r\n", data).as_bytes())
119 .await?;
120 }
121 Frame::Boolean(data) => {
122 self.stream.write_u8(BOOLEAN_IDENT).await?;
123 if *data {
124 self.stream
125 .write_all(format!("{}\r\n", 1).as_bytes())
126 .await?;
127 } else {
128 self.stream
129 .write_all(format!("{}\r\n", 0).as_bytes())
130 .await?;
131 }
132 }
133 Frame::Null => {
134 self.stream.write_all(b"-\r\n").await?;
135 }
136 Frame::Double(data) => {
137 self.stream.write_u8(DOUBLE_IDENT).await?;
138 self.stream
139 .write_all(format!("{}\r\n", data).as_bytes())
140 .await?;
141 }
142 Frame::Error(data) => {
143 let len = data.len();
144 self.stream.write_u8(ERROR_IDENT).await?;
145 self.stream
146 .write_all(format!("{}\r\n", len).as_bytes())
147 .await?;
148 self.stream.write_all(data).await?;
149 self.stream.write_all(b"\r\n").await?;
150 }
151 _ => unreachable!(),
152 }
153
154 Ok(())
155 }
156}
157
158impl ConnectionOptions {
159 pub fn new(host: &str, port: u16) -> Self {
161 ConnectionOptions {
162 host: host.to_string(),
163 port,
164 }
165 }
166
167 pub fn host(&self) -> &str {
169 &self.host
170 }
171
172 pub fn port(&self) -> u16 {
174 self.port
175 }
176}