std_modrpc/role_impls/
byte_stream_receiver.rs1use std::{
2 cell::{Cell, RefCell},
3 collections::HashMap,
4 rc::Rc,
5};
6
7use crate::{
8 ByteStreamInitState,
9 ByteStreamReceiverConfig,
10 ByteStreamReceiverHooks,
11 ByteStreamReceiverStubs,
12};
13
14struct State {
15 blobs: RefCell<HashMap<u64, modrpc::Packet>>,
17
18 current_blob_start: Cell<u64>,
19 consume_cursor: Cell<u64>,
20
21 waiters: localq::WaiterQueue,
26}
27
28#[derive(Clone)]
29pub struct ByteStreamReceiver {
30 state: Rc<State>,
31}
32
33pub struct ByteStreamReceiverBuilder {
34 stubs: ByteStreamReceiverStubs,
35 state: Rc<State>,
36}
37
38impl ByteStreamReceiverBuilder {
39 pub fn new(
40 _name: &'static str,
41 _hooks: ByteStreamReceiverHooks,
42 stubs: ByteStreamReceiverStubs,
43 _config: &ByteStreamReceiverConfig,
44 _init: ByteStreamInitState,
45 ) -> Self {
46 let state = Rc::new(State {
47 blobs: RefCell::new(HashMap::new()),
48 current_blob_start: Cell::new(0),
49 consume_cursor: Cell::new(0),
50 waiters: localq::WaiterQueue::new(),
51 });
52
53 Self { stubs, state }
54 }
55
56 pub fn create_handle(
57 &self,
58 _setup: &modrpc::RoleSetup,
59 ) -> ByteStreamReceiver {
60 ByteStreamReceiver {
61 state: self.state.clone(),
62 }
63 }
64
65 pub fn build(
66 self,
67 setup: &modrpc::RoleSetup,
68 ) {
69 let state = self.state.clone();
70 self.stubs.blob
71 .inline_untyped(setup, move |_source, packet| {
72 use mproto::BaseLen;
73
74 if packet.len() < 8 {
75 return;
77 }
78
79 packet.advance(modrpc::TransmitPacket::BASE_LEN);
81
82 let start_index_bytes: [u8; 8] = packet[..8].try_into().unwrap();
84 let start_index = u64::from_le_bytes(start_index_bytes);
85 packet.advance(8);
87
88 if start_index < state.current_blob_start.get() {
89 state.current_blob_start.set(start_index);
90 }
91
92 let mut blobs = state.blobs.borrow_mut();
93 blobs.entry(start_index).or_insert(packet.clone());
94
95 state.waiters.notify(usize::MAX);
96 })
97 .subscribe();
98 }
99}
100
101impl ByteStreamReceiver {
102 pub fn cursor(&self) -> u64 {
103 self.state.consume_cursor.get()
104 }
105
106 pub fn peek(&self) -> Option<modrpc::Packet> {
107 let start = self.state.current_blob_start.get();
108 let cursor = self.state.consume_cursor.get();
109 let blobs = self.state.blobs.borrow();
110
111 if start > cursor {
112 return None;
114 }
115
116 let blob = blobs.get(&start)?.clone();
117 blob.advance((cursor - start) as usize);
118 Some(blob)
119 }
120
121 pub fn consume(&self, count: u64) -> Option<modrpc::Packet> {
122 use std::collections::hash_map::Entry;
123
124 let start = self.state.current_blob_start.get();
125 let cursor = self.state.consume_cursor.get();
126 let offset_in_blob = cursor - start;
127 let mut blobs = self.state.blobs.borrow_mut();
128
129 if start > cursor {
130 return None;
132 }
133
134 let Entry::Occupied(blob_entry) = blobs.entry(start) else {
135 return None;
136 };
137
138 let blob =
139 if count >= blob_entry.get().len() as u64 - offset_in_blob {
140 let blob = blob_entry.remove();
142 self.state.current_blob_start.set(start + blob.len() as u64);
143 blob
144 } else {
145 blob_entry.get().clone()
146 };
147 blob.advance(offset_in_blob as usize);
148 blob.set_len(std::cmp::min(blob.len(), count as usize));
149
150 self.state.consume_cursor.set(cursor + blob.len() as u64);
151
152 Some(blob)
153 }
154
155 pub fn try_peek_ahead(&self, read_start: u64, read_len: u64) -> Option<modrpc::Packet> {
156 let start = self.state.current_blob_start.get();
157 let consume_cursor = self.state.consume_cursor.get();
158
159 if consume_cursor > read_start {
164 return None;
166 }
167 if start > read_start {
168 return None;
170 }
171
172 let mut cursor = start;
173 let blobs = self.state.blobs.borrow();
174 loop {
175 let Some(blob) = blobs.get(&cursor) else {
176 return None;
177 };
178
179 if cursor + blob.len() as u64 > read_start
180 || cursor + blob.len() as u64 == read_start && read_len == 0
182 {
183 let blob = blob.clone();
185 blob.advance((read_start - cursor) as usize);
186 blob.set_len(std::cmp::min(blob.len(), read_len as usize));
187 return Some(blob);
188 }
189
190 cursor += blob.len() as u64;
191 }
192 }
193
194 pub async fn peek_ahead(&self, read_start: u64, read_len: u64) -> modrpc::Packet {
195 self.state.waiters.wait_for(|| self.try_peek_ahead(read_start, read_len)).await
196 }
197}