1use std::sync::Arc;
2
3use bitflags::bitflags;
4use spacetimedb_sats::buffer::{BufReader, BufWriter, DecodeError};
5use thiserror::Error;
6
7use crate::{
8 error,
9 varint::{decode_varint, encode_varint},
10 Encode, Varchar, DEFAULT_LOG_FORMAT_VERSION,
11};
12
13pub use spacetimedb_primitives::TableId;
15
16pub trait Visitor {
18 type Error: From<DecodeError>;
19 type Row;
21
22 fn visit_insert<'a, R: BufReader<'a>>(
31 &mut self,
32 table_id: TableId,
33 reader: &mut R,
34 ) -> Result<Self::Row, Self::Error>;
35
36 fn visit_delete<'a, R: BufReader<'a>>(
41 &mut self,
42 table_id: TableId,
43 reader: &mut R,
44 ) -> Result<Self::Row, Self::Error>;
45
46 fn skip_row<'a, R: BufReader<'a>>(&mut self, table_id: TableId, reader: &mut R) -> Result<(), Self::Error>;
52
53 fn visit_truncate(&mut self, _table_id: TableId) -> Result<(), Self::Error> {
58 Ok(())
59 }
60
61 fn visit_tx_start(&mut self, _offset: u64) -> Result<(), Self::Error> {
65 Ok(())
66 }
67
68 fn visit_tx_end(&mut self) -> Result<(), Self::Error> {
72 Ok(())
73 }
74
75 fn visit_inputs(&mut self, _inputs: &Inputs) -> Result<(), Self::Error> {
79 Ok(())
80 }
81
82 fn visit_outputs(&mut self, _outputs: &Outputs) -> Result<(), Self::Error> {
86 Ok(())
87 }
88}
89
90bitflags! {
91 #[derive(Clone, Copy)]
92 pub struct Flags: u8 {
93 const HAVE_INPUTS = 0b10000000;
94 const HAVE_OUTPUTS = 0b01000000;
95 const HAVE_MUTATIONS = 0b00100000;
96 }
97}
98
99#[derive(Clone, Debug, PartialEq)]
103pub struct Txdata<T> {
104 pub inputs: Option<Inputs>,
105 pub outputs: Option<Outputs>,
106 pub mutations: Option<Mutations<T>>,
107}
108
109impl<T> Txdata<T> {
110 pub fn is_empty(&self) -> bool {
112 self.inputs.is_none()
113 && self.outputs.is_none()
114 && self.mutations.as_ref().map(Mutations::is_empty).unwrap_or(true)
115 }
116}
117
118impl<T: Encode> Txdata<T> {
119 pub const VERSION: u8 = DEFAULT_LOG_FORMAT_VERSION;
120
121 pub fn encode(&self, buf: &mut impl BufWriter) {
122 let mut flags = Flags::empty();
123 flags.set(Flags::HAVE_INPUTS, self.inputs.is_some());
124 flags.set(Flags::HAVE_OUTPUTS, self.outputs.is_some());
125 flags.set(Flags::HAVE_MUTATIONS, self.mutations.is_some());
126
127 buf.put_u8(flags.bits());
128 if let Some(inputs) = &self.inputs {
129 inputs.encode(buf);
130 }
131 if let Some(outputs) = &self.outputs {
132 outputs.encode(buf);
133 }
134 if let Some(mutations) = &self.mutations {
135 mutations.encode(buf)
136 }
137 }
138
139 pub fn decode<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<Self, V::Error>
141 where
142 V: Visitor<Row = T>,
143 R: BufReader<'a>,
144 {
145 let flags = Flags::from_bits_retain(reader.get_u8()?);
146
147 let inputs = flags
150 .contains(Flags::HAVE_INPUTS)
151 .then(|| Inputs::decode(reader))
152 .transpose()?;
153 if let Some(inputs) = &inputs {
154 visitor.visit_inputs(inputs)?;
155 }
156
157 let outputs = flags
160 .contains(Flags::HAVE_OUTPUTS)
161 .then(|| Outputs::decode(reader))
162 .transpose()?;
163 if let Some(outputs) = &outputs {
164 visitor.visit_outputs(outputs)?;
165 }
166
167 let mutations = flags
170 .contains(Flags::HAVE_MUTATIONS)
171 .then(|| Mutations::decode(visitor, reader))
172 .transpose()?;
173
174 Ok(Self {
175 inputs,
176 outputs,
177 mutations,
178 })
179 }
180
181 pub fn consume<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<(), V::Error>
187 where
188 V: Visitor<Row = T>,
189 R: BufReader<'a>,
190 {
191 let flags = Flags::from_bits_retain(reader.get_u8()?);
192
193 if flags.contains(Flags::HAVE_INPUTS) {
196 let inputs = Inputs::decode(reader)?;
197 visitor.visit_inputs(&inputs)?;
198 }
199
200 if flags.contains(Flags::HAVE_OUTPUTS) {
203 let outputs = Outputs::decode(reader)?;
204 visitor.visit_outputs(&outputs)?;
205 }
206
207 if flags.contains(Flags::HAVE_MUTATIONS) {
210 Mutations::consume(visitor, reader)?;
211 }
212
213 Ok(())
214 }
215
216 pub fn skip<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<(), V::Error>
217 where
218 V: Visitor<Row = T>,
219 R: BufReader<'a>,
220 {
221 let flags = Flags::from_bits_retain(reader.get_u8()?);
222
223 if flags.contains(Flags::HAVE_INPUTS) {
226 Inputs::decode(reader)?;
227 }
228
229 if flags.contains(Flags::HAVE_OUTPUTS) {
232 Outputs::decode(reader)?;
233 }
234
235 if flags.contains(Flags::HAVE_MUTATIONS) {
238 Mutations::skip(visitor, reader)?;
239 }
240
241 Ok(())
242 }
243}
244
245#[derive(Clone, Debug, PartialEq)]
247#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
248pub struct Inputs {
249 pub reducer_name: Arc<Varchar<255>>,
250 pub reducer_args: Arc<[u8]>,
251 }
262
263impl Inputs {
264 pub fn encode(&self, buf: &mut impl BufWriter) {
265 let slen = self.reducer_name.len() as u8;
266 let len = 1 + slen as usize + self.reducer_args.len();
267 buf.put_u32(len as u32);
268 buf.put_u8(slen);
269 buf.put_slice(self.reducer_name.as_bytes());
270 buf.put_slice(&self.reducer_args);
271 }
272
273 pub fn decode<'a, R: BufReader<'a>>(reader: &mut R) -> Result<Self, DecodeError> {
274 let len = reader.get_u32()?;
275 let slen = reader.get_u8()?;
276 let reducer_name = {
277 let bytes = reader.get_slice(slen as usize)?;
278 Varchar::from_str(std::str::from_utf8(bytes)?)
279 .expect("slice len cannot be > 255")
280 .into()
281 };
282 let reducer_args = reader.get_slice(len as usize - 1 - slen as usize)?.into();
283
284 Ok(Self {
285 reducer_name,
286 reducer_args,
287 })
288 }
289}
290
291#[derive(Clone, Debug, PartialEq)]
296#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
297pub struct Outputs {
298 pub reducer_output: Arc<Varchar<255>>,
300}
301
302impl Outputs {
303 pub fn encode(&self, buf: &mut impl BufWriter) {
304 let slen = self.reducer_output.len() as u8;
305 buf.put_u8(slen);
306 buf.put_slice(self.reducer_output.as_bytes());
307 }
308
309 pub fn decode<'a, R: BufReader<'a>>(reader: &mut R) -> Result<Self, DecodeError> {
310 let slen = reader.get_u8()?;
311 let reducer_output = {
312 let bytes = reader.get_slice(slen as usize)?;
313 Varchar::from_str(std::str::from_utf8(bytes)?).unwrap().into()
314 };
315
316 Ok(Self { reducer_output })
317 }
318}
319
320#[derive(Clone, Debug, PartialEq)]
327pub struct Mutations<T> {
328 pub inserts: Box<[Ops<T>]>,
330 pub deletes: Box<[Ops<T>]>,
332 pub truncates: Box<[TableId]>,
334}
335
336impl<T> Mutations<T> {
337 pub fn is_empty(&self) -> bool {
339 self.inserts.is_empty() && self.deletes.is_empty() && self.truncates.is_empty()
340 }
341}
342
343impl<T: Encode> Mutations<T> {
344 pub fn encode(&self, buf: &mut impl BufWriter) {
346 encode_varint(self.inserts.len(), buf);
347 for ops in self.inserts.iter() {
348 ops.encode(buf);
349 }
350 encode_varint(self.deletes.len(), buf);
351 for ops in self.deletes.iter() {
352 ops.encode(buf);
353 }
354 encode_varint(self.truncates.len(), buf);
355 for TableId(table_id) in self.truncates.iter() {
356 buf.put_u32(*table_id);
357 }
358 }
359
360 pub fn skip<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<(), V::Error>
361 where
362 V: Visitor<Row = T>,
363 R: BufReader<'a>,
364 {
365 let n = decode_varint(reader)?;
368 for _ in 0..n {
369 let table_id = reader.get_u32().map(TableId)?;
370 let m = decode_varint(reader)?;
371 for _ in 0..m {
372 visitor.skip_row(table_id, reader)?;
373 }
374 }
375
376 let n = decode_varint(reader)?;
378 for _ in 0..n {
379 let table_id = reader.get_u32().map(TableId)?;
380 let m = decode_varint(reader)?;
381 for _ in 0..m {
382 visitor.skip_row(table_id, reader)?;
383 }
384 }
385
386 let n = decode_varint(reader)?;
388 for _ in 0..n {
389 reader.get_u32()?;
390 }
391
392 Ok(())
393 }
394
395 pub fn decode<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<Self, V::Error>
397 where
398 V: Visitor<Row = T>,
399 R: BufReader<'a>,
400 {
401 let n = decode_varint(reader)?;
404 let mut inserts = Vec::with_capacity(n);
405 for _ in 0..n {
406 let table_id = reader.get_u32().map(TableId)?;
407 let m = decode_varint(reader)?;
408 let mut rowdata = Vec::with_capacity(m);
409 for _ in 0..m {
410 let row = visitor.visit_insert(table_id, reader)?;
411 rowdata.push(row);
412 }
413 inserts.push(Ops {
414 table_id,
415 rowdata: rowdata.into(),
416 });
417 }
418
419 let n = decode_varint(reader)?;
422 let mut deletes = Vec::with_capacity(n);
423 for _ in 0..n {
424 let table_id = reader.get_u32().map(TableId)?;
425 let m = decode_varint(reader)?;
426 let mut rowdata = Vec::with_capacity(m);
427 for _ in 0..m {
428 let row = visitor.visit_delete(table_id, reader)?;
429 rowdata.push(row);
430 }
431 deletes.push(Ops {
432 table_id,
433 rowdata: rowdata.into(),
434 });
435 }
436
437 let n = decode_varint(reader)?;
440 let mut truncates = Vec::with_capacity(n);
441 for _ in 0..n {
442 let table_id = reader.get_u32().map(TableId)?;
443 visitor.visit_truncate(table_id)?;
444 truncates.push(table_id);
445 }
446
447 Ok(Self {
448 inserts: inserts.into(),
449 deletes: deletes.into(),
450 truncates: truncates.into(),
451 })
452 }
453
454 pub fn consume<'a, V, R>(visitor: &mut V, reader: &mut R) -> Result<(), V::Error>
458 where
459 V: Visitor<Row = T>,
460 R: BufReader<'a>,
461 {
462 let n = decode_varint(reader)?;
465 for _ in 0..n {
466 let table_id = reader.get_u32().map(TableId)?;
467 let m = decode_varint(reader)?;
468 for _ in 0..m {
469 visitor.visit_insert(table_id, reader)?;
470 }
471 }
472
473 let n = decode_varint(reader)?;
476 for _ in 0..n {
477 let table_id = reader.get_u32().map(TableId)?;
478 let m = decode_varint(reader)?;
479 for _ in 0..m {
480 visitor.visit_delete(table_id, reader)?;
481 }
482 }
483
484 let n = decode_varint(reader)?;
486 for _ in 0..n {
487 let table_id = reader.get_u32().map(TableId)?;
488 visitor.visit_truncate(table_id)?;
489 }
490
491 Ok(())
492 }
493}
494
495impl<T: Encode> Encode for Txdata<T> {
496 fn encode_record<W: BufWriter>(&self, writer: &mut W) {
497 self.encode(writer)
498 }
499}
500
501#[derive(Clone, Debug, PartialEq)]
503pub struct Ops<T> {
504 pub table_id: TableId,
506 pub rowdata: Arc<[T]>,
508}
509
510impl<T: Encode> Ops<T> {
511 pub fn encode(&self, buf: &mut impl BufWriter) {
513 buf.put_u32(self.table_id.0);
514 encode_varint(self.rowdata.len(), buf);
515 for row in self.rowdata.iter() {
516 Encode::encode_record(row, buf);
517 }
518 }
519}
520
521#[derive(Debug, Error)]
522pub enum DecoderError<V> {
523 #[error("unsupported version: {given} supported={supported}")]
524 UnsupportedVersion { supported: u8, given: u8 },
525 #[error(transparent)]
526 Decode(#[from] DecodeError),
527 #[error(transparent)]
528 Visitor(V),
529 #[error(transparent)]
530 Traverse(#[from] error::Traversal),
531}
532
533pub fn skip_record_fn<'a, V, R>(visitor: &mut V, version: u8, reader: &mut R) -> Result<(), DecoderError<V::Error>>
536where
537 V: Visitor,
538 V::Row: Encode,
539 R: BufReader<'a>,
540{
541 if version > Txdata::<V::Row>::VERSION {
542 return Err(DecoderError::UnsupportedVersion {
543 supported: Txdata::<V::Row>::VERSION,
544 given: version,
545 });
546 }
547 Txdata::skip(visitor, reader).map_err(DecoderError::Visitor)?;
548
549 Ok(())
550}
551
552pub fn decode_record_fn<'a, V, R>(
559 visitor: &mut V,
560 version: u8,
561 tx_offset: u64,
562 reader: &mut R,
563) -> Result<Txdata<V::Row>, DecoderError<V::Error>>
564where
565 V: Visitor,
566 V::Row: Encode,
567 R: BufReader<'a>,
568{
569 process_record(visitor, version, tx_offset, reader, Txdata::decode)
570}
571
572pub fn consume_record_fn<'a, V, R>(
575 visitor: &mut V,
576 version: u8,
577 tx_offset: u64,
578 reader: &mut R,
579) -> Result<(), DecoderError<V::Error>>
580where
581 V: Visitor,
582 V::Row: Encode,
583 R: BufReader<'a>,
584{
585 process_record(visitor, version, tx_offset, reader, Txdata::consume)
586}
587
588fn process_record<'a, V, R, F, T>(
589 visitor: &mut V,
590 version: u8,
591 tx_offset: u64,
592 reader: &mut R,
593 decode_txdata: F,
594) -> Result<T, DecoderError<V::Error>>
595where
596 V: Visitor,
597 V::Row: Encode,
598 R: BufReader<'a>,
599 F: FnOnce(&mut V, &mut R) -> Result<T, V::Error>,
600{
601 if version > Txdata::<V::Row>::VERSION {
602 return Err(DecoderError::UnsupportedVersion {
603 supported: Txdata::<V::Row>::VERSION,
604 given: version,
605 });
606 }
607 visitor.visit_tx_start(tx_offset).map_err(DecoderError::Visitor)?;
608 let record = decode_txdata(visitor, reader).map_err(DecoderError::Visitor)?;
609 visitor.visit_tx_end().map_err(DecoderError::Visitor)?;
610
611 Ok(record)
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use once_cell::sync::Lazy;
618 use proptest::prelude::*;
619 use spacetimedb_sats::{product, AlgebraicType, ProductType, ProductValue};
620
621 fn gen_table_id() -> impl Strategy<Value = TableId> {
622 any::<u32>().prop_map(TableId)
623 }
624
625 fn gen_ops(pv: ProductValue) -> impl Strategy<Value = Ops<ProductValue>> {
626 (gen_table_id(), prop::collection::vec(Just(pv), 1..10)).prop_map(|(table_id, rowdata)| Ops {
627 table_id,
628 rowdata: rowdata.into(),
629 })
630 }
631
632 fn gen_mutations(pv: ProductValue) -> impl Strategy<Value = Mutations<ProductValue>> {
633 (
634 prop::collection::vec(gen_ops(pv.clone()), 0..10),
635 prop::collection::vec(gen_ops(pv.clone()), 0..10),
636 prop::collection::vec(gen_table_id(), 0..10),
637 )
638 .prop_map(|(inserts, deletes, truncates)| Mutations {
639 inserts: inserts.into(),
640 deletes: deletes.into(),
641 truncates: truncates.into(),
642 })
643 }
644
645 fn gen_txdata(pv: ProductValue) -> impl Strategy<Value = Txdata<ProductValue>> {
646 (
647 prop::option::of(any::<Inputs>()),
648 prop::option::of(any::<Outputs>()),
649 prop::option::of(gen_mutations(pv)),
650 )
651 .prop_map(|(inputs, outputs, mutations)| Txdata {
652 inputs,
653 outputs,
654 mutations,
655 })
656 }
657
658 static SOME_PV: Lazy<ProductValue> = Lazy::new(|| product![42u64, "kermit", 4u32, 2u32, 18u32]);
659 static SOME_PV_TY: Lazy<ProductType> = Lazy::new(|| {
660 ProductType::from([
661 ("id", AlgebraicType::U64),
662 ("name", AlgebraicType::String),
663 ("x", AlgebraicType::U32),
664 ("y", AlgebraicType::U32),
665 ("z", AlgebraicType::U32),
666 ])
667 });
668
669 struct MockVisitor;
670
671 impl Visitor for MockVisitor {
672 type Error = DecodeError;
673 type Row = ProductValue;
674
675 fn visit_insert<'a, R: BufReader<'a>>(
676 &mut self,
677 _table_id: TableId,
678 reader: &mut R,
679 ) -> Result<Self::Row, Self::Error> {
680 ProductValue::decode(&SOME_PV_TY, reader)
681 }
682
683 fn visit_delete<'a, R: BufReader<'a>>(
684 &mut self,
685 _table_id: TableId,
686 reader: &mut R,
687 ) -> Result<Self::Row, Self::Error> {
688 ProductValue::decode(&SOME_PV_TY, reader)
689 }
690
691 fn skip_row<'a, R: BufReader<'a>>(&mut self, _table_id: TableId, reader: &mut R) -> Result<(), Self::Error> {
692 ProductValue::decode(&SOME_PV_TY, reader)?;
693 Ok(())
694 }
695 }
696
697 proptest! {
698 #[test]
699 fn prop_inputs_roundtrip(inputs in any::<Inputs>()) {
700 let mut buf = Vec::new();
701 inputs.encode(&mut buf);
702 assert_eq!(inputs, Inputs::decode(&mut buf.as_slice()).unwrap());
703 }
704
705 #[test]
706 fn prop_outputs_roundtrip(outputs in any::<Outputs>()) {
707 let mut buf = Vec::new();
708 outputs.encode(&mut buf);
709 assert_eq!(outputs, Outputs::decode(&mut buf.as_slice()).unwrap());
710 }
711
712 #[test]
713 fn prop_mutations_roundtrip(muts in gen_mutations(SOME_PV.clone())) {
714 let mut buf = Vec::new();
715 muts.encode(&mut buf);
716 assert_eq!(muts, Mutations::decode(&mut MockVisitor, &mut buf.as_slice()).unwrap());
717 }
718
719 #[test]
720 fn prop_txdata_roundtrip(txdata in gen_txdata(SOME_PV.clone())) {
721 let mut buf = Vec::new();
722 txdata.encode(&mut buf);
723 assert_eq!(txdata, Txdata::decode(&mut MockVisitor, &mut buf.as_slice()).unwrap());
724 }
725 }
726}