1use crate::codec::CodecState;
4#[cfg(feature = "cassandra")]
5use crate::frame::{cassandra, cassandra::CassandraMetadata};
6#[cfg(feature = "valkey")]
7use crate::frame::{valkey::valkey_query_type, ValkeyFrame};
8use crate::frame::{Frame, MessageType};
9use anyhow::{anyhow, Context, Result};
10use bytes::Bytes;
11use derivative::Derivative;
12use fnv::FnvBuildHasher;
13use nonzero_ext::nonzero;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::num::NonZeroU32;
17use std::time::Instant;
18
19pub type MessageIdMap<T> = HashMap<MessageId, T, FnvBuildHasher>;
20pub type MessageIdSet = HashSet<MessageId, FnvBuildHasher>;
21
22pub enum Metadata {
23 #[cfg(feature = "cassandra")]
24 Cassandra(CassandraMetadata),
25 #[cfg(feature = "valkey")]
26 Valkey,
27 #[cfg(feature = "kafka")]
28 Kafka,
29 #[cfg(feature = "opensearch")]
30 OpenSearch,
31}
32
33impl Metadata {
34 pub fn to_error_response(&self, error: String) -> Result<Message> {
38 #[allow(unreachable_code)]
39 Ok(Message::from_frame(match self {
40 #[cfg(feature = "valkey")]
41 Metadata::Valkey => {
42 let message = format!("ERR {error}")
44 .replace("\r\n", " ")
45 .replace('\n', " ");
46 Frame::Valkey(ValkeyFrame::Error(message.into()))
47 }
48 #[cfg(feature = "cassandra")]
49 Metadata::Cassandra(meta) => Frame::Cassandra(meta.to_error_response(error)),
50 #[cfg(feature = "kafka")]
55 Metadata::Kafka => return Err(anyhow!(error).context(
56 "A generic error cannot be formed because the kafka protocol does not support it",
57 )),
58 #[cfg(feature = "opensearch")]
59 Metadata::OpenSearch => unimplemented!(),
60 }))
61 }
62}
63
64pub type Messages = Vec<Message>;
65
66pub type MessageId = u128;
68
69#[derive(Derivative, Debug, Clone)]
81#[derivative(PartialEq)]
82pub struct Message {
83 inner: Option<MessageInner>,
86
87 #[derivative(PartialEq = "ignore")]
96 pub(crate) received_from_source_or_sink_at: Option<Instant>,
97 pub(crate) codec_state: CodecState,
98
99 #[derivative(PartialEq = "ignore")]
101 pub(crate) id: MessageId,
102 #[derivative(PartialEq = "ignore")]
103 pub(crate) request_id: Option<MessageId>,
104}
105
106impl Message {
108 pub fn from_bytes_at_instant(
112 bytes: Bytes,
113 codec_state: CodecState,
114 received_from_source_or_sink_at: Option<Instant>,
115 ) -> Self {
116 Message {
117 inner: Some(MessageInner::RawBytes {
118 bytes,
119 message_type: MessageType::from(&codec_state),
120 }),
121 codec_state,
122 received_from_source_or_sink_at,
123 id: rand::random(),
124 request_id: None,
125 }
126 }
127
128 pub fn from_bytes_and_frame_at_instant(
132 bytes: Bytes,
133 frame: Frame,
134 received_from_source_or_sink_at: Option<Instant>,
135 ) -> Self {
136 Message {
137 codec_state: frame.as_codec_state(),
138 inner: Some(MessageInner::Parsed { bytes, frame }),
139 received_from_source_or_sink_at,
140 id: rand::random(),
141 request_id: None,
142 }
143 }
144
145 pub fn from_frame_at_instant(
149 frame: Frame,
150 received_from_source_or_sink_at: Option<Instant>,
151 ) -> Self {
152 Message {
153 codec_state: frame.as_codec_state(),
154 inner: Some(MessageInner::Modified { frame }),
155 received_from_source_or_sink_at,
156 id: rand::random(),
157 request_id: None,
158 }
159 }
160
161 pub fn from_frame_diverged(frame: Frame, diverged_from: &Message) -> Self {
164 Message {
165 codec_state: frame.as_codec_state(),
166 inner: Some(MessageInner::Modified { frame }),
167 received_from_source_or_sink_at: diverged_from.received_from_source_or_sink_at,
168 id: diverged_from.id(),
169 request_id: None,
170 }
171 }
172
173 pub fn from_bytes(bytes: Bytes, codec_state: CodecState) -> Self {
175 Self::from_bytes_at_instant(bytes, codec_state, None)
176 }
177
178 pub fn from_frame(frame: Frame) -> Self {
180 Self::from_frame_at_instant(frame, None)
181 }
182}
183
184impl Message {
186 pub fn frame(&mut self) -> Option<&mut Frame> {
198 let (inner, result) = self.inner.take().unwrap().ensure_parsed(self.codec_state);
199 self.inner = Some(inner);
200 if let Err(err) = result {
201 tracing::error!("{:?}", err.context("Failed to parse frame"));
203 return None;
204 }
205
206 match self.inner.as_mut().unwrap() {
207 MessageInner::RawBytes { .. } => {
208 unreachable!("Cannot be RawBytes because ensure_parsed was called")
209 }
210 MessageInner::Parsed { frame, .. } => Some(frame),
211 MessageInner::Modified { frame } => Some(frame),
212 }
213 }
214
215 pub fn into_frame(mut self) -> Option<Frame> {
218 let (inner, result) = self.inner.take().unwrap().ensure_parsed(self.codec_state);
219 if let Err(err) = result {
220 tracing::error!("{:?}", err.context("Failed to parse frame"));
222 return None;
223 }
224
225 match inner {
226 MessageInner::RawBytes { .. } => {
227 unreachable!("Cannot be RawBytes because ensure_parsed was called")
228 }
229 MessageInner::Parsed { frame, .. } => Some(frame),
230 MessageInner::Modified { frame } => Some(frame),
231 }
232 }
233
234 pub fn id(&self) -> MessageId {
236 self.id
237 }
238
239 pub fn request_id(&self) -> Option<MessageId> {
244 self.request_id
245 }
246
247 pub fn set_request_id(&mut self, request_id: MessageId) {
248 self.request_id = Some(request_id);
249 }
250
251 pub fn clone_with_new_id(&self) -> Self {
252 Message {
253 inner: self.inner.clone(),
254 received_from_source_or_sink_at: None,
255 codec_state: self.codec_state,
256 id: rand::random(),
257 request_id: self.request_id,
258 }
259 }
260
261 pub fn message_type(&self) -> MessageType {
262 match self.inner.as_ref().unwrap() {
263 MessageInner::RawBytes { message_type, .. } => *message_type,
264 MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame } => {
265 frame.get_type()
266 }
267 }
268 }
269
270 pub fn ensure_message_type(&self, expected_message_type: MessageType) -> Result<()> {
271 match self.inner.as_ref().unwrap() {
272 MessageInner::RawBytes { message_type, .. } => {
273 if *message_type == expected_message_type || *message_type == MessageType::Dummy {
274 Ok(())
275 } else {
276 Err(anyhow!(
277 "Expected message of type {:?} but was of type {:?}",
278 expected_message_type,
279 message_type
280 ))
281 }
282 }
283 MessageInner::Parsed { frame, .. } => {
284 let message_type = frame.get_type();
285 if message_type == expected_message_type || message_type == MessageType::Dummy {
286 Ok(())
287 } else {
288 Err(anyhow!(
289 "Expected message of type {:?} but was of type {:?}",
290 expected_message_type,
291 frame.name()
292 ))
293 }
294 }
295 MessageInner::Modified { frame } => {
296 let message_type = frame.get_type();
297 if message_type == expected_message_type || message_type == MessageType::Dummy {
298 Ok(())
299 } else {
300 Err(anyhow!(
301 "Expected message of type {:?} but was of type {:?}",
302 expected_message_type,
303 frame.name()
304 ))
305 }
306 }
307 }
308 }
309
310 pub fn into_encodable(self) -> Encodable {
311 match self.inner.unwrap() {
312 MessageInner::RawBytes { bytes, .. } => Encodable::Bytes(bytes),
313 MessageInner::Parsed { bytes, .. } => Encodable::Bytes(bytes),
314 MessageInner::Modified {
315 frame: Frame::Dummy,
316 } => Encodable::Bytes(Bytes::new()),
317 MessageInner::Modified { frame } => Encodable::Frame(frame),
318 }
319 }
320
321 pub fn cell_count(&self) -> Result<NonZeroU32> {
327 Ok(match self.inner.as_ref().unwrap() {
328 MessageInner::RawBytes {
329 #[cfg(feature = "cassandra")]
330 bytes,
331 message_type,
332 ..
333 } => match message_type {
334 #[cfg(feature = "valkey")]
335 MessageType::Valkey => nonzero!(1u32),
336 #[cfg(feature = "cassandra")]
337 MessageType::Cassandra => cassandra::raw_frame::cell_count(bytes)?,
338 #[cfg(feature = "kafka")]
339 MessageType::Kafka => todo!(),
340 MessageType::Dummy => nonzero!(1u32),
341 #[cfg(feature = "opensearch")]
342 MessageType::OpenSearch => todo!(),
343 },
344 MessageInner::Modified { frame } | MessageInner::Parsed { frame, .. } => match frame {
345 #[cfg(feature = "cassandra")]
346 Frame::Cassandra(frame) => frame.cell_count()?,
347 #[cfg(feature = "valkey")]
348 Frame::Valkey(_) => nonzero!(1u32),
349 #[cfg(feature = "kafka")]
350 Frame::Kafka(_) => todo!(),
351 Frame::Dummy => nonzero!(1u32),
352 #[cfg(feature = "opensearch")]
353 Frame::OpenSearch(_) => todo!(),
354 },
355 })
356 }
357
358 pub fn invalidate_cache(&mut self) {
369 self.inner = self.inner.take().map(|x| x.invalidate_cache());
372 }
373
374 pub fn get_query_type(&mut self) -> QueryType {
375 match self.frame() {
376 #[cfg(feature = "cassandra")]
377 Some(Frame::Cassandra(cassandra)) => cassandra.get_query_type(),
378 #[cfg(feature = "valkey")]
379 Some(Frame::Valkey(valkey)) => valkey_query_type(valkey), #[cfg(feature = "kafka")]
381 Some(Frame::Kafka(_)) => todo!(),
382 Some(Frame::Dummy) => todo!(),
383 #[cfg(feature = "opensearch")]
384 Some(Frame::OpenSearch(_)) => todo!(),
385 None => QueryType::ReadWrite,
386 }
387 }
388
389 pub fn from_response_to_error_response(&self, error: String) -> Result<Message> {
391 let mut response = self
392 .metadata()
393 .context("Failed to parse metadata of request or response when producing an error")?
394 .to_error_response(error)?;
395
396 if let Some(request_id) = self.request_id() {
397 response.set_request_id(request_id)
398 }
399
400 Ok(response)
401 }
402
403 pub fn from_request_to_error_response(&self, error: String) -> Result<Message> {
405 let mut request = self
406 .metadata()
407 .context("Failed to parse metadata of request or response when producing an error")?
408 .to_error_response(error)?;
409
410 request.set_request_id(self.id());
411 Ok(request)
412 }
413
414 pub fn metadata(&self) -> Result<Metadata> {
416 match self.inner.as_ref().unwrap() {
417 MessageInner::RawBytes {
418 #[cfg(feature = "cassandra")]
419 bytes,
420 message_type,
421 ..
422 } => match message_type {
423 #[cfg(feature = "cassandra")]
424 MessageType::Cassandra => {
425 Ok(Metadata::Cassandra(cassandra::raw_frame::metadata(bytes)?))
426 }
427 #[cfg(feature = "valkey")]
428 MessageType::Valkey => Ok(Metadata::Valkey),
429 #[cfg(feature = "kafka")]
430 MessageType::Kafka => Ok(Metadata::Kafka),
431 MessageType::Dummy => Err(anyhow!("Dummy has no metadata")),
432 #[cfg(feature = "opensearch")]
433 MessageType::OpenSearch => Err(anyhow!("OpenSearch has no metadata")),
434 },
435 MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame } => match frame {
436 #[cfg(feature = "cassandra")]
437 Frame::Cassandra(frame) => Ok(Metadata::Cassandra(frame.metadata())),
438 #[cfg(feature = "kafka")]
439 Frame::Kafka(_) => Ok(Metadata::Kafka),
440 #[cfg(feature = "valkey")]
441 Frame::Valkey(_) => Ok(Metadata::Valkey),
442 Frame::Dummy => Err(anyhow!("dummy has no metadata")),
443 #[cfg(feature = "opensearch")]
444 Frame::OpenSearch(_) => Err(anyhow!("OpenSearch has no metadata")),
445 },
446 }
447 }
448
449 pub fn replace_with_dummy(&mut self) {
454 self.inner = Some(MessageInner::Modified {
455 frame: Frame::Dummy,
456 });
457 }
458
459 pub(crate) fn response_is_dummy(&mut self) -> bool {
461 match self.message_type() {
462 #[cfg(feature = "valkey")]
463 MessageType::Valkey => false,
464 #[cfg(feature = "cassandra")]
465 MessageType::Cassandra => false,
466 #[cfg(feature = "kafka")]
467 MessageType::Kafka => match self.frame() {
468 Some(Frame::Kafka(crate::frame::kafka::KafkaFrame::Request {
469 body: crate::frame::kafka::RequestBody::Produce(produce),
470 ..
471 })) => produce.acks == 0,
472 _ => false,
473 },
474 #[cfg(feature = "opensearch")]
475 MessageType::OpenSearch => false,
476 MessageType::Dummy => true,
477 }
478 }
479
480 pub fn is_dummy(&self) -> bool {
481 matches!(
482 self.inner,
483 Some(MessageInner::Modified {
484 frame: Frame::Dummy
485 })
486 )
487 }
488
489 pub fn to_backpressure(&mut self) -> Result<Message> {
491 let metadata = self.metadata()?;
492
493 Ok(Message::from_frame_at_instant(
494 match metadata {
495 #[cfg(feature = "cassandra")]
496 Metadata::Cassandra(metadata) => Frame::Cassandra(metadata.backpressure_response()),
497 #[cfg(feature = "valkey")]
498 Metadata::Valkey => unimplemented!(),
499 #[cfg(feature = "kafka")]
500 Metadata::Kafka => unimplemented!(),
501 #[cfg(feature = "opensearch")]
502 Metadata::OpenSearch => unimplemented!(),
503 },
504 #[allow(unreachable_code)]
506 self.received_from_source_or_sink_at,
507 ))
508 }
509
510 pub(crate) fn stream_id(&self) -> Option<i16> {
516 match &self.inner {
517 #[cfg(feature = "cassandra")]
518 Some(MessageInner::RawBytes {
519 bytes,
520 message_type: MessageType::Cassandra,
521 }) => {
522 use bytes::Buf;
523 const HEADER_LEN: usize = 9;
524 if bytes.len() >= HEADER_LEN {
525 Some((&bytes[2..4]).get_i16())
526 } else {
527 None
528 }
529 }
530 Some(MessageInner::RawBytes { .. }) => None,
531 Some(MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame }) => {
532 match frame {
533 #[cfg(feature = "cassandra")]
534 Frame::Cassandra(cassandra) => Some(cassandra.stream_id),
535 #[cfg(feature = "valkey")]
536 Frame::Valkey(_) => None,
537 #[cfg(feature = "kafka")]
538 Frame::Kafka(_) => None,
539 Frame::Dummy => None,
540 #[cfg(feature = "opensearch")]
541 Frame::OpenSearch(_) => None,
542 }
543 }
544 None => None,
545 }
546 }
547
548 pub fn to_high_level_string(&mut self) -> String {
549 if let Some(response) = self.frame() {
550 format!("{}", response)
551 } else if let Some(MessageInner::RawBytes {
552 bytes,
553 message_type,
554 }) = &self.inner
555 {
556 format!("Unparseable {:?} message {:?}", message_type, bytes)
557 } else {
558 unreachable!("self.frame() failed so MessageInner must still be RawBytes")
559 }
560 }
561}
562
563#[derive(PartialEq, Debug, Clone)]
567enum MessageInner {
568 RawBytes {
569 bytes: Bytes,
570 message_type: MessageType,
571 },
572 Parsed {
573 bytes: Bytes,
574 frame: Frame,
575 },
576 Modified {
577 frame: Frame,
578 },
579}
580
581impl MessageInner {
582 fn ensure_parsed(self, codec_state: CodecState) -> (Self, Result<()>) {
583 match self {
584 MessageInner::RawBytes {
585 bytes,
586 message_type,
587 } => match Frame::from_bytes(bytes.clone(), message_type, codec_state) {
588 Ok(frame) => (MessageInner::Parsed { bytes, frame }, Ok(())),
589 Err(err) => (
590 MessageInner::RawBytes {
591 bytes,
592 message_type,
593 },
594 Err(err),
595 ),
596 },
597 MessageInner::Parsed { .. } => (self, Ok(())),
598 MessageInner::Modified { .. } => (self, Ok(())),
599 }
600 }
601
602 fn invalidate_cache(self) -> Self {
603 match self {
604 MessageInner::RawBytes { .. } => {
605 tracing::error!("Invalidated cache but the frame was not parsed");
606 self
607 }
608 MessageInner::Parsed { frame, .. } => MessageInner::Modified { frame },
609 MessageInner::Modified { .. } => self,
610 }
611 }
612}
613
614#[derive(Debug)]
615pub enum Encodable {
616 Bytes(Bytes),
618 Frame(Frame),
620}
621
622#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
623#[serde(deny_unknown_fields)]
624pub enum QueryType {
625 Read,
626 Write,
627 ReadWrite,
628 SchemaChange,
629 PubSubMessage,
630}