Skip to main content

sim_lib_server/transport/
backends.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, Mutex, OnceLock},
4};
5
6use sim_kernel::{Cx, Error, Result};
7use sim_wasm_abi::{Frame as WasmFrame, WasmFrameLimits};
8
9use crate::{EvalSite, ServerAddress, ServerFrame, wasm::lookup_wasm_region};
10
11use super::framing::endpoint_key;
12use super::{
13    ConnectionTransport, ServerTransport, decode_transport_frame, encode_transport_frame,
14    route_frame_bytes,
15};
16
17#[derive(Clone)]
18pub(crate) struct TransportEndpoint {
19    pub(crate) address: ServerAddress,
20    pub(crate) site: Arc<dyn EvalSite>,
21}
22
23#[derive(Default)]
24struct EndpointRegistry {
25    endpoints: BTreeMap<String, TransportEndpoint>,
26}
27
28fn endpoint_registry() -> &'static Mutex<EndpointRegistry> {
29    static REGISTRY: OnceLock<Mutex<EndpointRegistry>> = OnceLock::new();
30    REGISTRY.get_or_init(|| Mutex::new(EndpointRegistry::default()))
31}
32
33pub(crate) fn register_endpoint(endpoint: TransportEndpoint) -> Result<()> {
34    let mut registry = endpoint_registry()
35        .lock()
36        .map_err(|_| Error::HostError("endpoint registry mutex poisoned".to_owned()))?;
37    registry
38        .endpoints
39        .insert(endpoint_key(&endpoint.address), endpoint);
40    Ok(())
41}
42
43pub(crate) fn has_registered_endpoint(address: &ServerAddress) -> Result<bool> {
44    let registry = endpoint_registry()
45        .lock()
46        .map_err(|_| Error::HostError("endpoint registry mutex poisoned".to_owned()))?;
47    Ok(registry.endpoints.contains_key(&endpoint_key(address)))
48}
49
50pub(crate) fn unregister_endpoint(address: &ServerAddress) -> Result<()> {
51    let mut registry = endpoint_registry()
52        .lock()
53        .map_err(|_| Error::HostError("endpoint registry mutex poisoned".to_owned()))?;
54    registry.endpoints.remove(&endpoint_key(address));
55    Ok(())
56}
57
58pub(crate) fn lookup_endpoint(address: &ServerAddress) -> Result<TransportEndpoint> {
59    let registry = endpoint_registry()
60        .lock()
61        .map_err(|_| Error::HostError("endpoint registry mutex poisoned".to_owned()))?;
62    registry
63        .endpoints
64        .get(&endpoint_key(address))
65        .cloned()
66        .ok_or_else(|| {
67            Error::Eval(format!(
68                "no endpoint registered for {}",
69                address.kind_symbol()
70            ))
71        })
72}
73
74/// RAII guard for a registered loopback endpoint.
75///
76/// Dropping it, or calling [`close`](Self::close), unregisters the endpoint.
77pub struct LoopbackTransportEndpoint {
78    address: ServerAddress,
79}
80
81impl LoopbackTransportEndpoint {
82    /// Returns the registered endpoint address.
83    pub fn address(&self) -> &ServerAddress {
84        &self.address
85    }
86
87    /// Unregisters the endpoint explicitly.
88    pub fn close(&self) -> Result<()> {
89        unregister_endpoint(&self.address)
90    }
91}
92
93impl Drop for LoopbackTransportEndpoint {
94    fn drop(&mut self) {
95        let _ = unregister_endpoint(&self.address);
96    }
97}
98
99pub(crate) fn register_loopback_endpoint(
100    address: ServerAddress,
101    site: Arc<dyn EvalSite>,
102) -> Result<LoopbackTransportEndpoint> {
103    register_endpoint(TransportEndpoint {
104        address: address.clone(),
105        site,
106    })?;
107    Ok(LoopbackTransportEndpoint { address })
108}
109
110#[derive(Clone)]
111/// In-process transport that routes frames directly to a local eval site.
112///
113/// Sending a frame evaluates it against the site and buffers the reply for the
114/// next receive.
115pub struct LocalTransport {
116    address: ServerAddress,
117    site: Arc<dyn EvalSite>,
118    pending: Arc<Mutex<Option<ServerFrame>>>,
119}
120
121impl LocalTransport {
122    /// Creates a local transport for `address` backed by `site`.
123    pub fn new(address: ServerAddress, site: Arc<dyn EvalSite>) -> Self {
124        Self {
125            address,
126            site,
127            pending: Arc::new(Mutex::new(None)),
128        }
129    }
130}
131
132impl ServerTransport for LocalTransport {
133    fn address(&self) -> &ServerAddress {
134        &self.address
135    }
136
137    fn accept(&self, _cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
138        Ok(Box::new(self.clone()))
139    }
140
141    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
142        let mut pending = self
143            .pending
144            .lock()
145            .map_err(|_| Error::HostError("local transport mutex poisoned".to_owned()))?;
146        *pending = None;
147        Ok(())
148    }
149
150    fn accept_timeout(
151        &self,
152        _cx: &mut Cx,
153        _timeout: std::time::Duration,
154    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
155        Ok(None)
156    }
157}
158
159impl ConnectionTransport for LocalTransport {
160    fn send_frame(&mut self, cx: &mut Cx, frame: ServerFrame) -> Result<()> {
161        let reply = route_frame_bytes(cx, &self.site, &encode_transport_frame(&frame)?)?;
162        let reply = decode_transport_frame(&reply)?;
163        let mut pending = self
164            .pending
165            .lock()
166            .map_err(|_| Error::HostError("local transport mutex poisoned".to_owned()))?;
167        *pending = Some(reply);
168        Ok(())
169    }
170
171    fn recv_frame(
172        &mut self,
173        _cx: &mut Cx,
174        _timeout: Option<std::time::Duration>,
175    ) -> Result<Option<ServerFrame>> {
176        let mut pending = self
177            .pending
178            .lock()
179            .map_err(|_| Error::HostError("local transport mutex poisoned".to_owned()))?;
180        Ok(pending.take())
181    }
182
183    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
184        let mut pending = self
185            .pending
186            .lock()
187            .map_err(|_| Error::HostError("local transport mutex poisoned".to_owned()))?;
188        *pending = None;
189        Ok(())
190    }
191
192    fn as_any(&self) -> &dyn std::any::Any {
193        self
194    }
195}
196
197#[derive(Clone)]
198pub struct RegistryTransport {
199    address: ServerAddress,
200    pending: Arc<Mutex<Option<ServerFrame>>>,
201}
202
203impl RegistryTransport {
204    pub fn new(address: ServerAddress) -> Self {
205        Self {
206            address,
207            pending: Arc::new(Mutex::new(None)),
208        }
209    }
210}
211
212impl ServerTransport for RegistryTransport {
213    fn address(&self) -> &ServerAddress {
214        &self.address
215    }
216
217    fn accept(&self, _cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
218        let _ = lookup_endpoint(self.address())?;
219        Ok(Box::new(Self::new(self.address.clone())))
220    }
221
222    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
223        unregister_endpoint(&self.address)
224    }
225
226    fn accept_timeout(
227        &self,
228        _cx: &mut Cx,
229        _timeout: std::time::Duration,
230    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
231        Ok(None)
232    }
233}
234
235impl ConnectionTransport for RegistryTransport {
236    fn send_frame(&mut self, cx: &mut Cx, frame: ServerFrame) -> Result<()> {
237        let endpoint = lookup_endpoint(&self.address)?;
238        let bytes = encode_transport_frame(&frame)?;
239        let reply = route_frame_bytes(cx, &endpoint.site, &bytes)?;
240        let reply = decode_transport_frame(&reply)?;
241        let mut pending = self
242            .pending
243            .lock()
244            .map_err(|_| Error::HostError("registry transport mutex poisoned".to_owned()))?;
245        *pending = Some(reply);
246        Ok(())
247    }
248
249    fn recv_frame(
250        &mut self,
251        _cx: &mut Cx,
252        _timeout: Option<std::time::Duration>,
253    ) -> Result<Option<ServerFrame>> {
254        let mut pending = self
255            .pending
256            .lock()
257            .map_err(|_| Error::HostError("registry transport mutex poisoned".to_owned()))?;
258        Ok(pending.take())
259    }
260
261    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
262        let mut pending = self
263            .pending
264            .lock()
265            .map_err(|_| Error::HostError("registry transport mutex poisoned".to_owned()))?;
266        *pending = None;
267        Ok(())
268    }
269
270    fn as_any(&self) -> &dyn std::any::Any {
271        self
272    }
273}
274
275pub struct WasmConnectionTransport {
276    region: String,
277    pending: Option<ServerFrame>,
278}
279
280impl WasmConnectionTransport {
281    pub fn connect(address: &ServerAddress) -> Result<Self> {
282        let ServerAddress::Wasm { region } = address else {
283            return Err(Error::Eval(
284                "wasm connection transport requires a wasm address".to_owned(),
285            ));
286        };
287        let _ = lookup_wasm_region(region)?;
288        Ok(Self {
289            region: region.clone(),
290            pending: None,
291        })
292    }
293}
294
295impl ConnectionTransport for WasmConnectionTransport {
296    fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
297        let region = lookup_wasm_region(&self.region)?;
298        let request = encode_transport_frame(&frame)?;
299        enforce_wasm_transport_limit(&request, "wasm frame exceeds transport limit")?;
300        let reply = region.runtime.call(
301            region.module,
302            &sim_kernel::Symbol::qualified("server", "answer"),
303            WasmFrame::new(request),
304        )?;
305        enforce_wasm_frame_limit(&reply, "wasm reply exceeds transport limit")?;
306        self.pending = Some(decode_transport_frame(reply.bytes())?);
307        Ok(())
308    }
309
310    fn recv_frame(
311        &mut self,
312        _cx: &mut Cx,
313        _timeout: Option<std::time::Duration>,
314    ) -> Result<Option<ServerFrame>> {
315        Ok(self.pending.take())
316    }
317
318    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
319        self.pending = None;
320        Ok(())
321    }
322
323    fn as_any(&self) -> &dyn std::any::Any {
324        self
325    }
326}
327
328pub(super) fn enforce_wasm_transport_limit(bytes: &[u8], message: &str) -> Result<()> {
329    if bytes.len() > WasmFrameLimits::default().max_frame_bytes {
330        return Err(Error::HostError(message.to_owned()));
331    }
332    Ok(())
333}
334
335pub(super) fn enforce_wasm_frame_limit(frame: &WasmFrame, message: &str) -> Result<()> {
336    let frame_ref = frame.as_ref()?;
337    if usize::try_from(frame_ref.len).unwrap_or(usize::MAX)
338        > WasmFrameLimits::default().max_frame_bytes
339    {
340        return Err(Error::HostError(message.to_owned()));
341    }
342    Ok(())
343}