Skip to main content

s2_common/types/
stream.rs

1use std::{marker::PhantomData, ops::Deref, str::FromStr, time::Duration};
2
3use compact_str::{CompactString, ToCompactString};
4use time::OffsetDateTime;
5
6use super::{
7    ValidationError,
8    strings::{NameProps, PrefixProps, StartAfterProps, StrProps},
9};
10use crate::{
11    caps,
12    encryption::EncryptionSpec,
13    read_extent::{ReadLimit, ReadUntil},
14    record::{
15        FencingToken, Metered, MeteredExt, MeteredSize, Record, RecordDecryptionError, SeqNum,
16        Sequenced, StoredRecord, StreamPosition, Timestamp, decrypt_stored_record, encrypt_record,
17    },
18    types::resources::ListItemsRequest,
19};
20
21#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22#[cfg_attr(
23    feature = "rkyv",
24    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
25)]
26pub struct StreamNameStr<T: StrProps>(CompactString, PhantomData<T>);
27
28impl<T: StrProps> StreamNameStr<T> {
29    fn validate_str(name: &str) -> Result<(), ValidationError> {
30        if !T::IS_PREFIX && name.is_empty() {
31            return Err(format!("stream {} must not be empty", T::FIELD_NAME).into());
32        }
33
34        if !T::IS_PREFIX && (name == "." || name == "..") {
35            return Err(format!("stream {} must not be \".\" or \"..\"", T::FIELD_NAME).into());
36        }
37
38        if name.len() > caps::MAX_STREAM_NAME_LEN {
39            return Err(format!(
40                "stream {} must not exceed {} bytes in length",
41                T::FIELD_NAME,
42                caps::MAX_STREAM_NAME_LEN
43            )
44            .into());
45        }
46
47        Ok(())
48    }
49}
50
51#[cfg(feature = "utoipa")]
52impl<T> utoipa::PartialSchema for StreamNameStr<T>
53where
54    T: StrProps,
55{
56    fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
57        utoipa::openapi::Object::builder()
58            .schema_type(utoipa::openapi::Type::String)
59            .min_length((!T::IS_PREFIX).then_some(caps::MIN_STREAM_NAME_LEN))
60            .max_length(Some(caps::MAX_STREAM_NAME_LEN))
61            .into()
62    }
63}
64
65#[cfg(feature = "utoipa")]
66impl<T> utoipa::ToSchema for StreamNameStr<T> where T: StrProps {}
67
68impl<T: StrProps> serde::Serialize for StreamNameStr<T> {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::Serializer,
72    {
73        serializer.serialize_str(&self.0)
74    }
75}
76
77impl<'de, T: StrProps> serde::Deserialize<'de> for StreamNameStr<T> {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: serde::Deserializer<'de>,
81    {
82        let s = CompactString::deserialize(deserializer)?;
83        s.try_into().map_err(serde::de::Error::custom)
84    }
85}
86
87impl<T: StrProps> AsRef<str> for StreamNameStr<T> {
88    fn as_ref(&self) -> &str {
89        &self.0
90    }
91}
92
93impl<T: StrProps> Deref for StreamNameStr<T> {
94    type Target = str;
95
96    fn deref(&self) -> &Self::Target {
97        &self.0
98    }
99}
100
101impl<T: StrProps> TryFrom<CompactString> for StreamNameStr<T> {
102    type Error = ValidationError;
103
104    fn try_from(name: CompactString) -> Result<Self, Self::Error> {
105        Self::validate_str(&name)?;
106        Ok(Self(name, PhantomData))
107    }
108}
109
110impl<T: StrProps> FromStr for StreamNameStr<T> {
111    type Err = ValidationError;
112
113    fn from_str(s: &str) -> Result<Self, Self::Err> {
114        Self::validate_str(s)?;
115        Ok(Self(s.to_compact_string(), PhantomData))
116    }
117}
118
119impl<T: StrProps> std::fmt::Debug for StreamNameStr<T> {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.write_str(&self.0)
122    }
123}
124
125impl<T: StrProps> std::fmt::Display for StreamNameStr<T> {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.write_str(&self.0)
128    }
129}
130
131impl<T: StrProps> From<StreamNameStr<T>> for CompactString {
132    fn from(value: StreamNameStr<T>) -> Self {
133        value.0
134    }
135}
136
137pub type StreamName = StreamNameStr<NameProps>;
138
139pub type StreamNamePrefix = StreamNameStr<PrefixProps>;
140
141impl Default for StreamNamePrefix {
142    fn default() -> Self {
143        StreamNameStr(CompactString::default(), PhantomData)
144    }
145}
146
147impl From<StreamName> for StreamNamePrefix {
148    fn from(value: StreamName) -> Self {
149        Self(value.0, PhantomData)
150    }
151}
152
153pub type StreamNameStartAfter = StreamNameStr<StartAfterProps>;
154
155impl Default for StreamNameStartAfter {
156    fn default() -> Self {
157        StreamNameStr(CompactString::default(), PhantomData)
158    }
159}
160
161impl From<StreamName> for StreamNameStartAfter {
162    fn from(value: StreamName) -> Self {
163        Self(value.0, PhantomData)
164    }
165}
166
167#[derive(Debug, Clone)]
168pub struct StreamInfo {
169    pub name: StreamName,
170    pub created_at: OffsetDateTime,
171    pub deleted_at: Option<OffsetDateTime>,
172}
173
174#[derive(Debug, Clone)]
175pub struct AppendRecord<T = Record>(AppendRecordParts<T>);
176
177impl<T> AppendRecord<T> {
178    pub fn parts(&self) -> &AppendRecordParts<T> {
179        let Self(parts) = self;
180        parts
181    }
182
183    pub fn into_parts(self) -> AppendRecordParts<T> {
184        let Self(parts) = self;
185        parts
186    }
187}
188
189impl<T> MeteredSize for AppendRecord<T> {
190    fn metered_size(&self) -> usize {
191        self.0.record.metered_size()
192    }
193}
194
195#[derive(Debug, Clone)]
196pub struct AppendRecordParts<T = Record> {
197    pub timestamp: Option<Timestamp>,
198    pub record: Metered<T>,
199}
200
201impl<T> MeteredSize for AppendRecordParts<T> {
202    fn metered_size(&self) -> usize {
203        self.record.metered_size()
204    }
205}
206
207impl<T> From<AppendRecord<T>> for AppendRecordParts<T> {
208    fn from(record: AppendRecord<T>) -> Self {
209        record.into_parts()
210    }
211}
212
213impl<T> TryFrom<AppendRecordParts<T>> for AppendRecord<T> {
214    type Error = &'static str;
215
216    fn try_from(parts: AppendRecordParts<T>) -> Result<Self, Self::Error> {
217        if parts.metered_size() > caps::RECORD_BATCH_MAX.bytes {
218            Err("record must have metered size less than 1 MiB")
219        } else {
220            Ok(Self(parts))
221        }
222    }
223}
224
225#[derive(Clone)]
226pub struct AppendRecordBatch<T = Record>(Metered<Vec<AppendRecord<T>>>);
227
228impl<T> std::fmt::Debug for AppendRecordBatch<T> {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        f.debug_struct("AppendRecordBatch")
231            .field("num_records", &self.0.len())
232            .field("metered_size", &self.0.metered_size())
233            .finish()
234    }
235}
236
237impl<T> MeteredSize for AppendRecordBatch<T> {
238    fn metered_size(&self) -> usize {
239        self.0.metered_size()
240    }
241}
242
243impl<T> std::ops::Deref for AppendRecordBatch<T> {
244    type Target = [AppendRecord<T>];
245
246    fn deref(&self) -> &Self::Target {
247        &self.0
248    }
249}
250
251impl<T> TryFrom<Metered<Vec<AppendRecord<T>>>> for AppendRecordBatch<T> {
252    type Error = &'static str;
253
254    fn try_from(records: Metered<Vec<AppendRecord<T>>>) -> Result<Self, Self::Error> {
255        if records.is_empty() {
256            return Err("record batch must not be empty");
257        }
258
259        if records.len() > caps::RECORD_BATCH_MAX.count {
260            return Err("record batch must not exceed 1000 records");
261        }
262
263        if records.metered_size() > caps::RECORD_BATCH_MAX.bytes {
264            return Err("record batch must not exceed a metered size of 1 MiB");
265        }
266
267        Ok(Self(records))
268    }
269}
270
271impl<T> TryFrom<Vec<AppendRecord<T>>> for AppendRecordBatch<T> {
272    type Error = &'static str;
273
274    fn try_from(records: Vec<AppendRecord<T>>) -> Result<Self, Self::Error> {
275        let records = Metered::from(records);
276        Self::try_from(records)
277    }
278}
279
280impl<T> IntoIterator for AppendRecordBatch<T> {
281    type Item = AppendRecord<T>;
282    type IntoIter = std::vec::IntoIter<Self::Item>;
283
284    fn into_iter(self) -> Self::IntoIter {
285        self.0.into_iter()
286    }
287}
288
289pub type StoredAppendRecord = AppendRecord<StoredRecord>;
290pub type StoredAppendRecordParts = AppendRecordParts<StoredRecord>;
291pub type StoredAppendRecordBatch = AppendRecordBatch<StoredRecord>;
292
293impl From<AppendRecordParts<Record>> for AppendRecordParts<StoredRecord> {
294    fn from(
295        AppendRecordParts { timestamp, record }: AppendRecordParts<Record>,
296    ) -> AppendRecordParts<StoredRecord> {
297        AppendRecordParts {
298            timestamp,
299            record: StoredRecord::from(record.into_inner()).into(),
300        }
301    }
302}
303
304impl From<AppendRecord<Record>> for AppendRecord<StoredRecord> {
305    fn from(record: AppendRecord<Record>) -> Self {
306        Self(record.into_parts().into())
307    }
308}
309
310impl From<AppendRecordBatch<Record>> for AppendRecordBatch<StoredRecord> {
311    fn from(records: AppendRecordBatch<Record>) -> Self {
312        AppendRecordBatch(
313            records
314                .into_iter()
315                .map(|r| AppendRecord::<StoredRecord>::from(r).metered())
316                .collect(),
317        )
318    }
319}
320
321#[derive(Debug, Clone)]
322pub struct AppendInput<T = Record> {
323    pub records: AppendRecordBatch<T>,
324    pub match_seq_num: Option<SeqNum>,
325    pub fencing_token: Option<FencingToken>,
326}
327
328impl AppendInput<Record> {
329    pub fn encrypt(self, encryption: &EncryptionSpec, aad: &[u8]) -> AppendInput<StoredRecord> {
330        let AppendInput {
331            records,
332            match_seq_num,
333            fencing_token,
334        } = self;
335        let records = AppendRecordBatch(
336            records
337                .into_iter()
338                .map(|record| {
339                    let AppendRecordParts { timestamp, record } = record.into_parts();
340                    let record = encrypt_record(record, encryption, aad);
341                    AppendRecord(AppendRecordParts { timestamp, record }).metered()
342                })
343                .collect(),
344        );
345
346        AppendInput {
347            records,
348            match_seq_num,
349            fencing_token,
350        }
351    }
352}
353
354pub type StoredAppendInput = AppendInput<StoredRecord>;
355
356impl From<AppendInput<Record>> for AppendInput<StoredRecord> {
357    fn from(value: AppendInput<Record>) -> Self {
358        let AppendInput {
359            records,
360            match_seq_num,
361            fencing_token,
362        } = value;
363        let records = records.into();
364        AppendInput {
365            records,
366            match_seq_num,
367            fencing_token,
368        }
369    }
370}
371
372#[derive(Debug, Clone)]
373pub struct AppendAck {
374    pub start: StreamPosition,
375    pub end: StreamPosition,
376    pub tail: StreamPosition,
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum ReadPosition {
381    SeqNum(SeqNum),
382    Timestamp(Timestamp),
383}
384
385#[derive(Debug, Clone, Copy)]
386pub enum ReadFrom {
387    SeqNum(SeqNum),
388    Timestamp(Timestamp),
389    TailOffset(u64),
390}
391
392impl Default for ReadFrom {
393    fn default() -> Self {
394        Self::SeqNum(0)
395    }
396}
397
398#[derive(Debug, Default, Clone, Copy)]
399pub struct ReadStart {
400    pub from: ReadFrom,
401    pub clamp: bool,
402}
403
404#[derive(Debug, Default, Clone, Copy)]
405pub struct ReadEnd {
406    pub limit: ReadLimit,
407    pub until: ReadUntil,
408    pub wait: Option<Duration>,
409}
410
411impl ReadEnd {
412    pub fn may_follow(&self) -> bool {
413        (self.limit.is_unbounded() && self.until.is_unbounded())
414            || self.wait.is_some_and(|d| d > Duration::ZERO)
415    }
416}
417
418#[derive(Clone)]
419pub struct ReadBatch<T = Record> {
420    pub records: Metered<Vec<Sequenced<T>>>,
421    pub tail: Option<StreamPosition>,
422}
423
424impl<T> Default for ReadBatch<T>
425where
426    T: MeteredSize,
427{
428    fn default() -> Self {
429        Self {
430            records: Metered::default(),
431            tail: None,
432        }
433    }
434}
435
436impl<T> std::fmt::Debug for ReadBatch<T> {
437    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438        f.debug_struct("ReadBatch")
439            .field("num_records", &self.records.len())
440            .field("metered_size", &self.records.metered_size())
441            .field("tail", &self.tail)
442            .finish()
443    }
444}
445
446impl ReadBatch<StoredRecord> {
447    pub fn decrypt(
448        self,
449        encryption: &EncryptionSpec,
450        aad: &[u8],
451    ) -> Result<ReadBatch, RecordDecryptionError> {
452        let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = self
453            .records
454            .into_inner()
455            .into_iter()
456            .map(|record| {
457                let (position, record) = record.into_parts();
458                decrypt_stored_record(record, encryption, aad)
459                    .map(|record| record.sequenced(position))
460            })
461            .collect();
462
463        Ok(ReadBatch {
464            records: records?,
465            tail: self.tail,
466        })
467    }
468}
469
470pub type StoredReadBatch = ReadBatch<StoredRecord>;
471
472#[derive(Debug, Clone)]
473pub enum ReadSessionOutput<T = Record> {
474    Heartbeat(StreamPosition),
475    Batch(ReadBatch<T>),
476}
477
478impl ReadSessionOutput<StoredRecord> {
479    pub fn decrypt(
480        self,
481        encryption: &EncryptionSpec,
482        aad: &[u8],
483    ) -> Result<ReadSessionOutput, RecordDecryptionError> {
484        match self {
485            Self::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
486            Self::Batch(batch) => batch.decrypt(encryption, aad).map(ReadSessionOutput::Batch),
487        }
488    }
489}
490
491pub type StoredReadSessionOutput = ReadSessionOutput<StoredRecord>;
492
493pub type ListStreamsRequest = ListItemsRequest<StreamNamePrefix, StreamNameStartAfter>;
494
495#[cfg(test)]
496mod test {
497    use bytes::Bytes;
498    use rstest::rstest;
499
500    use super::{
501        super::strings::{NameProps, PrefixProps, StartAfterProps},
502        *,
503    };
504    use crate::record::{EnvelopeRecord, MeteredExt, Record, StoredRecord, StreamPosition};
505
506    #[rstest]
507    #[case::normal("my-stream".to_owned())]
508    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
509    fn validate_name_ok(#[case] name: String) {
510        assert_eq!(StreamNameStr::<NameProps>::validate_str(&name), Ok(()));
511    }
512
513    #[rstest]
514    #[case::empty("".to_owned())]
515    #[case::dot(".".to_owned())]
516    #[case::dot_dot("..".to_owned())]
517    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
518    fn validate_name_err(#[case] name: String) {
519        StreamNameStr::<NameProps>::validate_str(&name).expect_err("expected validation error");
520    }
521
522    #[rstest]
523    #[case::empty("".to_owned())]
524    #[case::dot(".".to_owned())]
525    #[case::dot_dot("..".to_owned())]
526    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
527    fn validate_prefix_ok(#[case] prefix: String) {
528        assert_eq!(StreamNameStr::<PrefixProps>::validate_str(&prefix), Ok(()));
529    }
530
531    #[rstest]
532    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
533    fn validate_prefix_err(#[case] prefix: String) {
534        StreamNameStr::<PrefixProps>::validate_str(&prefix).expect_err("expected validation error");
535    }
536
537    #[rstest]
538    #[case::empty("".to_owned())]
539    #[case::dot(".".to_owned())]
540    #[case::dot_dot("..".to_owned())]
541    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
542    fn validate_start_after_ok(#[case] start_after: String) {
543        assert_eq!(
544            StreamNameStr::<StartAfterProps>::validate_str(&start_after),
545            Ok(())
546        );
547    }
548
549    #[rstest]
550    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
551    fn validate_start_after_err(#[case] start_after: String) {
552        StreamNameStr::<StartAfterProps>::validate_str(&start_after)
553            .expect_err("expected validation error");
554    }
555
556    const TEST_AAD: &[u8] = b"test-stream-aad";
557
558    fn sample_append_input() -> AppendInput {
559        let record = Record::Envelope(
560            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
561        );
562        AppendInput {
563            records: vec![
564                AppendRecord::try_from(AppendRecordParts {
565                    timestamp: Some(42),
566                    record: record.metered(),
567                })
568                .unwrap(),
569            ]
570            .try_into()
571            .unwrap(),
572            match_seq_num: Some(7),
573            fencing_token: Some("fence".parse().unwrap()),
574        }
575    }
576
577    #[test]
578    fn append_record_batch_rejects_empty_batches() {
579        let empty_batch: Result<AppendRecordBatch, _> = Vec::<AppendRecord>::new().try_into();
580
581        assert_eq!(empty_batch.unwrap_err(), "record batch must not be empty");
582    }
583
584    #[rstest]
585    #[case::encrypt(true)]
586    #[case::into(false)]
587    fn append_input_to_stored_preserves_metadata(#[case] encrypt: bool) {
588        let encryption = EncryptionSpec::aegis256([0x42; 32]);
589        let mapped = if encrypt {
590            sample_append_input().encrypt(&encryption, TEST_AAD)
591        } else {
592            sample_append_input().into()
593        };
594
595        assert_eq!(mapped.match_seq_num, Some(7));
596        assert_eq!(
597            mapped.fencing_token.as_ref().map(|token| token.as_ref()),
598            Some("fence")
599        );
600
601        let append_record: AppendRecordParts<StoredRecord> = mapped
602            .records
603            .into_iter()
604            .next()
605            .expect("sample append input should contain a single record")
606            .into_parts();
607        assert_eq!(append_record.timestamp, Some(42));
608
609        let stored_record = append_record.record.into_inner();
610        assert_eq!(
611            matches!(&stored_record, StoredRecord::Encrypted { .. }),
612            encrypt
613        );
614
615        let decrypted = decrypt_stored_record(stored_record, &encryption, TEST_AAD).unwrap();
616        let Record::Envelope(record) = decrypted.into_inner() else {
617            panic!("expected envelope record");
618        };
619        assert_eq!(record.body().as_ref(), b"hello");
620    }
621
622    #[test]
623    fn stored_read_batch_decrypt_preserves_positions_and_tail() {
624        let batch = ReadBatch {
625            records: Metered::from(vec![
626                StoredRecord::Plaintext(Record::Envelope(
627                    EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"one")).unwrap(),
628                ))
629                .metered()
630                .sequenced(StreamPosition {
631                    seq_num: 1,
632                    timestamp: 10,
633                })
634                .into_inner(),
635                StoredRecord::Plaintext(Record::Envelope(
636                    EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"two")).unwrap(),
637                ))
638                .metered()
639                .sequenced(StreamPosition {
640                    seq_num: 2,
641                    timestamp: 20,
642                })
643                .into_inner(),
644            ]),
645            tail: Some(StreamPosition {
646                seq_num: 3,
647                timestamp: 30,
648            }),
649        };
650
651        let mapped = batch
652            .decrypt(&crate::encryption::EncryptionSpec::Plain, &[])
653            .unwrap();
654        let records = mapped.records.into_inner();
655
656        assert_eq!(
657            mapped.tail,
658            Some(StreamPosition {
659                seq_num: 3,
660                timestamp: 30
661            })
662        );
663        assert_eq!(
664            records[0].position(),
665            &StreamPosition {
666                seq_num: 1,
667                timestamp: 10
668            }
669        );
670        assert_eq!(
671            records[1].position(),
672            &StreamPosition {
673                seq_num: 2,
674                timestamp: 20
675            }
676        );
677        assert!(matches!(records[0].inner(), Record::Envelope(_)));
678        assert!(matches!(records[1].inner(), Record::Envelope(_)));
679    }
680}