1use bytes::Bytes;
4use futures::SinkExt;
5use std::io;
6use tokio_stream::{self as stream, StreamExt};
7
8use crate::codec::{BlockMessage, ChannelMessage};
9use crate::connection::Connection;
10
11#[derive(Debug)]
16pub struct Protocol<C> {
17 conn: C,
18}
19
20impl<C> Protocol<C> {
21 pub fn new(conn: C) -> Self {
23 Protocol { conn }
24 }
25
26 pub fn into_connection(self) -> C {
28 self.conn
29 }
30}
31
32#[cfg(unix)]
33mod unix {
34 use super::*;
35 use std::os::unix::io::{AsRawFd, RawFd};
36
37 impl<C> AsRawFd for Protocol<C>
38 where
39 C: AsRawFd,
40 {
41 fn as_raw_fd(&self) -> RawFd {
42 self.conn.as_raw_fd()
43 }
44 }
45}
46
47impl<C> Protocol<C>
48where
49 C: Connection,
50{
51 pub async fn send_command(&mut self, cmd: impl Into<Bytes>) -> io::Result<()> {
56 self.conn
57 .get_tx_mut()
58 .send(BlockMessage::Command(cmd.into()))
59 .await?;
60 Ok(())
61 }
62
63 pub async fn send_command_with_args(
68 &mut self,
69 cmd: impl Into<Bytes>,
70 packed_args: impl Into<Bytes>,
71 ) -> io::Result<()> {
72 let blocks = vec![
73 Ok(BlockMessage::Command(cmd.into())),
74 Ok(BlockMessage::Data(packed_args.into())),
75 ];
76 self.conn
77 .get_tx_mut()
78 .send_all(&mut stream::iter(blocks))
79 .await?;
80 Ok(())
81 }
82
83 pub async fn send_data(&mut self, data: impl Into<Bytes>) -> io::Result<()> {
88 self.conn
89 .get_tx_mut()
90 .send(BlockMessage::Data(data.into()))
91 .await?;
92 Ok(())
93 }
94
95 pub async fn query(&mut self, cmd: impl Into<Bytes>) -> io::Result<Bytes> {
99 self.send_command(cmd).await?;
100 self.fetch_result().await
101 }
102
103 pub async fn query_with_args(
107 &mut self,
108 cmd: impl Into<Bytes>,
109 packed_args: impl Into<Bytes>,
110 ) -> io::Result<Bytes> {
111 self.send_command_with_args(cmd, packed_args).await?;
112 self.fetch_result().await
113 }
114
115 pub async fn fetch_response(&mut self) -> io::Result<ChannelMessage> {
119 let v = self.conn.get_rx_mut().try_next().await?;
120 expect_msg(v)
121 }
122
123 async fn fetch_result(&mut self) -> io::Result<Bytes> {
124 loop {
125 match self.fetch_response().await? {
126 ChannelMessage::Data(b'r', data) => {
127 return Ok(data);
128 }
129 ChannelMessage::Data(..) => {
130 }
132 ChannelMessage::InputRequest(..)
133 | ChannelMessage::LineRequest(..)
134 | ChannelMessage::SystemRequest(..) => {
135 return Err(io::Error::new(
136 io::ErrorKind::InvalidData,
137 "unsupported request while querying",
138 ));
139 }
140 }
141 }
142 }
143}
144
145fn expect_msg(v: Option<ChannelMessage>) -> Result<ChannelMessage, io::Error> {
146 v.ok_or(io::Error::new(
147 io::ErrorKind::UnexpectedEof,
148 "no result code received",
149 ))
150}