std_modrpc/role_impls/
byte_stream_sender.rs

1use std::{cell::Cell, rc::Rc};
2
3use modrpc::RoleSetup;
4
5use crate::{
6    ByteStreamInitState, ByteStreamSenderConfig, ByteStreamSenderHooks, ByteStreamSenderStubs,
7};
8
9struct State {
10    hooks: ByteStreamSenderHooks,
11    send_cursor: Cell<u64>,
12}
13
14#[derive(Clone)]
15pub struct ByteStreamSender {
16    state: Rc<State>,
17}
18
19pub struct ByteStreamSenderBuilder {
20    state: Rc<State>,
21}
22
23impl ByteStreamSenderBuilder {
24    pub fn new(
25        _name: &'static str,
26        hooks: ByteStreamSenderHooks,
27        _stubs: ByteStreamSenderStubs,
28        _config: &ByteStreamSenderConfig,
29        _init: ByteStreamInitState,
30    ) -> Self {
31        let state = Rc::new(State {
32            hooks: hooks.clone(),
33            send_cursor: Cell::new(0),
34        });
35        Self { state }
36    }
37
38    pub fn create_handle(&self, _setup: &RoleSetup) -> ByteStreamSender {
39        ByteStreamSender {
40            state: self.state.clone(),
41        }
42    }
43
44    pub fn build(self, _setup: &RoleSetup) {}
45}
46
47impl ByteStreamSender {
48    pub async fn send(&self, bytes: &[u8]) -> u64 {
49        let start_index = self.state.send_cursor.get();
50        self.state.send_cursor.set(start_index + bytes.len() as u64);
51
52        self.state
53            .hooks
54            .blob
55            .send_raw(8 + bytes.len(), |write_buf| {
56                write_buf[..8].copy_from_slice(&start_index.to_le_bytes());
57                write_buf[8..].copy_from_slice(bytes);
58            })
59            .await;
60
61        start_index
62    }
63
64    /// SAFETY: You must have exclusive ownership of the buffer and there must be enough headroom
65    /// for modrpc::TransmitPacket::BASE_LEN + 8 bytes
66    pub async unsafe fn send_buffer(&self, buffer: modrpc::BufferPtr) -> u64 {
67        let headroom = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
68        let payload_len =
69            modrpc::WriterFlushSender::get_complete_buffer_len(buffer) as usize - headroom - 8;
70
71        let start_index = self.state.send_cursor.get();
72        self.state.send_cursor.set(start_index + payload_len as u64);
73
74        // Write the start index
75        let headroom = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
76        let start_index_buf = unsafe { buffer.slice_mut(headroom..headroom + 8) };
77        start_index_buf.copy_from_slice(&start_index.to_le_bytes());
78
79        unsafe {
80            self.state.hooks.blob.send_buffer(buffer).await;
81        }
82
83        start_index
84    }
85
86    pub async fn wait_consumed(&self, _cursor: u64) {
87        // TODO
88    }
89}