Skip to main content

selium_switchboard/
switchboard.rs

1//! Guest-side switchboard client for intent-based wiring.
2
3use core::{
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use std::{
9    collections::{HashMap, HashSet},
10    sync::{
11        Arc, Mutex,
12        atomic::{AtomicU64, Ordering},
13    },
14};
15
16use futures::{Future, SinkExt, Stream, StreamExt};
17pub use selium_switchboard_protocol::{AdoptMode, Backpressure, Cardinality, EndpointId};
18use selium_switchboard_protocol::{
19    Direction, EndpointDirections, Message, ProtocolError, WiringEgress, WiringIngress,
20    decode_message, encode_message,
21};
22use thiserror::Error;
23use tracing::{debug, warn};
24
25use selium_userland::{
26    Dependency, DependencyDescriptor, dependency_id,
27    encoding::{FlatMsg, HasSchema},
28    io::{Channel, ChannelHandle, DriverError, Reader, SharedChannel, Writer},
29};
30
31type PendingWiring<In, Out> =
32    Pin<Box<dyn Future<Output = Result<EndpointWiring<In, Out>, SwitchboardError>>>>;
33
34struct PendingUpdate<In, Out> {
35    future: PendingWiring<In, Out>,
36    reuse_inbound: bool,
37    reuse_outbound: bool,
38    next_inbound: Vec<WiringIngress>,
39    next_outbound: Vec<WiringEgress>,
40}
41
42const CONTROL_CHUNK_SIZE: u32 = 64 * 1024;
43const UPDATE_CHANNEL_CAPACITY: u32 = 64 * 1024;
44const INTERNAL_CHANNEL_CAPACITY: u32 = 16 * 1024;
45const DATA_CHUNK_SIZE: u32 = 64 * 1024;
46
47/// Errors produced by the switchboard client.
48#[derive(Debug, Error)]
49pub enum SwitchboardError {
50    /// Driver returned an error while creating or destroying channels.
51    #[error("driver error: {0}")]
52    Driver(#[from] DriverError),
53    /// Control-plane protocol could not be decoded.
54    #[error("protocol error: {0}")]
55    Protocol(#[from] ProtocolError),
56    /// The switchboard service returned an error.
57    #[error("switchboard error: {0}")]
58    Remote(String),
59    /// The switchboard update stream was closed.
60    #[error("endpoint closed")]
61    EndpointClosed,
62    /// No route is available yet for this endpoint.
63    #[error("no route available")]
64    NoRoute,
65    /// Internal switchboard state was unavailable.
66    #[error("switchboard state unavailable")]
67    StateUnavailable,
68}
69
70/// Switchboard front-end that guests use to register endpoints.
71#[derive(Clone)]
72pub struct Switchboard {
73    request_channel: Channel,
74    _updates_channel: Channel,
75    updates_shared: SharedChannel,
76    state: Arc<Mutex<ClientState>>,
77    next_request_id: Arc<AtomicU64>,
78}
79
80struct ClientState {
81    pending: HashMap<u64, Channel>,
82    endpoints: HashMap<EndpointId, Channel>,
83    queued_updates: HashMap<EndpointId, Vec<Vec<u8>>>,
84    ignored: HashSet<EndpointId>,
85}
86
87/// Builder for a new endpoint.
88pub struct EndpointBuilder<In, Out> {
89    switchboard: Switchboard,
90    input: Cardinality,
91    output: Cardinality,
92    output_backpressure: Backpressure,
93    output_exclusive: bool,
94    _in: PhantomData<In>,
95    _out: PhantomData<Out>,
96}
97
98/// Registered endpoint that implements `Stream` for inbound frames and `Sink` for outbound frames.
99pub struct EndpointHandle<In, Out> {
100    id: EndpointId,
101    updates: Reader,
102    pending: Option<PendingUpdate<In, Out>>,
103    last_inbound: Vec<WiringIngress>,
104    last_outbound: Vec<WiringEgress>,
105    backpressure: Backpressure,
106    /// Inbound/outbound channel handles for this endpoint.
107    pub io: EndpointIo<In, Out>,
108}
109
110/// Data-plane handles for an endpoint.
111pub struct EndpointIo<In, Out> {
112    /// Inbound links connected to this endpoint.
113    pub inbound: Vec<InboundLink<In>>,
114    /// Outbound publishers for this endpoint.
115    pub outbound: Vec<RawPublisher<Out>>,
116    outbound_map: HashMap<EndpointId, ChannelHandle>,
117}
118
119/// Single inbound link description.
120pub struct InboundLink<In> {
121    /// Producer endpoint identifier.
122    pub from: EndpointId,
123    /// Subscriber receiving the inbound data.
124    pub subscriber: RawSubscriber<In>,
125}
126
127struct EndpointWiring<In, Out> {
128    inbound: Vec<InboundLink<In>>,
129    outbound: Vec<RawPublisher<Out>>,
130    outbound_map: HashMap<EndpointId, ChannelHandle>,
131}
132
133/// Publisher that encodes typed payloads onto a channel.
134pub struct RawPublisher<T> {
135    writer: Writer,
136    _marker: PhantomData<T>,
137}
138
139/// Subscriber that decodes typed payloads from a channel.
140pub struct RawSubscriber<T> {
141    reader: Reader,
142    _marker: PhantomData<T>,
143}
144
145impl Switchboard {
146    /// Connect to the switchboard service using the shared request channel handle.
147    pub async fn attach(request_channel: SharedChannel) -> Result<Self, SwitchboardError> {
148        let request_channel = Channel::attach_shared(request_channel).await?;
149        let updates_channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
150        let updates_shared = updates_channel.share().await?;
151        let updates_reader = updates_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
152
153        let state = Arc::new(Mutex::new(ClientState {
154            pending: HashMap::new(),
155            endpoints: HashMap::new(),
156            queued_updates: HashMap::new(),
157            ignored: HashSet::new(),
158        }));
159
160        let dispatcher_state = Arc::clone(&state);
161        spawn_dispatcher(updates_reader, dispatcher_state);
162        debug!(
163            updates_channel = updates_shared.raw(),
164            "switchboard: dispatcher spawned"
165        );
166
167        Ok(Self {
168            request_channel,
169            _updates_channel: updates_channel,
170            updates_shared,
171            state,
172            next_request_id: Arc::new(AtomicU64::new(1)),
173        })
174    }
175
176    /// Begin building a new endpoint carrying `In` inbound frames and `Out` outbound frames.
177    pub fn endpoint<In, Out>(&self) -> EndpointBuilder<In, Out>
178    where
179        In: FlatMsg + HasSchema + Send + Unpin + 'static,
180        Out: FlatMsg + HasSchema + Send + Unpin + 'static,
181    {
182        EndpointBuilder {
183            switchboard: self.clone(),
184            input: Cardinality::One,
185            output: Cardinality::One,
186            output_backpressure: Backpressure::Park,
187            output_exclusive: false,
188            _in: PhantomData,
189            _out: PhantomData,
190        }
191    }
192
193    /// Connect two endpoints by their identifiers, triggering a reconcile.
194    pub async fn connect_ids(
195        &self,
196        from: EndpointId,
197        to: EndpointId,
198    ) -> Result<(), SwitchboardError> {
199        let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
200        let response = self
201            .send_request(Message::ConnectRequest {
202                request_id,
203                from,
204                to,
205                reply_channel: self.updates_shared.raw(),
206            })
207            .await?;
208
209        match response {
210            Message::ResponseOk { .. } => Ok(()),
211            Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
212            _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
213        }
214    }
215
216    /// Connect two endpoints by handle, triggering a reconcile.
217    pub async fn connect<In, Out, Common>(
218        &self,
219        from: &EndpointHandle<In, Common>,
220        to: &EndpointHandle<Common, Out>,
221    ) -> Result<(), SwitchboardError>
222    where
223        In: FlatMsg + Send + Unpin + 'static,
224        Out: FlatMsg + Send + Unpin + 'static,
225        Common: FlatMsg + Send + Unpin + 'static,
226    {
227        self.connect_ids(from.id, to.id).await
228    }
229
230    async fn register_endpoint(
231        &self,
232        directions: EndpointDirections,
233    ) -> Result<EndpointId, SwitchboardError> {
234        let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
235        debug!(request_id, "switchboard: registering endpoint");
236        let response = self
237            .send_request(Message::RegisterRequest {
238                request_id,
239                directions,
240                updates_channel: self.updates_shared.raw(),
241            })
242            .await?;
243
244        match response {
245            Message::ResponseRegister { endpoint_id, .. } => Ok(endpoint_id),
246            Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
247            _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
248        }
249    }
250
251    /// Adopt a shared channel as the outbound flow for a new endpoint.
252    pub async fn adopt_output_channel<Out>(
253        &self,
254        channel: SharedChannel,
255        backpressure: Backpressure,
256        mode: AdoptMode,
257    ) -> Result<EndpointId, SwitchboardError>
258    where
259        Out: FlatMsg + HasSchema + Send + Unpin + 'static,
260    {
261        let directions = EndpointDirections::new(
262            Direction::new(Out::SCHEMA.hash, Cardinality::Zero, Backpressure::Park),
263            Direction::new(Out::SCHEMA.hash, Cardinality::One, backpressure).with_exclusive(true),
264        );
265        self.adopt_endpoint(directions, channel, mode).await
266    }
267
268    /// Adopt an existing shared channel as a new endpoint with the supplied directions.
269    pub async fn adopt_endpoint(
270        &self,
271        directions: EndpointDirections,
272        channel: SharedChannel,
273        mode: AdoptMode,
274    ) -> Result<EndpointId, SwitchboardError> {
275        let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
276        debug!(request_id, "switchboard: adopting endpoint");
277        let response = self
278            .send_request(Message::AdoptRequest {
279                request_id,
280                directions,
281                updates_channel: self.updates_shared.raw(),
282                channel: channel.raw(),
283                mode,
284            })
285            .await?;
286
287        match response {
288            Message::ResponseRegister { endpoint_id, .. } => {
289                let mut guard = self.state()?;
290                guard.ignored.insert(endpoint_id);
291                Ok(endpoint_id)
292            }
293            Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
294            _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
295        }
296    }
297
298    async fn send_request(&self, message: Message) -> Result<Message, SwitchboardError> {
299        let request_id = request_id_for(&message);
300        debug!(request_id, "switchboard: sending request");
301        let response_channel = Channel::create(INTERNAL_CHANNEL_CAPACITY).await?;
302        let mut response_reader = response_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
303
304        {
305            let mut guard = self.state()?;
306            guard.pending.insert(request_id, response_channel.clone());
307        }
308
309        let mut writer = self.request_channel.publish_weak().await?;
310        debug!(request_id, "switchboard: request channel opened");
311        let bytes = encode_message(&message)?;
312        writer.send(bytes).await?;
313        debug!(request_id, "switchboard: request sent");
314
315        let response = match response_reader.next().await {
316            Some(Ok(frame)) => decode_message(&frame.payload)?,
317            Some(Err(err)) => return Err(SwitchboardError::Driver(err)),
318            None => return Err(SwitchboardError::EndpointClosed),
319        };
320        debug!(request_id, "switchboard: received response");
321
322        {
323            let mut guard = self.state()?;
324            guard.pending.remove(&request_id);
325        }
326
327        response_channel.delete().await?;
328
329        Ok(response)
330    }
331
332    async fn register_endpoint_channel(
333        &self,
334        endpoint_id: EndpointId,
335        channel: Channel,
336    ) -> Result<(), SwitchboardError> {
337        let queued = {
338            let mut guard = self.state()?;
339            guard.ignored.remove(&endpoint_id);
340            guard.endpoints.insert(endpoint_id, channel.clone());
341            guard.queued_updates.remove(&endpoint_id)
342        };
343
344        if let Some(pending) = queued {
345            for payload in pending {
346                forward_bytes(channel.clone(), payload).await?;
347            }
348        }
349
350        Ok(())
351    }
352
353    fn state(&self) -> Result<std::sync::MutexGuard<'_, ClientState>, SwitchboardError> {
354        self.state
355            .lock()
356            .map_err(|_| SwitchboardError::StateUnavailable)
357    }
358}
359
360impl Dependency for Switchboard {
361    type Handle = SharedChannel;
362    type Error = SwitchboardError;
363
364    const DESCRIPTOR: DependencyDescriptor = DependencyDescriptor {
365        name: "selium::Switchboard",
366        id: dependency_id!("selium.switchboard.singleton"),
367    };
368
369    async fn from_handle(handle: Self::Handle) -> Result<Self, Self::Error> {
370        Switchboard::attach(handle).await
371    }
372}
373
374impl<In, Out> EndpointBuilder<In, Out>
375where
376    In: FlatMsg + HasSchema + Send + Unpin + 'static,
377    Out: FlatMsg + HasSchema + Send + Unpin + 'static,
378{
379    /// Set the inbound cardinality constraint. Defaults to `Cardinality::One`.
380    pub fn inputs(mut self, cardinality: Cardinality) -> Self {
381        self.input = cardinality;
382        self
383    }
384
385    /// Set the outbound cardinality constraint. Defaults to `Cardinality::One`.
386    pub fn outputs(mut self, cardinality: Cardinality) -> Self {
387        self.output = cardinality;
388        self
389    }
390
391    /// Set the outbound backpressure behaviour. Defaults to [`Backpressure::Park`].
392    pub fn output_backpressure(mut self, backpressure: Backpressure) -> Self {
393        self.output_backpressure = backpressure;
394        self
395    }
396
397    /// Set whether outbound flows should remain isolated to a single channel.
398    pub fn output_exclusive(mut self, exclusive: bool) -> Self {
399        self.output_exclusive = exclusive;
400        self
401    }
402
403    /// Register the endpoint with the switchboard.
404    pub async fn register(self) -> Result<EndpointHandle<In, Out>, SwitchboardError> {
405        let directions = EndpointDirections::new(
406            Direction::new(In::SCHEMA.hash, self.input, Backpressure::Park),
407            Direction::new(Out::SCHEMA.hash, self.output, self.output_backpressure)
408                .with_exclusive(self.output_exclusive),
409        );
410
411        let internal_channel = Channel::create(INTERNAL_CHANNEL_CAPACITY).await?;
412        let updates = internal_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
413
414        let endpoint_id = self.switchboard.register_endpoint(directions).await?;
415        self.switchboard
416            .register_endpoint_channel(endpoint_id, internal_channel)
417            .await?;
418
419        Ok(EndpointHandle {
420            id: endpoint_id,
421            updates,
422            pending: None,
423            last_inbound: Vec::new(),
424            last_outbound: Vec::new(),
425            backpressure: self.output_backpressure,
426            io: EndpointIo::new(),
427        })
428    }
429}
430
431impl<In, Out> EndpointHandle<In, Out> {
432    /// Return the endpoint identifier assigned by the switchboard.
433    pub fn id(&self) -> EndpointId {
434        self.id
435    }
436
437    /// Backwards-compatible alias for the endpoint identifier.
438    pub fn get_id(&self) -> EndpointId {
439        self.id
440    }
441
442    /// Lookup the outbound channel handle targeting a specific endpoint.
443    pub fn outbound_handle(&self, target: EndpointId) -> Option<ChannelHandle> {
444        self.io.outbound_handle(target)
445    }
446
447    /// Backpressure behaviour for outbound writers when no route exists.
448    pub fn backpressure(&self) -> Backpressure {
449        self.backpressure
450    }
451}
452
453impl<In, Out> EndpointHandle<In, Out>
454where
455    In: FlatMsg + Send + Unpin + 'static,
456    Out: FlatMsg + Send + Unpin + 'static,
457{
458    pub(crate) fn poll_updates(&mut self, cx: &mut Context<'_>) -> Result<(), SwitchboardError> {
459        loop {
460            let mut completed = None;
461            if let Some(pending) = self.pending.as_mut() {
462                match pending.future.as_mut().poll(cx) {
463                    Poll::Ready(Ok(wiring)) => completed = Some(Ok(wiring)),
464                    Poll::Ready(Err(err)) => completed = Some(Err(err)),
465                    Poll::Pending => {}
466                }
467            }
468
469            if let Some(result) = completed {
470                let pending = self.pending.take().expect("pending wiring");
471                match result {
472                    Ok(mut wiring) => {
473                        if pending.reuse_inbound {
474                            wiring.inbound = std::mem::take(&mut self.io.inbound);
475                        }
476                        if pending.reuse_outbound {
477                            wiring.outbound = std::mem::take(&mut self.io.outbound);
478                            wiring.outbound_map = std::mem::take(&mut self.io.outbound_map);
479                        }
480                        self.io.apply_wiring(wiring);
481                        self.last_inbound = pending.next_inbound;
482                        self.last_outbound = pending.next_outbound;
483                    }
484                    Err(err) => return Err(err),
485                }
486            }
487
488            match Pin::new(&mut self.updates).poll_next(cx) {
489                Poll::Ready(Some(Ok(frame))) => {
490                    if frame.payload.is_empty() {
491                        continue;
492                    }
493                    let message = decode_message(&frame.payload)?;
494                    if let Message::WiringUpdate {
495                        inbound, outbound, ..
496                    } = message
497                    {
498                        let reuse_inbound = inbound == self.last_inbound;
499                        let reuse_outbound = outbound == self.last_outbound;
500                        if reuse_inbound && reuse_outbound {
501                            continue;
502                        }
503
504                        let next_inbound = inbound.clone();
505                        let next_outbound = outbound.clone();
506
507                        let future = Box::pin(async move {
508                            let inbound_links = if reuse_inbound {
509                                Vec::new()
510                            } else {
511                                build_inbound(inbound).await?
512                            };
513                            let (outbound_links, outbound_map) = if reuse_outbound {
514                                (Vec::new(), HashMap::new())
515                            } else {
516                                build_outbound(outbound).await?
517                            };
518
519                            Ok(EndpointWiring {
520                                inbound: inbound_links,
521                                outbound: outbound_links,
522                                outbound_map,
523                            })
524                        });
525
526                        self.pending = Some(PendingUpdate {
527                            future,
528                            reuse_inbound,
529                            reuse_outbound,
530                            next_inbound,
531                            next_outbound,
532                        });
533                    }
534                }
535                Poll::Ready(Some(Err(err))) => return Err(SwitchboardError::Driver(err)),
536                Poll::Ready(None) => return Err(SwitchboardError::EndpointClosed),
537                Poll::Pending => break,
538            }
539        }
540
541        Ok(())
542    }
543}
544
545impl<In, Out> EndpointIo<In, Out> {
546    /// Lookup the outbound channel handle targeting a specific endpoint.
547    pub fn outbound_handle(&self, target: EndpointId) -> Option<ChannelHandle> {
548        self.outbound_map.get(&target).cloned()
549    }
550}
551
552impl<In, Out> EndpointIo<In, Out>
553where
554    In: FlatMsg + Send + Unpin + 'static,
555    Out: FlatMsg + Send + Unpin + 'static,
556{
557    fn new() -> Self {
558        Self {
559            inbound: Vec::new(),
560            outbound: Vec::new(),
561            outbound_map: HashMap::new(),
562        }
563    }
564
565    fn apply_wiring(&mut self, wiring: EndpointWiring<In, Out>) {
566        self.inbound = wiring.inbound;
567        self.outbound = wiring.outbound;
568        self.outbound_map = wiring.outbound_map;
569    }
570}
571
572impl<T> RawPublisher<T> {
573    fn new(writer: Writer) -> Self {
574        Self {
575            writer,
576            _marker: PhantomData,
577        }
578    }
579
580    pub(crate) async fn from_channel_handle(
581        handle: ChannelHandle,
582    ) -> Result<Self, SwitchboardError> {
583        let channel = unsafe { Channel::from_raw(handle) };
584        let writer = channel.publish_weak().await?;
585        Ok(Self::new(writer))
586    }
587}
588
589impl<T> RawSubscriber<T> {
590    fn new(reader: Reader) -> Self {
591        Self {
592            reader,
593            _marker: PhantomData,
594        }
595    }
596}
597
598impl<T> futures::Sink<T> for RawPublisher<T>
599where
600    T: FlatMsg + Send + Unpin + 'static,
601{
602    type Error = SwitchboardError;
603
604    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
605        match Pin::new(&mut self.get_mut().writer).poll_ready(cx) {
606            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
607            Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
608            Poll::Pending => Poll::Pending,
609        }
610    }
611
612    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
613        let encoded = FlatMsg::encode(&item);
614        Pin::new(&mut self.get_mut().writer)
615            .start_send(encoded)
616            .map_err(SwitchboardError::Driver)
617    }
618
619    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
620        match Pin::new(&mut self.get_mut().writer).poll_flush(cx) {
621            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
622            Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
623            Poll::Pending => Poll::Pending,
624        }
625    }
626
627    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
628        match Pin::new(&mut self.get_mut().writer).poll_close(cx) {
629            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
630            Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
631            Poll::Pending => Poll::Pending,
632        }
633    }
634}
635
636impl<T> futures::Stream for RawSubscriber<T>
637where
638    T: FlatMsg + Send + Unpin + 'static,
639{
640    type Item = Result<T, SwitchboardError>;
641
642    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
643        match Pin::new(&mut self.get_mut().reader).poll_next(cx) {
644            Poll::Ready(Some(Ok(frame))) => {
645                let decoded =
646                    T::decode(&frame.payload).map_err(|err| SwitchboardError::Protocol(err.into()));
647                Poll::Ready(Some(decoded))
648            }
649            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(SwitchboardError::Driver(err)))),
650            Poll::Ready(None) => Poll::Ready(None),
651            Poll::Pending => Poll::Pending,
652        }
653    }
654}
655
656async fn build_inbound<In>(
657    inbound: Vec<WiringIngress>,
658) -> Result<Vec<InboundLink<In>>, SwitchboardError>
659where
660    In: FlatMsg + Send + Unpin + 'static,
661{
662    let mut inbound_links = Vec::with_capacity(inbound.len());
663    for ingress in inbound {
664        let shared = unsafe { SharedChannel::from_raw(ingress.channel) };
665        let channel = Channel::attach_shared(shared).await?;
666        let reader = channel.subscribe(DATA_CHUNK_SIZE).await?;
667        inbound_links.push(InboundLink {
668            from: ingress.from,
669            subscriber: RawSubscriber::new(reader),
670        });
671    }
672    Ok(inbound_links)
673}
674
675async fn build_outbound<Out>(
676    outbound: Vec<WiringEgress>,
677) -> Result<(Vec<RawPublisher<Out>>, HashMap<EndpointId, ChannelHandle>), SwitchboardError>
678where
679    Out: FlatMsg + Send + Unpin + 'static,
680{
681    let mut outbound_links = Vec::with_capacity(outbound.len());
682    let mut outbound_map = HashMap::with_capacity(outbound.len());
683    for egress in outbound {
684        let shared = unsafe { SharedChannel::from_raw(egress.channel) };
685        let channel = Channel::attach_shared(shared).await?;
686        let writer = channel.publish_weak().await?;
687        outbound_map.insert(egress.to, channel.handle());
688        outbound_links.push(RawPublisher::new(writer));
689    }
690    Ok((outbound_links, outbound_map))
691}
692
693fn request_id_for(message: &Message) -> u64 {
694    match message {
695        Message::RegisterRequest { request_id, .. }
696        | Message::AdoptRequest { request_id, .. }
697        | Message::ConnectRequest { request_id, .. }
698        | Message::ResponseRegister { request_id, .. }
699        | Message::ResponseOk { request_id, .. }
700        | Message::ResponseError { request_id, .. } => *request_id,
701        Message::WiringUpdate { .. } => 0,
702    }
703}
704
705fn spawn_dispatcher(reader: Reader, state: Arc<Mutex<ClientState>>) {
706    #[cfg(test)]
707    {
708        tokio::spawn(async move {
709            drive_updates(reader, state).await;
710        });
711    }
712    #[cfg(not(test))]
713    {
714        selium_userland::spawn(async move {
715            drive_updates(reader, state).await;
716        });
717    }
718}
719
720async fn drive_updates(mut reader: Reader, state: Arc<Mutex<ClientState>>) {
721    while let Some(frame) = reader.next().await {
722        let payload = match frame {
723            Ok(frame) if frame.payload.is_empty() => continue,
724            Ok(frame) => frame.payload,
725            Err(err) => {
726                warn!(?err, "switchboard: update stream failed");
727                break;
728            }
729        };
730
731        let message = match decode_message(&payload) {
732            Ok(message) => message,
733            Err(err) => {
734                warn!(?err, "switchboard: failed to decode update");
735                continue;
736            }
737        };
738
739        match message {
740            Message::ResponseOk { request_id }
741            | Message::ResponseRegister { request_id, .. }
742            | Message::ResponseError { request_id, .. } => {
743                debug!(request_id, "switchboard: dispatching response");
744                let channel = match state.lock() {
745                    Ok(mut guard) => guard.pending.remove(&request_id),
746                    Err(_) => None,
747                };
748                if let Some(channel) = channel
749                    && let Err(err) = forward_bytes(channel, payload).await
750                {
751                    warn!(?err, "switchboard: failed to deliver response");
752                }
753            }
754            Message::WiringUpdate { endpoint_id, .. } => {
755                debug!(endpoint_id, "switchboard: dispatching wiring update");
756                let (channel, queued, ignored) = match state.lock() {
757                    Ok(mut guard) => {
758                        if guard.ignored.contains(&endpoint_id) {
759                            (None, None, true)
760                        } else if let Some(channel) = guard.endpoints.get(&endpoint_id).cloned() {
761                            (Some(channel), None, false)
762                        } else {
763                            let entry = guard.queued_updates.entry(endpoint_id).or_default();
764                            entry.push(payload.clone());
765                            (None, Some(()), false)
766                        }
767                    }
768                    Err(_) => (None, Some(()), false),
769                };
770
771                if ignored {
772                    continue;
773                }
774
775                if let Some(channel) = channel
776                    && let Err(err) = forward_bytes(channel, payload).await
777                {
778                    warn!(?err, "switchboard: failed to deliver update");
779                }
780
781                if queued.is_some() {
782                    continue;
783                }
784            }
785            _ => {}
786        }
787    }
788}
789
790async fn forward_bytes(channel: Channel, payload: Vec<u8>) -> Result<(), DriverError> {
791    let mut writer = channel.publish_weak().await?;
792    writer.send(payload).await?;
793    Ok(())
794}
795
796#[cfg(test)]
797impl Switchboard {
798    /// Create a local switchboard instance for tests.
799    pub async fn new() -> Result<Self, SwitchboardError> {
800        let request_channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
801        let request_shared = request_channel.share().await?;
802        spawn_local_service(request_shared);
803        Switchboard::attach(request_shared).await
804    }
805}
806
807#[cfg(test)]
808fn spawn_local_service(request_channel: SharedChannel) {
809    tokio::spawn(async move {
810        if let Err(err) = run_local_service(request_channel).await {
811            warn!(?err, "switchboard: local service terminated");
812        }
813    });
814}
815
816#[cfg(test)]
817async fn run_local_service(request_channel: SharedChannel) -> Result<(), SwitchboardError> {
818    use selium_switchboard_core::{
819        ChannelKey, SwitchboardCore, SwitchboardError as CoreError, best_compatible_match,
820    };
821
822    let request_channel = Channel::attach_shared(request_channel).await?;
823    let mut reader = request_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
824
825    struct LocalChannelEntry {
826        channel: Channel,
827        shared: SharedChannel,
828        key: ChannelKey,
829    }
830
831    #[derive(Default)]
832    struct EndpointUpdate {
833        inbound: Vec<WiringIngress>,
834        outbound: Vec<WiringEgress>,
835    }
836
837    struct LocalService {
838        core: SwitchboardCore,
839        endpoints: HashMap<EndpointId, Writer>,
840        channels: Vec<LocalChannelEntry>,
841    }
842
843    impl LocalService {
844        fn new() -> Self {
845            Self {
846                core: SwitchboardCore::default(),
847                endpoints: HashMap::new(),
848                channels: Vec::new(),
849            }
850        }
851
852        async fn handle_message(&mut self, message: Message) -> Result<(), SwitchboardError> {
853            match message {
854                Message::RegisterRequest {
855                    request_id,
856                    directions,
857                    updates_channel,
858                } => {
859                    self.handle_register(request_id, directions, updates_channel)
860                        .await
861                }
862                Message::ConnectRequest {
863                    request_id,
864                    from,
865                    to,
866                    reply_channel,
867                } => {
868                    self.handle_connect(request_id, from, to, reply_channel)
869                        .await
870                }
871                _ => Ok(()),
872            }
873        }
874
875        async fn handle_register(
876            &mut self,
877            request_id: u64,
878            directions: EndpointDirections,
879            updates_channel: u64,
880        ) -> Result<(), SwitchboardError> {
881            let updates_channel = unsafe { SharedChannel::from_raw(updates_channel) };
882            let updates_writer = Channel::attach_shared(updates_channel)
883                .await?
884                .publish_weak()
885                .await?;
886
887            let endpoint_id = self.core.add_endpoint(directions);
888            self.endpoints.insert(endpoint_id, updates_writer);
889
890            if let Err(err) = self.reconcile().await {
891                self.core.remove_endpoint(endpoint_id);
892                self.endpoints.remove(&endpoint_id);
893                self.send_error(updates_channel, request_id, err).await?;
894                return Ok(());
895            }
896
897            let response = Message::ResponseRegister {
898                request_id,
899                endpoint_id,
900            };
901            self.send_response(updates_channel, response).await?;
902            Ok(())
903        }
904
905        async fn handle_connect(
906            &mut self,
907            request_id: u64,
908            from: EndpointId,
909            to: EndpointId,
910            reply_channel: u64,
911        ) -> Result<(), SwitchboardError> {
912            let reply_channel = unsafe { SharedChannel::from_raw(reply_channel) };
913            if let Err(err) = self.core.add_intent(from, to) {
914                self.send_error(reply_channel, request_id, err.into())
915                    .await?;
916                return Ok(());
917            }
918
919            if let Err(err) = self.reconcile().await {
920                self.core.remove_intent(from, to);
921                self.send_error(reply_channel, request_id, err).await?;
922                return Ok(());
923            }
924
925            let response = Message::ResponseOk { request_id };
926            self.send_response(reply_channel, response).await?;
927            Ok(())
928        }
929
930        async fn send_response(
931            &self,
932            channel: SharedChannel,
933            message: Message,
934        ) -> Result<(), SwitchboardError> {
935            let channel = Channel::attach_shared(channel).await?;
936            let mut writer = channel.publish_weak().await?;
937            let bytes = encode_message(&message)?;
938            writer.send(bytes).await?;
939            Ok(())
940        }
941
942        async fn send_error(
943            &self,
944            channel: SharedChannel,
945            request_id: u64,
946            err: SwitchboardError,
947        ) -> Result<(), SwitchboardError> {
948            let response = Message::ResponseError {
949                request_id,
950                message: err.to_string(),
951            };
952            self.send_response(channel, response).await
953        }
954
955        async fn reconcile(&mut self) -> Result<(), SwitchboardError> {
956            let solution = self.core.solve().map_err(SwitchboardError::from)?;
957            let wiring = self.apply_solution(solution).await?;
958            self.send_updates(wiring).await?;
959            Ok(())
960        }
961
962        async fn apply_solution(
963            &mut self,
964            solution: selium_switchboard_core::Solution,
965        ) -> Result<HashMap<EndpointId, EndpointUpdate>, SwitchboardError> {
966            let mut available = std::mem::take(&mut self.channels);
967            let mut retained = Vec::with_capacity(solution.channels.len());
968            let mut resolved: Vec<SharedChannel> = Vec::with_capacity(solution.channels.len());
969
970            for spec in &solution.channels {
971                let desired_key = spec.key().clone();
972                let position = available
973                    .iter()
974                    .position(|entry| entry.key == desired_key)
975                    .or_else(|| {
976                        let keys: Vec<ChannelKey> =
977                            available.iter().map(|entry| entry.key.clone()).collect();
978                        best_compatible_match(&keys, &desired_key)
979                    });
980
981                if let Some(pos) = position {
982                    let mut entry = available.swap_remove(pos);
983                    if entry.key != desired_key {
984                        entry.key = desired_key.clone();
985                    }
986                    resolved.push(entry.shared);
987                    retained.push(entry);
988                } else {
989                    let entry = create_local_channel(desired_key.clone()).await?;
990                    resolved.push(entry.shared);
991                    retained.push(entry);
992                }
993            }
994
995            for entry in available {
996                if let Err(err) = entry.channel.drain().await {
997                    warn!(?err, "switchboard: failed to drain channel");
998                }
999                if let Err(err) = entry.channel.delete().await {
1000                    warn!(?err, "switchboard: failed to delete channel");
1001                }
1002            }
1003
1004            self.channels = retained;
1005
1006            let mut wiring: HashMap<EndpointId, EndpointUpdate> = HashMap::new();
1007            for route in &solution.routes {
1008                for flow in &route.flows {
1009                    let handle = resolved
1010                        .get(flow.channel)
1011                        .ok_or(SwitchboardError::from(CoreError::Unsolveable))?;
1012                    wiring
1013                        .entry(flow.producer)
1014                        .or_default()
1015                        .outbound
1016                        .push(WiringEgress {
1017                            to: flow.consumer,
1018                            channel: handle.raw(),
1019                        });
1020                    wiring
1021                        .entry(flow.consumer)
1022                        .or_default()
1023                        .inbound
1024                        .push(WiringIngress {
1025                            from: flow.producer,
1026                            channel: handle.raw(),
1027                        });
1028                }
1029            }
1030
1031            for entry in wiring.values_mut() {
1032                entry
1033                    .inbound
1034                    .sort_unstable_by_key(|ingress| (ingress.channel, ingress.from));
1035                entry.inbound.dedup_by_key(|ingress| ingress.channel);
1036                entry
1037                    .outbound
1038                    .sort_unstable_by_key(|egress| (egress.channel, egress.to));
1039                entry.outbound.dedup_by_key(|egress| egress.channel);
1040            }
1041
1042            Ok(wiring)
1043        }
1044
1045        async fn send_updates(
1046            &mut self,
1047            mut wiring: HashMap<EndpointId, EndpointUpdate>,
1048        ) -> Result<(), SwitchboardError> {
1049            for (endpoint_id, writer) in &mut self.endpoints {
1050                let update = wiring.remove(endpoint_id).unwrap_or_default();
1051                let message = Message::WiringUpdate {
1052                    endpoint_id: *endpoint_id,
1053                    inbound: update.inbound,
1054                    outbound: update.outbound,
1055                };
1056                let bytes = encode_message(&message)?;
1057                if let Err(err) = writer.send(bytes).await {
1058                    warn!(?err, endpoint_id, "switchboard: failed to send update");
1059                }
1060            }
1061            Ok(())
1062        }
1063    }
1064
1065    async fn create_local_channel(key: ChannelKey) -> Result<LocalChannelEntry, SwitchboardError> {
1066        let channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
1067        let shared = channel.share().await?;
1068        Ok(LocalChannelEntry {
1069            channel,
1070            shared,
1071            key,
1072        })
1073    }
1074
1075    let mut service = LocalService::new();
1076
1077    while let Some(frame) = reader.next().await {
1078        match frame {
1079            Ok(frame) => {
1080                if frame.payload.is_empty() {
1081                    continue;
1082                }
1083                let message = decode_message(&frame.payload)?;
1084                service.handle_message(message).await?;
1085            }
1086            Err(err) => return Err(SwitchboardError::Driver(err)),
1087        }
1088    }
1089
1090    Ok(())
1091}
1092
1093#[cfg(test)]
1094impl From<selium_switchboard_core::SwitchboardError> for SwitchboardError {
1095    fn from(value: selium_switchboard_core::SwitchboardError) -> Self {
1096        SwitchboardError::Remote(value.to_string())
1097    }
1098}