tui_nodes/
connection.rs

1use ratatui::{
2	buffer::Buffer,
3	layout::{Position, Rect},
4	style::Style,
5	symbols::line,
6	widgets::BorderType,
7};
8use std::collections::BTreeMap as Map;
9
10const SEARCH_TIMEOUT: usize = 5000;
11
12#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum LineType {
14	#[default]
15	Plain,
16	Rounded,
17	Double,
18	Thick,
19}
20
21impl LineType {
22	fn to_line_set(&self) -> line::Set {
23		match self {
24			LineType::Plain => line::NORMAL,
25			LineType::Rounded => line::ROUNDED,
26			LineType::Double => line::DOUBLE,
27			LineType::Thick => line::THICK,
28		}
29	}
30}
31
32impl From<BorderType> for LineType {
33	fn from(value: BorderType) -> Self {
34		match value {
35			BorderType::Plain => LineType::Plain,
36			BorderType::Rounded => LineType::Rounded,
37			BorderType::Double => LineType::Double,
38			BorderType::Thick => LineType::Thick,
39			_ => unimplemented!(),
40		}
41	}
42}
43
44#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
45pub enum Direction {
46	North = 0,
47	South = 1,
48	East = 2,
49	West = 3,
50}
51
52impl Direction {
53	fn is_vertical(&self) -> bool {
54		(*self as usize) < 2
55	}
56	/*
57	fn invert(self) -> Self {
58		use Direction as D;
59		match self {
60			D::North => D::South,
61			D::South => D::North,
62			D::East => D::West,
63			D::West => D::East,
64		}
65	}
66	fn rotate(self) -> Self {
67		use Direction as D;
68		match self {
69			D::North => D::East,
70			D::East => D::South,
71			D::South => D::West,
72			D::West => D::North,
73		}
74	}
75	*/
76}
77
78impl std::fmt::Debug for Direction {
79	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80		let print = match self {
81			Direction::North => '↑',
82			Direction::South => '↓',
83			Direction::East => '→',
84			Direction::West => '←',
85		};
86		write!(f, "{}", print)
87	}
88}
89
90#[derive(Debug, Clone, Copy)]
91pub struct Connection {
92	pub from_node: usize,
93	pub from_port: usize,
94	pub to_node: usize,
95	pub to_port: usize,
96	line_type: LineType,
97	line_style: Style,
98}
99
100impl Connection {
101	pub fn new(
102		from_node: usize,
103		from_port: usize,
104		to_node: usize,
105		to_port: usize,
106	) -> Self {
107		Self {
108			from_node,
109			from_port,
110			to_node,
111			to_port,
112			line_type: LineType::Rounded,
113			line_style: Style::default(),
114		}
115	}
116
117	pub fn with_line_type(mut self, line_type: LineType) -> Self {
118		self.line_type = line_type;
119		self
120	}
121
122	pub fn line_type(&self) -> LineType {
123		self.line_type
124	}
125
126	pub fn with_line_style(mut self, line_style: Style) -> Self {
127		self.line_style = line_style;
128		self
129	}
130
131	pub fn line_style(&self) -> Style {
132		self.line_style
133	}
134}
135
136/// Generate the correct connection symbol for this node
137pub fn conn_symbol(
138	is_input: bool,
139	block_style: BorderType,
140	conn_style: LineType,
141) -> &'static str {
142	let out = match (block_style, conn_style) {
143		(BorderType::Plain | BorderType::Rounded, LineType::Thick) => ("┥", "┝"),
144		(BorderType::Plain | BorderType::Rounded, LineType::Double) => ("╡", "╞"),
145		(
146			BorderType::Plain | BorderType::Rounded,
147			LineType::Plain | LineType::Rounded,
148		) => ("┤", "├"),
149
150		(BorderType::Thick, LineType::Thick) => ("┫", "┣"),
151		(BorderType::Thick, LineType::Double) => ("╡", "╞"), // fallback
152		(BorderType::Thick, LineType::Plain | LineType::Rounded) => ("┨", "┠"),
153
154		(BorderType::Double, LineType::Thick) => ("╢", "╟"), // fallback
155		(BorderType::Double, LineType::Double) => ("╣", "╠"),
156		(BorderType::Double, LineType::Plain | LineType::Rounded) => ("╢", "╟"),
157		(BorderType::QuadrantInside | BorderType::QuadrantOutside, _) => ("u", "u"),
158	};
159	if is_input {
160		out.0
161	} else {
162		out.1
163	}
164}
165
166pub const ALIAS_CHARS: [&str; 24] = [
167	"α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ", "ν", "ξ", "ο", "π", "ρ",
168	"σ", "τ", "υ", "φ", "χ", "ψ", "ω",
169];
170
171#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
172pub enum Edge {
173	#[default]
174	Empty,
175	Blocked,
176	Connection(usize),
177}
178const E: Edge = Edge::Empty;
179const B: Edge = Edge::Blocked;
180
181#[derive(Debug)]
182pub struct ConnectionsLayout {
183	ports: Map<(bool, usize, usize), (usize, usize)>, // (x,y)
184	connections: Vec<(Connection, usize)>,            // ((from, to), class)
185	edge_field: Betweens<Edge>,
186	width: usize,
187	height: usize,
188	pub alias_connections: Map<(bool, usize, usize), &'static str>,
189	line_types: Map<usize, LineType>,
190	line_styles: Map<usize, Style>,
191}
192
193impl ConnectionsLayout {
194	pub fn new(width: usize, height: usize) -> Self {
195		Self {
196			ports: Map::new(),
197			connections: Vec::new(),
198			edge_field: Betweens::new(width, height),
199			width,
200			height,
201			alias_connections: Map::new(),
202			line_types: Map::new(),
203			line_styles: Map::new(),
204		}
205	}
206
207	pub fn push_connection(&mut self, connection: (Connection, usize)) {
208		self.connections.push(connection)
209	}
210
211	pub fn insert_port(
212		&mut self,
213		is_input: bool,
214		node: usize,
215		port: usize,
216		pos: (usize, usize),
217	) {
218		self.ports.insert((is_input, node, port), pos);
219	}
220
221	pub fn block_zone(&mut self, area: Rect) {
222		for x in 0..area.width {
223			for y in 0..area.height {
224				if x != area.width - 1 {
225					self.edge_field[(
226						((x + area.x) as usize, (y + area.y) as usize),
227						Direction::East,
228					)
229						.into()] = Edge::Blocked;
230				}
231				if y != area.height - 1 {
232					self.edge_field[(
233						((x + area.x) as usize, (y + area.y) as usize),
234						Direction::South,
235					)
236						.into()] = Edge::Blocked;
237				}
238			}
239		}
240	}
241
242	pub fn block_port(&mut self, coord: (usize, usize)) {
243		self.edge_field[(coord, Direction::North).into()] = Edge::Blocked;
244		self.edge_field[(coord, Direction::South).into()] = Edge::Blocked;
245	}
246
247	pub fn calculate(&mut self) {
248		let mut idx_next_alias = 0;
249		'outer: for ea_conn in &self.connections {
250			self.line_types.insert(ea_conn.1, ea_conn.0.line_type());
251			self.line_styles.insert(ea_conn.1, ea_conn.0.line_style());
252			let start = (
253				self.ports[&(false, ea_conn.0.from_node, ea_conn.0.from_port)],
254				Direction::West,
255			);
256			let goal = (
257				self.ports[&(true, ea_conn.0.to_node, ea_conn.0.to_port)],
258				Direction::East,
259			);
260			if start.0 .0 > self.edge_field.width || start.0 .1 > self.edge_field.height {
261				continue;
262			}
263			if goal.0 .0 > self.edge_field.width || goal.0 .1 > self.edge_field.height {
264				continue;
265			}
266			//println!("drawing connection {start:?} to {goal:?}");
267			let mut frontier = sorted_vec::SortedVec::new();
268			let mut came_from = Betweens::<Option<_>>::new(self.width, self.height);
269			let mut cost = Betweens::<isize>::new(self.width, self.height);
270			frontier.push(((0, 0), start));
271			let mut count = 0;
272			while let Some((_, current)) = frontier.pop() {
273				count += 1;
274				if count > SEARCH_TIMEOUT {
275					break;
276				}
277				if current == goal {
278					break;
279				}
280				for ea_nei in neighbors(current.0, self.width, self.height) {
281					let ea_edge = ea_nei.into();
282					let current_cost = cost[current.into()];
283					//println!("{current_cost}");
284					let new_cost = current_cost.saturating_add(
285						self.calc_cost(current, ea_nei, start.0, goal.0, ea_conn.1),
286					);
287					if came_from[ea_edge].is_none() || new_cost < cost[ea_edge] {
288						let prio = (-new_cost, -Self::heuristic(ea_nei.0, goal.0));
289						if new_cost != isize::MAX {
290							frontier.push((prio, ea_nei));
291						}
292						came_from[ea_edge] = Some(current);
293						cost[ea_edge] = new_cost;
294					}
295				}
296				/*
297				print!("\x1b[2J\x1b[1;1H");
298				println!("{frontier:?}");
299				let mut prio = Betweens::new(self.width, self.height);
300				for ea_front in frontier.iter() {
301					prio[ea_front.1.into()] = ea_front.0;
302				}
303				println!("prio\n");
304				prio.print_with(4, |ea| print!("{:>4} ", ea.0));
305				prio.print_with(4, |ea| print!("{:>4} ", ea.1));
306				println!("cost\n");
307				cost.print_with(4, |ea| print!("{:>4} ", ea));
308				println!("from\n");
309				came_from.print_with(1, |ea| {
310					if let Some(inner) = ea {
311						print!("{:?} ", inner.1);
312					}
313					else {
314						print!("_ ");
315					}
316				});
317				std::io::stdin().read_line(&mut String::new()).unwrap();
318				*/
319			}
320			// first pass: mark connections that didnt reach the goal
321			let mut next = goal;
322			loop {
323				if next == start {
324					break;
325				}
326				if let Some(from) = came_from[next.into()] {
327					next = from;
328				} else {
329					// register alias character
330					if !self.alias_connections.contains_key(&(
331						false,
332						ea_conn.0.from_node,
333						ea_conn.0.from_port,
334					)) {
335						self.alias_connections.insert(
336							(false, ea_conn.0.from_node, ea_conn.0.from_port),
337							ALIAS_CHARS[idx_next_alias],
338						);
339						idx_next_alias += 1;
340					}
341					let alias = self.alias_connections
342						[&(false, ea_conn.0.from_node, ea_conn.0.from_port)];
343					self.alias_connections
344						.insert((true, ea_conn.0.to_node, ea_conn.0.to_port), alias);
345					continue 'outer;
346				}
347			}
348
349			// second pass: draw edges
350			let mut next = goal;
351			loop {
352				if next == start {
353					break;
354				}
355				self.edge_field[next.into()] = Edge::Connection(ea_conn.1);
356				next = came_from[next.into()].unwrap();
357			}
358		}
359	}
360
361	pub fn render(&self, area: Rect, buf: &mut Buffer) {
362		let bor = |idx: Edge| -> line::Set {
363			if let Edge::Connection(idx) = idx {
364				self.line_types[&idx].to_line_set()
365			} else if idx == Edge::Blocked {
366				line::THICK
367			} else {
368				line::Set {
369					vertical: " ",
370					horizontal: " ",
371					top_right: " ",
372					top_left: " ",
373					bottom_right: " ",
374					bottom_left: " ",
375					vertical_left: " ",
376					vertical_right: " ",
377					horizontal_down: " ",
378					horizontal_up: " ",
379					cross: " ",
380				}
381			}
382		};
383
384		let get_line_style = |idx: Edge| -> Style {
385			if let Edge::Connection(idx) = idx {
386				self.line_styles[&idx]
387			} else {
388				Style::default()
389			}
390		};
391		for y in 0..self.height {
392			for x in 0..self.width {
393				let pos = (x, y);
394				let north = self.edge_field[(pos, Direction::North).into()];
395				let south = self.edge_field[(pos, Direction::South).into()];
396				let east = self.edge_field[(pos, Direction::East).into()];
397				let west = self.edge_field[(pos, Direction::West).into()];
398				#[rustfmt::skip]
399				let (symbol, line_style) = match (north, south, east, west) {
400					(B | E, B | E, B | E, B | E) => continue,
401					(n, s, e, w) if n == B || s == B || e == B || w == B => {
402						if n == B && s == B && e != E || w != E && e == w {
403							(bor(e).horizontal, get_line_style(e))
404						} else if e == B && w == B && n != E && s != E && n == s {
405							(bor(n).vertical, get_line_style(n))
406						} else {
407							("*", Style::default())
408						}
409					}
410					(n, E, E, E) => (bor(n).vertical, get_line_style(n)),
411					(E, s, E, E) => (bor(s).vertical, get_line_style(s)),
412					(E, E, e, E) => (bor(e).horizontal, get_line_style(e)),
413					(E, E, E, w) => (bor(w).horizontal, get_line_style(w)),
414
415					(n, s, E, w) if n == s && n == w => (bor(n).vertical_left, get_line_style(n)),
416					(n, E, e, w) if n == e && n == w => (bor(n).horizontal_up, get_line_style(n)),
417					(n, s, e, E) if n == s && n == e => (bor(n).vertical_right, get_line_style(n)),
418					(E, s, e, w) if s == e && s == w => (bor(s).horizontal_down, get_line_style(s)),
419					(E, s, E, w) if s == w => (bor(s).top_right, get_line_style(s)),
420					(n, E, E, w) if n == w => (bor(n).bottom_right, get_line_style(n)),
421					(n, E, e, E) if n == e => (bor(n).bottom_left, get_line_style(n)),
422					(E, s, e, E) if s == e => (bor(s).top_left, get_line_style(s)),
423
424					(n, s, E, E) if n == s => (bor(n).vertical, get_line_style(n)),
425					(E, E, e, w) if e == w => (bor(e).horizontal, get_line_style(e)),
426
427					(n, s, e, w) if n == s && n == e && n == w => (bor(n).cross, get_line_style(n)),
428					// intersections should just be verticals
429					(n, s, e, w) if n == s && e == w && n != E && e != E => (bor(n).vertical, get_line_style(n)),
430					(_, _, _, _) => ("?", Style::default()),
431				};
432
433				buf.cell_mut(Position::new(
434					x as u16 + area.left(),
435					y as u16 + area.top(),
436				))
437				.unwrap()
438				.set_symbol(symbol)
439				.set_style(line_style);
440			}
441		}
442	}
443
444	fn heuristic(from: (usize, usize), to: (usize, usize)) -> isize {
445		(from.0 as isize - to.0 as isize).pow(2)
446			+ (from.1 as isize - to.1 as isize).pow(2)
447	}
448
449	fn calc_cost(
450		&self,
451		current: ((usize, usize), Direction),
452		neigh: ((usize, usize), Direction),
453		start: (usize, usize),
454		end: (usize, usize),
455		conn_t: usize,
456	) -> isize {
457		let conn_t = Edge::Connection(conn_t);
458		let north = self.edge_field[(current.0, Direction::North).into()];
459		let south = self.edge_field[(current.0, Direction::South).into()];
460		let east = self.edge_field[(current.0, Direction::East).into()];
461		let west = self.edge_field[(current.0, Direction::West).into()];
462
463		let in_dir = self.edge_field[current.into()];
464		// TODO: fix
465		if !(in_dir == Edge::Empty || in_dir == conn_t) {
466			return isize::MAX;
467		}
468		//	assert!(in_dir == 0 || in_dir == conn_t); // should only calculate cost if its possible
469		let out_dir = self.edge_field[neigh.into()];
470		if out_dir == conn_t {
471			// already exists
472			1
473		} else if out_dir == Edge::Empty {
474			if north == conn_t || south == conn_t || east == conn_t || west == conn_t {
475				// intersecting with an existing connection
476				2 // maybe multiply with distances?
477			} else {
478				let in_is_vert = current.1.is_vertical();
479				let out_is_vert = neigh.1.is_vertical();
480				let straight = in_is_vert == out_is_vert;
481				if straight {
482					if north == Edge::Empty
483						&& south == Edge::Empty && east == Edge::Empty
484						&& west == Edge::Empty
485					{
486						2
487					} else {
488						4
489					}
490				} else {
491					// curved
492					if north != Edge::Empty
493						|| south != Edge::Empty || east != Edge::Empty
494						|| west != Edge::Empty
495					{
496						isize::MAX
497					} else {
498						let ax = current.0 .0 as isize;
499						let ay = current.0 .1 as isize;
500						let sx = start.0 as isize;
501						let sy = start.1 as isize;
502						let ex = end.0 as isize;
503						let ey = end.1 as isize;
504						4 + ((ax - sx).pow(2)
505							+ (ay - sy).pow(2) + (ax - ex).pow(2)
506							+ (ay - ey).pow(2))
507					}
508				}
509			}
510		} else {
511			isize::MAX
512		}
513	}
514}
515
516fn neighbors(
517	pos: (usize, usize),
518	width: usize,
519	height: usize,
520) -> Vec<((usize, usize), Direction)> {
521	let mut out = Vec::new();
522	if pos.0 < width - 1 {
523		out.push(((pos.0 + 1, pos.1), Direction::West));
524	}
525	if pos.1 < height - 1 {
526		out.push(((pos.0, pos.1 + 1), Direction::North));
527	}
528	if pos.0 > 0 {
529		out.push(((pos.0 - 1, pos.1), Direction::East));
530	}
531	if pos.1 > 0 {
532		out.push(((pos.0, pos.1 - 1), Direction::South));
533	}
534	out
535}
536
537use core::ops::{Index, IndexMut};
538
539#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
540struct EdgeIdx {
541	x: usize,
542	y: usize,
543	is_vertical: bool,
544}
545/*
546impl EdgeIdx {
547	fn pos(self) -> (usize, usize) {
548		(self.0, self.1)
549	}
550}
551*/
552impl From<((usize, usize), Direction)> for EdgeIdx {
553	fn from(value: ((usize, usize), Direction)) -> Self {
554		match value.1 {
555			Direction::North => Self { x: value.0 .0, y: value.0 .1, is_vertical: true },
556			Direction::South => Self {
557				x: value.0 .0,
558				y: value.0 .1 + 1,
559				is_vertical: true,
560			},
561			Direction::East => Self {
562				x: value.0 .0 + 1,
563				y: value.0 .1,
564				is_vertical: false,
565			},
566			Direction::West => Self { x: value.0 .0, y: value.0 .1, is_vertical: false },
567		}
568	}
569}
570
571// the outermost values are unnecessary
572#[derive(Debug)]
573struct Betweens<T: Default> {
574	horizontal: Vec<Vec<T>>,
575	vertical: Vec<Vec<T>>,
576	width: usize,
577	height: usize,
578}
579impl<T: Default> Index<EdgeIdx> for Betweens<T> {
580	type Output = T;
581	fn index(&self, index: EdgeIdx) -> &Self::Output {
582		if index.is_vertical {
583			&self.vertical[index.y][index.x]
584		} else {
585			&self.horizontal[index.y][index.x]
586		}
587	}
588}
589impl<T: Default> IndexMut<EdgeIdx> for Betweens<T> {
590	fn index_mut(&mut self, index: EdgeIdx) -> &mut T {
591		if index.is_vertical {
592			&mut self.vertical[index.y][index.x]
593		} else {
594			&mut self.horizontal[index.y][index.x]
595		}
596	}
597}
598
599impl<T: Default> Betweens<T> {
600	fn new(x: usize, y: usize) -> Self {
601		let mut out = Self {
602			horizontal: Vec::new(),
603			vertical: Vec::new(),
604			width: 0,
605			height: 0,
606		};
607		out.set_size(x, y);
608		out
609	}
610
611	fn set_size(&mut self, x: usize, y: usize) {
612		self.horizontal.resize_with(y, || {
613			let mut inner = Vec::new();
614			inner.resize_with(x + 1, Default::default);
615			inner
616		});
617		self.vertical.resize_with(y + 1, || {
618			let mut inner = Vec::new();
619			inner.resize_with(x, Default::default);
620			inner
621		});
622		self.width = x;
623		self.height = y;
624	}
625
626	#[allow(unused)]
627	fn print_with(&self, width: usize, f: impl Fn(&T)) {
628		for y in 0..(self.height + 1) {
629			for x in 0..self.width {
630				print!("{} ", "-".repeat(width));
631				f(&self.vertical[y][x]);
632			}
633			println!("{}", "-".repeat(width));
634			if y < self.height {
635				for x in 0..(self.width + 1) {
636					f(&self.horizontal[y][x]);
637					if x < self.width {
638						print!("{} ", "-".repeat(width));
639					}
640				}
641			}
642			println!();
643		}
644	}
645}