1use std::{marker::PhantomData, ops::Deref, str::FromStr, time::Duration};
2
3use compact_str::{CompactString, ToCompactString};
4use time::OffsetDateTime;
5
6use super::{
7 ValidationError,
8 strings::{NameProps, PrefixProps, StartAfterProps, StrProps},
9};
10use crate::{
11 caps,
12 encryption::EncryptionSpec,
13 read_extent::{ReadLimit, ReadUntil},
14 record::{
15 FencingToken, Metered, MeteredExt, MeteredSize, Record, RecordDecryptionError, SeqNum,
16 Sequenced, StoredRecord, StreamPosition, Timestamp, decrypt_stored_record, encrypt_record,
17 },
18 types::resources::ListItemsRequest,
19};
20
21#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22#[cfg_attr(
23 feature = "rkyv",
24 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
25)]
26pub struct StreamNameStr<T: StrProps>(CompactString, PhantomData<T>);
27
28impl<T: StrProps> StreamNameStr<T> {
29 fn validate_str(name: &str) -> Result<(), ValidationError> {
30 if !T::IS_PREFIX && name.is_empty() {
31 return Err(format!("stream {} must not be empty", T::FIELD_NAME).into());
32 }
33
34 if !T::IS_PREFIX && (name == "." || name == "..") {
35 return Err(format!("stream {} must not be \".\" or \"..\"", T::FIELD_NAME).into());
36 }
37
38 if name.len() > caps::MAX_STREAM_NAME_LEN {
39 return Err(format!(
40 "stream {} must not exceed {} bytes in length",
41 T::FIELD_NAME,
42 caps::MAX_STREAM_NAME_LEN
43 )
44 .into());
45 }
46
47 Ok(())
48 }
49}
50
51#[cfg(feature = "utoipa")]
52impl<T> utoipa::PartialSchema for StreamNameStr<T>
53where
54 T: StrProps,
55{
56 fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
57 utoipa::openapi::Object::builder()
58 .schema_type(utoipa::openapi::Type::String)
59 .min_length((!T::IS_PREFIX).then_some(caps::MIN_STREAM_NAME_LEN))
60 .max_length(Some(caps::MAX_STREAM_NAME_LEN))
61 .into()
62 }
63}
64
65#[cfg(feature = "utoipa")]
66impl<T> utoipa::ToSchema for StreamNameStr<T> where T: StrProps {}
67
68impl<T: StrProps> serde::Serialize for StreamNameStr<T> {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer,
72 {
73 serializer.serialize_str(&self.0)
74 }
75}
76
77impl<'de, T: StrProps> serde::Deserialize<'de> for StreamNameStr<T> {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 let s = CompactString::deserialize(deserializer)?;
83 s.try_into().map_err(serde::de::Error::custom)
84 }
85}
86
87impl<T: StrProps> AsRef<str> for StreamNameStr<T> {
88 fn as_ref(&self) -> &str {
89 &self.0
90 }
91}
92
93impl<T: StrProps> Deref for StreamNameStr<T> {
94 type Target = str;
95
96 fn deref(&self) -> &Self::Target {
97 &self.0
98 }
99}
100
101impl<T: StrProps> TryFrom<CompactString> for StreamNameStr<T> {
102 type Error = ValidationError;
103
104 fn try_from(name: CompactString) -> Result<Self, Self::Error> {
105 Self::validate_str(&name)?;
106 Ok(Self(name, PhantomData))
107 }
108}
109
110impl<T: StrProps> FromStr for StreamNameStr<T> {
111 type Err = ValidationError;
112
113 fn from_str(s: &str) -> Result<Self, Self::Err> {
114 Self::validate_str(s)?;
115 Ok(Self(s.to_compact_string(), PhantomData))
116 }
117}
118
119impl<T: StrProps> std::fmt::Debug for StreamNameStr<T> {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.write_str(&self.0)
122 }
123}
124
125impl<T: StrProps> std::fmt::Display for StreamNameStr<T> {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.write_str(&self.0)
128 }
129}
130
131impl<T: StrProps> From<StreamNameStr<T>> for CompactString {
132 fn from(value: StreamNameStr<T>) -> Self {
133 value.0
134 }
135}
136
137pub type StreamName = StreamNameStr<NameProps>;
138
139pub type StreamNamePrefix = StreamNameStr<PrefixProps>;
140
141impl Default for StreamNamePrefix {
142 fn default() -> Self {
143 StreamNameStr(CompactString::default(), PhantomData)
144 }
145}
146
147impl From<StreamName> for StreamNamePrefix {
148 fn from(value: StreamName) -> Self {
149 Self(value.0, PhantomData)
150 }
151}
152
153pub type StreamNameStartAfter = StreamNameStr<StartAfterProps>;
154
155impl Default for StreamNameStartAfter {
156 fn default() -> Self {
157 StreamNameStr(CompactString::default(), PhantomData)
158 }
159}
160
161impl From<StreamName> for StreamNameStartAfter {
162 fn from(value: StreamName) -> Self {
163 Self(value.0, PhantomData)
164 }
165}
166
167#[derive(Debug, Clone)]
168pub struct StreamInfo {
169 pub name: StreamName,
170 pub created_at: OffsetDateTime,
171 pub deleted_at: Option<OffsetDateTime>,
172}
173
174#[derive(Debug, Clone)]
175pub struct AppendRecord<T = Record>(AppendRecordParts<T>);
176
177impl<T> AppendRecord<T> {
178 pub fn parts(&self) -> &AppendRecordParts<T> {
179 let Self(parts) = self;
180 parts
181 }
182
183 pub fn into_parts(self) -> AppendRecordParts<T> {
184 let Self(parts) = self;
185 parts
186 }
187}
188
189impl<T> MeteredSize for AppendRecord<T> {
190 fn metered_size(&self) -> usize {
191 self.0.record.metered_size()
192 }
193}
194
195#[derive(Debug, Clone)]
196pub struct AppendRecordParts<T = Record> {
197 pub timestamp: Option<Timestamp>,
198 pub record: Metered<T>,
199}
200
201impl<T> MeteredSize for AppendRecordParts<T> {
202 fn metered_size(&self) -> usize {
203 self.record.metered_size()
204 }
205}
206
207impl<T> From<AppendRecord<T>> for AppendRecordParts<T> {
208 fn from(record: AppendRecord<T>) -> Self {
209 record.into_parts()
210 }
211}
212
213impl<T> TryFrom<AppendRecordParts<T>> for AppendRecord<T> {
214 type Error = &'static str;
215
216 fn try_from(parts: AppendRecordParts<T>) -> Result<Self, Self::Error> {
217 if parts.metered_size() > caps::RECORD_BATCH_MAX.bytes {
218 Err("record must have metered size less than 1 MiB")
219 } else {
220 Ok(Self(parts))
221 }
222 }
223}
224
225#[derive(Clone)]
226pub struct AppendRecordBatch<T = Record>(Metered<Vec<AppendRecord<T>>>);
227
228impl<T> std::fmt::Debug for AppendRecordBatch<T> {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("AppendRecordBatch")
231 .field("num_records", &self.0.len())
232 .field("metered_size", &self.0.metered_size())
233 .finish()
234 }
235}
236
237impl<T> MeteredSize for AppendRecordBatch<T> {
238 fn metered_size(&self) -> usize {
239 self.0.metered_size()
240 }
241}
242
243impl<T> std::ops::Deref for AppendRecordBatch<T> {
244 type Target = [AppendRecord<T>];
245
246 fn deref(&self) -> &Self::Target {
247 &self.0
248 }
249}
250
251impl<T> TryFrom<Metered<Vec<AppendRecord<T>>>> for AppendRecordBatch<T> {
252 type Error = &'static str;
253
254 fn try_from(records: Metered<Vec<AppendRecord<T>>>) -> Result<Self, Self::Error> {
255 if records.is_empty() {
256 return Err("record batch must not be empty");
257 }
258
259 if records.len() > caps::RECORD_BATCH_MAX.count {
260 return Err("record batch must not exceed 1000 records");
261 }
262
263 if records.metered_size() > caps::RECORD_BATCH_MAX.bytes {
264 return Err("record batch must not exceed a metered size of 1 MiB");
265 }
266
267 Ok(Self(records))
268 }
269}
270
271impl<T> TryFrom<Vec<AppendRecord<T>>> for AppendRecordBatch<T> {
272 type Error = &'static str;
273
274 fn try_from(records: Vec<AppendRecord<T>>) -> Result<Self, Self::Error> {
275 let records = Metered::from(records);
276 Self::try_from(records)
277 }
278}
279
280impl<T> IntoIterator for AppendRecordBatch<T> {
281 type Item = AppendRecord<T>;
282 type IntoIter = std::vec::IntoIter<Self::Item>;
283
284 fn into_iter(self) -> Self::IntoIter {
285 self.0.into_iter()
286 }
287}
288
289pub type StoredAppendRecord = AppendRecord<StoredRecord>;
290pub type StoredAppendRecordParts = AppendRecordParts<StoredRecord>;
291pub type StoredAppendRecordBatch = AppendRecordBatch<StoredRecord>;
292
293impl From<AppendRecordParts<Record>> for AppendRecordParts<StoredRecord> {
294 fn from(
295 AppendRecordParts { timestamp, record }: AppendRecordParts<Record>,
296 ) -> AppendRecordParts<StoredRecord> {
297 AppendRecordParts {
298 timestamp,
299 record: StoredRecord::from(record.into_inner()).into(),
300 }
301 }
302}
303
304impl From<AppendRecord<Record>> for AppendRecord<StoredRecord> {
305 fn from(record: AppendRecord<Record>) -> Self {
306 Self(record.into_parts().into())
307 }
308}
309
310impl From<AppendRecordBatch<Record>> for AppendRecordBatch<StoredRecord> {
311 fn from(records: AppendRecordBatch<Record>) -> Self {
312 AppendRecordBatch(
313 records
314 .into_iter()
315 .map(|r| AppendRecord::<StoredRecord>::from(r).metered())
316 .collect(),
317 )
318 }
319}
320
321#[derive(Debug, Clone)]
322pub struct AppendInput<T = Record> {
323 pub records: AppendRecordBatch<T>,
324 pub match_seq_num: Option<SeqNum>,
325 pub fencing_token: Option<FencingToken>,
326}
327
328impl AppendInput<Record> {
329 pub fn encrypt(self, encryption: &EncryptionSpec, aad: &[u8]) -> AppendInput<StoredRecord> {
330 let AppendInput {
331 records,
332 match_seq_num,
333 fencing_token,
334 } = self;
335 let records = AppendRecordBatch(
336 records
337 .into_iter()
338 .map(|record| {
339 let AppendRecordParts { timestamp, record } = record.into_parts();
340 let record = encrypt_record(record, encryption, aad);
341 AppendRecord(AppendRecordParts { timestamp, record }).metered()
342 })
343 .collect(),
344 );
345
346 AppendInput {
347 records,
348 match_seq_num,
349 fencing_token,
350 }
351 }
352}
353
354pub type StoredAppendInput = AppendInput<StoredRecord>;
355
356impl From<AppendInput<Record>> for AppendInput<StoredRecord> {
357 fn from(value: AppendInput<Record>) -> Self {
358 let AppendInput {
359 records,
360 match_seq_num,
361 fencing_token,
362 } = value;
363 let records = records.into();
364 AppendInput {
365 records,
366 match_seq_num,
367 fencing_token,
368 }
369 }
370}
371
372#[derive(Debug, Clone)]
373pub struct AppendAck {
374 pub start: StreamPosition,
375 pub end: StreamPosition,
376 pub tail: StreamPosition,
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum ReadPosition {
381 SeqNum(SeqNum),
382 Timestamp(Timestamp),
383}
384
385#[derive(Debug, Clone, Copy)]
386pub enum ReadFrom {
387 SeqNum(SeqNum),
388 Timestamp(Timestamp),
389 TailOffset(u64),
390}
391
392impl Default for ReadFrom {
393 fn default() -> Self {
394 Self::SeqNum(0)
395 }
396}
397
398#[derive(Debug, Default, Clone, Copy)]
399pub struct ReadStart {
400 pub from: ReadFrom,
401 pub clamp: bool,
402}
403
404#[derive(Debug, Default, Clone, Copy)]
405pub struct ReadEnd {
406 pub limit: ReadLimit,
407 pub until: ReadUntil,
408 pub wait: Option<Duration>,
409}
410
411impl ReadEnd {
412 pub fn may_follow(&self) -> bool {
413 (self.limit.is_unbounded() && self.until.is_unbounded())
414 || self.wait.is_some_and(|d| d > Duration::ZERO)
415 }
416}
417
418#[derive(Clone)]
419pub struct ReadBatch<T = Record> {
420 pub records: Metered<Vec<Sequenced<T>>>,
421 pub tail: Option<StreamPosition>,
422}
423
424impl<T> Default for ReadBatch<T>
425where
426 T: MeteredSize,
427{
428 fn default() -> Self {
429 Self {
430 records: Metered::default(),
431 tail: None,
432 }
433 }
434}
435
436impl<T> std::fmt::Debug for ReadBatch<T> {
437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438 f.debug_struct("ReadBatch")
439 .field("num_records", &self.records.len())
440 .field("metered_size", &self.records.metered_size())
441 .field("tail", &self.tail)
442 .finish()
443 }
444}
445
446impl ReadBatch<StoredRecord> {
447 pub fn decrypt(
448 self,
449 encryption: &EncryptionSpec,
450 aad: &[u8],
451 ) -> Result<ReadBatch, RecordDecryptionError> {
452 let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = self
453 .records
454 .into_inner()
455 .into_iter()
456 .map(|record| {
457 let (position, record) = record.into_parts();
458 decrypt_stored_record(record, encryption, aad)
459 .map(|record| record.sequenced(position))
460 })
461 .collect();
462
463 Ok(ReadBatch {
464 records: records?,
465 tail: self.tail,
466 })
467 }
468}
469
470pub type StoredReadBatch = ReadBatch<StoredRecord>;
471
472#[derive(Debug, Clone)]
473pub enum ReadSessionOutput<T = Record> {
474 Heartbeat(StreamPosition),
475 Batch(ReadBatch<T>),
476}
477
478impl ReadSessionOutput<StoredRecord> {
479 pub fn decrypt(
480 self,
481 encryption: &EncryptionSpec,
482 aad: &[u8],
483 ) -> Result<ReadSessionOutput, RecordDecryptionError> {
484 match self {
485 Self::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
486 Self::Batch(batch) => batch.decrypt(encryption, aad).map(ReadSessionOutput::Batch),
487 }
488 }
489}
490
491pub type StoredReadSessionOutput = ReadSessionOutput<StoredRecord>;
492
493pub type ListStreamsRequest = ListItemsRequest<StreamNamePrefix, StreamNameStartAfter>;
494
495#[cfg(test)]
496mod test {
497 use bytes::Bytes;
498 use rstest::rstest;
499
500 use super::{
501 super::strings::{NameProps, PrefixProps, StartAfterProps},
502 *,
503 };
504 use crate::record::{EnvelopeRecord, MeteredExt, Record, StoredRecord, StreamPosition};
505
506 #[rstest]
507 #[case::normal("my-stream".to_owned())]
508 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
509 fn validate_name_ok(#[case] name: String) {
510 assert_eq!(StreamNameStr::<NameProps>::validate_str(&name), Ok(()));
511 }
512
513 #[rstest]
514 #[case::empty("".to_owned())]
515 #[case::dot(".".to_owned())]
516 #[case::dot_dot("..".to_owned())]
517 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
518 fn validate_name_err(#[case] name: String) {
519 StreamNameStr::<NameProps>::validate_str(&name).expect_err("expected validation error");
520 }
521
522 #[rstest]
523 #[case::empty("".to_owned())]
524 #[case::dot(".".to_owned())]
525 #[case::dot_dot("..".to_owned())]
526 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
527 fn validate_prefix_ok(#[case] prefix: String) {
528 assert_eq!(StreamNameStr::<PrefixProps>::validate_str(&prefix), Ok(()));
529 }
530
531 #[rstest]
532 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
533 fn validate_prefix_err(#[case] prefix: String) {
534 StreamNameStr::<PrefixProps>::validate_str(&prefix).expect_err("expected validation error");
535 }
536
537 #[rstest]
538 #[case::empty("".to_owned())]
539 #[case::dot(".".to_owned())]
540 #[case::dot_dot("..".to_owned())]
541 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
542 fn validate_start_after_ok(#[case] start_after: String) {
543 assert_eq!(
544 StreamNameStr::<StartAfterProps>::validate_str(&start_after),
545 Ok(())
546 );
547 }
548
549 #[rstest]
550 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
551 fn validate_start_after_err(#[case] start_after: String) {
552 StreamNameStr::<StartAfterProps>::validate_str(&start_after)
553 .expect_err("expected validation error");
554 }
555
556 const TEST_AAD: &[u8] = b"test-stream-aad";
557
558 fn sample_append_input() -> AppendInput {
559 let record = Record::Envelope(
560 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
561 );
562 AppendInput {
563 records: vec![
564 AppendRecord::try_from(AppendRecordParts {
565 timestamp: Some(42),
566 record: record.metered(),
567 })
568 .unwrap(),
569 ]
570 .try_into()
571 .unwrap(),
572 match_seq_num: Some(7),
573 fencing_token: Some("fence".parse().unwrap()),
574 }
575 }
576
577 #[test]
578 fn append_record_batch_rejects_empty_batches() {
579 let empty_batch: Result<AppendRecordBatch, _> = Vec::<AppendRecord>::new().try_into();
580
581 assert_eq!(empty_batch.unwrap_err(), "record batch must not be empty");
582 }
583
584 #[rstest]
585 #[case::encrypt(true)]
586 #[case::into(false)]
587 fn append_input_to_stored_preserves_metadata(#[case] encrypt: bool) {
588 let encryption = EncryptionSpec::aegis256([0x42; 32]);
589 let mapped = if encrypt {
590 sample_append_input().encrypt(&encryption, TEST_AAD)
591 } else {
592 sample_append_input().into()
593 };
594
595 assert_eq!(mapped.match_seq_num, Some(7));
596 assert_eq!(
597 mapped.fencing_token.as_ref().map(|token| token.as_ref()),
598 Some("fence")
599 );
600
601 let append_record: AppendRecordParts<StoredRecord> = mapped
602 .records
603 .into_iter()
604 .next()
605 .expect("sample append input should contain a single record")
606 .into_parts();
607 assert_eq!(append_record.timestamp, Some(42));
608
609 let stored_record = append_record.record.into_inner();
610 assert_eq!(
611 matches!(&stored_record, StoredRecord::Encrypted { .. }),
612 encrypt
613 );
614
615 let decrypted = decrypt_stored_record(stored_record, &encryption, TEST_AAD).unwrap();
616 let Record::Envelope(record) = decrypted.into_inner() else {
617 panic!("expected envelope record");
618 };
619 assert_eq!(record.body().as_ref(), b"hello");
620 }
621
622 #[test]
623 fn stored_read_batch_decrypt_preserves_positions_and_tail() {
624 let batch = ReadBatch {
625 records: Metered::from(vec![
626 StoredRecord::Plaintext(Record::Envelope(
627 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"one")).unwrap(),
628 ))
629 .metered()
630 .sequenced(StreamPosition {
631 seq_num: 1,
632 timestamp: 10,
633 })
634 .into_inner(),
635 StoredRecord::Plaintext(Record::Envelope(
636 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"two")).unwrap(),
637 ))
638 .metered()
639 .sequenced(StreamPosition {
640 seq_num: 2,
641 timestamp: 20,
642 })
643 .into_inner(),
644 ]),
645 tail: Some(StreamPosition {
646 seq_num: 3,
647 timestamp: 30,
648 }),
649 };
650
651 let mapped = batch
652 .decrypt(&crate::encryption::EncryptionSpec::Plain, &[])
653 .unwrap();
654 let records = mapped.records.into_inner();
655
656 assert_eq!(
657 mapped.tail,
658 Some(StreamPosition {
659 seq_num: 3,
660 timestamp: 30
661 })
662 );
663 assert_eq!(
664 records[0].position(),
665 &StreamPosition {
666 seq_num: 1,
667 timestamp: 10
668 }
669 );
670 assert_eq!(
671 records[1].position(),
672 &StreamPosition {
673 seq_num: 2,
674 timestamp: 20
675 }
676 );
677 assert!(matches!(records[0].inner(), Record::Envelope(_)));
678 assert!(matches!(records[1].inner(), Record::Envelope(_)));
679 }
680}