1use std::convert::TryFrom;
4
5use flatbuffers::{FlatBufferBuilder, InvalidFlatbuffer};
6use thiserror::Error;
7
8#[allow(missing_docs)]
10#[allow(warnings)]
11#[rustfmt::skip]
12pub mod fbs;
13
14use crate::fbs::selium::switchboard as fb;
15
16pub type EndpointId = u32;
18pub type SchemaId = [u8; 16];
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub enum Cardinality {
24 Zero,
26 One,
28 Many,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
34pub enum Backpressure {
35 Park,
37 Drop,
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum AdoptMode {
44 Alias,
46 Tap,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub struct Direction {
53 schema_id: SchemaId,
54 cardinality: Cardinality,
55 backpressure: Backpressure,
56 exclusive: bool,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub struct EndpointDirections {
62 input: Direction,
63 output: Direction,
64}
65
66#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct WiringIngress {
69 pub from: EndpointId,
71 pub channel: u64,
73}
74
75#[derive(Clone, Debug, PartialEq, Eq)]
77pub struct WiringEgress {
78 pub to: EndpointId,
80 pub channel: u64,
82}
83
84#[derive(Clone, Debug, PartialEq, Eq)]
86pub enum Message {
87 RegisterRequest {
89 request_id: u64,
91 directions: EndpointDirections,
93 updates_channel: u64,
95 },
96 AdoptRequest {
98 request_id: u64,
100 directions: EndpointDirections,
102 updates_channel: u64,
104 channel: u64,
106 mode: AdoptMode,
108 },
109 ConnectRequest {
111 request_id: u64,
113 from: EndpointId,
115 to: EndpointId,
117 reply_channel: u64,
119 },
120 ResponseRegister {
122 request_id: u64,
124 endpoint_id: EndpointId,
126 },
127 ResponseOk {
129 request_id: u64,
131 },
132 ResponseError {
134 request_id: u64,
136 message: String,
138 },
139 WiringUpdate {
141 endpoint_id: EndpointId,
143 inbound: Vec<WiringIngress>,
145 outbound: Vec<WiringEgress>,
147 },
148}
149
150#[derive(Debug, Error)]
152pub enum ProtocolError {
153 #[error("invalid flatbuffer: {0:?}")]
155 InvalidFlatbuffer(InvalidFlatbuffer),
156 #[error("switchboard message missing payload")]
158 MissingPayload,
159 #[error("unknown switchboard payload type")]
161 UnknownPayload,
162 #[error("missing schema identifier")]
164 MissingSchemaId,
165 #[error("schema identifier length mismatch")]
167 InvalidSchemaId,
168 #[error("unknown cardinality variant")]
170 UnknownCardinality,
171 #[error("unknown backpressure variant")]
173 UnknownBackpressure,
174 #[error("unknown adopt mode variant")]
176 UnknownAdoptMode,
177 #[error("invalid switchboard message identifier")]
179 InvalidIdentifier,
180}
181
182const SWITCHBOARD_IDENTIFIER: &str = "SBSW";
183
184impl Cardinality {
185 pub fn allows(self, count: usize) -> bool {
187 match self {
188 Cardinality::Zero => count == 0,
189 Cardinality::One => count <= 1,
190 Cardinality::Many => true,
191 }
192 }
193}
194
195impl Direction {
196 pub fn new(schema_id: SchemaId, cardinality: Cardinality, backpressure: Backpressure) -> Self {
198 Self {
199 schema_id,
200 cardinality,
201 backpressure,
202 exclusive: false,
203 }
204 }
205
206 pub fn schema_id(&self) -> SchemaId {
208 self.schema_id
209 }
210
211 pub fn cardinality(&self) -> Cardinality {
213 self.cardinality
214 }
215
216 pub fn backpressure(&self) -> Backpressure {
218 self.backpressure
219 }
220
221 pub fn exclusive(&self) -> bool {
223 self.exclusive
224 }
225
226 pub fn with_exclusive(mut self, exclusive: bool) -> Self {
228 self.exclusive = exclusive;
229 self
230 }
231}
232
233impl EndpointDirections {
234 pub fn new(input: Direction, output: Direction) -> Self {
236 Self { input, output }
237 }
238
239 pub fn input(&self) -> &Direction {
241 &self.input
242 }
243
244 pub fn output(&self) -> &Direction {
246 &self.output
247 }
248}
249
250impl TryFrom<fb::Cardinality> for Cardinality {
251 type Error = ProtocolError;
252
253 fn try_from(value: fb::Cardinality) -> Result<Self, Self::Error> {
254 match value {
255 fb::Cardinality::Zero => Ok(Cardinality::Zero),
256 fb::Cardinality::One => Ok(Cardinality::One),
257 fb::Cardinality::Many => Ok(Cardinality::Many),
258 _ => Err(ProtocolError::UnknownCardinality),
259 }
260 }
261}
262
263impl From<Cardinality> for fb::Cardinality {
264 fn from(value: Cardinality) -> Self {
265 match value {
266 Cardinality::Zero => fb::Cardinality::Zero,
267 Cardinality::One => fb::Cardinality::One,
268 Cardinality::Many => fb::Cardinality::Many,
269 }
270 }
271}
272
273impl TryFrom<fb::Backpressure> for Backpressure {
274 type Error = ProtocolError;
275
276 fn try_from(value: fb::Backpressure) -> Result<Self, Self::Error> {
277 match value {
278 fb::Backpressure::Park => Ok(Backpressure::Park),
279 fb::Backpressure::Drop => Ok(Backpressure::Drop),
280 _ => Err(ProtocolError::UnknownBackpressure),
281 }
282 }
283}
284
285impl From<Backpressure> for fb::Backpressure {
286 fn from(value: Backpressure) -> Self {
287 match value {
288 Backpressure::Park => fb::Backpressure::Park,
289 Backpressure::Drop => fb::Backpressure::Drop,
290 }
291 }
292}
293
294impl TryFrom<fb::AdoptMode> for AdoptMode {
295 type Error = ProtocolError;
296
297 fn try_from(value: fb::AdoptMode) -> Result<Self, Self::Error> {
298 match value {
299 fb::AdoptMode::Alias => Ok(AdoptMode::Alias),
300 fb::AdoptMode::Tap => Ok(AdoptMode::Tap),
301 _ => Err(ProtocolError::UnknownAdoptMode),
302 }
303 }
304}
305
306impl From<AdoptMode> for fb::AdoptMode {
307 fn from(value: AdoptMode) -> Self {
308 match value {
309 AdoptMode::Alias => fb::AdoptMode::Alias,
310 AdoptMode::Tap => fb::AdoptMode::Tap,
311 }
312 }
313}
314
315impl From<InvalidFlatbuffer> for ProtocolError {
316 fn from(value: InvalidFlatbuffer) -> Self {
317 ProtocolError::InvalidFlatbuffer(value)
318 }
319}
320
321pub fn encode_message(message: &Message) -> Result<Vec<u8>, ProtocolError> {
323 let mut builder = FlatBufferBuilder::new();
324 let (request_id, payload_type, payload) = match message {
325 Message::RegisterRequest {
326 request_id,
327 directions,
328 updates_channel,
329 } => {
330 let directions = encode_directions(&mut builder, directions);
331 let payload = fb::RegisterRequest::create(
332 &mut builder,
333 &fb::RegisterRequestArgs {
334 directions: Some(directions),
335 updates_channel: *updates_channel,
336 },
337 );
338 (
339 *request_id,
340 fb::SwitchboardPayload::RegisterRequest,
341 Some(payload.as_union_value()),
342 )
343 }
344 Message::AdoptRequest {
345 request_id,
346 directions,
347 updates_channel,
348 channel,
349 mode,
350 } => {
351 let directions = encode_directions(&mut builder, directions);
352 let payload = fb::AdoptRequest::create(
353 &mut builder,
354 &fb::AdoptRequestArgs {
355 directions: Some(directions),
356 updates_channel: *updates_channel,
357 channel: *channel,
358 mode: (*mode).into(),
359 },
360 );
361 (
362 *request_id,
363 fb::SwitchboardPayload::AdoptRequest,
364 Some(payload.as_union_value()),
365 )
366 }
367 Message::ConnectRequest {
368 request_id,
369 from,
370 to,
371 reply_channel,
372 } => {
373 let payload = fb::ConnectRequest::create(
374 &mut builder,
375 &fb::ConnectRequestArgs {
376 from: *from,
377 to: *to,
378 reply_channel: *reply_channel,
379 },
380 );
381 (
382 *request_id,
383 fb::SwitchboardPayload::ConnectRequest,
384 Some(payload.as_union_value()),
385 )
386 }
387 Message::ResponseRegister {
388 request_id,
389 endpoint_id,
390 } => {
391 let payload = fb::RegisterResponse::create(
392 &mut builder,
393 &fb::RegisterResponseArgs {
394 endpoint_id: *endpoint_id,
395 },
396 );
397 (
398 *request_id,
399 fb::SwitchboardPayload::RegisterResponse,
400 Some(payload.as_union_value()),
401 )
402 }
403 Message::ResponseOk { request_id } => {
404 let payload = fb::OkResponse::create(&mut builder, &fb::OkResponseArgs {});
405 (
406 *request_id,
407 fb::SwitchboardPayload::OkResponse,
408 Some(payload.as_union_value()),
409 )
410 }
411 Message::ResponseError {
412 request_id,
413 message,
414 } => {
415 let msg = builder.create_string(message);
416 let payload = fb::ErrorResponse::create(
417 &mut builder,
418 &fb::ErrorResponseArgs { message: Some(msg) },
419 );
420 (
421 *request_id,
422 fb::SwitchboardPayload::ErrorResponse,
423 Some(payload.as_union_value()),
424 )
425 }
426 Message::WiringUpdate {
427 endpoint_id,
428 inbound,
429 outbound,
430 } => {
431 let inbound_vec = encode_ingress(&mut builder, inbound);
432 let outbound_vec = encode_egress(&mut builder, outbound);
433 let payload = fb::WiringUpdate::create(
434 &mut builder,
435 &fb::WiringUpdateArgs {
436 endpoint_id: *endpoint_id,
437 inbound: Some(inbound_vec),
438 outbound: Some(outbound_vec),
439 },
440 );
441 (
442 0,
443 fb::SwitchboardPayload::WiringUpdate,
444 Some(payload.as_union_value()),
445 )
446 }
447 };
448
449 let message = fb::SwitchboardMessage::create(
450 &mut builder,
451 &fb::SwitchboardMessageArgs {
452 request_id,
453 payload_type,
454 payload,
455 },
456 );
457 builder.finish(message, Some(SWITCHBOARD_IDENTIFIER));
458 Ok(builder.finished_data().to_vec())
459}
460
461pub fn decode_message(bytes: &[u8]) -> Result<Message, ProtocolError> {
463 if !fb::switchboard_message_buffer_has_identifier(bytes) {
464 return Err(ProtocolError::InvalidIdentifier);
465 }
466 let message = flatbuffers::root::<fb::SwitchboardMessage>(bytes)?;
467
468 match message.payload_type() {
469 fb::SwitchboardPayload::RegisterRequest => {
470 let req = message
471 .payload_as_register_request()
472 .ok_or(ProtocolError::MissingPayload)?;
473 let directions =
474 decode_directions(req.directions().ok_or(ProtocolError::MissingPayload)?)?;
475 Ok(Message::RegisterRequest {
476 request_id: message.request_id(),
477 directions,
478 updates_channel: req.updates_channel(),
479 })
480 }
481 fb::SwitchboardPayload::AdoptRequest => {
482 let req = message
483 .payload_as_adopt_request()
484 .ok_or(ProtocolError::MissingPayload)?;
485 let directions =
486 decode_directions(req.directions().ok_or(ProtocolError::MissingPayload)?)?;
487 let mode = AdoptMode::try_from(req.mode())?;
488 Ok(Message::AdoptRequest {
489 request_id: message.request_id(),
490 directions,
491 updates_channel: req.updates_channel(),
492 channel: req.channel(),
493 mode,
494 })
495 }
496 fb::SwitchboardPayload::ConnectRequest => {
497 let req = message
498 .payload_as_connect_request()
499 .ok_or(ProtocolError::MissingPayload)?;
500 Ok(Message::ConnectRequest {
501 request_id: message.request_id(),
502 from: req.from(),
503 to: req.to(),
504 reply_channel: req.reply_channel(),
505 })
506 }
507 fb::SwitchboardPayload::RegisterResponse => {
508 let resp = message
509 .payload_as_register_response()
510 .ok_or(ProtocolError::MissingPayload)?;
511 Ok(Message::ResponseRegister {
512 request_id: message.request_id(),
513 endpoint_id: resp.endpoint_id(),
514 })
515 }
516 fb::SwitchboardPayload::OkResponse => Ok(Message::ResponseOk {
517 request_id: message.request_id(),
518 }),
519 fb::SwitchboardPayload::ErrorResponse => {
520 let resp = message
521 .payload_as_error_response()
522 .ok_or(ProtocolError::MissingPayload)?;
523 Ok(Message::ResponseError {
524 request_id: message.request_id(),
525 message: resp.message().unwrap_or_default().to_string(),
526 })
527 }
528 fb::SwitchboardPayload::WiringUpdate => {
529 let update = message
530 .payload_as_wiring_update()
531 .ok_or(ProtocolError::MissingPayload)?;
532 Ok(Message::WiringUpdate {
533 endpoint_id: update.endpoint_id(),
534 inbound: decode_ingress(update.inbound())?,
535 outbound: decode_egress(update.outbound())?,
536 })
537 }
538 _ => Err(ProtocolError::UnknownPayload),
539 }
540}
541
542fn encode_directions<'bldr>(
543 builder: &mut FlatBufferBuilder<'bldr>,
544 directions: &EndpointDirections,
545) -> flatbuffers::WIPOffset<fb::EndpointDirections<'bldr>> {
546 let input = encode_direction(builder, directions.input());
547 let output = encode_direction(builder, directions.output());
548 fb::EndpointDirections::create(
549 builder,
550 &fb::EndpointDirectionsArgs {
551 input: Some(input),
552 output: Some(output),
553 },
554 )
555}
556
557fn encode_direction<'bldr>(
558 builder: &mut FlatBufferBuilder<'bldr>,
559 direction: &Direction,
560) -> flatbuffers::WIPOffset<fb::Direction<'bldr>> {
561 let schema_id = builder.create_vector(&direction.schema_id());
562 fb::Direction::create(
563 builder,
564 &fb::DirectionArgs {
565 schema_id: Some(schema_id),
566 cardinality: direction.cardinality().into(),
567 backpressure: direction.backpressure().into(),
568 exclusive: direction.exclusive(),
569 },
570 )
571}
572
573fn encode_ingress<'bldr>(
574 builder: &mut FlatBufferBuilder<'bldr>,
575 inbound: &[WiringIngress],
576) -> flatbuffers::WIPOffset<
577 flatbuffers::Vector<'bldr, flatbuffers::ForwardsUOffset<fb::WiringIngress<'bldr>>>,
578> {
579 let items: Vec<_> = inbound
580 .iter()
581 .map(|ingress| {
582 fb::WiringIngress::create(
583 builder,
584 &fb::WiringIngressArgs {
585 from: ingress.from,
586 channel: ingress.channel,
587 },
588 )
589 })
590 .collect();
591 builder.create_vector(&items)
592}
593
594fn encode_egress<'bldr>(
595 builder: &mut FlatBufferBuilder<'bldr>,
596 outbound: &[WiringEgress],
597) -> flatbuffers::WIPOffset<
598 flatbuffers::Vector<'bldr, flatbuffers::ForwardsUOffset<fb::WiringEgress<'bldr>>>,
599> {
600 let items: Vec<_> = outbound
601 .iter()
602 .map(|egress| {
603 fb::WiringEgress::create(
604 builder,
605 &fb::WiringEgressArgs {
606 to: egress.to,
607 channel: egress.channel,
608 },
609 )
610 })
611 .collect();
612 builder.create_vector(&items)
613}
614
615fn decode_directions(
616 directions: fb::EndpointDirections<'_>,
617) -> Result<EndpointDirections, ProtocolError> {
618 let input = decode_direction(directions.input().ok_or(ProtocolError::MissingPayload)?)?;
619 let output = decode_direction(directions.output().ok_or(ProtocolError::MissingPayload)?)?;
620 Ok(EndpointDirections::new(input, output))
621}
622
623fn decode_direction(direction: fb::Direction<'_>) -> Result<Direction, ProtocolError> {
624 let schema_id = decode_schema_id(direction.schema_id())?;
625 let cardinality = Cardinality::try_from(direction.cardinality())?;
626 let backpressure = Backpressure::try_from(direction.backpressure())?;
627 let exclusive = direction.exclusive();
628 Ok(Direction::new(schema_id, cardinality, backpressure).with_exclusive(exclusive))
629}
630
631fn decode_schema_id(
632 schema_id: Option<flatbuffers::Vector<'_, u8>>,
633) -> Result<SchemaId, ProtocolError> {
634 let vec = schema_id.ok_or(ProtocolError::MissingSchemaId)?;
635 if vec.len() != 16 {
636 return Err(ProtocolError::InvalidSchemaId);
637 }
638 let mut out = [0u8; 16];
639 for (idx, value) in vec.iter().enumerate() {
640 if idx >= out.len() {
641 break;
642 }
643 out[idx] = value;
644 }
645 Ok(out)
646}
647
648fn decode_ingress(
649 inbound: Option<flatbuffers::Vector<'_, flatbuffers::ForwardsUOffset<fb::WiringIngress<'_>>>>,
650) -> Result<Vec<WiringIngress>, ProtocolError> {
651 let mut items = Vec::new();
652 if let Some(vec) = inbound {
653 for ingress in vec {
654 items.push(WiringIngress {
655 from: ingress.from(),
656 channel: ingress.channel(),
657 });
658 }
659 }
660 Ok(items)
661}
662
663fn decode_egress(
664 outbound: Option<flatbuffers::Vector<'_, flatbuffers::ForwardsUOffset<fb::WiringEgress<'_>>>>,
665) -> Result<Vec<WiringEgress>, ProtocolError> {
666 let mut items = Vec::new();
667 if let Some(vec) = outbound {
668 for egress in vec {
669 items.push(WiringEgress {
670 to: egress.to(),
671 channel: egress.channel(),
672 });
673 }
674 }
675 Ok(items)
676}