1use std::{marker::PhantomData, ops::Deref, str::FromStr, time::Duration};
2
3use compact_str::{CompactString, ToCompactString};
4use time::OffsetDateTime;
5
6use super::{
7 ValidationError,
8 strings::{NameProps, PrefixProps, StartAfterProps, StrProps},
9};
10use crate::{
11 caps,
12 encryption::{EncryptionAlgorithm, EncryptionSpec},
13 read_extent::{ReadLimit, ReadUntil},
14 record::{
15 FencingToken, Metered, MeteredExt, MeteredSize, Record, RecordDecryptionError, SeqNum,
16 Sequenced, StoredRecord, StreamPosition, Timestamp, decrypt_stored_record, encrypt_record,
17 },
18 types::{
19 config::{OptionalStreamConfig, StreamReconfiguration},
20 resources::{ListItemsRequest, RequestToken},
21 },
22};
23
24#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
25#[cfg_attr(
26 feature = "rkyv",
27 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
28)]
29pub struct StreamNameStr<T: StrProps>(CompactString, PhantomData<T>);
30
31impl<T: StrProps> StreamNameStr<T> {
32 fn validate_str(name: &str) -> Result<(), ValidationError> {
33 if !T::IS_PREFIX && name.is_empty() {
34 return Err(format!("stream {} must not be empty", T::FIELD_NAME).into());
35 }
36
37 if !T::IS_PREFIX && (name == "." || name == "..") {
38 return Err(format!("stream {} must not be \".\" or \"..\"", T::FIELD_NAME).into());
39 }
40
41 if name.len() > caps::MAX_STREAM_NAME_LEN {
42 return Err(format!(
43 "stream {} must not exceed {} bytes in length",
44 T::FIELD_NAME,
45 caps::MAX_STREAM_NAME_LEN
46 )
47 .into());
48 }
49
50 Ok(())
51 }
52}
53
54#[cfg(feature = "utoipa")]
55impl<T> utoipa::PartialSchema for StreamNameStr<T>
56where
57 T: StrProps,
58{
59 fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
60 utoipa::openapi::Object::builder()
61 .schema_type(utoipa::openapi::Type::String)
62 .min_length((!T::IS_PREFIX).then_some(caps::MIN_STREAM_NAME_LEN))
63 .max_length(Some(caps::MAX_STREAM_NAME_LEN))
64 .into()
65 }
66}
67
68#[cfg(feature = "utoipa")]
69impl<T> utoipa::ToSchema for StreamNameStr<T> where T: StrProps {}
70
71impl<T: StrProps> serde::Serialize for StreamNameStr<T> {
72 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73 where
74 S: serde::Serializer,
75 {
76 serializer.serialize_str(&self.0)
77 }
78}
79
80impl<'de, T: StrProps> serde::Deserialize<'de> for StreamNameStr<T> {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 let s = CompactString::deserialize(deserializer)?;
86 s.try_into().map_err(serde::de::Error::custom)
87 }
88}
89
90impl<T: StrProps> AsRef<str> for StreamNameStr<T> {
91 fn as_ref(&self) -> &str {
92 &self.0
93 }
94}
95
96impl<T: StrProps> Deref for StreamNameStr<T> {
97 type Target = str;
98
99 fn deref(&self) -> &Self::Target {
100 &self.0
101 }
102}
103
104impl<T: StrProps> TryFrom<CompactString> for StreamNameStr<T> {
105 type Error = ValidationError;
106
107 fn try_from(name: CompactString) -> Result<Self, Self::Error> {
108 Self::validate_str(&name)?;
109 Ok(Self(name, PhantomData))
110 }
111}
112
113impl<T: StrProps> FromStr for StreamNameStr<T> {
114 type Err = ValidationError;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 Self::validate_str(s)?;
118 Ok(Self(s.to_compact_string(), PhantomData))
119 }
120}
121
122impl<T: StrProps> std::fmt::Debug for StreamNameStr<T> {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.write_str(&self.0)
125 }
126}
127
128impl<T: StrProps> std::fmt::Display for StreamNameStr<T> {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.write_str(&self.0)
131 }
132}
133
134impl<T: StrProps> From<StreamNameStr<T>> for CompactString {
135 fn from(value: StreamNameStr<T>) -> Self {
136 value.0
137 }
138}
139
140pub type StreamName = StreamNameStr<NameProps>;
141
142pub type StreamNamePrefix = StreamNameStr<PrefixProps>;
143
144impl Default for StreamNamePrefix {
145 fn default() -> Self {
146 StreamNameStr(CompactString::default(), PhantomData)
147 }
148}
149
150impl From<StreamName> for StreamNamePrefix {
151 fn from(value: StreamName) -> Self {
152 Self(value.0, PhantomData)
153 }
154}
155
156pub type StreamNameStartAfter = StreamNameStr<StartAfterProps>;
157
158impl Default for StreamNameStartAfter {
159 fn default() -> Self {
160 StreamNameStr(CompactString::default(), PhantomData)
161 }
162}
163
164impl From<StreamName> for StreamNameStartAfter {
165 fn from(value: StreamName) -> Self {
166 Self(value.0, PhantomData)
167 }
168}
169
170#[derive(Debug, Clone)]
171pub struct StreamInfo {
172 pub name: StreamName,
173 pub created_at: OffsetDateTime,
174 pub deleted_at: Option<OffsetDateTime>,
175 pub cipher: Option<EncryptionAlgorithm>,
176}
177
178#[derive(Debug)]
184pub enum CreateStreamIntent {
185 CreateOnly {
190 config: OptionalStreamConfig,
192 request_token: Option<RequestToken>,
194 },
195 CreateOrReconfigure {
200 reconfiguration: StreamReconfiguration,
202 },
203}
204
205#[derive(Debug, Clone)]
206pub struct AppendRecord<T = Record>(AppendRecordParts<T>);
207
208impl<T> AppendRecord<T> {
209 pub fn parts(&self) -> &AppendRecordParts<T> {
210 let Self(parts) = self;
211 parts
212 }
213
214 pub fn into_parts(self) -> AppendRecordParts<T> {
215 let Self(parts) = self;
216 parts
217 }
218}
219
220impl<T> MeteredSize for AppendRecord<T> {
221 fn metered_size(&self) -> usize {
222 self.0.record.metered_size()
223 }
224}
225
226#[derive(Debug, Clone)]
227pub struct AppendRecordParts<T = Record> {
228 pub timestamp: Option<Timestamp>,
229 pub record: Metered<T>,
230}
231
232impl<T> MeteredSize for AppendRecordParts<T> {
233 fn metered_size(&self) -> usize {
234 self.record.metered_size()
235 }
236}
237
238impl<T> From<AppendRecord<T>> for AppendRecordParts<T> {
239 fn from(record: AppendRecord<T>) -> Self {
240 record.into_parts()
241 }
242}
243
244impl<T> TryFrom<AppendRecordParts<T>> for AppendRecord<T> {
245 type Error = &'static str;
246
247 fn try_from(parts: AppendRecordParts<T>) -> Result<Self, Self::Error> {
248 if parts.metered_size() > caps::RECORD_BATCH_MAX.bytes {
249 Err("record must have metered size less than 1 MiB")
250 } else {
251 Ok(Self(parts))
252 }
253 }
254}
255
256#[derive(Clone)]
257pub struct AppendRecordBatch<T = Record>(Metered<Vec<AppendRecord<T>>>);
258
259impl<T> std::fmt::Debug for AppendRecordBatch<T> {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("AppendRecordBatch")
262 .field("num_records", &self.0.len())
263 .field("metered_size", &self.0.metered_size())
264 .finish()
265 }
266}
267
268impl<T> MeteredSize for AppendRecordBatch<T> {
269 fn metered_size(&self) -> usize {
270 self.0.metered_size()
271 }
272}
273
274impl<T> std::ops::Deref for AppendRecordBatch<T> {
275 type Target = [AppendRecord<T>];
276
277 fn deref(&self) -> &Self::Target {
278 &self.0
279 }
280}
281
282impl<T> TryFrom<Metered<Vec<AppendRecord<T>>>> for AppendRecordBatch<T> {
283 type Error = &'static str;
284
285 fn try_from(records: Metered<Vec<AppendRecord<T>>>) -> Result<Self, Self::Error> {
286 if records.is_empty() {
287 return Err("record batch must not be empty");
288 }
289
290 if records.len() > caps::RECORD_BATCH_MAX.count {
291 return Err("record batch must not exceed 1000 records");
292 }
293
294 if records.metered_size() > caps::RECORD_BATCH_MAX.bytes {
295 return Err("record batch must not exceed a metered size of 1 MiB");
296 }
297
298 Ok(Self(records))
299 }
300}
301
302impl<T> TryFrom<Vec<AppendRecord<T>>> for AppendRecordBatch<T> {
303 type Error = &'static str;
304
305 fn try_from(records: Vec<AppendRecord<T>>) -> Result<Self, Self::Error> {
306 let records = Metered::from(records);
307 Self::try_from(records)
308 }
309}
310
311impl<T> IntoIterator for AppendRecordBatch<T> {
312 type Item = AppendRecord<T>;
313 type IntoIter = std::vec::IntoIter<Self::Item>;
314
315 fn into_iter(self) -> Self::IntoIter {
316 self.0.into_iter()
317 }
318}
319
320pub type StoredAppendRecord = AppendRecord<StoredRecord>;
321pub type StoredAppendRecordParts = AppendRecordParts<StoredRecord>;
322pub type StoredAppendRecordBatch = AppendRecordBatch<StoredRecord>;
323
324impl From<AppendRecordParts<Record>> for AppendRecordParts<StoredRecord> {
325 fn from(
326 AppendRecordParts { timestamp, record }: AppendRecordParts<Record>,
327 ) -> AppendRecordParts<StoredRecord> {
328 AppendRecordParts {
329 timestamp,
330 record: StoredRecord::from(record.into_inner()).into(),
331 }
332 }
333}
334
335impl From<AppendRecord<Record>> for AppendRecord<StoredRecord> {
336 fn from(record: AppendRecord<Record>) -> Self {
337 Self(record.into_parts().into())
338 }
339}
340
341impl From<AppendRecordBatch<Record>> for AppendRecordBatch<StoredRecord> {
342 fn from(records: AppendRecordBatch<Record>) -> Self {
343 AppendRecordBatch(
344 records
345 .into_iter()
346 .map(|r| AppendRecord::<StoredRecord>::from(r).metered())
347 .collect(),
348 )
349 }
350}
351
352#[derive(Debug, Clone)]
353pub struct AppendInput<T = Record> {
354 pub records: AppendRecordBatch<T>,
355 pub match_seq_num: Option<SeqNum>,
356 pub fencing_token: Option<FencingToken>,
357}
358
359impl AppendInput<Record> {
360 pub fn encrypt(self, encryption: &EncryptionSpec, aad: &[u8]) -> AppendInput<StoredRecord> {
361 let AppendInput {
362 records,
363 match_seq_num,
364 fencing_token,
365 } = self;
366 let records = AppendRecordBatch(
367 records
368 .into_iter()
369 .map(|record| {
370 let AppendRecordParts { timestamp, record } = record.into_parts();
371 let record = encrypt_record(record, encryption, aad);
372 AppendRecord(AppendRecordParts { timestamp, record }).metered()
373 })
374 .collect(),
375 );
376
377 AppendInput {
378 records,
379 match_seq_num,
380 fencing_token,
381 }
382 }
383}
384
385pub type StoredAppendInput = AppendInput<StoredRecord>;
386
387impl From<AppendInput<Record>> for AppendInput<StoredRecord> {
388 fn from(value: AppendInput<Record>) -> Self {
389 let AppendInput {
390 records,
391 match_seq_num,
392 fencing_token,
393 } = value;
394 let records = records.into();
395 AppendInput {
396 records,
397 match_seq_num,
398 fencing_token,
399 }
400 }
401}
402
403#[derive(Debug, Clone)]
404pub struct AppendAck {
405 pub start: StreamPosition,
406 pub end: StreamPosition,
407 pub tail: StreamPosition,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
411pub enum ReadPosition {
412 SeqNum(SeqNum),
413 Timestamp(Timestamp),
414}
415
416#[derive(Debug, Clone, Copy)]
417pub enum ReadFrom {
418 SeqNum(SeqNum),
419 Timestamp(Timestamp),
420 TailOffset(u64),
421}
422
423impl Default for ReadFrom {
424 fn default() -> Self {
425 Self::SeqNum(0)
426 }
427}
428
429#[derive(Debug, Default, Clone, Copy)]
430pub struct ReadStart {
431 pub from: ReadFrom,
432 pub clamp: bool,
433}
434
435#[derive(Debug, Default, Clone, Copy)]
436pub struct ReadEnd {
437 pub limit: ReadLimit,
438 pub until: ReadUntil,
439 pub wait: Option<Duration>,
440}
441
442impl ReadEnd {
443 pub fn may_follow(&self) -> bool {
444 (self.limit.is_unbounded() && self.until.is_unbounded())
445 || self.wait.is_some_and(|d| d > Duration::ZERO)
446 }
447}
448
449#[derive(Clone)]
450pub struct ReadBatch<T = Record> {
451 pub records: Metered<Vec<Sequenced<T>>>,
452 pub tail: Option<StreamPosition>,
453}
454
455impl<T> Default for ReadBatch<T>
456where
457 T: MeteredSize,
458{
459 fn default() -> Self {
460 Self {
461 records: Metered::default(),
462 tail: None,
463 }
464 }
465}
466
467impl<T> std::fmt::Debug for ReadBatch<T> {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.debug_struct("ReadBatch")
470 .field("num_records", &self.records.len())
471 .field("metered_size", &self.records.metered_size())
472 .field("tail", &self.tail)
473 .finish()
474 }
475}
476
477impl ReadBatch<StoredRecord> {
478 pub fn decrypt(
479 self,
480 encryption: &EncryptionSpec,
481 aad: &[u8],
482 ) -> Result<ReadBatch, RecordDecryptionError> {
483 let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = self
484 .records
485 .into_inner()
486 .into_iter()
487 .map(|record| {
488 let (position, record) = record.into_parts();
489 decrypt_stored_record(record, encryption, aad)
490 .map(|record| record.sequenced(position))
491 })
492 .collect();
493
494 Ok(ReadBatch {
495 records: records?,
496 tail: self.tail,
497 })
498 }
499}
500
501pub type StoredReadBatch = ReadBatch<StoredRecord>;
502
503#[derive(Debug, Clone)]
504pub enum ReadSessionOutput<T = Record> {
505 Heartbeat(StreamPosition),
506 Batch(ReadBatch<T>),
507}
508
509impl ReadSessionOutput<StoredRecord> {
510 pub fn decrypt(
511 self,
512 encryption: &EncryptionSpec,
513 aad: &[u8],
514 ) -> Result<ReadSessionOutput, RecordDecryptionError> {
515 match self {
516 Self::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
517 Self::Batch(batch) => batch.decrypt(encryption, aad).map(ReadSessionOutput::Batch),
518 }
519 }
520}
521
522pub type StoredReadSessionOutput = ReadSessionOutput<StoredRecord>;
523
524pub type ListStreamsRequest = ListItemsRequest<StreamNamePrefix, StreamNameStartAfter>;
525
526#[cfg(test)]
527mod test {
528 use bytes::Bytes;
529 use rstest::rstest;
530
531 use super::{
532 super::strings::{NameProps, PrefixProps, StartAfterProps},
533 *,
534 };
535 use crate::record::{EnvelopeRecord, MeteredExt, Record, StoredRecord, StreamPosition};
536
537 #[rstest]
538 #[case::normal("my-stream".to_owned())]
539 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
540 fn validate_name_ok(#[case] name: String) {
541 assert_eq!(StreamNameStr::<NameProps>::validate_str(&name), Ok(()));
542 }
543
544 #[rstest]
545 #[case::empty("".to_owned())]
546 #[case::dot(".".to_owned())]
547 #[case::dot_dot("..".to_owned())]
548 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
549 fn validate_name_err(#[case] name: String) {
550 StreamNameStr::<NameProps>::validate_str(&name).expect_err("expected validation error");
551 }
552
553 #[rstest]
554 #[case::empty("".to_owned())]
555 #[case::dot(".".to_owned())]
556 #[case::dot_dot("..".to_owned())]
557 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
558 fn validate_prefix_ok(#[case] prefix: String) {
559 assert_eq!(StreamNameStr::<PrefixProps>::validate_str(&prefix), Ok(()));
560 }
561
562 #[rstest]
563 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
564 fn validate_prefix_err(#[case] prefix: String) {
565 StreamNameStr::<PrefixProps>::validate_str(&prefix).expect_err("expected validation error");
566 }
567
568 #[rstest]
569 #[case::empty("".to_owned())]
570 #[case::dot(".".to_owned())]
571 #[case::dot_dot("..".to_owned())]
572 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
573 fn validate_start_after_ok(#[case] start_after: String) {
574 assert_eq!(
575 StreamNameStr::<StartAfterProps>::validate_str(&start_after),
576 Ok(())
577 );
578 }
579
580 #[rstest]
581 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
582 fn validate_start_after_err(#[case] start_after: String) {
583 StreamNameStr::<StartAfterProps>::validate_str(&start_after)
584 .expect_err("expected validation error");
585 }
586
587 const TEST_AAD: &[u8] = b"test-stream-aad";
588
589 fn sample_append_input() -> AppendInput {
590 let record = Record::Envelope(
591 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
592 );
593 AppendInput {
594 records: vec![
595 AppendRecord::try_from(AppendRecordParts {
596 timestamp: Some(42),
597 record: record.metered(),
598 })
599 .unwrap(),
600 ]
601 .try_into()
602 .unwrap(),
603 match_seq_num: Some(7),
604 fencing_token: Some("fence".parse().unwrap()),
605 }
606 }
607
608 #[test]
609 fn append_record_batch_rejects_empty_batches() {
610 let empty_batch: Result<AppendRecordBatch, _> = Vec::<AppendRecord>::new().try_into();
611
612 assert_eq!(empty_batch.unwrap_err(), "record batch must not be empty");
613 }
614
615 #[rstest]
616 #[case::encrypt(true)]
617 #[case::into(false)]
618 fn append_input_to_stored_preserves_metadata(#[case] encrypt: bool) {
619 let encryption = EncryptionSpec::aegis256([0x42; 32]);
620 let mapped = if encrypt {
621 sample_append_input().encrypt(&encryption, TEST_AAD)
622 } else {
623 sample_append_input().into()
624 };
625
626 assert_eq!(mapped.match_seq_num, Some(7));
627 assert_eq!(
628 mapped.fencing_token.as_ref().map(|token| token.as_ref()),
629 Some("fence")
630 );
631
632 let append_record: AppendRecordParts<StoredRecord> = mapped
633 .records
634 .into_iter()
635 .next()
636 .expect("sample append input should contain a single record")
637 .into_parts();
638 assert_eq!(append_record.timestamp, Some(42));
639
640 let stored_record = append_record.record.into_inner();
641 assert_eq!(
642 matches!(&stored_record, StoredRecord::Encrypted { .. }),
643 encrypt
644 );
645
646 let decryption = if encrypt {
647 &encryption
648 } else {
649 &EncryptionSpec::Plain
650 };
651 let decrypted = decrypt_stored_record(stored_record, decryption, TEST_AAD).unwrap();
652 let Record::Envelope(record) = decrypted.into_inner() else {
653 panic!("expected envelope record");
654 };
655 assert_eq!(record.body().as_ref(), b"hello");
656 }
657
658 #[test]
659 fn stored_read_batch_decrypt_preserves_positions_and_tail() {
660 let batch = ReadBatch {
661 records: Metered::from(vec![
662 StoredRecord::Plaintext(Record::Envelope(
663 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"one")).unwrap(),
664 ))
665 .metered()
666 .sequenced(StreamPosition {
667 seq_num: 1,
668 timestamp: 10,
669 })
670 .into_inner(),
671 StoredRecord::Plaintext(Record::Envelope(
672 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"two")).unwrap(),
673 ))
674 .metered()
675 .sequenced(StreamPosition {
676 seq_num: 2,
677 timestamp: 20,
678 })
679 .into_inner(),
680 ]),
681 tail: Some(StreamPosition {
682 seq_num: 3,
683 timestamp: 30,
684 }),
685 };
686
687 let mapped = batch
688 .decrypt(&crate::encryption::EncryptionSpec::Plain, &[])
689 .unwrap();
690 let records = mapped.records.into_inner();
691
692 assert_eq!(
693 mapped.tail,
694 Some(StreamPosition {
695 seq_num: 3,
696 timestamp: 30
697 })
698 );
699 assert_eq!(
700 records[0].position(),
701 &StreamPosition {
702 seq_num: 1,
703 timestamp: 10
704 }
705 );
706 assert_eq!(
707 records[1].position(),
708 &StreamPosition {
709 seq_num: 2,
710 timestamp: 20
711 }
712 );
713 assert!(matches!(records[0].inner(), Record::Envelope(_)));
714 assert!(matches!(records[1].inner(), Record::Envelope(_)));
715 }
716}