1use std::collections::{BTreeSet, HashMap};
4
5use selium_switchboard_protocol::{
6 Backpressure, Cardinality, EndpointDirections, EndpointId, SchemaId,
7};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum SwitchboardError {
13 #[error("invalid endpoint")]
15 InvalidEndpoint,
16 #[error("directions cannot be connected")]
18 DirectionMismatch,
19 #[error("graph cannot be solved")]
21 Unsolveable,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26pub struct Intent {
27 pub from: EndpointId,
29 pub to: EndpointId,
31}
32
33#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
35pub struct ChannelKey {
36 schema: SchemaId,
37 backpressure: Backpressure,
38 producers: Vec<EndpointId>,
39 consumers: Vec<EndpointId>,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
44pub struct ChannelSpec {
45 key: ChannelKey,
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct FlowRoute {
51 pub producer: EndpointId,
53 pub consumer: EndpointId,
55 pub channel: usize,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
61pub struct IntentRoute {
62 pub from: EndpointId,
64 pub to: EndpointId,
66 pub flows: Vec<FlowRoute>,
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
72pub struct Solution {
73 pub channels: Vec<ChannelSpec>,
75 pub routes: Vec<IntentRoute>,
77}
78
79pub trait Solver {
81 fn solve(
83 &self,
84 endpoints: &HashMap<EndpointId, EndpointDirections>,
85 intents: &[Intent],
86 ) -> Result<Solution, SwitchboardError>;
87}
88
89#[derive(Default)]
91pub struct DefaultSolver;
92
93pub struct SwitchboardCore<S: Solver = DefaultSolver> {
95 endpoints: HashMap<EndpointId, EndpointDirections>,
96 intents: Vec<Intent>,
97 solver: S,
98 next_id: EndpointId,
99}
100
101#[derive(Clone)]
102struct FlowSpec {
103 producer: EndpointId,
104 consumer: EndpointId,
105 schema: SchemaId,
106 backpressure: Backpressure,
107 producer_exclusive: bool,
108}
109
110#[derive(Clone)]
111struct ChannelGroup {
112 schema: SchemaId,
113 backpressure: Backpressure,
114 producers: BTreeSet<EndpointId>,
115 consumers: BTreeSet<EndpointId>,
116 exclusive_producers: Vec<EndpointId>,
117}
118
119type ChannelGroupKey = (SchemaId, Backpressure, Vec<EndpointId>, Vec<EndpointId>);
120type ChannelGroupMap = HashMap<ChannelGroupKey, ChannelGroup>;
121type MergeKey = (BTreeSet<EndpointId>, Vec<EndpointId>);
122type ChannelGroupsBySchema = HashMap<(SchemaId, Backpressure), HashMap<MergeKey, ChannelGroup>>;
123
124impl ChannelKey {
125 pub fn new(
127 schema: SchemaId,
128 backpressure: Backpressure,
129 producers: impl Iterator<Item = EndpointId>,
130 consumers: impl Iterator<Item = EndpointId>,
131 ) -> Self {
132 let mut producers: Vec<EndpointId> = producers.collect();
133 producers.sort();
134 producers.dedup();
135 let mut consumers: Vec<EndpointId> = consumers.collect();
136 consumers.sort();
137 consumers.dedup();
138 Self {
139 schema,
140 backpressure,
141 producers,
142 consumers,
143 }
144 }
145
146 pub fn contains(
148 &self,
149 schema: SchemaId,
150 backpressure: Backpressure,
151 producer: EndpointId,
152 consumer: EndpointId,
153 ) -> bool {
154 self.schema == schema
155 && self.backpressure == backpressure
156 && self.producers.binary_search(&producer).is_ok()
157 && self.consumers.binary_search(&consumer).is_ok()
158 }
159
160 pub fn is_compatible(&self, desired: &ChannelKey) -> bool {
162 if self.schema != desired.schema || self.backpressure != desired.backpressure {
163 return false;
164 }
165 (is_subset(&self.producers, desired.producers())
166 || is_subset(desired.producers(), &self.producers))
167 && (is_subset(&self.consumers, desired.consumers())
168 || is_subset(desired.consumers(), &self.consumers))
169 }
170
171 pub fn schema(&self) -> SchemaId {
173 self.schema
174 }
175
176 pub fn backpressure(&self) -> Backpressure {
178 self.backpressure
179 }
180
181 pub fn producers(&self) -> &[EndpointId] {
183 &self.producers
184 }
185
186 pub fn consumers(&self) -> &[EndpointId] {
188 &self.consumers
189 }
190}
191
192impl ChannelSpec {
193 pub fn new(
195 schema: SchemaId,
196 backpressure: Backpressure,
197 producers: impl Iterator<Item = EndpointId>,
198 consumers: impl Iterator<Item = EndpointId>,
199 ) -> Self {
200 Self {
201 key: ChannelKey::new(schema, backpressure, producers, consumers),
202 }
203 }
204
205 pub fn key(&self) -> &ChannelKey {
207 &self.key
208 }
209}
210
211impl DefaultSolver {
212 fn make_flow(
213 &self,
214 producer_id: EndpointId,
215 producer: &EndpointDirections,
216 consumer_id: EndpointId,
217 consumer: &EndpointDirections,
218 ) -> Result<FlowSpec, SwitchboardError> {
219 let output = producer.output();
220 let input = consumer.input();
221 if output.cardinality() == Cardinality::Zero || input.cardinality() == Cardinality::Zero {
222 return Err(SwitchboardError::DirectionMismatch);
223 }
224 if output.schema_id() != input.schema_id() {
225 return Err(SwitchboardError::DirectionMismatch);
226 }
227 Ok(FlowSpec {
228 producer: producer_id,
229 consumer: consumer_id,
230 schema: output.schema_id(),
231 backpressure: output.backpressure(),
232 producer_exclusive: output.exclusive(),
233 })
234 }
235}
236
237impl<S> SwitchboardCore<S>
238where
239 S: Solver,
240{
241 pub fn new(solver: S) -> Self {
243 Self {
244 endpoints: HashMap::new(),
245 intents: Vec::new(),
246 solver,
247 next_id: 1,
248 }
249 }
250
251 pub fn add_endpoint(&mut self, directions: EndpointDirections) -> EndpointId {
253 let id = self.next_id;
254 self.next_id = self.next_id.saturating_add(1);
255 self.endpoints.insert(id, directions);
256 id
257 }
258
259 pub fn remove_endpoint(&mut self, endpoint_id: EndpointId) {
261 self.endpoints.remove(&endpoint_id);
262 self.intents
263 .retain(|intent| intent.from != endpoint_id && intent.to != endpoint_id);
264 }
265
266 pub fn endpoints(&self) -> &HashMap<EndpointId, EndpointDirections> {
268 &self.endpoints
269 }
270
271 pub fn intents(&self) -> &[Intent] {
273 &self.intents
274 }
275
276 pub fn add_intent(&mut self, from: EndpointId, to: EndpointId) -> Result<(), SwitchboardError> {
278 if !self.endpoints.contains_key(&from) || !self.endpoints.contains_key(&to) {
279 return Err(SwitchboardError::InvalidEndpoint);
280 }
281 if self
282 .intents
283 .iter()
284 .any(|intent| intent.from == from && intent.to == to)
285 {
286 return Ok(());
287 }
288 self.intents.push(Intent { from, to });
289 Ok(())
290 }
291
292 pub fn remove_intent(&mut self, from: EndpointId, to: EndpointId) {
294 self.intents
295 .retain(|intent| intent.from != from || intent.to != to);
296 }
297
298 pub fn solve(&self) -> Result<Solution, SwitchboardError> {
300 self.solver.solve(&self.endpoints, &self.intents)
301 }
302}
303
304impl Default for SwitchboardCore {
305 fn default() -> Self {
306 Self::new(DefaultSolver)
307 }
308}
309
310impl Solver for DefaultSolver {
311 fn solve(
312 &self,
313 endpoints: &HashMap<EndpointId, EndpointDirections>,
314 intents: &[Intent],
315 ) -> Result<Solution, SwitchboardError> {
316 let mut flows: Vec<FlowSpec> = Vec::new();
317 let mut flows_by_intent: Vec<Vec<FlowSpec>> = Vec::with_capacity(intents.len());
318 for intent in intents {
319 let from_directions = endpoints
320 .get(&intent.from)
321 .ok_or(SwitchboardError::InvalidEndpoint)?;
322 let to_directions = endpoints
323 .get(&intent.to)
324 .ok_or(SwitchboardError::InvalidEndpoint)?;
325
326 let flow = self.make_flow(intent.from, from_directions, intent.to, to_directions)?;
327 flows_by_intent.push(vec![flow.clone()]);
328 flows.push(flow);
329 }
330
331 let mut consumer_map: HashMap<
332 (EndpointId, SchemaId, Backpressure, Option<EndpointId>),
333 BTreeSet<EndpointId>,
334 > = HashMap::new();
335 for flow in flows {
336 let exclusive_key = if flow.producer_exclusive {
337 Some(flow.producer)
338 } else {
339 None
340 };
341 consumer_map
342 .entry((flow.consumer, flow.schema, flow.backpressure, exclusive_key))
343 .or_default()
344 .insert(flow.producer);
345 }
346
347 let mut channel_groups: ChannelGroupMap = HashMap::new();
348 for ((consumer, schema, backpressure, exclusive_producer), producers) in
349 consumer_map.into_iter()
350 {
351 if producers.is_empty() {
352 continue;
353 }
354 let producers_vec: Vec<EndpointId> = producers.iter().copied().collect();
355 let mut exclusive_producers = Vec::new();
356 if let Some(exclusive) = exclusive_producer {
357 exclusive_producers.push(exclusive);
358 }
359 let key = (
360 schema,
361 backpressure,
362 producers_vec.clone(),
363 exclusive_producers.clone(),
364 );
365 let group = channel_groups.entry(key).or_insert_with(|| ChannelGroup {
366 schema,
367 backpressure,
368 producers: producers.clone(),
369 consumers: BTreeSet::new(),
370 exclusive_producers,
371 });
372 group.consumers.insert(consumer);
373 }
374
375 let mut merged: Vec<ChannelGroup> = Vec::new();
376 let mut by_schema: ChannelGroupsBySchema = HashMap::new();
377 for group in channel_groups.values() {
378 let schema_entry = by_schema
379 .entry((group.schema, group.backpressure))
380 .or_default();
381 let consumer_key = group.consumers.clone();
382 let merge_key: MergeKey = (consumer_key, group.exclusive_producers.clone());
383 let entry = schema_entry
384 .entry(merge_key)
385 .or_insert_with(|| group.clone());
386 entry.producers.extend(group.producers.iter().copied());
387 }
388 for schema_entry in by_schema.values() {
389 merged.extend(schema_entry.values().cloned());
390 }
391
392 let mut input_counts: HashMap<EndpointId, usize> = HashMap::new();
393 let mut output_counts: HashMap<EndpointId, usize> = HashMap::new();
394 for group in &merged {
395 for consumer in &group.consumers {
396 *input_counts.entry(*consumer).or_insert(0) += 1;
397 }
398 for producer in &group.producers {
399 *output_counts.entry(*producer).or_insert(0) += 1;
400 }
401 }
402
403 for (id, endpoint) in endpoints {
404 let inputs = *input_counts.get(id).unwrap_or(&0);
405 let outputs = *output_counts.get(id).unwrap_or(&0);
406 if !endpoint.input().cardinality().allows(inputs)
407 || !endpoint.output().cardinality().allows(outputs)
408 {
409 return Err(SwitchboardError::Unsolveable);
410 }
411 }
412
413 let mut channel_specs: Vec<ChannelSpec> = merged
414 .into_iter()
415 .map(|group| {
416 ChannelSpec::new(
417 group.schema,
418 group.backpressure,
419 group.producers.iter().copied(),
420 group.consumers.iter().copied(),
421 )
422 })
423 .collect();
424 channel_specs.sort_by(|a, b| a.key().cmp(b.key()));
425
426 let mut routes: Vec<IntentRoute> = Vec::with_capacity(intents.len());
427 for (idx, intent) in intents.iter().enumerate() {
428 let mut flow_routes = Vec::new();
429 for flow in &flows_by_intent[idx] {
430 let channel_index = channel_specs
431 .iter()
432 .position(|spec| {
433 spec.key().contains(
434 flow.schema,
435 flow.backpressure,
436 flow.producer,
437 flow.consumer,
438 )
439 })
440 .ok_or(SwitchboardError::Unsolveable)?;
441 flow_routes.push(FlowRoute {
442 producer: flow.producer,
443 consumer: flow.consumer,
444 channel: channel_index,
445 });
446 }
447 routes.push(IntentRoute {
448 from: intent.from,
449 to: intent.to,
450 flows: flow_routes,
451 });
452 }
453
454 Ok(Solution {
455 channels: channel_specs,
456 routes,
457 })
458 }
459}
460
461pub fn best_compatible_match(available: &[ChannelKey], desired: &ChannelKey) -> Option<usize> {
463 let mut best: Option<(usize, usize, usize)> = None;
464 for (idx, key) in available.iter().enumerate() {
465 if !key.is_compatible(desired) {
466 continue;
467 }
468 let (score, penalty) = compatibility_score(key, desired);
469 if best
470 .as_ref()
471 .map(|(s, p, _)| score > *s || (score == *s && penalty < *p))
472 .unwrap_or(true)
473 {
474 best = Some((score, penalty, idx));
475 }
476 }
477 best.map(|(_, _, idx)| idx)
478}
479
480fn compatibility_score(current: &ChannelKey, desired: &ChannelKey) -> (usize, usize) {
481 let shared_producers = intersection_len(current.producers(), desired.producers());
482 let shared_consumers = intersection_len(current.consumers(), desired.consumers());
483 let score = shared_producers + shared_consumers;
484
485 let missing_producers = desired.producers().len().saturating_sub(shared_producers);
486 let extra_producers = current.producers().len().saturating_sub(shared_producers);
487 let missing_consumers = desired.consumers().len().saturating_sub(shared_consumers);
488 let extra_consumers = current.consumers().len().saturating_sub(shared_consumers);
489 let penalty = missing_producers + extra_producers + missing_consumers + extra_consumers;
490
491 (score, penalty)
492}
493
494fn is_subset<T: Ord + Copy>(a: &[T], b: &[T]) -> bool {
495 let mut ia = 0;
496 let mut ib = 0;
497 while ia < a.len() && ib < b.len() {
498 match a[ia].cmp(&b[ib]) {
499 std::cmp::Ordering::Less => return false,
500 std::cmp::Ordering::Greater => ib += 1,
501 std::cmp::Ordering::Equal => {
502 ia += 1;
503 ib += 1;
504 }
505 }
506 }
507 ia == a.len()
508}
509
510fn intersection_len<T: Ord + Copy>(a: &[T], b: &[T]) -> usize {
511 let mut count = 0;
512 let mut ia = 0;
513 let mut ib = 0;
514 while ia < a.len() && ib < b.len() {
515 match a[ia].cmp(&b[ib]) {
516 std::cmp::Ordering::Less => ia += 1,
517 std::cmp::Ordering::Greater => ib += 1,
518 std::cmp::Ordering::Equal => {
519 count += 1;
520 ia += 1;
521 ib += 1;
522 }
523 }
524 }
525 count
526}