1use http;
2use crate::{
3 codec::{
4 FrameCodec, FrameConfig, FrameReadState, FrameRecv, FrameSend, FrameWriteState, Split,
5 },
6 errors::{ProtocolError, WsError},
7 frame::OpCode,
8 protocol::standard_handshake_resp_check,
9 Message,
10};
11use bytes::Buf;
12use std::borrow::Cow;
13use std::io::{Read, Write};
14
15macro_rules! impl_recv {
16 () => {
17 pub fn receive_raw(&mut self) -> Result<Message<Cow<[u8]>>, WsError> {
21 let (header, mut data) = self.frame_codec.receive()?;
22 let close_code = if header.code == OpCode::Close && data.len() >= 2 {
23 let code = if data.len() >= 2 {
24 data.get_u16()
25 } else {
26 1000
27 };
28 Some(code)
29 } else {
30 None
31 };
32 Ok(Message {
33 data: Cow::Borrowed(data),
34 close_code,
35 code: header.code,
36 })
37 }
38
39 pub fn receive(&mut self) -> Result<Message<Cow<str>>, WsError> {
41 let (header, mut data) = self.frame_codec.receive()?;
42 let close_code = if header.code == OpCode::Close && data.len() >= 2 {
43 let code = if data.len() >= 2 {
44 data.get_u16()
45 } else {
46 1000
47 };
48 Some(code)
49 } else {
50 None
51 };
52 let data = if self.validate_utf8 && header.code == OpCode::Text {
53 std::str::from_utf8(data).map_err(|_| WsError::ProtocolError {
54 close_code: 1001,
55 error: ProtocolError::InvalidUtf8,
56 })?
57 } else {
58 unsafe { std::str::from_utf8_unchecked(data) }
59 };
60 Ok(Message {
61 data: Cow::Borrowed(data),
62 close_code,
63 code: header.code,
64 })
65 }
66 };
67}
68
69macro_rules! impl_send {
70 () => {
71 pub fn ping<'a>(&mut self, msg: &'a str) -> Result<(), WsError> {
73 self.send((OpCode::Ping, msg))
74 }
75
76 pub fn pong<'a>(&mut self, msg: &'a str) -> Result<(), WsError> {
78 self.send((OpCode::Pong, msg))
79 }
80
81 pub fn close<'a>(&mut self, code: u16, msg: &'a str) -> Result<(), WsError> {
83 self.send((code, msg))
84 }
85
86 pub fn send<'a, T: Into<Message<Cow<'a, str>>>>(&mut self, msg: T) -> Result<(), WsError> {
88 let msg: Message<Cow<'a, str>> = msg.into();
89 if let Some(close_code) = msg.close_code {
90 if msg.code == OpCode::Close {
91 let mut data = close_code.to_be_bytes().to_vec();
92 data.extend_from_slice(msg.data.as_bytes());
93 self.frame_codec.send(msg.code, &data)
94 } else {
95 self.frame_codec.send(msg.code, msg.data.as_bytes())
96 }
97 } else {
98 self.frame_codec.send(msg.code, msg.data.as_bytes())
99 }
100 }
101
102 pub fn flush(&mut self) -> Result<(), WsError> {
104 self.frame_codec.flush()
105 }
106 };
107}
108
109pub struct StringRecv<S: Read> {
111 frame_codec: FrameRecv<S>,
112 validate_utf8: bool,
113}
114
115impl<S: Read> StringRecv<S> {
116 pub fn new(stream: S, state: FrameReadState, validate_utf8: bool) -> Self {
118 Self {
119 frame_codec: FrameRecv::new(stream, state),
120 validate_utf8,
121 }
122 }
123
124 impl_recv! {}
125}
126
127pub struct StringSend<S: Write> {
129 frame_codec: FrameSend<S>,
130}
131
132impl<S: Write> StringSend<S> {
133 pub fn new(stream: S, state: FrameWriteState) -> Self {
135 Self {
136 frame_codec: FrameSend::new(stream, state),
137 }
138 }
139
140 impl_send! {}
141}
142
143pub struct StringCodec<S: Read + Write> {
145 frame_codec: FrameCodec<S>,
146 validate_utf8: bool,
147}
148
149impl<S: Read + Write> StringCodec<S> {
150 pub fn new(stream: S) -> Self {
152 Self {
153 frame_codec: FrameCodec::new(stream),
154 validate_utf8: false,
155 }
156 }
157
158 pub fn new_with(stream: S, config: FrameConfig, validate_utf8: bool) -> Self {
160 Self {
161 frame_codec: FrameCodec::new_with(stream, config),
162 validate_utf8,
163 }
164 }
165
166 pub fn stream_mut(&mut self) -> &mut S {
168 self.frame_codec.stream_mut()
169 }
170
171 pub fn factory(_req: http::Request<()>, stream: S) -> Result<Self, WsError> {
173 let config = FrameConfig {
174 mask_send_frame: false,
175 ..Default::default()
176 };
177 Ok(Self::new_with(stream, config, true))
178 }
179
180 pub fn check_fn(key: String, resp: http::Response<()>, stream: S) -> Result<Self, WsError> {
182 standard_handshake_resp_check(key.as_bytes(), &resp)?;
183 Ok(Self::new_with(stream, FrameConfig::default(), true))
184 }
185
186 impl_recv! {}
187
188 impl_send! {}
189}
190
191impl<R, W, S> StringCodec<S>
192where
193 R: Read,
194 W: Write,
195 S: Read + Write + Split<R = R, W = W>,
196{
197 pub fn split(self) -> (StringRecv<R>, StringSend<W>) {
199 let FrameCodec {
200 stream,
201 read_state,
202 write_state,
203 } = self.frame_codec;
204 let (read, write) = stream.split();
205 (
206 StringRecv::new(read, read_state, self.validate_utf8),
207 StringSend::new(write, write_state),
208 )
209 }
210}