1use std::{marker::PhantomData, ops::Deref, str::FromStr, time::Duration};
2
3use compact_str::{CompactString, ToCompactString};
4use time::OffsetDateTime;
5
6use super::{
7 ValidationError,
8 strings::{NameProps, PrefixProps, StartAfterProps, StrProps},
9};
10use crate::{
11 caps,
12 encryption::{EncryptionAlgorithm, EncryptionSpec},
13 read_extent::{ReadLimit, ReadUntil},
14 record::{
15 FencingToken, Metered, MeteredExt, MeteredSize, Record, RecordDecryptionError, SeqNum,
16 Sequenced, StoredRecord, StreamPosition, Timestamp, decrypt_stored_record, encrypt_record,
17 },
18 types::resources::ListItemsRequest,
19};
20
21#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22#[cfg_attr(
23 feature = "rkyv",
24 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
25)]
26pub struct StreamNameStr<T: StrProps>(CompactString, PhantomData<T>);
27
28impl<T: StrProps> StreamNameStr<T> {
29 fn validate_str(name: &str) -> Result<(), ValidationError> {
30 if !T::IS_PREFIX && name.is_empty() {
31 return Err(format!("stream {} must not be empty", T::FIELD_NAME).into());
32 }
33
34 if !T::IS_PREFIX && (name == "." || name == "..") {
35 return Err(format!("stream {} must not be \".\" or \"..\"", T::FIELD_NAME).into());
36 }
37
38 if name.len() > caps::MAX_STREAM_NAME_LEN {
39 return Err(format!(
40 "stream {} must not exceed {} bytes in length",
41 T::FIELD_NAME,
42 caps::MAX_STREAM_NAME_LEN
43 )
44 .into());
45 }
46
47 Ok(())
48 }
49}
50
51#[cfg(feature = "utoipa")]
52impl<T> utoipa::PartialSchema for StreamNameStr<T>
53where
54 T: StrProps,
55{
56 fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
57 utoipa::openapi::Object::builder()
58 .schema_type(utoipa::openapi::Type::String)
59 .min_length((!T::IS_PREFIX).then_some(caps::MIN_STREAM_NAME_LEN))
60 .max_length(Some(caps::MAX_STREAM_NAME_LEN))
61 .into()
62 }
63}
64
65#[cfg(feature = "utoipa")]
66impl<T> utoipa::ToSchema for StreamNameStr<T> where T: StrProps {}
67
68impl<T: StrProps> serde::Serialize for StreamNameStr<T> {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer,
72 {
73 serializer.serialize_str(&self.0)
74 }
75}
76
77impl<'de, T: StrProps> serde::Deserialize<'de> for StreamNameStr<T> {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 let s = CompactString::deserialize(deserializer)?;
83 s.try_into().map_err(serde::de::Error::custom)
84 }
85}
86
87impl<T: StrProps> AsRef<str> for StreamNameStr<T> {
88 fn as_ref(&self) -> &str {
89 &self.0
90 }
91}
92
93impl<T: StrProps> Deref for StreamNameStr<T> {
94 type Target = str;
95
96 fn deref(&self) -> &Self::Target {
97 &self.0
98 }
99}
100
101impl<T: StrProps> TryFrom<CompactString> for StreamNameStr<T> {
102 type Error = ValidationError;
103
104 fn try_from(name: CompactString) -> Result<Self, Self::Error> {
105 Self::validate_str(&name)?;
106 Ok(Self(name, PhantomData))
107 }
108}
109
110impl<T: StrProps> FromStr for StreamNameStr<T> {
111 type Err = ValidationError;
112
113 fn from_str(s: &str) -> Result<Self, Self::Err> {
114 Self::validate_str(s)?;
115 Ok(Self(s.to_compact_string(), PhantomData))
116 }
117}
118
119impl<T: StrProps> std::fmt::Debug for StreamNameStr<T> {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.write_str(&self.0)
122 }
123}
124
125impl<T: StrProps> std::fmt::Display for StreamNameStr<T> {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.write_str(&self.0)
128 }
129}
130
131impl<T: StrProps> From<StreamNameStr<T>> for CompactString {
132 fn from(value: StreamNameStr<T>) -> Self {
133 value.0
134 }
135}
136
137pub type StreamName = StreamNameStr<NameProps>;
138
139pub type StreamNamePrefix = StreamNameStr<PrefixProps>;
140
141impl Default for StreamNamePrefix {
142 fn default() -> Self {
143 StreamNameStr(CompactString::default(), PhantomData)
144 }
145}
146
147impl From<StreamName> for StreamNamePrefix {
148 fn from(value: StreamName) -> Self {
149 Self(value.0, PhantomData)
150 }
151}
152
153pub type StreamNameStartAfter = StreamNameStr<StartAfterProps>;
154
155impl Default for StreamNameStartAfter {
156 fn default() -> Self {
157 StreamNameStr(CompactString::default(), PhantomData)
158 }
159}
160
161impl From<StreamName> for StreamNameStartAfter {
162 fn from(value: StreamName) -> Self {
163 Self(value.0, PhantomData)
164 }
165}
166
167#[derive(Debug, Clone)]
168pub struct StreamInfo {
169 pub name: StreamName,
170 pub created_at: OffsetDateTime,
171 pub deleted_at: Option<OffsetDateTime>,
172 pub cipher: Option<EncryptionAlgorithm>,
173}
174
175#[derive(Debug, Clone)]
176pub struct AppendRecord<T = Record>(AppendRecordParts<T>);
177
178impl<T> AppendRecord<T> {
179 pub fn parts(&self) -> &AppendRecordParts<T> {
180 let Self(parts) = self;
181 parts
182 }
183
184 pub fn into_parts(self) -> AppendRecordParts<T> {
185 let Self(parts) = self;
186 parts
187 }
188}
189
190impl<T> MeteredSize for AppendRecord<T> {
191 fn metered_size(&self) -> usize {
192 self.0.record.metered_size()
193 }
194}
195
196#[derive(Debug, Clone)]
197pub struct AppendRecordParts<T = Record> {
198 pub timestamp: Option<Timestamp>,
199 pub record: Metered<T>,
200}
201
202impl<T> MeteredSize for AppendRecordParts<T> {
203 fn metered_size(&self) -> usize {
204 self.record.metered_size()
205 }
206}
207
208impl<T> From<AppendRecord<T>> for AppendRecordParts<T> {
209 fn from(record: AppendRecord<T>) -> Self {
210 record.into_parts()
211 }
212}
213
214impl<T> TryFrom<AppendRecordParts<T>> for AppendRecord<T> {
215 type Error = &'static str;
216
217 fn try_from(parts: AppendRecordParts<T>) -> Result<Self, Self::Error> {
218 if parts.metered_size() > caps::RECORD_BATCH_MAX.bytes {
219 Err("record must have metered size less than 1 MiB")
220 } else {
221 Ok(Self(parts))
222 }
223 }
224}
225
226#[derive(Clone)]
227pub struct AppendRecordBatch<T = Record>(Metered<Vec<AppendRecord<T>>>);
228
229impl<T> std::fmt::Debug for AppendRecordBatch<T> {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("AppendRecordBatch")
232 .field("num_records", &self.0.len())
233 .field("metered_size", &self.0.metered_size())
234 .finish()
235 }
236}
237
238impl<T> MeteredSize for AppendRecordBatch<T> {
239 fn metered_size(&self) -> usize {
240 self.0.metered_size()
241 }
242}
243
244impl<T> std::ops::Deref for AppendRecordBatch<T> {
245 type Target = [AppendRecord<T>];
246
247 fn deref(&self) -> &Self::Target {
248 &self.0
249 }
250}
251
252impl<T> TryFrom<Metered<Vec<AppendRecord<T>>>> for AppendRecordBatch<T> {
253 type Error = &'static str;
254
255 fn try_from(records: Metered<Vec<AppendRecord<T>>>) -> Result<Self, Self::Error> {
256 if records.is_empty() {
257 return Err("record batch must not be empty");
258 }
259
260 if records.len() > caps::RECORD_BATCH_MAX.count {
261 return Err("record batch must not exceed 1000 records");
262 }
263
264 if records.metered_size() > caps::RECORD_BATCH_MAX.bytes {
265 return Err("record batch must not exceed a metered size of 1 MiB");
266 }
267
268 Ok(Self(records))
269 }
270}
271
272impl<T> TryFrom<Vec<AppendRecord<T>>> for AppendRecordBatch<T> {
273 type Error = &'static str;
274
275 fn try_from(records: Vec<AppendRecord<T>>) -> Result<Self, Self::Error> {
276 let records = Metered::from(records);
277 Self::try_from(records)
278 }
279}
280
281impl<T> IntoIterator for AppendRecordBatch<T> {
282 type Item = AppendRecord<T>;
283 type IntoIter = std::vec::IntoIter<Self::Item>;
284
285 fn into_iter(self) -> Self::IntoIter {
286 self.0.into_iter()
287 }
288}
289
290pub type StoredAppendRecord = AppendRecord<StoredRecord>;
291pub type StoredAppendRecordParts = AppendRecordParts<StoredRecord>;
292pub type StoredAppendRecordBatch = AppendRecordBatch<StoredRecord>;
293
294impl From<AppendRecordParts<Record>> for AppendRecordParts<StoredRecord> {
295 fn from(
296 AppendRecordParts { timestamp, record }: AppendRecordParts<Record>,
297 ) -> AppendRecordParts<StoredRecord> {
298 AppendRecordParts {
299 timestamp,
300 record: StoredRecord::from(record.into_inner()).into(),
301 }
302 }
303}
304
305impl From<AppendRecord<Record>> for AppendRecord<StoredRecord> {
306 fn from(record: AppendRecord<Record>) -> Self {
307 Self(record.into_parts().into())
308 }
309}
310
311impl From<AppendRecordBatch<Record>> for AppendRecordBatch<StoredRecord> {
312 fn from(records: AppendRecordBatch<Record>) -> Self {
313 AppendRecordBatch(
314 records
315 .into_iter()
316 .map(|r| AppendRecord::<StoredRecord>::from(r).metered())
317 .collect(),
318 )
319 }
320}
321
322#[derive(Debug, Clone)]
323pub struct AppendInput<T = Record> {
324 pub records: AppendRecordBatch<T>,
325 pub match_seq_num: Option<SeqNum>,
326 pub fencing_token: Option<FencingToken>,
327}
328
329impl AppendInput<Record> {
330 pub fn encrypt(self, encryption: &EncryptionSpec, aad: &[u8]) -> AppendInput<StoredRecord> {
331 let AppendInput {
332 records,
333 match_seq_num,
334 fencing_token,
335 } = self;
336 let records = AppendRecordBatch(
337 records
338 .into_iter()
339 .map(|record| {
340 let AppendRecordParts { timestamp, record } = record.into_parts();
341 let record = encrypt_record(record, encryption, aad);
342 AppendRecord(AppendRecordParts { timestamp, record }).metered()
343 })
344 .collect(),
345 );
346
347 AppendInput {
348 records,
349 match_seq_num,
350 fencing_token,
351 }
352 }
353}
354
355pub type StoredAppendInput = AppendInput<StoredRecord>;
356
357impl From<AppendInput<Record>> for AppendInput<StoredRecord> {
358 fn from(value: AppendInput<Record>) -> Self {
359 let AppendInput {
360 records,
361 match_seq_num,
362 fencing_token,
363 } = value;
364 let records = records.into();
365 AppendInput {
366 records,
367 match_seq_num,
368 fencing_token,
369 }
370 }
371}
372
373#[derive(Debug, Clone)]
374pub struct AppendAck {
375 pub start: StreamPosition,
376 pub end: StreamPosition,
377 pub tail: StreamPosition,
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, Eq)]
381pub enum ReadPosition {
382 SeqNum(SeqNum),
383 Timestamp(Timestamp),
384}
385
386#[derive(Debug, Clone, Copy)]
387pub enum ReadFrom {
388 SeqNum(SeqNum),
389 Timestamp(Timestamp),
390 TailOffset(u64),
391}
392
393impl Default for ReadFrom {
394 fn default() -> Self {
395 Self::SeqNum(0)
396 }
397}
398
399#[derive(Debug, Default, Clone, Copy)]
400pub struct ReadStart {
401 pub from: ReadFrom,
402 pub clamp: bool,
403}
404
405#[derive(Debug, Default, Clone, Copy)]
406pub struct ReadEnd {
407 pub limit: ReadLimit,
408 pub until: ReadUntil,
409 pub wait: Option<Duration>,
410}
411
412impl ReadEnd {
413 pub fn may_follow(&self) -> bool {
414 (self.limit.is_unbounded() && self.until.is_unbounded())
415 || self.wait.is_some_and(|d| d > Duration::ZERO)
416 }
417}
418
419#[derive(Clone)]
420pub struct ReadBatch<T = Record> {
421 pub records: Metered<Vec<Sequenced<T>>>,
422 pub tail: Option<StreamPosition>,
423}
424
425impl<T> Default for ReadBatch<T>
426where
427 T: MeteredSize,
428{
429 fn default() -> Self {
430 Self {
431 records: Metered::default(),
432 tail: None,
433 }
434 }
435}
436
437impl<T> std::fmt::Debug for ReadBatch<T> {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 f.debug_struct("ReadBatch")
440 .field("num_records", &self.records.len())
441 .field("metered_size", &self.records.metered_size())
442 .field("tail", &self.tail)
443 .finish()
444 }
445}
446
447impl ReadBatch<StoredRecord> {
448 pub fn decrypt(
449 self,
450 encryption: &EncryptionSpec,
451 aad: &[u8],
452 ) -> Result<ReadBatch, RecordDecryptionError> {
453 let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = self
454 .records
455 .into_inner()
456 .into_iter()
457 .map(|record| {
458 let (position, record) = record.into_parts();
459 decrypt_stored_record(record, encryption, aad)
460 .map(|record| record.sequenced(position))
461 })
462 .collect();
463
464 Ok(ReadBatch {
465 records: records?,
466 tail: self.tail,
467 })
468 }
469}
470
471pub type StoredReadBatch = ReadBatch<StoredRecord>;
472
473#[derive(Debug, Clone)]
474pub enum ReadSessionOutput<T = Record> {
475 Heartbeat(StreamPosition),
476 Batch(ReadBatch<T>),
477}
478
479impl ReadSessionOutput<StoredRecord> {
480 pub fn decrypt(
481 self,
482 encryption: &EncryptionSpec,
483 aad: &[u8],
484 ) -> Result<ReadSessionOutput, RecordDecryptionError> {
485 match self {
486 Self::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
487 Self::Batch(batch) => batch.decrypt(encryption, aad).map(ReadSessionOutput::Batch),
488 }
489 }
490}
491
492pub type StoredReadSessionOutput = ReadSessionOutput<StoredRecord>;
493
494pub type ListStreamsRequest = ListItemsRequest<StreamNamePrefix, StreamNameStartAfter>;
495
496#[cfg(test)]
497mod test {
498 use bytes::Bytes;
499 use rstest::rstest;
500
501 use super::{
502 super::strings::{NameProps, PrefixProps, StartAfterProps},
503 *,
504 };
505 use crate::record::{EnvelopeRecord, MeteredExt, Record, StoredRecord, StreamPosition};
506
507 #[rstest]
508 #[case::normal("my-stream".to_owned())]
509 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
510 fn validate_name_ok(#[case] name: String) {
511 assert_eq!(StreamNameStr::<NameProps>::validate_str(&name), Ok(()));
512 }
513
514 #[rstest]
515 #[case::empty("".to_owned())]
516 #[case::dot(".".to_owned())]
517 #[case::dot_dot("..".to_owned())]
518 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
519 fn validate_name_err(#[case] name: String) {
520 StreamNameStr::<NameProps>::validate_str(&name).expect_err("expected validation error");
521 }
522
523 #[rstest]
524 #[case::empty("".to_owned())]
525 #[case::dot(".".to_owned())]
526 #[case::dot_dot("..".to_owned())]
527 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
528 fn validate_prefix_ok(#[case] prefix: String) {
529 assert_eq!(StreamNameStr::<PrefixProps>::validate_str(&prefix), Ok(()));
530 }
531
532 #[rstest]
533 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
534 fn validate_prefix_err(#[case] prefix: String) {
535 StreamNameStr::<PrefixProps>::validate_str(&prefix).expect_err("expected validation error");
536 }
537
538 #[rstest]
539 #[case::empty("".to_owned())]
540 #[case::dot(".".to_owned())]
541 #[case::dot_dot("..".to_owned())]
542 #[case::max_len("a".repeat(crate::caps::MAX_STREAM_NAME_LEN))]
543 fn validate_start_after_ok(#[case] start_after: String) {
544 assert_eq!(
545 StreamNameStr::<StartAfterProps>::validate_str(&start_after),
546 Ok(())
547 );
548 }
549
550 #[rstest]
551 #[case::too_long("a".repeat(crate::caps::MAX_STREAM_NAME_LEN + 1))]
552 fn validate_start_after_err(#[case] start_after: String) {
553 StreamNameStr::<StartAfterProps>::validate_str(&start_after)
554 .expect_err("expected validation error");
555 }
556
557 const TEST_AAD: &[u8] = b"test-stream-aad";
558
559 fn sample_append_input() -> AppendInput {
560 let record = Record::Envelope(
561 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
562 );
563 AppendInput {
564 records: vec![
565 AppendRecord::try_from(AppendRecordParts {
566 timestamp: Some(42),
567 record: record.metered(),
568 })
569 .unwrap(),
570 ]
571 .try_into()
572 .unwrap(),
573 match_seq_num: Some(7),
574 fencing_token: Some("fence".parse().unwrap()),
575 }
576 }
577
578 #[test]
579 fn append_record_batch_rejects_empty_batches() {
580 let empty_batch: Result<AppendRecordBatch, _> = Vec::<AppendRecord>::new().try_into();
581
582 assert_eq!(empty_batch.unwrap_err(), "record batch must not be empty");
583 }
584
585 #[rstest]
586 #[case::encrypt(true)]
587 #[case::into(false)]
588 fn append_input_to_stored_preserves_metadata(#[case] encrypt: bool) {
589 let encryption = EncryptionSpec::aegis256([0x42; 32]);
590 let mapped = if encrypt {
591 sample_append_input().encrypt(&encryption, TEST_AAD)
592 } else {
593 sample_append_input().into()
594 };
595
596 assert_eq!(mapped.match_seq_num, Some(7));
597 assert_eq!(
598 mapped.fencing_token.as_ref().map(|token| token.as_ref()),
599 Some("fence")
600 );
601
602 let append_record: AppendRecordParts<StoredRecord> = mapped
603 .records
604 .into_iter()
605 .next()
606 .expect("sample append input should contain a single record")
607 .into_parts();
608 assert_eq!(append_record.timestamp, Some(42));
609
610 let stored_record = append_record.record.into_inner();
611 assert_eq!(
612 matches!(&stored_record, StoredRecord::Encrypted { .. }),
613 encrypt
614 );
615
616 let decryption = if encrypt {
617 &encryption
618 } else {
619 &EncryptionSpec::Plain
620 };
621 let decrypted = decrypt_stored_record(stored_record, decryption, TEST_AAD).unwrap();
622 let Record::Envelope(record) = decrypted.into_inner() else {
623 panic!("expected envelope record");
624 };
625 assert_eq!(record.body().as_ref(), b"hello");
626 }
627
628 #[test]
629 fn stored_read_batch_decrypt_preserves_positions_and_tail() {
630 let batch = ReadBatch {
631 records: Metered::from(vec![
632 StoredRecord::Plaintext(Record::Envelope(
633 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"one")).unwrap(),
634 ))
635 .metered()
636 .sequenced(StreamPosition {
637 seq_num: 1,
638 timestamp: 10,
639 })
640 .into_inner(),
641 StoredRecord::Plaintext(Record::Envelope(
642 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"two")).unwrap(),
643 ))
644 .metered()
645 .sequenced(StreamPosition {
646 seq_num: 2,
647 timestamp: 20,
648 })
649 .into_inner(),
650 ]),
651 tail: Some(StreamPosition {
652 seq_num: 3,
653 timestamp: 30,
654 }),
655 };
656
657 let mapped = batch
658 .decrypt(&crate::encryption::EncryptionSpec::Plain, &[])
659 .unwrap();
660 let records = mapped.records.into_inner();
661
662 assert_eq!(
663 mapped.tail,
664 Some(StreamPosition {
665 seq_num: 3,
666 timestamp: 30
667 })
668 );
669 assert_eq!(
670 records[0].position(),
671 &StreamPosition {
672 seq_num: 1,
673 timestamp: 10
674 }
675 );
676 assert_eq!(
677 records[1].position(),
678 &StreamPosition {
679 seq_num: 2,
680 timestamp: 20
681 }
682 );
683 assert!(matches!(records[0].inner(), Record::Envelope(_)));
684 assert!(matches!(records[1].inner(), Record::Envelope(_)));
685 }
686}