1use std::collections::HashMap;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use scylla_cql::frame::frame_errors::FrameHeaderParseError;
5use scylla_cql::frame::protocol_features::ProtocolFeatures;
6pub use scylla_cql::frame::request::RequestOpcode;
7use scylla_cql::frame::request::{RequestDeserializationError, RequestV2};
8pub use scylla_cql::frame::response::ResponseOpcode;
9use scylla_cql::frame::{response::error::DbError, types};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use tracing::warn;
13
14use crate::errors::ReadFrameError;
15use crate::proxy::CompressionReader;
16
17const HEADER_SIZE: usize = 9;
18
19#[derive(Debug, Copy, Clone, PartialEq, Eq)]
21pub struct FrameParams {
22 pub version: u8,
23 pub flags: u8,
24 pub stream: i16,
25}
26
27impl FrameParams {
28 pub const fn for_request(&self) -> FrameParams {
29 Self {
30 version: self.version & 0x7F,
31 ..*self
32 }
33 }
34 pub const fn for_response(&self) -> FrameParams {
35 Self {
36 version: 0x80 | (self.version & 0x7F),
37 ..*self
38 }
39 }
40}
41
42#[derive(Copy, Clone, Debug)]
43pub(crate) enum FrameType {
44 Request,
45 Response,
46}
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq)]
49pub(crate) enum FrameOpcode {
50 Request(RequestOpcode),
51 Response(ResponseOpcode),
52}
53
54#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct RequestFrame {
56 pub params: FrameParams,
57 pub opcode: RequestOpcode,
58 pub body: Bytes,
59}
60
61impl RequestFrame {
62 pub(crate) async fn write(
63 &self,
64 writer: &mut (impl AsyncWrite + Unpin),
65 compression: &CompressionReader,
66 ) -> Result<(), tokio::io::Error> {
67 write_frame(
68 self.params,
69 FrameOpcode::Request(self.opcode),
70 &self.body,
71 writer,
72 compression,
73 )
74 .await
75 }
76
77 pub fn deserialize(
78 &self,
79 features: &ProtocolFeatures,
80 ) -> Result<RequestV2<'_>, RequestDeserializationError> {
81 RequestV2::deserialize(&mut &self.body[..], self.opcode, features)
82 }
83}
84#[derive(Clone, Debug, PartialEq, Eq)]
85pub struct ResponseFrame {
86 pub params: FrameParams,
87 pub opcode: ResponseOpcode,
88 pub body: Bytes,
89}
90
91impl ResponseFrame {
92 pub fn forged_error(
95 request_params: FrameParams,
96 error: DbError,
97 msg: Option<&str>,
98 ) -> Result<Self, std::num::TryFromIntError> {
99 let msg = msg.unwrap_or("Proxy-triggered error.");
100 let len_bytes = (msg.len() as u16).to_be_bytes(); let code_bytes = error.code(&ProtocolFeatures::default()).to_be_bytes(); let body_len = msg.len() + code_bytes.len() + len_bytes.len();
103 let mut buf = BytesMut::with_capacity(body_len);
104
105 buf.extend_from_slice(&code_bytes);
106 buf.extend_from_slice(&len_bytes);
107 buf.extend_from_slice(msg.as_bytes());
108
109 serialize_error_specific_fields(&mut buf, error)?;
110
111 Ok(ResponseFrame {
112 params: request_params.for_response(),
113 opcode: ResponseOpcode::Error,
114 body: buf.freeze(),
115 })
116 }
117
118 pub fn forged_supported(
120 request_params: FrameParams,
121 options: &HashMap<String, Vec<String>>,
122 ) -> Result<Self, std::num::TryFromIntError> {
123 let mut buf = BytesMut::new();
124 types::write_string_multimap(options, &mut buf)?;
125
126 Ok(ResponseFrame {
127 params: request_params.for_response(),
128 opcode: ResponseOpcode::Supported,
129 body: buf.freeze(),
130 })
131 }
132
133 pub fn forged_ready(request_params: FrameParams) -> Self {
134 ResponseFrame {
135 params: request_params.for_response(),
136 opcode: ResponseOpcode::Ready,
137 body: Bytes::new(),
138 }
139 }
140
141 pub(crate) async fn write(
142 &self,
143 writer: &mut (impl AsyncWrite + Unpin),
144 compression: &CompressionReader,
145 ) -> Result<(), tokio::io::Error> {
146 write_frame(
147 self.params,
148 FrameOpcode::Response(self.opcode),
149 &self.body,
150 writer,
151 compression,
152 )
153 .await
154 }
155}
156
157fn serialize_error_specific_fields(
158 buf: &mut BytesMut,
159 error: DbError,
160) -> Result<(), std::num::TryFromIntError> {
161 match error {
162 DbError::Unavailable {
163 consistency,
164 required,
165 alive,
166 } => {
167 types::write_consistency(consistency, buf);
168 types::write_int(required, buf);
169 types::write_int(alive, buf);
170 }
171 DbError::WriteTimeout {
172 consistency,
173 received,
174 required,
175 write_type,
176 } => {
177 types::write_consistency(consistency, buf);
178 types::write_int(received, buf);
179 types::write_int(required, buf);
180 types::write_string(write_type.as_str(), buf)?;
181 }
182 DbError::ReadTimeout {
183 consistency,
184 received,
185 required,
186 data_present,
187 } => {
188 types::write_consistency(consistency, buf);
189 types::write_int(received, buf);
190 types::write_int(required, buf);
191 buf.put_u8(u8::from(data_present));
192 }
193 DbError::ReadFailure {
194 consistency,
195 received,
196 required,
197 numfailures,
198 data_present,
199 } => {
200 types::write_consistency(consistency, buf);
201 types::write_int(received, buf);
202 types::write_int(required, buf);
203 types::write_int(numfailures, buf);
204 buf.put_u8(u8::from(data_present));
205 }
206 DbError::WriteFailure {
207 consistency,
208 received,
209 required,
210 numfailures,
211 write_type,
212 } => {
213 types::write_consistency(consistency, buf);
214 types::write_int(received, buf);
215 types::write_int(required, buf);
216 types::write_int(numfailures, buf);
217 types::write_string(write_type.as_str(), buf)?;
218 }
219 DbError::FunctionFailure {
220 keyspace,
221 function,
222 arg_types,
223 } => {
224 types::write_string(keyspace.as_str(), buf)?;
225 types::write_string(function.as_str(), buf)?;
226 types::write_string_list(&arg_types, buf)?;
227 }
228 DbError::AlreadyExists { keyspace, table } => {
229 types::write_string(keyspace.as_str(), buf)?;
230 types::write_string(table.as_str(), buf)?;
231 }
232 DbError::Unprepared { statement_id } => {
233 types::write_short_bytes(statement_id.as_ref(), buf)?;
234 }
235 _ => (),
236 }
237 Ok(())
238}
239
240pub(crate) async fn write_frame(
241 params: FrameParams,
242 opcode: FrameOpcode,
243 body: &[u8],
244 writer: &mut (impl AsyncWrite + Unpin),
245 compression: &CompressionReader,
246) -> Result<(), tokio::io::Error> {
247 let compressed_body = compression
248 .maybe_compress_body(params.flags, body)
249 .map_err(tokio::io::Error::other)?;
250
251 let body = compressed_body.as_deref().unwrap_or(body);
252
253 let mut header = [0; HEADER_SIZE];
254
255 header[0] = params.version;
256 header[1] = params.flags;
257 header[2..=3].copy_from_slice(¶ms.stream.to_be_bytes());
258 header[4] = match opcode {
259 FrameOpcode::Request(op) => op as u8,
260 FrameOpcode::Response(op) => op as u8,
261 };
262 header[5..9].copy_from_slice(&(body.len() as u32).to_be_bytes());
263
264 writer.write_all(&header).await?;
265 writer.write_all(body).await?;
266 writer.flush().await?;
267 Ok(())
268}
269
270pub(crate) async fn read_frame(
271 reader: &mut (impl AsyncRead + Unpin),
272 frame_type: FrameType,
273 compression: &CompressionReader,
274) -> Result<(FrameParams, FrameOpcode, Bytes), ReadFrameError> {
275 let mut raw_header = [0u8; HEADER_SIZE];
276 reader
277 .read_exact(&mut raw_header[..])
278 .await
279 .map_err(FrameHeaderParseError::HeaderIoError)?;
280
281 let mut buf = &raw_header[..];
282
283 let version = buf.get_u8();
284 {
285 let (err, valid_direction, direction_str) = match frame_type {
286 FrameType::Request => (FrameHeaderParseError::FrameFromServer, 0x00, "request"),
287 FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"),
288 };
289 if version & 0x80 != valid_direction {
290 return Err(err.into());
291 }
292 let protocol_version = version & 0x7F;
293 if protocol_version != 0x04 {
294 warn!(
295 "Received {} with protocol version {}.",
296 direction_str, protocol_version
297 );
298 }
299 }
300
301 let flags = buf.get_u8();
302 let stream = buf.get_i16();
303
304 let frame_params = FrameParams {
305 version,
306 flags,
307 stream,
308 };
309
310 let opcode = match frame_type {
311 FrameType::Request => FrameOpcode::Request(
312 RequestOpcode::try_from(buf.get_u8())
313 .map_err(|_| FrameHeaderParseError::FrameFromServer)?,
314 ),
315 FrameType::Response => FrameOpcode::Response(
316 ResponseOpcode::try_from(buf.get_u8())
317 .map_err(|_| FrameHeaderParseError::FrameFromClient)?,
318 ),
319 };
320
321 let length = buf.get_u32() as usize;
322
323 let mut body = Vec::with_capacity(length).limit(length);
324
325 while body.has_remaining_mut() {
326 let n = reader
327 .read_buf(&mut body)
328 .await
329 .map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?;
330 if n == 0 {
331 return Err(
333 FrameHeaderParseError::ConnectionClosed(body.remaining_mut(), length).into(),
334 );
335 }
336 }
337
338 let body = compression.maybe_decompress_body(flags, body.into_inner().into())?;
339
340 Ok((frame_params, opcode, body))
341}
342
343pub(crate) async fn read_request_frame(
344 reader: &mut (impl AsyncRead + Unpin),
345 compression: &CompressionReader,
346) -> Result<RequestFrame, ReadFrameError> {
347 read_frame(reader, FrameType::Request, compression)
348 .await
349 .map(|(params, opcode, body)| RequestFrame {
350 params,
351 opcode: match opcode {
352 FrameOpcode::Request(op) => op,
353 FrameOpcode::Response(_) => unreachable!(),
354 },
355 body,
356 })
357}
358
359pub(crate) async fn read_response_frame(
360 reader: &mut (impl AsyncRead + Unpin),
361 compression: &CompressionReader,
362) -> Result<ResponseFrame, ReadFrameError> {
363 read_frame(reader, FrameType::Response, compression)
364 .await
365 .map(|(params, opcode, body)| ResponseFrame {
366 params,
367 opcode: match opcode {
368 FrameOpcode::Request(_) => unreachable!(),
369 FrameOpcode::Response(op) => op,
370 },
371 body,
372 })
373}