steiner_tree/
gen_lut.rs

1// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5//! Generate the lookup-table.
6
7//use rayon::prelude::*;
8use std::sync;
9use std::time;
10
11use super::lut::*;
12use super::marker_types::Canonical;
13use super::permutations;
14use super::pins::*;
15use super::point::*;
16use super::position_sequence::*;
17use super::rectangle::*;
18use super::tree::*;
19
20use super::compaction_expansion::*;
21use super::hanan_grid::{Boundary, UnitHananGrid};
22use super::wirelength_vector::*;
23
24use super::HananCoord;
25use super::MAX_DEGREE;
26use itertools::Itertools;
27use smallvec::SmallVec;
28use std::cmp::Ordering;
29use std::collections::HashMap;
30
31/// Generate the lookup-table for all nets with number of pins up to `max_degree`.
32pub fn gen_full_lut(max_degree: usize) -> LookupTable {
33    let mut cache = Cache::new();
34
35    let sub_lut_by_degree = (2..=max_degree)
36        .map(|num_pins| gen_sub_lut(num_pins, &mut cache))
37        .collect();
38
39    LookupTable { sub_lut_by_degree }
40}
41
42#[test]
43fn test_gen_full_lut() {
44    let max_degree = 5;
45    let lut = gen_full_lut(max_degree);
46    assert_eq!(lut.sub_lut_by_degree.len(), max_degree - 1);
47}
48
49/// Generate the sub lookup-table for all nets with `num_pins` pins.
50fn gen_sub_lut(num_pins: usize, cache: &mut Cache) -> SubLookupTable {
51    let pos_sequences = deduplicated_position_sequences(num_pins);
52
53    assert!(num_pins > 0);
54
55    // let mutex_cache = sync::Arc::new(sync::Mutex::new(cache));
56    let num_entries = pos_sequences.len();
57
58    let lut_entries: Vec<ValuesOrRedirect> = pos_sequences
59        // .into_par_iter()
60        .into_iter()
61        .enumerate()
62        // .map_with(cache, |mut cache, (i, position_sequence)| {
63        .map(|(i, position_sequence)| {
64            // if i % 1000 == 0 || i+1 == num_entries {
65            //     println!("progress: {:.1}%", ((100 * (i + 1)) as f64) / (num_entries as f64));
66            // }
67
68            match position_sequence {
69                PositionSequenceOrRedirect::PositionSequence(position_sequence) => {
70                    // Compute the potential optimal wirelength vectors and trees for this position sequence.
71                    let pins = Pins::from_vec(position_sequence.to_points().collect())
72                        .into_canonical_form();
73                    let grid = UnitHananGrid::new(
74                        pins.bounding_box()
75                            .unwrap_or(Rect::new((0, 0).into(), (0, 0).into())),
76                    );
77                    let compaction_stage = CompactionStage::new(grid, pins);
78
79                    let trees = gen_lut(compaction_stage, cache);
80                    let values = trees
81                        .into_iter()
82                        .map(|e| {
83                            let grid = e.grid();
84                            let mut wv = WirelenghtVector::zero(
85                                grid.upper_right().x as usize,
86                                grid.upper_right().y as usize,
87                            );
88                            e.tree().compute_wirelength_vector(&mut wv);
89                            (wv, e.into_tree())
90                        })
91                        .map(|(wv, tree)| LutValue {
92                            potential_optimal_wirelength_vector: wv,
93                            potential_optimal_steiner_tree: tree.into_canonical_form(),
94                        })
95                        .collect();
96                    ValuesOrRedirect::Values(values)
97                }
98                PositionSequenceOrRedirect::Redirect {
99                    group_index,
100                    transform,
101                } => ValuesOrRedirect::Redirect {
102                    group_index,
103                    transform,
104                },
105            }
106        })
107        .collect();
108
109    SubLookupTable {
110        values: lut_entries,
111    }
112}
113
114#[test]
115fn test_gen_sub_lut() {
116    let sub_lut = gen_sub_lut(6, &mut Cache::new());
117    assert_eq!(sub_lut.values.len(), 1 * 2 * 3 * 4 * 5 * 6);
118}
119
120/// Get a list of all position sequences for the given number of pins.
121/// The position sequences are indexed by their 'group index' (which is the rank of the permutation).
122/// The position sequences are de-duplicated.
123/// Position sequences which are equivalent under rotation (by 90 degrees) and mirroring along an axis form an equivalence class.
124/// Only one position sequence is stored per equivalence class. All other members of the class are stored as
125/// a link to the first one together with the necessary geometrical transformation.
126fn deduplicated_position_sequences(num_pins: usize) -> Vec<PositionSequenceOrRedirect> {
127    let num_entries = permutations::factorial(num_pins);
128
129    let mut position_sequences: Vec<_> = (0..num_entries).map(|_| None).collect();
130
131    for positions_sequence in permutations::all_position_sequences(num_pins) {
132        let group_index = positions_sequence.group_index();
133
134        if position_sequences[group_index].is_none() {
135            let points = Pins::from_vec(positions_sequence.to_points().collect());
136
137            // Store the position sequence.
138            assert!(position_sequences[group_index].is_none());
139            position_sequences[group_index] = Some(PositionSequenceOrRedirect::PositionSequence(
140                positions_sequence,
141            ));
142
143            // Create links from equivalent groups (under rotation and mirroring) to this group.
144            {
145                let mut points = points;
146                for mirror in [false, true] {
147                    debug_assert_eq!(group_index, points.position_sequence().group_index());
148                    if mirror {
149                        points = points.mirror_at_y_axis()
150                    }
151                    for rot in 1..5 {
152                        if mirror && rot == 4 {
153                            // Skip last transformation. This is equal to the identity again.
154                            break;
155                        }
156                        points = points.rotate_90ccw();
157
158                        let redirected_group = points.position_sequence().group_index();
159
160                        let transform = Transform::new(rot, mirror).inverse();
161
162                        let redirect = PositionSequenceOrRedirect::Redirect {
163                            group_index,
164                            transform,
165                        };
166
167                        // Store the redirection.
168                        if position_sequences[redirected_group].is_none() {
169                            position_sequences[redirected_group] = Some(redirect)
170                        }
171                    }
172                }
173            }
174        }
175    }
176
177    // Convert to output format.
178    position_sequences
179        .into_iter()
180        .map(|p| p.expect("A group index is not present."))
181        .collect()
182}
183
184#[test]
185fn test_deduplicated_position_sequences() {
186    for n in 0..6 + 1 {
187        let n_factorial = (1..n + 1).product();
188        let dedup_sequences = deduplicated_position_sequences(n);
189
190        assert_eq!(dedup_sequences.len(), n_factorial);
191
192        let num_non_redirects = dedup_sequences.iter().filter(|s| !s.is_redirect()).count();
193
194        assert!(num_non_redirects > 0);
195
196        let non_redirect_ratio = (n_factorial as f64) / (num_non_redirects as f64);
197        assert!(non_redirect_ratio <= 4.0); // The number of deduplicated entries is indeed in the order of n!/4.
198        assert!(non_redirect_ratio >= 1.0);
199    }
200}
201
202#[derive(Debug, Clone, Hash, Eq, PartialEq)]
203enum PositionSequenceOrRedirect {
204    PositionSequence(PositionSequence),
205    Redirect {
206        /// Target group address.
207        group_index: usize,
208        /// Transformation that needs to be applied to get from the current group to the redirect.
209        transform: Transform,
210    },
211}
212
213impl PositionSequenceOrRedirect {
214    fn is_redirect(&self) -> bool {
215        match self {
216            PositionSequenceOrRedirect::PositionSequence(_) => false,
217            _ => true,
218        }
219    }
220}
221
222/// Generate the lookup-table entries for the given grid and set of pins.
223///
224/// Corresponds to algorithm 'Gen-LUT(G)' in the paper.
225///
226fn gen_lut(compaction_stage: CompactionStage, cache: &mut Cache) -> Vec<ExpansionStage> {
227    let grid = compaction_stage.grid();
228
229    // Use the cache only for smaller trees. Starting from 6x6 downwards seems to be a good choice.
230    // Caching only smaller trees reduces memory usage.
231    let use_cache = grid.rect().width() <= 6 && grid.rect().height() <= 6;
232
233    let cache_key = if use_cache {
234        let mut normalized_pins = compaction_stage.current_pins().clone();
235        normalized_pins.dedup();
236        Some(normalized_pins.move_to_origin())
237    } else {
238        None
239    };
240
241    // Check cache.
242    if let Some(cache_key) = &cache_key {
243        let cache_result: Option<Vec<ExpansionStage>> = cache.with(cache_key, |r| {
244            r.map(|trees| {
245                // Cache hit.
246                let lower_left = grid.lower_left();
247
248                trees
249                    .into_iter()
250                    .map(|tree| {
251                        let mut tree = tree.clone();
252                        tree.translate((lower_left.x, lower_left.y));
253                        tree
254                    })
255                    .map(|tree| ExpansionStage::new(compaction_stage.clone(), tree))
256                    .collect()
257            })
258        });
259
260        if let Some(mut cache_result) = cache_result {
261            let unpruned_length = cache_result.len();
262            debug_assert_eq!(
263                {
264                    prune_inplace(&mut cache_result);
265                    cache_result.len()
266                },
267                unpruned_length,
268                "cache content must be pruned already"
269            );
270
271            return cache_result;
272        }
273    }
274
275    let trees: Vec<ExpansionStage> =
276        if let Some(simple_trees) = create_simple_trees(grid, compaction_stage.current_pins()) {
277            // Found simple solutions.
278            // Convert them to LutValues.
279
280            simple_trees
281                .into_iter()
282                .map(|tree| ExpansionStage::new(compaction_stage.clone(), tree))
283                .collect()
284        } else {
285            // Compact boundaries and call recursively.
286            // Rule 1.
287            if let Some(boundary) = find_boundary_with_one_pin(&compaction_stage) {
288                // Compact
289                let (compacted, expansion_op) = compaction_stage.compact_boundary(boundary);
290                // Generate smaller trees.
291                let trees = gen_lut(compacted, cache);
292                // Expand
293                trees
294                    .into_iter()
295                    .map(|e| e.expand_boundary(&expansion_op))
296                    .collect()
297            }
298            // Rule 2.
299            else if let Some((b1, b2)) =
300                find_two_adjacent_boundaries_with_shared_pin_in_corner(&compaction_stage)
301            {
302                let (compacted_b1, expansion_op_b1) = compaction_stage.clone().compact_boundary(b1);
303                let (compacted_b2, expansion_op_b2) = compaction_stage.compact_boundary(b2);
304
305                let trees_b1 = gen_lut(compacted_b1, cache);
306                let trees_b2 = gen_lut(compacted_b2, cache);
307
308                let expanded_b1 = trees_b1
309                    .into_iter()
310                    .map(|t| t.expand_boundary(&expansion_op_b1));
311
312                let expanded_b2 = trees_b2
313                    .into_iter()
314                    .map(|t| t.expand_boundary(&expansion_op_b2));
315
316                let union = expanded_b1.chain(expanded_b2).collect();
317
318                prune(union)
319            }
320            // Rule 3
321            else {
322                let num_pins = compaction_stage
323                    .current_pins()
324                    .pin_locations()
325                    .dedup()
326                    .count();
327                let num_pins_on_boundary = compaction_stage
328                    .current_pins()
329                    .pin_locations()
330                    .filter(|p| grid.is_on_boundary(*p))
331                    .dedup()
332                    .count();
333
334                let s = if num_pins == 7 && num_pins_on_boundary == 7 {
335                    // Create trees with near-ring structure (i.e. a ring around the boundary with one edge removed).
336                    create_near_ring_trees(&compaction_stage)
337                        .map(|tree| ExpansionStage::new(compaction_stage.clone(), tree))
338                        .collect()
339                } else if num_pins >= 8 && num_pins_on_boundary >= 7 {
340                    // connect_adj_pins
341                    let d = num_pins - 3;
342                    connect_adj_pins(&compaction_stage, d as HananCoord, cache)
343                } else {
344                    vec![]
345                };
346
347                // Compact+expand on all boundaries.
348                let all_boundary_expansions = [
349                    Boundary::Left,
350                    Boundary::Right,
351                    Boundary::Top,
352                    Boundary::Bottom,
353                ]
354                .iter()
355                .flat_map(|b| {
356                    let (compacted, expansion_op) = compaction_stage.clone().compact_boundary(*b);
357                    let trees = gen_lut(compacted, cache);
358                    // Expand all.
359                    trees
360                        .into_iter()
361                        .map(move |t| t.expand_boundary(&expansion_op))
362                });
363
364                let union = s.into_iter().chain(all_boundary_expansions).collect();
365
366                prune(union)
367            }
368        };
369
370    if let Some(cache_key) = &cache_key {
371        // Store trees to cache.
372
373        let contains_key = cache.with(&cache_key, |r| r.is_some());
374        if !contains_key {
375            cache.insert(cache_key.clone(), vec![]);
376        }
377
378        cache.with_mut(&cache_key, |cache_entry| {
379            let cache_entry = cache_entry.unwrap();
380            for tree in &trees {
381                cache_entry.push(tree.tree().clone().translate_to_origin());
382            }
383        });
384    }
385
386    debug_assert!(
387        trees.iter().all(|t| t.tree().is_tree()),
388        "generated an invalid tree"
389    );
390
391    trees
392}
393
394#[test]
395fn test_gen_lut_simple() {
396    let pins =
397        Pins::from_vec(vec![(0, 0).into(), (2, 2).into(), (2, 3).into()]).into_canonical_form();
398    let grid = UnitHananGrid::new(pins.bounding_box().unwrap());
399
400    let compaction_stage = CompactionStage::new(grid, pins);
401
402    let trees = gen_lut(compaction_stage, &mut Default::default());
403    dbg!(trees.len());
404}
405
406#[test]
407fn test_gen_lut_rule_2() {
408    let pins = Pins::from_vec(vec![
409        (0, 0).into(),
410        (2, 2).into(),
411        (2, 1).into(),
412        (1, 2).into(),
413    ])
414    .into_canonical_form();
415    let grid = UnitHananGrid::new(pins.bounding_box().unwrap());
416
417    let compaction_stage = CompactionStage::new(grid, pins);
418
419    let trees = gen_lut(compaction_stage, &mut Default::default());
420    dbg!(trees.len());
421
422    dbg!(trees);
423}
424
425fn prune(mut trees: Vec<ExpansionStage>) -> Vec<ExpansionStage> {
426    prune_inplace(&mut trees);
427    trees
428}
429
430/// Deduplicate and remove all trees which are not on the pareto front regarding their wirelength vectors.
431fn prune_inplace(trees: &mut Vec<ExpansionStage>) {
432    debug_assert!(
433        trees.iter().map(|e| e.num_pins()).all_equal(),
434        "all trees must have the same number of pins"
435    );
436
437    // Compute all wirelength vectors.
438    let mut wirelength_vectors: Vec<_> = trees
439        .iter()
440        .map(|e| {
441            let grid = e.grid();
442            let mut wv = WirelenghtVector::zero(
443                grid.upper_right().x as usize,
444                grid.upper_right().y as usize,
445            );
446            e.tree().compute_wirelength_vector(&mut wv);
447            wv
448        })
449        .collect();
450
451    // Remove non-pareto optimal trees.
452    let mut i = 0;
453    while i < trees.len() {
454        assert_eq!(trees.len(), wirelength_vectors.len());
455
456        let mut j = i + 1;
457        while j < wirelength_vectors.len() {
458            let wv_i = &wirelength_vectors[i];
459            let wv_j = &wirelength_vectors[j];
460            let cmp = wv_j.partial_cmp(wv_i);
461            if cmp == Some(Ordering::Greater) || cmp == Some(Ordering::Equal) {
462                // tree_i makes tree_j redundant or even inferior.
463                wirelength_vectors.swap_remove(j);
464                trees.swap_remove(j);
465                // Don't increment j, because new element has been moved to index j.
466            } else if cmp == Some(Ordering::Less) {
467                // tree_i is dominated by tree_j.
468                trees.swap(i, j);
469                wirelength_vectors.swap(i, j);
470                trees.swap_remove(j);
471                wirelength_vectors.swap_remove(j);
472
473                // Better tree has been moved to i. Need to rewind to i+1 again.
474                j = i + 1;
475            } else {
476                j += 1;
477            }
478        }
479
480        i += 1;
481    }
482}
483
484/// Find a boundary which contains exactly one pin (when pins are deduplicated).
485/// Returns `None` if none is found.
486fn find_boundary_with_one_pin(compaction_stage: &CompactionStage) -> Option<Boundary> {
487    let grid = compaction_stage.grid();
488    [
489        Boundary::Right,
490        Boundary::Top,
491        Boundary::Left,
492        Boundary::Bottom,
493    ]
494    .into_iter()
495    .filter(|&b| {
496        let num_pins_on_boundary = compaction_stage
497            .current_pins()
498            .pin_locations()
499            .filter(|&p| grid.is_on_partial_boundary(p, b))
500            .dedup() // Can deduplicate because pins are sorted.
501            .count();
502        num_pins_on_boundary == 1
503    })
504    .next()
505}
506
507/// Find two adjacent boundaries which contain three pins in total and have one pin in the shared corner.
508fn find_two_adjacent_boundaries_with_shared_pin_in_corner(
509    compaction_stage: &CompactionStage,
510) -> Option<(Boundary, Boundary)> {
511    let grid = compaction_stage.grid();
512
513    let boundaries = [
514        Boundary::Right,
515        Boundary::Top,
516        Boundary::Left,
517        Boundary::Bottom,
518    ];
519    let corners = [
520        grid.upper_right(),
521        grid.upper_left(),
522        grid.lower_left(),
523        grid.lower_right(),
524    ];
525    let mut boundary_pin_count = [0usize; 4];
526    let mut corner_pin_count = [0usize; 4];
527
528    // Iterate over all points and keep track of number of points on boundaries and corners.
529    compaction_stage
530        .current_pins()
531        .pin_locations()
532        .dedup()
533        .for_each(|p| {
534            // Increment all pin counters of boundaries which contain `p`.
535            boundary_pin_count
536                .iter_mut()
537                .zip(&boundaries)
538                .filter(|(_, b)| grid.is_on_partial_boundary(p, **b))
539                .for_each(|(pin_count, _)| {
540                    *pin_count += 1;
541                });
542
543            // If `p` is a corner, increment the corner pin count.
544            if grid.is_corner(p) {
545                // Increment the correct corner counter.
546                corners
547                    .iter()
548                    .zip(corner_pin_count.iter_mut())
549                    .find(|(corner, _)| corner == &&p) // Find the correct corner counter.
550                    .into_iter()
551                    // Increment the counter, if any.
552                    .for_each(|(_, count)| {
553                        *count += 1;
554                    });
555            }
556        });
557
558    let corners_with_one_pin = corner_pin_count
559        .iter()
560        .enumerate()
561        .filter(|(i, count)| **count == 1)
562        .map(|(i, _)| i);
563
564    // Find corner with one pin where adjacent boundaries have two pins (the one in the corner plus another one).
565    corners_with_one_pin
566        .map(|i| (i, (i + 1) % 4)) // Compute indices of adjacent boundaries.
567        .find(|(i1, i2)| boundary_pin_count[*i1] == 2 && boundary_pin_count[*i2] == 2)
568        // Convert indices to boundaries.
569        .map(|(i1, i2)| (boundaries[i1], boundaries[i2]))
570}
571
572fn create_near_ring_trees(compaction_stage: &CompactionStage) -> impl Iterator<Item = Tree> + '_ {
573    let grid = compaction_stage.grid();
574
575    let pins = compaction_stage.current_pins();
576
577    debug_assert!(
578        pins.pin_locations().all(|p| grid.is_on_boundary(p)),
579        "all pins must be on the boundary"
580    );
581
582    // Create near-ring trees.
583    pins.pin_locations().dedup().map(move |start| {
584        let end = grid
585            .all_boundary_points(start)
586            .rev()
587            .skip(1)
588            .filter(|p| pins.contains(*p))
589            .next()
590            .unwrap(); // Unwrap is fine because we know that there's 7 points on the boundary.
591
592        let points_along_boundary = grid.all_boundary_points(start);
593
594        let tree_edges = points_along_boundary
595            .clone()
596            .take_while(|p| p != &end)
597            .zip(points_along_boundary.skip(1));
598
599        let mut tree = Tree::empty();
600
601        for (a, b) in tree_edges {
602            let edge = TreeEdge::from_points(a, b).unwrap();
603            tree.add_edge_unchecked(edge);
604            debug_assert!(tree.is_tree());
605        }
606        tree
607    })
608}
609
610#[test]
611fn test_create_near_ring_trees() {
612    let pins =
613        Pins::from_vec(vec![(0, 0).into(), (2, 2).into(), (2, 1).into()]).into_canonical_form();
614    let grid = UnitHananGrid::new(pins.bounding_box().unwrap());
615    let compaction_stage = CompactionStage::new(grid, pins.clone());
616
617    let trees: Vec<_> = create_near_ring_trees(&compaction_stage).collect();
618    assert_eq!(trees.len(), 3);
619
620    // All pins must be covered by the tree.
621    for tree in &trees {
622        assert!(pins.pin_locations().all(|p| tree.contains_node(p)))
623    }
624}
625
626/// If the grid an pins are simple enough, create all potentially optimal steiner trees for them.
627/// If they are not simple enough, return `None`.
628fn create_simple_trees(grid: UnitHananGrid, pins: &Pins<Canonical>) -> Option<Vec<Tree>> {
629    let r = grid.rect();
630
631    debug_assert!(
632        pins.pin_locations().all(|p| grid.contains_point(p)),
633        "all pins must be contained on the grid"
634    );
635
636    if r.width() == 0 && r.height() == 0 {
637        // Trivial case.
638        Some(vec![Tree::empty()])
639    } else if r.width() == 0 && r.height() > 0 {
640        // Tree is a vertical line.
641        let x = r.lower_left().x;
642        let mut tree = Tree::empty_non_canonical();
643        for y in r.lower_left().y..r.upper_right().y {
644            let p = Point::new(x, y);
645            tree.add_edge_unchecked(TreeEdge::new(p, EdgeDirection::Up));
646        }
647        debug_assert!(tree.is_tree());
648        Some(vec![tree])
649    } else if r.width() > 0 && r.height() == 0 {
650        // Tree is a horizontal line.
651        let y = r.lower_left().y;
652        let mut tree = Tree::empty_non_canonical();
653        for x in r.lower_left().x..r.upper_right().x {
654            let p = Point::new(x, y);
655            tree.add_edge_unchecked(TreeEdge::new(p, EdgeDirection::Right));
656        }
657        debug_assert!(tree.is_tree());
658        Some(vec![tree])
659    } else {
660        None
661    }
662}
663
664fn connect_adj_pins(
665    compaction_stage: &CompactionStage,
666    d: HananCoord,
667    cache: &mut Cache,
668) -> Vec<ExpansionStage> {
669    let grid = compaction_stage.grid();
670
671    // Find pins on a boundary which have no more than `d` distance between eachother.
672    let boundaries = [
673        Boundary::Right,
674        Boundary::Top,
675        Boundary::Left,
676        Boundary::Bottom,
677    ];
678
679    // Container for results.
680    let mut connect_adj_trees = vec![];
681
682    for boundary in boundaries {
683        let pins_on_boundary: SmallVec<[Point<HananCoord>; MAX_DEGREE]> = compaction_stage
684            .current_pins()
685            .pin_locations()
686            .dedup()
687            .filter(|p| grid.is_on_partial_boundary(*p, boundary))
688            .collect();
689
690        // Find `(start_index, end_index)` tuples which mark consecutive lines of pins along the edge with span-length <= d.
691        let d_span_indices = pins_on_boundary
692            .iter()
693            .enumerate()
694            // Start from each point...
695            .filter_map(|(start_idx, start_p)| {
696                // ... and continue for as many points as the distance is <= d.
697                let end_idx = pins_on_boundary[start_idx + 1..]
698                    .iter()
699                    .enumerate()
700                    .take_while(|(end_idx, end)| start_p.manhattan_distance(end) <= d)
701                    .map(|(end_idx, _)| end_idx + start_idx) // Correct offset.
702                    .last();
703
704                // If endpoint was found, output a tuple of the indices.
705                end_idx.map(|end_idx| (start_idx, end_idx))
706            })
707            .filter(|(start_idx, end_idx)| end_idx > start_idx);
708
709        // Counter-clock-wise travel direction on the boundary.
710        // Used to iterate over points on the d-width spans.
711        let boundary_direction = match boundary {
712            Boundary::Right | Boundary::Left => (0, 1),
713            Boundary::Top | Boundary::Bottom => (1, 0),
714        };
715
716        for (start_idx, end_idx) in d_span_indices {
717            // All pins of this d-width line.
718            let pins_in_span = &pins_on_boundary[start_idx..end_idx + 1];
719
720            let start_p = pins_on_boundary[start_idx];
721            let end_p = pins_on_boundary[end_idx];
722
723            // Remove the selected pins on the edge.
724            let pins_with_removed_span = {
725                let mut pins_with_removed_span = compaction_stage.current_pins().clone();
726                pins_with_removed_span.dedup();
727
728                pins_with_removed_span.remove_pins(|p| pins_in_span.contains(p));
729                pins_with_removed_span
730            };
731
732            // Replace the removed line by a point somewhere that line.
733            // Get candidate replacement points.
734            let replacement_points = {
735                let end_p_inclusive = end_p.translate(boundary_direction); // Next point after the end point.
736
737                (0..)
738                    .map(|i| (boundary_direction.0 * i, boundary_direction.1 * i))
739                    .map(|off| start_p.translate(off))
740                    .take_while(move |p| *p != end_p_inclusive)
741            };
742
743            // Create the subtree which covers the removed nodes on the edge with a straight line.
744            // Will then be merged with the sub-steiner-tree.
745            let sub_tree_on_boundary = {
746                let mut sub_tree = Tree::empty_non_canonical();
747
748                let tree_edge_start_points = (0..)
749                    .map(|i| (boundary_direction.0 * i, boundary_direction.1 * i))
750                    .map(|off| start_p.translate(off))
751                    .take_while(|p| *p != end_p);
752
753                let edge_direction = match boundary {
754                    Boundary::Right | Boundary::Left => EdgeDirection::Up,
755                    Boundary::Top | Boundary::Bottom => EdgeDirection::Right,
756                };
757
758                for p in tree_edge_start_points {
759                    sub_tree.add_edge_unchecked(TreeEdge::new(p, edge_direction));
760                }
761                debug_assert!(sub_tree.is_tree());
762
763                sub_tree
764            };
765
766            // Substitute the removed pins by a single pin, compute the corresponding steiner trees (of smaller degree),
767            // add to them the removed pins. And connect them with a straight line.
768            for replacement in replacement_points {
769                let mut simplified_pins = pins_with_removed_span.clone();
770                if !simplified_pins.contains(replacement) {
771                    simplified_pins.add_pin(replacement);
772                }
773
774                assert!(sub_tree_on_boundary.contains_node(replacement));
775
776                let grid = UnitHananGrid::new(simplified_pins.bounding_box().unwrap());
777
778                // Generate steiner-trees of smaller degree.
779                let smaller_trees = gen_lut(CompactionStage::new(grid, simplified_pins), cache);
780
781                // Add the removed pins to the smaller trees. Connect the pins by a straight line.
782                let merged_trees =
783                    smaller_trees
784                        .into_iter()
785                        .map(|e| e.into_tree())
786                        .map(|mut tree| {
787                            // Construct tree edges which cover the removed pins.
788                            tree.merge(&sub_tree_on_boundary)
789                                .expect("trees cannot be merged");
790                            tree
791                        });
792
793                // Append merged trees to the result.
794                let expansion_stages = merged_trees
795                    // Convert trees into expansion stages.
796                    .map(|t| ExpansionStage::new(compaction_stage.clone(), t));
797
798                connect_adj_trees.extend(expansion_stages);
799            }
800        }
801    }
802
803    connect_adj_trees
804}
805
806#[derive(Default)]
807struct Cache {
808    data: HashMap<Pins<Canonical>, Vec<Tree>>,
809}
810
811impl Cache {
812    fn new() -> Self {
813        Self {
814            data: Default::default(),
815        }
816    }
817
818    fn with<F, R>(&self, key: &Pins<Canonical>, f: F) -> R
819    where
820        F: Fn(Option<&Vec<Tree>>) -> R,
821    {
822        f(self.data.get(key))
823    }
824
825    fn with_mut<F, R>(&mut self, key: &Pins<Canonical>, f: F) -> R
826    where
827        F: Fn(Option<&mut Vec<Tree>>) -> R,
828    {
829        f(self.data.get_mut(key))
830    }
831
832    fn insert(&mut self, key: Pins<Canonical>, data: Vec<Tree>) -> Option<Vec<Tree>> {
833        self.data.insert(key, data)
834    }
835}