Skip to main content

s2_common/types/
stream.rs

1use std::{marker::PhantomData, ops::Deref, str::FromStr, time::Duration};
2
3use compact_str::{CompactString, ToCompactString};
4use time::OffsetDateTime;
5
6use super::{
7    ValidationError,
8    strings::{NameProps, PrefixProps, StartAfterProps, StrProps},
9};
10use crate::{
11    caps,
12    encryption::{EncryptionAlgorithm, EncryptionSpec},
13    read_extent::{ReadLimit, ReadUntil},
14    record::{
15        FencingToken, Metered, MeteredExt, MeteredSize, Record, RecordDecryptionError, SeqNum,
16        Sequenced, StoredRecord, StreamPosition, Timestamp, decrypt_stored_record, encrypt_record,
17    },
18    types::resources::ListItemsRequest,
19};
20
21#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22#[cfg_attr(
23    feature = "rkyv",
24    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
25)]
26pub struct StreamNameStr<T: StrProps>(CompactString, PhantomData<T>);
27
28impl<T: StrProps> StreamNameStr<T> {
29    fn validate_str(name: &str) -> Result<(), ValidationError> {
30        if !T::IS_PREFIX && name.is_empty() {
31            return Err(format!("stream {} must not be empty", T::FIELD_NAME).into());
32        }
33
34        if !T::IS_PREFIX && (name == "." || name == "..") {
35            return Err(format!("stream {} must not be \".\" or \"..\"", T::FIELD_NAME).into());
36        }
37
38        if name.len() > caps::MAX_STREAM_NAME_LEN {
39            return Err(format!(
40                "stream {} must not exceed {} bytes in length",
41                T::FIELD_NAME,
42                caps::MAX_STREAM_NAME_LEN
43            )
44            .into());
45        }
46
47        Ok(())
48    }
49}
50
51#[cfg(feature = "utoipa")]
52impl<T> utoipa::PartialSchema for StreamNameStr<T>
53where
54    T: StrProps,
55{
56    fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
57        utoipa::openapi::Object::builder()
58            .schema_type(utoipa::openapi::Type::String)
59            .min_length((!T::IS_PREFIX).then_some(caps::MIN_STREAM_NAME_LEN))
60            .max_length(Some(caps::MAX_STREAM_NAME_LEN))
61            .into()
62    }
63}
64
65#[cfg(feature = "utoipa")]
66impl<T> utoipa::ToSchema for StreamNameStr<T> where T: StrProps {}
67
68impl<T: StrProps> serde::Serialize for StreamNameStr<T> {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::Serializer,
72    {
73        serializer.serialize_str(&self.0)
74    }
75}
76
77impl<'de, T: StrProps> serde::Deserialize<'de> for StreamNameStr<T> {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: serde::Deserializer<'de>,
81    {
82        let s = CompactString::deserialize(deserializer)?;
83        s.try_into().map_err(serde::de::Error::custom)
84    }
85}
86
87impl<T: StrProps> AsRef<str> for StreamNameStr<T> {
88    fn as_ref(&self) -> &str {
89        &self.0
90    }
91}
92
93impl<T: StrProps> Deref for StreamNameStr<T> {
94    type Target = str;
95
96    fn deref(&self) -> &Self::Target {
97        &self.0
98    }
99}
100
101impl<T: StrProps> TryFrom<CompactString> for StreamNameStr<T> {
102    type Error = ValidationError;
103
104    fn try_from(name: CompactString) -> Result<Self, Self::Error> {
105        Self::validate_str(&name)?;
106        Ok(Self(name, PhantomData))
107    }
108}
109
110impl<T: StrProps> FromStr for StreamNameStr<T> {
111    type Err = ValidationError;
112
113    fn from_str(s: &str) -> Result<Self, Self::Err> {
114        Self::validate_str(s)?;
115        Ok(Self(s.to_compact_string(), PhantomData))
116    }
117}
118
119impl<T: StrProps> std::fmt::Debug for StreamNameStr<T> {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.write_str(&self.0)
122    }
123}
124
125impl<T: StrProps> std::fmt::Display for StreamNameStr<T> {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.write_str(&self.0)
128    }
129}
130
131impl<T: StrProps> From<StreamNameStr<T>> for CompactString {
132    fn from(value: StreamNameStr<T>) -> Self {
133        value.0
134    }
135}
136
137pub type StreamName = StreamNameStr<NameProps>;
138
139pub type StreamNamePrefix = StreamNameStr<PrefixProps>;
140
141impl Default for StreamNamePrefix {
142    fn default() -> Self {
143        StreamNameStr(CompactString::default(), PhantomData)
144    }
145}
146
147impl From<StreamName> for StreamNamePrefix {
148    fn from(value: StreamName) -> Self {
149        Self(value.0, PhantomData)
150    }
151}
152
153pub type StreamNameStartAfter = StreamNameStr<StartAfterProps>;
154
155impl Default for StreamNameStartAfter {
156    fn default() -> Self {
157        StreamNameStr(CompactString::default(), PhantomData)
158    }
159}
160
161impl From<StreamName> for StreamNameStartAfter {
162    fn from(value: StreamName) -> Self {
163        Self(value.0, PhantomData)
164    }
165}
166
167#[derive(Debug, Clone)]
168pub struct StreamInfo {
169    pub name: StreamName,
170    pub created_at: OffsetDateTime,
171    pub deleted_at: Option<OffsetDateTime>,
172    pub cipher: Option<EncryptionAlgorithm>,
173}
174
175#[derive(Debug, Clone)]
176pub struct AppendRecord<T = Record>(AppendRecordParts<T>);
177
178impl<T> AppendRecord<T> {
179    pub fn parts(&self) -> &AppendRecordParts<T> {
180        let Self(parts) = self;
181        parts
182    }
183
184    pub fn into_parts(self) -> AppendRecordParts<T> {
185        let Self(parts) = self;
186        parts
187    }
188}
189
190impl<T> MeteredSize for AppendRecord<T> {
191    fn metered_size(&self) -> usize {
192        self.0.record.metered_size()
193    }
194}
195
196#[derive(Debug, Clone)]
197pub struct AppendRecordParts<T = Record> {
198    pub timestamp: Option<Timestamp>,
199    pub record: Metered<T>,
200}
201
202impl<T> MeteredSize for AppendRecordParts<T> {
203    fn metered_size(&self) -> usize {
204        self.record.metered_size()
205    }
206}
207
208impl<T> From<AppendRecord<T>> for AppendRecordParts<T> {
209    fn from(record: AppendRecord<T>) -> Self {
210        record.into_parts()
211    }
212}
213
214impl<T> TryFrom<AppendRecordParts<T>> for AppendRecord<T> {
215    type Error = &'static str;
216
217    fn try_from(parts: AppendRecordParts<T>) -> Result<Self, Self::Error> {
218        if parts.metered_size() > caps::RECORD_BATCH_MAX.bytes {
219            Err("record must have metered size less than 1 MiB")
220        } else {
221            Ok(Self(parts))
222        }
223    }
224}
225
226#[derive(Clone)]
227pub struct AppendRecordBatch<T = Record>(Metered<Vec<AppendRecord<T>>>);
228
229impl<T> std::fmt::Debug for AppendRecordBatch<T> {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.debug_struct("AppendRecordBatch")
232            .field("num_records", &self.0.len())
233            .field("metered_size", &self.0.metered_size())
234            .finish()
235    }
236}
237
238impl<T> MeteredSize for AppendRecordBatch<T> {
239    fn metered_size(&self) -> usize {
240        self.0.metered_size()
241    }
242}
243
244impl<T> std::ops::Deref for AppendRecordBatch<T> {
245    type Target = [AppendRecord<T>];
246
247    fn deref(&self) -> &Self::Target {
248        &self.0
249    }
250}
251
252impl<T> TryFrom<Metered<Vec<AppendRecord<T>>>> for AppendRecordBatch<T> {
253    type Error = &'static str;
254
255    fn try_from(records: Metered<Vec<AppendRecord<T>>>) -> Result<Self, Self::Error> {
256        if records.is_empty() {
257            return Err("record batch must not be empty");
258        }
259
260        if records.len() > caps::RECORD_BATCH_MAX.count {
261            return Err("record batch must not exceed 1000 records");
262        }
263
264        if records.metered_size() > caps::RECORD_BATCH_MAX.bytes {
265            return Err("record batch must not exceed a metered size of 1 MiB");
266        }
267
268        Ok(Self(records))
269    }
270}
271
272impl<T> TryFrom<Vec<AppendRecord<T>>> for AppendRecordBatch<T> {
273    type Error = &'static str;
274
275    fn try_from(records: Vec<AppendRecord<T>>) -> Result<Self, Self::Error> {
276        let records = Metered::from(records);
277        Self::try_from(records)
278    }
279}
280
281impl<T> IntoIterator for AppendRecordBatch<T> {
282    type Item = AppendRecord<T>;
283    type IntoIter = std::vec::IntoIter<Self::Item>;
284
285    fn into_iter(self) -> Self::IntoIter {
286        self.0.into_iter()
287    }
288}
289
290pub type StoredAppendRecord = AppendRecord<StoredRecord>;
291pub type StoredAppendRecordParts = AppendRecordParts<StoredRecord>;
292pub type StoredAppendRecordBatch = AppendRecordBatch<StoredRecord>;
293
294impl From<AppendRecordParts<Record>> for AppendRecordParts<StoredRecord> {
295    fn from(
296        AppendRecordParts { timestamp, record }: AppendRecordParts<Record>,
297    ) -> AppendRecordParts<StoredRecord> {
298        AppendRecordParts {
299            timestamp,
300            record: StoredRecord::from(record.into_inner()).into(),
301        }
302    }
303}
304
305impl From<AppendRecord<Record>> for AppendRecord<StoredRecord> {
306    fn from(record: AppendRecord<Record>) -> Self {
307        Self(record.into_parts().into())
308    }
309}
310
311impl From<AppendRecordBatch<Record>> for AppendRecordBatch<StoredRecord> {
312    fn from(records: AppendRecordBatch<Record>) -> Self {
313        AppendRecordBatch(
314            records
315                .into_iter()
316                .map(|r| AppendRecord::<StoredRecord>::from(r).metered())
317                .collect(),
318        )
319    }
320}
321
322#[derive(Debug, Clone)]
323pub struct AppendInput<T = Record> {
324    pub records: AppendRecordBatch<T>,
325    pub match_seq_num: Option<SeqNum>,
326    pub fencing_token: Option<FencingToken>,
327}
328
329impl AppendInput<Record> {
330    pub fn encrypt(self, encryption: &EncryptionSpec, aad: &[u8]) -> AppendInput<StoredRecord> {
331        let AppendInput {
332            records,
333            match_seq_num,
334            fencing_token,
335        } = self;
336        let records = AppendRecordBatch(
337            records
338                .into_iter()
339                .map(|record| {
340                    let AppendRecordParts { timestamp, record } = record.into_parts();
341                    let record = encrypt_record(record, encryption, aad);
342                    AppendRecord(AppendRecordParts { timestamp, record }).metered()
343                })
344                .collect(),
345        );
346
347        AppendInput {
348            records,
349            match_seq_num,
350            fencing_token,
351        }
352    }
353}
354
355pub type StoredAppendInput = AppendInput<StoredRecord>;
356
357impl From<AppendInput<Record>> for AppendInput<StoredRecord> {
358    fn from(value: AppendInput<Record>) -> Self {
359        let AppendInput {
360            records,
361            match_seq_num,
362            fencing_token,
363        } = value;
364        let records = records.into();
365        AppendInput {
366            records,
367            match_seq_num,
368            fencing_token,
369        }
370    }
371}
372
373#[derive(Debug, Clone)]
374pub struct AppendAck {
375    pub start: StreamPosition,
376    pub end: StreamPosition,
377    pub tail: StreamPosition,
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
381pub enum ReadPosition {
382    SeqNum(SeqNum),
383    Timestamp(Timestamp),
384}
385
386#[derive(Debug, Clone, Copy)]
387pub enum ReadFrom {
388    SeqNum(SeqNum),
389    Timestamp(Timestamp),
390    TailOffset(u64),
391}
392
393impl Default for ReadFrom {
394    fn default() -> Self {
395        Self::SeqNum(0)
396    }
397}
398
399#[derive(Debug, Default, Clone, Copy)]
400pub struct ReadStart {
401    pub from: ReadFrom,
402    pub clamp: bool,
403}
404
405#[derive(Debug, Default, Clone, Copy)]
406pub struct ReadEnd {
407    pub limit: ReadLimit,
408    pub until: ReadUntil,
409    pub wait: Option<Duration>,
410}
411
412impl ReadEnd {
413    pub fn may_follow(&self) -> bool {
414        (self.limit.is_unbounded() && self.until.is_unbounded())
415            || self.wait.is_some_and(|d| d > Duration::ZERO)
416    }
417}
418
419#[derive(Clone)]
420pub struct ReadBatch<T = Record> {
421    pub records: Metered<Vec<Sequenced<T>>>,
422    pub tail: Option<StreamPosition>,
423}
424
425impl<T> Default for ReadBatch<T>
426where
427    T: MeteredSize,
428{
429    fn default() -> Self {
430        Self {
431            records: Metered::default(),
432            tail: None,
433        }
434    }
435}
436
437impl<T> std::fmt::Debug for ReadBatch<T> {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        f.debug_struct("ReadBatch")
440            .field("num_records", &self.records.len())
441            .field("metered_size", &self.records.metered_size())
442            .field("tail", &self.tail)
443            .finish()
444    }
445}
446
447impl ReadBatch<StoredRecord> {
448    pub fn decrypt(
449        self,
450        encryption: &EncryptionSpec,
451        aad: &[u8],
452    ) -> Result<ReadBatch, RecordDecryptionError> {
453        let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = self
454            .records
455            .into_inner()
456            .into_iter()
457            .map(|record| {
458                let (position, record) = record.into_parts();
459                decrypt_stored_record(record, encryption, aad)
460                    .map(|record| record.sequenced(position))
461            })
462            .collect();
463
464        Ok(ReadBatch {
465            records: records?,
466            tail: self.tail,
467        })
468    }
469}
470
471pub type StoredReadBatch = ReadBatch<StoredRecord>;
472
473#[derive(Debug, Clone)]
474pub enum ReadSessionOutput<T = Record> {
475    Heartbeat(StreamPosition),
476    Batch(ReadBatch<T>),
477}
478
479impl ReadSessionOutput<StoredRecord> {
480    pub fn decrypt(
481        self,
482        encryption: &EncryptionSpec,
483        aad: &[u8],
484    ) -> Result<ReadSessionOutput, RecordDecryptionError> {
485        match self {
486            Self::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
487            Self::Batch(batch) => batch.decrypt(encryption, aad).map(ReadSessionOutput::Batch),
488        }
489    }
490}
491
492pub type StoredReadSessionOutput = ReadSessionOutput<StoredRecord>;
493
494pub type ListStreamsRequest = ListItemsRequest<StreamNamePrefix, StreamNameStartAfter>;
495
496#[cfg(test)]
497mod test {
498    use bytes::Bytes;
499    use rstest::rstest;
500
501    use super::{
502        super::strings::{NameProps, PrefixProps, StartAfterProps},
503        *,
504    };
505    use crate::record::{EnvelopeRecord, MeteredExt, Record, StoredRecord, StreamPosition};
506
507    #[rstest]
508    #[case::normal("my-stream".to_owned())]
509    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
510    fn validate_name_ok(#[case] name: String) {
511        assert_eq!(StreamNameStr::<NameProps>::validate_str(&name), Ok(()));
512    }
513
514    #[rstest]
515    #[case::empty("".to_owned())]
516    #[case::dot(".".to_owned())]
517    #[case::dot_dot("..".to_owned())]
518    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
519    fn validate_name_err(#[case] name: String) {
520        StreamNameStr::<NameProps>::validate_str(&name).expect_err("expected validation error");
521    }
522
523    #[rstest]
524    #[case::empty("".to_owned())]
525    #[case::dot(".".to_owned())]
526    #[case::dot_dot("..".to_owned())]
527    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
528    fn validate_prefix_ok(#[case] prefix: String) {
529        assert_eq!(StreamNameStr::<PrefixProps>::validate_str(&prefix), Ok(()));
530    }
531
532    #[rstest]
533    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
534    fn validate_prefix_err(#[case] prefix: String) {
535        StreamNameStr::<PrefixProps>::validate_str(&prefix).expect_err("expected validation error");
536    }
537
538    #[rstest]
539    #[case::empty("".to_owned())]
540    #[case::dot(".".to_owned())]
541    #[case::dot_dot("..".to_owned())]
542    #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
543    fn validate_start_after_ok(#[case] start_after: String) {
544        assert_eq!(
545            StreamNameStr::<StartAfterProps>::validate_str(&start_after),
546            Ok(())
547        );
548    }
549
550    #[rstest]
551    #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
552    fn validate_start_after_err(#[case] start_after: String) {
553        StreamNameStr::<StartAfterProps>::validate_str(&start_after)
554            .expect_err("expected validation error");
555    }
556
557    const TEST_AAD: &[u8] = b"test-stream-aad";
558
559    fn sample_append_input() -> AppendInput {
560        let record = Record::Envelope(
561            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
562        );
563        AppendInput {
564            records: vec![
565                AppendRecord::try_from(AppendRecordParts {
566                    timestamp: Some(42),
567                    record: record.metered(),
568                })
569                .unwrap(),
570            ]
571            .try_into()
572            .unwrap(),
573            match_seq_num: Some(7),
574            fencing_token: Some("fence".parse().unwrap()),
575        }
576    }
577
578    #[test]
579    fn append_record_batch_rejects_empty_batches() {
580        let empty_batch: Result<AppendRecordBatch, _> = Vec::<AppendRecord>::new().try_into();
581
582        assert_eq!(empty_batch.unwrap_err(), "record batch must not be empty");
583    }
584
585    #[rstest]
586    #[case::encrypt(true)]
587    #[case::into(false)]
588    fn append_input_to_stored_preserves_metadata(#[case] encrypt: bool) {
589        let encryption = EncryptionSpec::aegis256([0x42; 32]);
590        let mapped = if encrypt {
591            sample_append_input().encrypt(&encryption, TEST_AAD)
592        } else {
593            sample_append_input().into()
594        };
595
596        assert_eq!(mapped.match_seq_num, Some(7));
597        assert_eq!(
598            mapped.fencing_token.as_ref().map(|token| token.as_ref()),
599            Some("fence")
600        );
601
602        let append_record: AppendRecordParts<StoredRecord> = mapped
603            .records
604            .into_iter()
605            .next()
606            .expect("sample append input should contain a single record")
607            .into_parts();
608        assert_eq!(append_record.timestamp, Some(42));
609
610        let stored_record = append_record.record.into_inner();
611        assert_eq!(
612            matches!(&stored_record, StoredRecord::Encrypted { .. }),
613            encrypt
614        );
615
616        let decryption = if encrypt {
617            &encryption
618        } else {
619            &EncryptionSpec::Plain
620        };
621        let decrypted = decrypt_stored_record(stored_record, decryption, TEST_AAD).unwrap();
622        let Record::Envelope(record) = decrypted.into_inner() else {
623            panic!("expected envelope record");
624        };
625        assert_eq!(record.body().as_ref(), b"hello");
626    }
627
628    #[test]
629    fn stored_read_batch_decrypt_preserves_positions_and_tail() {
630        let batch = ReadBatch {
631            records: Metered::from(vec![
632                StoredRecord::Plaintext(Record::Envelope(
633                    EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"one")).unwrap(),
634                ))
635                .metered()
636                .sequenced(StreamPosition {
637                    seq_num: 1,
638                    timestamp: 10,
639                })
640                .into_inner(),
641                StoredRecord::Plaintext(Record::Envelope(
642                    EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"two")).unwrap(),
643                ))
644                .metered()
645                .sequenced(StreamPosition {
646                    seq_num: 2,
647                    timestamp: 20,
648                })
649                .into_inner(),
650            ]),
651            tail: Some(StreamPosition {
652                seq_num: 3,
653                timestamp: 30,
654            }),
655        };
656
657        let mapped = batch
658            .decrypt(&crate::encryption::EncryptionSpec::Plain, &[])
659            .unwrap();
660        let records = mapped.records.into_inner();
661
662        assert_eq!(
663            mapped.tail,
664            Some(StreamPosition {
665                seq_num: 3,
666                timestamp: 30
667            })
668        );
669        assert_eq!(
670            records[0].position(),
671            &StreamPosition {
672                seq_num: 1,
673                timestamp: 10
674            }
675        );
676        assert_eq!(
677            records[1].position(),
678            &StreamPosition {
679                seq_num: 2,
680                timestamp: 20
681            }
682        );
683        assert!(matches!(records[0].inner(), Record::Envelope(_)));
684        assert!(matches!(records[1].inner(), Record::Envelope(_)));
685    }
686}