std_modrpc/role_impls/
byte_stream_receiver.rs1use std::{
2 cell::{Cell, RefCell},
3 collections::HashMap,
4 rc::Rc,
5};
6
7use crate::{
8 ByteStreamInitState, ByteStreamReceiverConfig, ByteStreamReceiverHooks, ByteStreamReceiverStubs,
9};
10
11struct State {
12 blobs: RefCell<HashMap<u64, modrpc::Packet>>,
14
15 current_blob_start: Cell<u64>,
16 consume_cursor: Cell<u64>,
17
18 waiters: localq::WaiterQueue,
23}
24
25#[derive(Clone)]
26pub struct ByteStreamReceiver {
27 state: Rc<State>,
28}
29
30pub struct ByteStreamReceiverBuilder {
31 stubs: ByteStreamReceiverStubs,
32 state: Rc<State>,
33}
34
35impl ByteStreamReceiverBuilder {
36 pub fn new(
37 _name: &'static str,
38 _hooks: ByteStreamReceiverHooks,
39 stubs: ByteStreamReceiverStubs,
40 _config: &ByteStreamReceiverConfig,
41 _init: ByteStreamInitState,
42 ) -> Self {
43 let state = Rc::new(State {
44 blobs: RefCell::new(HashMap::new()),
45 current_blob_start: Cell::new(0),
46 consume_cursor: Cell::new(0),
47 waiters: localq::WaiterQueue::new(),
48 });
49
50 Self { stubs, state }
51 }
52
53 pub fn create_handle(&self, _setup: &modrpc::RoleSetup) -> ByteStreamReceiver {
54 ByteStreamReceiver {
55 state: self.state.clone(),
56 }
57 }
58
59 pub fn build(self, setup: &modrpc::RoleSetup) {
60 let state = self.state.clone();
61 self.stubs
62 .blob
63 .inline_untyped(setup, move |_source, packet| {
64 use mproto::BaseLen;
65
66 if packet.len() < 8 {
67 return;
69 }
70
71 packet.advance(modrpc::TransmitPacket::BASE_LEN);
73
74 let start_index_bytes: [u8; 8] = packet[..8].try_into().unwrap();
76 let start_index = u64::from_le_bytes(start_index_bytes);
77 packet.advance(8);
79
80 if start_index < state.current_blob_start.get() {
81 state.current_blob_start.set(start_index);
82 }
83
84 let mut blobs = state.blobs.borrow_mut();
85 blobs.entry(start_index).or_insert(packet.clone());
86
87 state.waiters.notify(usize::MAX);
88 })
89 .subscribe();
90 }
91}
92
93impl ByteStreamReceiver {
94 pub fn cursor(&self) -> u64 {
95 self.state.consume_cursor.get()
96 }
97
98 pub fn peek(&self) -> Option<modrpc::Packet> {
99 let start = self.state.current_blob_start.get();
100 let cursor = self.state.consume_cursor.get();
101 let blobs = self.state.blobs.borrow();
102
103 if start > cursor {
104 return None;
106 }
107
108 let blob = blobs.get(&start)?.clone();
109 blob.advance((cursor - start) as usize);
110 Some(blob)
111 }
112
113 pub fn consume(&self, count: u64) -> Option<modrpc::Packet> {
114 use std::collections::hash_map::Entry;
115
116 let start = self.state.current_blob_start.get();
117 let cursor = self.state.consume_cursor.get();
118 let offset_in_blob = cursor - start;
119 let mut blobs = self.state.blobs.borrow_mut();
120
121 if start > cursor {
122 return None;
124 }
125
126 let Entry::Occupied(blob_entry) = blobs.entry(start) else {
127 return None;
128 };
129
130 let blob = if count >= blob_entry.get().len() as u64 - offset_in_blob {
131 let blob = blob_entry.remove();
133 self.state.current_blob_start.set(start + blob.len() as u64);
134 blob
135 } else {
136 blob_entry.get().clone()
137 };
138 blob.advance(offset_in_blob as usize);
139 blob.set_len(std::cmp::min(blob.len(), count as usize));
140
141 self.state.consume_cursor.set(cursor + blob.len() as u64);
142
143 Some(blob)
144 }
145
146 pub fn try_peek_ahead(&self, read_start: u64, read_len: u64) -> Option<modrpc::Packet> {
147 let start = self.state.current_blob_start.get();
148 let consume_cursor = self.state.consume_cursor.get();
149
150 if consume_cursor > read_start {
155 return None;
157 }
158 if start > read_start {
159 return None;
161 }
162
163 let mut cursor = start;
164 let blobs = self.state.blobs.borrow();
165 loop {
166 let Some(blob) = blobs.get(&cursor) else {
167 return None;
168 };
169
170 if cursor + blob.len() as u64 > read_start
171 || cursor + blob.len() as u64 == read_start && read_len == 0
173 {
174 let blob = blob.clone();
176 blob.advance((read_start - cursor) as usize);
177 blob.set_len(std::cmp::min(blob.len(), read_len as usize));
178 return Some(blob);
179 }
180
181 cursor += blob.len() as u64;
182 }
183 }
184
185 pub async fn peek_ahead(&self, read_start: u64, read_len: u64) -> modrpc::Packet {
186 self.state
187 .waiters
188 .wait_for(|| self.try_peek_ahead(read_start, read_len))
189 .await
190 }
191}