std_modrpc/role_impls/
byte_stream_receiver.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::HashMap,
4    rc::Rc,
5};
6
7use crate::{
8    ByteStreamInitState,
9    ByteStreamReceiverConfig,
10    ByteStreamReceiverHooks,
11    ByteStreamReceiverStubs,
12};
13
14struct State {
15    // Blobs by their start index
16    blobs: RefCell<HashMap<u64, modrpc::Packet>>,
17
18    current_blob_start: Cell<u64>,
19    consume_cursor: Cell<u64>,
20
21    // A bit clunky, but rather than precisely track and wake waiting tasks by the byte range they
22    // are waiting for, just wake every waiting task whenever new bytes come in. We don't expect
23    // there to be a lot of concurrent waiters, so doing this seems cheaper than managing another
24    // datastructure.
25    waiters: localq::WaiterQueue,
26}
27
28#[derive(Clone)]
29pub struct ByteStreamReceiver {
30    state: Rc<State>,
31}
32
33pub struct ByteStreamReceiverBuilder {
34    stubs: ByteStreamReceiverStubs,
35    state: Rc<State>,
36}
37
38impl ByteStreamReceiverBuilder {
39    pub fn new(
40        _name: &'static str,
41        _hooks: ByteStreamReceiverHooks,
42        stubs: ByteStreamReceiverStubs,
43        _config: &ByteStreamReceiverConfig,
44        _init: ByteStreamInitState,
45    ) -> Self {
46        let state = Rc::new(State {
47            blobs: RefCell::new(HashMap::new()),
48            current_blob_start: Cell::new(0),
49            consume_cursor: Cell::new(0),
50            waiters: localq::WaiterQueue::new(),
51        });
52
53        Self { stubs, state }
54    }
55
56    pub fn create_handle(
57        &self,
58        _setup: &modrpc::RoleSetup,
59    ) -> ByteStreamReceiver {
60        ByteStreamReceiver {
61            state: self.state.clone(),
62        }
63    }
64
65    pub fn build(
66        self,
67        setup: &modrpc::RoleSetup,
68    ) {
69        let state = self.state.clone();
70        self.stubs.blob
71            .inline_untyped(setup, move |_source, packet| {
72                use mproto::BaseLen;
73
74                if packet.len() < 8 {
75                    // Invalid packet
76                    return;
77                }
78
79                // Skip the packet header
80                packet.advance(modrpc::TransmitPacket::BASE_LEN);
81
82                // Read start index
83                let start_index_bytes: [u8; 8] = packet[..8].try_into().unwrap();
84                let start_index = u64::from_le_bytes(start_index_bytes);
85                // Remove start index header
86                packet.advance(8);
87
88                if start_index < state.current_blob_start.get() {
89                    state.current_blob_start.set(start_index);
90                }
91
92                let mut blobs = state.blobs.borrow_mut();
93                blobs.entry(start_index).or_insert(packet.clone());
94
95                state.waiters.notify(usize::MAX);
96            })
97            .subscribe();
98    }
99}
100
101impl ByteStreamReceiver {
102    pub fn cursor(&self) -> u64 {
103        self.state.consume_cursor.get()
104    }
105
106    pub fn peek(&self) -> Option<modrpc::Packet> {
107        let start = self.state.current_blob_start.get();
108        let cursor = self.state.consume_cursor.get();
109        let blobs = self.state.blobs.borrow();
110
111        if start > cursor {
112            // Blobs arrived out of order and we don't have the next blob to peek yet.
113            return None;
114        }
115
116        let blob = blobs.get(&start)?.clone();
117        blob.advance((cursor - start) as usize);
118        Some(blob)
119    }
120
121    pub fn consume(&self, count: u64) -> Option<modrpc::Packet> {
122        use std::collections::hash_map::Entry;
123
124        let start = self.state.current_blob_start.get();
125        let cursor = self.state.consume_cursor.get();
126        let offset_in_blob = cursor - start;
127        let mut blobs = self.state.blobs.borrow_mut();
128
129        if start > cursor {
130            // Blobs arrived out of order and we don't have the next blob to peek yet.
131            return None;
132        }
133
134        let Entry::Occupied(blob_entry) = blobs.entry(start) else {
135            return None;
136        };
137
138        let blob =
139            if count >= blob_entry.get().len() as u64 - offset_in_blob {
140                // Finish consuming the current blob
141                let blob = blob_entry.remove();
142                self.state.current_blob_start.set(start + blob.len() as u64);
143                blob
144            } else {
145                blob_entry.get().clone()
146            };
147        blob.advance(offset_in_blob as usize);
148        blob.set_len(std::cmp::min(blob.len(), count as usize));
149
150        self.state.consume_cursor.set(cursor + blob.len() as u64);
151
152        Some(blob)
153    }
154
155    pub fn try_peek_ahead(&self, read_start: u64, read_len: u64) -> Option<modrpc::Packet> {
156        let start = self.state.current_blob_start.get();
157        let consume_cursor = self.state.consume_cursor.get();
158
159        // Another clunk - we allow bytes to be peeked out-of-order, but wait for all bytes up
160        // through the end of the read to be present. We could lift this restriction by storing
161        // received blobs in a BinaryHeap instead of a HashMap.
162
163        if consume_cursor > read_start {
164            // The bytes to read have already been consumed.
165            return None;
166        }
167        if start > read_start {
168            // Blobs arrived out of order and we don't have the next blob to peek yet.
169            return None;
170        }
171
172        let mut cursor = start;
173        let blobs = self.state.blobs.borrow();
174        loop {
175            let Some(blob) = blobs.get(&cursor) else {
176                return None;
177            };
178
179            if cursor + blob.len() as u64 > read_start
180                // Special handling for empty reads
181                || cursor + blob.len() as u64 == read_start && read_len == 0
182            {
183                // Found the blob to read
184                let blob = blob.clone();
185                blob.advance((read_start - cursor) as usize);
186                blob.set_len(std::cmp::min(blob.len(), read_len as usize));
187                return Some(blob);
188            }
189
190            cursor += blob.len() as u64;
191        }
192    }
193
194    pub async fn peek_ahead(&self, read_start: u64, read_len: u64) -> modrpc::Packet {
195        self.state.waiters.wait_for(|| self.try_peek_ahead(read_start, read_len)).await
196    }
197}