Skip to main content

scylla_proxy/
frame.rs

1use std::collections::HashMap;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use scylla_cql::frame::frame_errors::FrameHeaderParseError;
5use scylla_cql::frame::protocol_features::ProtocolFeatures;
6pub use scylla_cql::frame::request::RequestOpcode;
7use scylla_cql::frame::request::{RequestDeserializationError, RequestV2};
8pub use scylla_cql::frame::response::ResponseOpcode;
9use scylla_cql::frame::{response::error::DbError, types};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use tracing::warn;
13
14use crate::errors::ReadFrameError;
15use crate::proxy::CompressionReader;
16
17const HEADER_SIZE: usize = 9;
18
19// Parts of the frame header which are not determined by the request/response type.
20#[derive(Debug, Copy, Clone, PartialEq, Eq)]
21pub struct FrameParams {
22    pub version: u8,
23    pub flags: u8,
24    pub stream: i16,
25}
26
27impl FrameParams {
28    pub const fn for_request(&self) -> FrameParams {
29        Self {
30            version: self.version & 0x7F,
31            ..*self
32        }
33    }
34    pub const fn for_response(&self) -> FrameParams {
35        Self {
36            version: 0x80 | (self.version & 0x7F),
37            ..*self
38        }
39    }
40}
41
42#[derive(Copy, Clone, Debug)]
43pub(crate) enum FrameType {
44    Request,
45    Response,
46}
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq)]
49pub(crate) enum FrameOpcode {
50    Request(RequestOpcode),
51    Response(ResponseOpcode),
52}
53
54#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct RequestFrame {
56    pub params: FrameParams,
57    pub opcode: RequestOpcode,
58    pub body: Bytes,
59}
60
61impl RequestFrame {
62    pub(crate) async fn write(
63        &self,
64        writer: &mut (impl AsyncWrite + Unpin),
65        compression: &CompressionReader,
66    ) -> Result<(), tokio::io::Error> {
67        write_frame(
68            self.params,
69            FrameOpcode::Request(self.opcode),
70            &self.body,
71            writer,
72            compression,
73        )
74        .await
75    }
76
77    pub fn deserialize(
78        &self,
79        features: &ProtocolFeatures,
80    ) -> Result<RequestV2<'_>, RequestDeserializationError> {
81        RequestV2::deserialize(&mut &self.body[..], self.opcode, features)
82    }
83}
84#[derive(Clone, Debug, PartialEq, Eq)]
85pub struct ResponseFrame {
86    pub params: FrameParams,
87    pub opcode: ResponseOpcode,
88    pub body: Bytes,
89}
90
91impl ResponseFrame {
92    /// Creates a response frame that signifies the given DbError type.
93    /// Useful for testing server-side error handling in drivers.
94    pub fn forged_error(
95        request_params: FrameParams,
96        error: DbError,
97        msg: Option<&str>,
98    ) -> Result<Self, std::num::TryFromIntError> {
99        let msg = msg.unwrap_or("Proxy-triggered error.");
100        let len_bytes = (msg.len() as u16).to_be_bytes(); // string len is a short in CQL protocol
101        let code_bytes = error.code(&ProtocolFeatures::default()).to_be_bytes(); // TODO: configurable features
102        let body_len = msg.len() + code_bytes.len() + len_bytes.len();
103        let mut buf = BytesMut::with_capacity(body_len);
104
105        buf.extend_from_slice(&code_bytes);
106        buf.extend_from_slice(&len_bytes);
107        buf.extend_from_slice(msg.as_bytes());
108
109        serialize_error_specific_fields(&mut buf, error)?;
110
111        Ok(ResponseFrame {
112            params: request_params.for_response(),
113            opcode: ResponseOpcode::Error,
114            body: buf.freeze(),
115        })
116    }
117
118    /// Creates a Supported response frame with given supported options.
119    pub fn forged_supported(
120        request_params: FrameParams,
121        options: &HashMap<String, Vec<String>>,
122    ) -> Result<Self, std::num::TryFromIntError> {
123        let mut buf = BytesMut::new();
124        types::write_string_multimap(options, &mut buf)?;
125
126        Ok(ResponseFrame {
127            params: request_params.for_response(),
128            opcode: ResponseOpcode::Supported,
129            body: buf.freeze(),
130        })
131    }
132
133    pub fn forged_ready(request_params: FrameParams) -> Self {
134        ResponseFrame {
135            params: request_params.for_response(),
136            opcode: ResponseOpcode::Ready,
137            body: Bytes::new(),
138        }
139    }
140
141    pub(crate) async fn write(
142        &self,
143        writer: &mut (impl AsyncWrite + Unpin),
144        compression: &CompressionReader,
145    ) -> Result<(), tokio::io::Error> {
146        write_frame(
147            self.params,
148            FrameOpcode::Response(self.opcode),
149            &self.body,
150            writer,
151            compression,
152        )
153        .await
154    }
155}
156
157fn serialize_error_specific_fields(
158    buf: &mut BytesMut,
159    error: DbError,
160) -> Result<(), std::num::TryFromIntError> {
161    match error {
162        DbError::Unavailable {
163            consistency,
164            required,
165            alive,
166        } => {
167            types::write_consistency(consistency, buf);
168            types::write_int(required, buf);
169            types::write_int(alive, buf);
170        }
171        DbError::WriteTimeout {
172            consistency,
173            received,
174            required,
175            write_type,
176        } => {
177            types::write_consistency(consistency, buf);
178            types::write_int(received, buf);
179            types::write_int(required, buf);
180            types::write_string(write_type.as_str(), buf)?;
181        }
182        DbError::ReadTimeout {
183            consistency,
184            received,
185            required,
186            data_present,
187        } => {
188            types::write_consistency(consistency, buf);
189            types::write_int(received, buf);
190            types::write_int(required, buf);
191            buf.put_u8(u8::from(data_present));
192        }
193        DbError::ReadFailure {
194            consistency,
195            received,
196            required,
197            numfailures,
198            data_present,
199        } => {
200            types::write_consistency(consistency, buf);
201            types::write_int(received, buf);
202            types::write_int(required, buf);
203            types::write_int(numfailures, buf);
204            buf.put_u8(u8::from(data_present));
205        }
206        DbError::WriteFailure {
207            consistency,
208            received,
209            required,
210            numfailures,
211            write_type,
212        } => {
213            types::write_consistency(consistency, buf);
214            types::write_int(received, buf);
215            types::write_int(required, buf);
216            types::write_int(numfailures, buf);
217            types::write_string(write_type.as_str(), buf)?;
218        }
219        DbError::FunctionFailure {
220            keyspace,
221            function,
222            arg_types,
223        } => {
224            types::write_string(keyspace.as_str(), buf)?;
225            types::write_string(function.as_str(), buf)?;
226            types::write_string_list(&arg_types, buf)?;
227        }
228        DbError::AlreadyExists { keyspace, table } => {
229            types::write_string(keyspace.as_str(), buf)?;
230            types::write_string(table.as_str(), buf)?;
231        }
232        DbError::Unprepared { statement_id } => {
233            types::write_short_bytes(statement_id.as_ref(), buf)?;
234        }
235        _ => (),
236    }
237    Ok(())
238}
239
240pub(crate) async fn write_frame(
241    params: FrameParams,
242    opcode: FrameOpcode,
243    body: &[u8],
244    writer: &mut (impl AsyncWrite + Unpin),
245    compression: &CompressionReader,
246) -> Result<(), tokio::io::Error> {
247    let compressed_body = compression
248        .maybe_compress_body(params.flags, body)
249        .map_err(tokio::io::Error::other)?;
250
251    let body = compressed_body.as_deref().unwrap_or(body);
252
253    let mut header = [0; HEADER_SIZE];
254
255    header[0] = params.version;
256    header[1] = params.flags;
257    header[2..=3].copy_from_slice(&params.stream.to_be_bytes());
258    header[4] = match opcode {
259        FrameOpcode::Request(op) => op as u8,
260        FrameOpcode::Response(op) => op as u8,
261    };
262    header[5..9].copy_from_slice(&(body.len() as u32).to_be_bytes());
263
264    writer.write_all(&header).await?;
265    writer.write_all(body).await?;
266    writer.flush().await?;
267    Ok(())
268}
269
270pub(crate) async fn read_frame(
271    reader: &mut (impl AsyncRead + Unpin),
272    frame_type: FrameType,
273    compression: &CompressionReader,
274) -> Result<(FrameParams, FrameOpcode, Bytes), ReadFrameError> {
275    let mut raw_header = [0u8; HEADER_SIZE];
276    reader
277        .read_exact(&mut raw_header[..])
278        .await
279        .map_err(FrameHeaderParseError::HeaderIoError)?;
280
281    let mut buf = &raw_header[..];
282
283    let version = buf.get_u8();
284    {
285        let (err, valid_direction, direction_str) = match frame_type {
286            FrameType::Request => (FrameHeaderParseError::FrameFromServer, 0x00, "request"),
287            FrameType::Response => (FrameHeaderParseError::FrameFromClient, 0x80, "response"),
288        };
289        if version & 0x80 != valid_direction {
290            return Err(err.into());
291        }
292        let protocol_version = version & 0x7F;
293        if protocol_version != 0x04 {
294            warn!(
295                "Received {} with protocol version {}.",
296                direction_str, protocol_version
297            );
298        }
299    }
300
301    let flags = buf.get_u8();
302    let stream = buf.get_i16();
303
304    let frame_params = FrameParams {
305        version,
306        flags,
307        stream,
308    };
309
310    let opcode = match frame_type {
311        FrameType::Request => FrameOpcode::Request(
312            RequestOpcode::try_from(buf.get_u8())
313                .map_err(|_| FrameHeaderParseError::FrameFromServer)?,
314        ),
315        FrameType::Response => FrameOpcode::Response(
316            ResponseOpcode::try_from(buf.get_u8())
317                .map_err(|_| FrameHeaderParseError::FrameFromClient)?,
318        ),
319    };
320
321    let length = buf.get_u32() as usize;
322
323    let mut body = Vec::with_capacity(length).limit(length);
324
325    while body.has_remaining_mut() {
326        let n = reader
327            .read_buf(&mut body)
328            .await
329            .map_err(|err| FrameHeaderParseError::BodyChunkIoError(body.remaining_mut(), err))?;
330        if n == 0 {
331            // EOF, too early
332            return Err(
333                FrameHeaderParseError::ConnectionClosed(body.remaining_mut(), length).into(),
334            );
335        }
336    }
337
338    let body = compression.maybe_decompress_body(flags, body.into_inner().into())?;
339
340    Ok((frame_params, opcode, body))
341}
342
343pub(crate) async fn read_request_frame(
344    reader: &mut (impl AsyncRead + Unpin),
345    compression: &CompressionReader,
346) -> Result<RequestFrame, ReadFrameError> {
347    read_frame(reader, FrameType::Request, compression)
348        .await
349        .map(|(params, opcode, body)| RequestFrame {
350            params,
351            opcode: match opcode {
352                FrameOpcode::Request(op) => op,
353                FrameOpcode::Response(_) => unreachable!(),
354            },
355            body,
356        })
357}
358
359pub(crate) async fn read_response_frame(
360    reader: &mut (impl AsyncRead + Unpin),
361    compression: &CompressionReader,
362) -> Result<ResponseFrame, ReadFrameError> {
363    read_frame(reader, FrameType::Response, compression)
364        .await
365        .map(|(params, opcode, body)| ResponseFrame {
366            params,
367            opcode: match opcode {
368                FrameOpcode::Request(_) => unreachable!(),
369                FrameOpcode::Response(op) => op,
370            },
371            body,
372        })
373}