std_modrpc/role_impls/
byte_stream_receiver.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::HashMap,
4    rc::Rc,
5};
6
7use crate::{
8    ByteStreamInitState, ByteStreamReceiverConfig, ByteStreamReceiverHooks, ByteStreamReceiverStubs,
9};
10
11struct State {
12    // Blobs by their start index
13    blobs: RefCell<HashMap<u64, modrpc::Packet>>,
14
15    current_blob_start: Cell<u64>,
16    consume_cursor: Cell<u64>,
17
18    // A bit clunky, but rather than precisely track and wake waiting tasks by the byte range they
19    // are waiting for, just wake every waiting task whenever new bytes come in. We don't expect
20    // there to be a lot of concurrent waiters, so doing this seems cheaper than managing another
21    // datastructure.
22    waiters: localq::WaiterQueue,
23}
24
25#[derive(Clone)]
26pub struct ByteStreamReceiver {
27    state: Rc<State>,
28}
29
30pub struct ByteStreamReceiverBuilder {
31    stubs: ByteStreamReceiverStubs,
32    state: Rc<State>,
33}
34
35impl ByteStreamReceiverBuilder {
36    pub fn new(
37        _name: &'static str,
38        _hooks: ByteStreamReceiverHooks,
39        stubs: ByteStreamReceiverStubs,
40        _config: &ByteStreamReceiverConfig,
41        _init: ByteStreamInitState,
42    ) -> Self {
43        let state = Rc::new(State {
44            blobs: RefCell::new(HashMap::new()),
45            current_blob_start: Cell::new(0),
46            consume_cursor: Cell::new(0),
47            waiters: localq::WaiterQueue::new(),
48        });
49
50        Self { stubs, state }
51    }
52
53    pub fn create_handle(&self, _setup: &modrpc::RoleSetup) -> ByteStreamReceiver {
54        ByteStreamReceiver {
55            state: self.state.clone(),
56        }
57    }
58
59    pub fn build(self, setup: &modrpc::RoleSetup) {
60        let state = self.state.clone();
61        self.stubs
62            .blob
63            .inline_untyped(setup, move |_source, packet| {
64                use mproto::BaseLen;
65
66                if packet.len() < 8 {
67                    // Invalid packet
68                    return;
69                }
70
71                // Skip the packet header
72                packet.advance(modrpc::TransmitPacket::BASE_LEN);
73
74                // Read start index
75                let start_index_bytes: [u8; 8] = packet[..8].try_into().unwrap();
76                let start_index = u64::from_le_bytes(start_index_bytes);
77                // Remove start index header
78                packet.advance(8);
79
80                if start_index < state.current_blob_start.get() {
81                    state.current_blob_start.set(start_index);
82                }
83
84                let mut blobs = state.blobs.borrow_mut();
85                blobs.entry(start_index).or_insert(packet.clone());
86
87                state.waiters.notify(usize::MAX);
88            })
89            .subscribe();
90    }
91}
92
93impl ByteStreamReceiver {
94    pub fn cursor(&self) -> u64 {
95        self.state.consume_cursor.get()
96    }
97
98    pub fn peek(&self) -> Option<modrpc::Packet> {
99        let start = self.state.current_blob_start.get();
100        let cursor = self.state.consume_cursor.get();
101        let blobs = self.state.blobs.borrow();
102
103        if start > cursor {
104            // Blobs arrived out of order and we don't have the next blob to peek yet.
105            return None;
106        }
107
108        let blob = blobs.get(&start)?.clone();
109        blob.advance((cursor - start) as usize);
110        Some(blob)
111    }
112
113    pub fn consume(&self, count: u64) -> Option<modrpc::Packet> {
114        use std::collections::hash_map::Entry;
115
116        let start = self.state.current_blob_start.get();
117        let cursor = self.state.consume_cursor.get();
118        let offset_in_blob = cursor - start;
119        let mut blobs = self.state.blobs.borrow_mut();
120
121        if start > cursor {
122            // Blobs arrived out of order and we don't have the next blob to peek yet.
123            return None;
124        }
125
126        let Entry::Occupied(blob_entry) = blobs.entry(start) else {
127            return None;
128        };
129
130        let blob = if count >= blob_entry.get().len() as u64 - offset_in_blob {
131            // Finish consuming the current blob
132            let blob = blob_entry.remove();
133            self.state.current_blob_start.set(start + blob.len() as u64);
134            blob
135        } else {
136            blob_entry.get().clone()
137        };
138        blob.advance(offset_in_blob as usize);
139        blob.set_len(std::cmp::min(blob.len(), count as usize));
140
141        self.state.consume_cursor.set(cursor + blob.len() as u64);
142
143        Some(blob)
144    }
145
146    pub fn try_peek_ahead(&self, read_start: u64, read_len: u64) -> Option<modrpc::Packet> {
147        let start = self.state.current_blob_start.get();
148        let consume_cursor = self.state.consume_cursor.get();
149
150        // Another clunk - we allow bytes to be peeked out-of-order, but wait for all bytes up
151        // through the end of the read to be present. We could lift this restriction by storing
152        // received blobs in a BinaryHeap instead of a HashMap.
153
154        if consume_cursor > read_start {
155            // The bytes to read have already been consumed.
156            return None;
157        }
158        if start > read_start {
159            // Blobs arrived out of order and we don't have the next blob to peek yet.
160            return None;
161        }
162
163        let mut cursor = start;
164        let blobs = self.state.blobs.borrow();
165        loop {
166            let Some(blob) = blobs.get(&cursor) else {
167                return None;
168            };
169
170            if cursor + blob.len() as u64 > read_start
171                // Special handling for empty reads
172                || cursor + blob.len() as u64 == read_start && read_len == 0
173            {
174                // Found the blob to read
175                let blob = blob.clone();
176                blob.advance((read_start - cursor) as usize);
177                blob.set_len(std::cmp::min(blob.len(), read_len as usize));
178                return Some(blob);
179            }
180
181            cursor += blob.len() as u64;
182        }
183    }
184
185    pub async fn peek_ahead(&self, read_start: u64, read_len: u64) -> modrpc::Packet {
186        self.state
187            .waiters
188            .wait_for(|| self.try_peek_ahead(read_start, read_len))
189            .await
190    }
191}