Skip to main content

selium_switchboard_core/
lib.rs

1//! Switchboard solver and planning helpers.
2
3use std::collections::{BTreeSet, HashMap};
4
5use selium_switchboard_protocol::{
6    Backpressure, Cardinality, EndpointDirections, EndpointId, SchemaId,
7};
8use thiserror::Error;
9
10/// Errors produced while planning switchboard wiring.
11#[derive(Debug, Error)]
12pub enum SwitchboardError {
13    /// An invalid endpoint identifier was referenced.
14    #[error("invalid endpoint")]
15    InvalidEndpoint,
16    /// The requested directions are incompatible.
17    #[error("directions cannot be connected")]
18    DirectionMismatch,
19    /// The solver could not find a valid wiring.
20    #[error("graph cannot be solved")]
21    Unsolveable,
22}
23
24/// Pair of endpoints the solver should consider when planning flows.
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26pub struct Intent {
27    /// Producer endpoint identifier.
28    pub from: EndpointId,
29    /// Consumer endpoint identifier.
30    pub to: EndpointId,
31}
32
33/// Canonical key describing a channel's schema and attachments.
34#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
35pub struct ChannelKey {
36    schema: SchemaId,
37    backpressure: Backpressure,
38    producers: Vec<EndpointId>,
39    consumers: Vec<EndpointId>,
40}
41
42/// A channel that should exist once the solution is applied.
43#[derive(Clone, Debug, PartialEq, Eq)]
44pub struct ChannelSpec {
45    key: ChannelKey,
46}
47
48/// A flow for a single intent mapped to a channel in the solution.
49#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct FlowRoute {
51    /// Producer endpoint for the flow.
52    pub producer: EndpointId,
53    /// Consumer endpoint for the flow.
54    pub consumer: EndpointId,
55    /// Index into the solution channel list.
56    pub channel: usize,
57}
58
59/// The set of flows required for a given intent.
60#[derive(Clone, Debug, PartialEq, Eq)]
61pub struct IntentRoute {
62    /// Producer endpoint in the intent.
63    pub from: EndpointId,
64    /// Consumer endpoint in the intent.
65    pub to: EndpointId,
66    /// Flow list for the intent.
67    pub flows: Vec<FlowRoute>,
68}
69
70/// The solver output, describing the channels and per-intent routing.
71#[derive(Clone, Debug, PartialEq, Eq)]
72pub struct Solution {
73    /// Channels required for the solution.
74    pub channels: Vec<ChannelSpec>,
75    /// Routes for each intent.
76    pub routes: Vec<IntentRoute>,
77}
78
79/// Solver interface for producing wiring plans.
80pub trait Solver {
81    /// Build a wiring solution for the supplied endpoints and intents.
82    fn solve(
83        &self,
84        endpoints: &HashMap<EndpointId, EndpointDirections>,
85        intents: &[Intent],
86    ) -> Result<Solution, SwitchboardError>;
87}
88
89/// Default `Solver` implementation used by `SwitchboardCore`.
90#[derive(Default)]
91pub struct DefaultSolver;
92
93/// In-memory state tracker for endpoints and intents.
94pub struct SwitchboardCore<S: Solver = DefaultSolver> {
95    endpoints: HashMap<EndpointId, EndpointDirections>,
96    intents: Vec<Intent>,
97    solver: S,
98    next_id: EndpointId,
99}
100
101#[derive(Clone)]
102struct FlowSpec {
103    producer: EndpointId,
104    consumer: EndpointId,
105    schema: SchemaId,
106    backpressure: Backpressure,
107    producer_exclusive: bool,
108}
109
110#[derive(Clone)]
111struct ChannelGroup {
112    schema: SchemaId,
113    backpressure: Backpressure,
114    producers: BTreeSet<EndpointId>,
115    consumers: BTreeSet<EndpointId>,
116    exclusive_producers: Vec<EndpointId>,
117}
118
119type ChannelGroupKey = (SchemaId, Backpressure, Vec<EndpointId>, Vec<EndpointId>);
120type ChannelGroupMap = HashMap<ChannelGroupKey, ChannelGroup>;
121type MergeKey = (BTreeSet<EndpointId>, Vec<EndpointId>);
122type ChannelGroupsBySchema = HashMap<(SchemaId, Backpressure), HashMap<MergeKey, ChannelGroup>>;
123
124impl ChannelKey {
125    /// Create a new channel key from producers and consumers.
126    pub fn new(
127        schema: SchemaId,
128        backpressure: Backpressure,
129        producers: impl Iterator<Item = EndpointId>,
130        consumers: impl Iterator<Item = EndpointId>,
131    ) -> Self {
132        let mut producers: Vec<EndpointId> = producers.collect();
133        producers.sort();
134        producers.dedup();
135        let mut consumers: Vec<EndpointId> = consumers.collect();
136        consumers.sort();
137        consumers.dedup();
138        Self {
139            schema,
140            backpressure,
141            producers,
142            consumers,
143        }
144    }
145
146    /// Check whether this key matches the supplied endpoints and schema.
147    pub fn contains(
148        &self,
149        schema: SchemaId,
150        backpressure: Backpressure,
151        producer: EndpointId,
152        consumer: EndpointId,
153    ) -> bool {
154        self.schema == schema
155            && self.backpressure == backpressure
156            && self.producers.binary_search(&producer).is_ok()
157            && self.consumers.binary_search(&consumer).is_ok()
158    }
159
160    /// Check whether the key is compatible with another key.
161    pub fn is_compatible(&self, desired: &ChannelKey) -> bool {
162        if self.schema != desired.schema || self.backpressure != desired.backpressure {
163            return false;
164        }
165        (is_subset(&self.producers, desired.producers())
166            || is_subset(desired.producers(), &self.producers))
167            && (is_subset(&self.consumers, desired.consumers())
168                || is_subset(desired.consumers(), &self.consumers))
169    }
170
171    /// Return the schema identifier for the key.
172    pub fn schema(&self) -> SchemaId {
173        self.schema
174    }
175
176    /// Return the backpressure behaviour for this channel.
177    pub fn backpressure(&self) -> Backpressure {
178        self.backpressure
179    }
180
181    /// Return producers on this channel.
182    pub fn producers(&self) -> &[EndpointId] {
183        &self.producers
184    }
185
186    /// Return consumers on this channel.
187    pub fn consumers(&self) -> &[EndpointId] {
188        &self.consumers
189    }
190}
191
192impl ChannelSpec {
193    /// Create a new channel specification.
194    pub fn new(
195        schema: SchemaId,
196        backpressure: Backpressure,
197        producers: impl Iterator<Item = EndpointId>,
198        consumers: impl Iterator<Item = EndpointId>,
199    ) -> Self {
200        Self {
201            key: ChannelKey::new(schema, backpressure, producers, consumers),
202        }
203    }
204
205    /// Access the channel key.
206    pub fn key(&self) -> &ChannelKey {
207        &self.key
208    }
209}
210
211impl DefaultSolver {
212    fn make_flow(
213        &self,
214        producer_id: EndpointId,
215        producer: &EndpointDirections,
216        consumer_id: EndpointId,
217        consumer: &EndpointDirections,
218    ) -> Result<FlowSpec, SwitchboardError> {
219        let output = producer.output();
220        let input = consumer.input();
221        if output.cardinality() == Cardinality::Zero || input.cardinality() == Cardinality::Zero {
222            return Err(SwitchboardError::DirectionMismatch);
223        }
224        if output.schema_id() != input.schema_id() {
225            return Err(SwitchboardError::DirectionMismatch);
226        }
227        Ok(FlowSpec {
228            producer: producer_id,
229            consumer: consumer_id,
230            schema: output.schema_id(),
231            backpressure: output.backpressure(),
232            producer_exclusive: output.exclusive(),
233        })
234    }
235}
236
237impl<S> SwitchboardCore<S>
238where
239    S: Solver,
240{
241    /// Create a new core with the supplied solver.
242    pub fn new(solver: S) -> Self {
243        Self {
244            endpoints: HashMap::new(),
245            intents: Vec::new(),
246            solver,
247            next_id: 1,
248        }
249    }
250
251    /// Register a new endpoint and return its identifier.
252    pub fn add_endpoint(&mut self, directions: EndpointDirections) -> EndpointId {
253        let id = self.next_id;
254        self.next_id = self.next_id.saturating_add(1);
255        self.endpoints.insert(id, directions);
256        id
257    }
258
259    /// Remove an endpoint and any intents referencing it.
260    pub fn remove_endpoint(&mut self, endpoint_id: EndpointId) {
261        self.endpoints.remove(&endpoint_id);
262        self.intents
263            .retain(|intent| intent.from != endpoint_id && intent.to != endpoint_id);
264    }
265
266    /// Access the registered endpoints.
267    pub fn endpoints(&self) -> &HashMap<EndpointId, EndpointDirections> {
268        &self.endpoints
269    }
270
271    /// Access the registered intents.
272    pub fn intents(&self) -> &[Intent] {
273        &self.intents
274    }
275
276    /// Add a new intent between endpoints.
277    pub fn add_intent(&mut self, from: EndpointId, to: EndpointId) -> Result<(), SwitchboardError> {
278        if !self.endpoints.contains_key(&from) || !self.endpoints.contains_key(&to) {
279            return Err(SwitchboardError::InvalidEndpoint);
280        }
281        if self
282            .intents
283            .iter()
284            .any(|intent| intent.from == from && intent.to == to)
285        {
286            return Ok(());
287        }
288        self.intents.push(Intent { from, to });
289        Ok(())
290    }
291
292    /// Remove a previously registered intent.
293    pub fn remove_intent(&mut self, from: EndpointId, to: EndpointId) {
294        self.intents
295            .retain(|intent| intent.from != from || intent.to != to);
296    }
297
298    /// Solve the current wiring plan.
299    pub fn solve(&self) -> Result<Solution, SwitchboardError> {
300        self.solver.solve(&self.endpoints, &self.intents)
301    }
302}
303
304impl Default for SwitchboardCore {
305    fn default() -> Self {
306        Self::new(DefaultSolver)
307    }
308}
309
310impl Solver for DefaultSolver {
311    fn solve(
312        &self,
313        endpoints: &HashMap<EndpointId, EndpointDirections>,
314        intents: &[Intent],
315    ) -> Result<Solution, SwitchboardError> {
316        let mut flows: Vec<FlowSpec> = Vec::new();
317        let mut flows_by_intent: Vec<Vec<FlowSpec>> = Vec::with_capacity(intents.len());
318        for intent in intents {
319            let from_directions = endpoints
320                .get(&intent.from)
321                .ok_or(SwitchboardError::InvalidEndpoint)?;
322            let to_directions = endpoints
323                .get(&intent.to)
324                .ok_or(SwitchboardError::InvalidEndpoint)?;
325
326            let flow = self.make_flow(intent.from, from_directions, intent.to, to_directions)?;
327            flows_by_intent.push(vec![flow.clone()]);
328            flows.push(flow);
329        }
330
331        let mut consumer_map: HashMap<
332            (EndpointId, SchemaId, Backpressure, Option<EndpointId>),
333            BTreeSet<EndpointId>,
334        > = HashMap::new();
335        for flow in flows {
336            let exclusive_key = if flow.producer_exclusive {
337                Some(flow.producer)
338            } else {
339                None
340            };
341            consumer_map
342                .entry((flow.consumer, flow.schema, flow.backpressure, exclusive_key))
343                .or_default()
344                .insert(flow.producer);
345        }
346
347        let mut channel_groups: ChannelGroupMap = HashMap::new();
348        for ((consumer, schema, backpressure, exclusive_producer), producers) in
349            consumer_map.into_iter()
350        {
351            if producers.is_empty() {
352                continue;
353            }
354            let producers_vec: Vec<EndpointId> = producers.iter().copied().collect();
355            let mut exclusive_producers = Vec::new();
356            if let Some(exclusive) = exclusive_producer {
357                exclusive_producers.push(exclusive);
358            }
359            let key = (
360                schema,
361                backpressure,
362                producers_vec.clone(),
363                exclusive_producers.clone(),
364            );
365            let group = channel_groups.entry(key).or_insert_with(|| ChannelGroup {
366                schema,
367                backpressure,
368                producers: producers.clone(),
369                consumers: BTreeSet::new(),
370                exclusive_producers,
371            });
372            group.consumers.insert(consumer);
373        }
374
375        let mut merged: Vec<ChannelGroup> = Vec::new();
376        let mut by_schema: ChannelGroupsBySchema = HashMap::new();
377        for group in channel_groups.values() {
378            let schema_entry = by_schema
379                .entry((group.schema, group.backpressure))
380                .or_default();
381            let consumer_key = group.consumers.clone();
382            let merge_key: MergeKey = (consumer_key, group.exclusive_producers.clone());
383            let entry = schema_entry
384                .entry(merge_key)
385                .or_insert_with(|| group.clone());
386            entry.producers.extend(group.producers.iter().copied());
387        }
388        for schema_entry in by_schema.values() {
389            merged.extend(schema_entry.values().cloned());
390        }
391
392        let mut input_counts: HashMap<EndpointId, usize> = HashMap::new();
393        let mut output_counts: HashMap<EndpointId, usize> = HashMap::new();
394        for group in &merged {
395            for consumer in &group.consumers {
396                *input_counts.entry(*consumer).or_insert(0) += 1;
397            }
398            for producer in &group.producers {
399                *output_counts.entry(*producer).or_insert(0) += 1;
400            }
401        }
402
403        for (id, endpoint) in endpoints {
404            let inputs = *input_counts.get(id).unwrap_or(&0);
405            let outputs = *output_counts.get(id).unwrap_or(&0);
406            if !endpoint.input().cardinality().allows(inputs)
407                || !endpoint.output().cardinality().allows(outputs)
408            {
409                return Err(SwitchboardError::Unsolveable);
410            }
411        }
412
413        let mut channel_specs: Vec<ChannelSpec> = merged
414            .into_iter()
415            .map(|group| {
416                ChannelSpec::new(
417                    group.schema,
418                    group.backpressure,
419                    group.producers.iter().copied(),
420                    group.consumers.iter().copied(),
421                )
422            })
423            .collect();
424        channel_specs.sort_by(|a, b| a.key().cmp(b.key()));
425
426        let mut routes: Vec<IntentRoute> = Vec::with_capacity(intents.len());
427        for (idx, intent) in intents.iter().enumerate() {
428            let mut flow_routes = Vec::new();
429            for flow in &flows_by_intent[idx] {
430                let channel_index = channel_specs
431                    .iter()
432                    .position(|spec| {
433                        spec.key().contains(
434                            flow.schema,
435                            flow.backpressure,
436                            flow.producer,
437                            flow.consumer,
438                        )
439                    })
440                    .ok_or(SwitchboardError::Unsolveable)?;
441                flow_routes.push(FlowRoute {
442                    producer: flow.producer,
443                    consumer: flow.consumer,
444                    channel: channel_index,
445                });
446            }
447            routes.push(IntentRoute {
448                from: intent.from,
449                to: intent.to,
450                flows: flow_routes,
451            });
452        }
453
454        Ok(Solution {
455            channels: channel_specs,
456            routes,
457        })
458    }
459}
460
461/// Return the position of the best compatible channel, if any.
462pub fn best_compatible_match(available: &[ChannelKey], desired: &ChannelKey) -> Option<usize> {
463    let mut best: Option<(usize, usize, usize)> = None;
464    for (idx, key) in available.iter().enumerate() {
465        if !key.is_compatible(desired) {
466            continue;
467        }
468        let (score, penalty) = compatibility_score(key, desired);
469        if best
470            .as_ref()
471            .map(|(s, p, _)| score > *s || (score == *s && penalty < *p))
472            .unwrap_or(true)
473        {
474            best = Some((score, penalty, idx));
475        }
476    }
477    best.map(|(_, _, idx)| idx)
478}
479
480fn compatibility_score(current: &ChannelKey, desired: &ChannelKey) -> (usize, usize) {
481    let shared_producers = intersection_len(current.producers(), desired.producers());
482    let shared_consumers = intersection_len(current.consumers(), desired.consumers());
483    let score = shared_producers + shared_consumers;
484
485    let missing_producers = desired.producers().len().saturating_sub(shared_producers);
486    let extra_producers = current.producers().len().saturating_sub(shared_producers);
487    let missing_consumers = desired.consumers().len().saturating_sub(shared_consumers);
488    let extra_consumers = current.consumers().len().saturating_sub(shared_consumers);
489    let penalty = missing_producers + extra_producers + missing_consumers + extra_consumers;
490
491    (score, penalty)
492}
493
494fn is_subset<T: Ord + Copy>(a: &[T], b: &[T]) -> bool {
495    let mut ia = 0;
496    let mut ib = 0;
497    while ia < a.len() && ib < b.len() {
498        match a[ia].cmp(&b[ib]) {
499            std::cmp::Ordering::Less => return false,
500            std::cmp::Ordering::Greater => ib += 1,
501            std::cmp::Ordering::Equal => {
502                ia += 1;
503                ib += 1;
504            }
505        }
506    }
507    ia == a.len()
508}
509
510fn intersection_len<T: Ord + Copy>(a: &[T], b: &[T]) -> usize {
511    let mut count = 0;
512    let mut ia = 0;
513    let mut ib = 0;
514    while ia < a.len() && ib < b.len() {
515        match a[ia].cmp(&b[ib]) {
516            std::cmp::Ordering::Less => ia += 1,
517            std::cmp::Ordering::Greater => ib += 1,
518            std::cmp::Ordering::Equal => {
519                count += 1;
520                ia += 1;
521                ib += 1;
522            }
523        }
524    }
525    count
526}