street_engine/core/container/
path_network.rs

1use std::collections::BTreeMap;
2
3use rstar::RTree;
4
5use crate::core::geometry::{line_segment::LineSegment, site::Site};
6
7use super::{
8    index_object::{NodeTreeObject, PathTreeObject},
9    undirected::UndirectedGraph,
10};
11
12pub trait PathNetworkNodeTrait: Into<Site> + Copy + Eq {}
13impl<T> PathNetworkNodeTrait for T where T: Into<Site> + Copy + Eq {}
14
15/// ID for identifying a node in the network.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
17pub struct NodeId(usize);
18
19impl NodeId {
20    pub fn new(id: usize) -> Self {
21        Self(id)
22    }
23}
24
25/// Path network.
26/// This struct is used to manage nodes and paths between nodes in 2D space.
27///
28/// This struct provides:
29///  - functions to add, remove, and search nodes and paths.
30///  - functions to search nodes around a site or a line segment.
31#[derive(Debug, Clone)]
32pub struct PathNetwork<N>
33where
34    N: PathNetworkNodeTrait,
35{
36    nodes: BTreeMap<NodeId, N>,
37    path_tree: RTree<PathTreeObject<NodeId>>,
38    node_tree: RTree<NodeTreeObject<NodeId>>,
39    path_connection: UndirectedGraph<NodeId>,
40    last_node_id: NodeId,
41}
42
43impl<N> Default for PathNetwork<N>
44where
45    N: PathNetworkNodeTrait,
46{
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl<N> PathNetwork<N>
53where
54    N: PathNetworkNodeTrait,
55{
56    /// Create a new path network.
57    pub fn new() -> Self {
58        Self {
59            nodes: BTreeMap::new(),
60            path_tree: RTree::new(),
61            node_tree: RTree::new(),
62            path_connection: UndirectedGraph::new(),
63            last_node_id: NodeId::new(0),
64        }
65    }
66    /// Get nodes in the network.
67    pub fn nodes_iter(&self) -> impl Iterator<Item = (NodeId, &N)> {
68        self.nodes.iter().map(|(node_id, node)| (*node_id, node))
69    }
70
71    /// Get neighbors of a node.
72    pub fn neighbors_iter(&self, node_id: NodeId) -> Option<impl Iterator<Item = (NodeId, &N)>> {
73        self.path_connection
74            .neighbors_iter(node_id)
75            .map(|neighbors| {
76                neighbors.filter_map(move |neighbor| Some((*neighbor, self.nodes.get(neighbor)?)))
77            })
78    }
79
80    /// Add a node to the network.
81    pub(crate) fn add_node(&mut self, node: N) -> NodeId {
82        let node_id = self.last_node_id;
83        self.nodes.insert(node_id, node);
84        self.node_tree
85            .insert(NodeTreeObject::new(node.into(), node_id));
86        self.last_node_id = NodeId::new(node_id.0 + 1);
87        node_id
88    }
89
90    /// Remove a node from the network.
91    /// This function can be never used, but it is kept for future use.
92    #[allow(dead_code)]
93    fn remove_node(&mut self, node_id: NodeId) -> Option<NodeId> {
94        let neighbors = if let Some(neighbors) = self.path_connection.neighbors_iter(node_id) {
95            neighbors.copied().collect::<Vec<_>>()
96        } else {
97            return None;
98        };
99
100        let site = if let Some(node) = self.nodes.get(&node_id) {
101            (*node).into()
102        } else {
103            return None;
104        };
105
106        neighbors.iter().for_each(|neighbor| {
107            self.remove_path(node_id, *neighbor);
108        });
109
110        self.node_tree.remove(&NodeTreeObject::new(site, node_id));
111
112        self.nodes.remove(&node_id);
113        Some(node_id)
114    }
115
116    /// Add a path to the network.
117    pub(crate) fn add_path(&mut self, start: NodeId, end: NodeId) -> Option<(NodeId, NodeId)> {
118        if start == end {
119            return None;
120        }
121        if self.path_connection.has_edge(start, end) {
122            return None;
123        }
124
125        let (start_site, end_site) = if let (Some(start_node), Some(end_node)) =
126            (self.nodes.get(&start), self.nodes.get(&end))
127        {
128            (*start_node, *end_node)
129        } else {
130            return None;
131        };
132
133        self.path_connection.add_edge(start, end);
134
135        let (start_site, end_site) = (start_site.into(), end_site.into());
136
137        self.path_tree.insert(PathTreeObject::new(
138            LineSegment::new(start_site, end_site),
139            (start, end),
140        ));
141
142        Some((start, end))
143    }
144
145    /// Remove a path from the network.
146    pub(crate) fn remove_path(&mut self, start: NodeId, end: NodeId) -> Option<(NodeId, NodeId)> {
147        let (start_site, end_site) = if let (Some(start_node), Some(end_node)) =
148            (self.nodes.get(&start), self.nodes.get(&end))
149        {
150            (*start_node, *end_node)
151        } else {
152            return None;
153        };
154
155        self.path_connection.remove_edge(start, end);
156
157        self.path_tree.remove(&PathTreeObject::new(
158            LineSegment::new(start_site.into(), end_site.into()),
159            (start, end),
160        ));
161
162        Some((start, end))
163    }
164
165    /// Get a node by its NodeId.
166    pub fn get_node(&self, node_id: NodeId) -> Option<&N> {
167        self.nodes.get(&node_id)
168    }
169
170    /// Check if there is a path between two nodes.
171    pub fn has_path(&self, start: NodeId, to: NodeId) -> bool {
172        self.path_connection.has_edge(start, to)
173    }
174
175    /// Search nodes around a site within a radius.
176    pub fn nodes_around_site_iter(&self, site: Site, radius: f64) -> impl Iterator<Item = &NodeId> {
177        self.nodes.iter().filter_map(move |(node_id, &node)| {
178            if site.distance(&node.into()) <= radius {
179                Some(node_id)
180            } else {
181                None
182            }
183        })
184    }
185
186    /// Search nodes around a line segment within a radius.
187    pub fn nodes_around_line_iter(
188        &self,
189        line: LineSegment,
190        radius: f64,
191    ) -> impl Iterator<Item = &NodeId> {
192        let envelope = rstar::AABB::from_corners(
193            [
194                line.0.x.min(line.1.x) - radius,
195                line.0.y.min(line.1.y) - radius,
196            ],
197            [
198                line.0.x.max(line.1.x) + radius,
199                line.0.y.max(line.1.y) + radius,
200            ],
201        );
202        self.node_tree
203            .locate_in_envelope(&envelope)
204            .filter(move |object| line.get_distance(object.site()) <= radius)
205            .map(|object| object.node_id())
206    }
207
208    /// Search paths touching a rectangle.
209    pub fn paths_touching_rect_iter(
210        &self,
211        corner_0: Site,
212        corner_1: Site,
213    ) -> impl Iterator<Item = &(NodeId, NodeId)> {
214        let search_rect =
215            rstar::AABB::from_corners([corner_0.x, corner_0.y], [corner_1.x, corner_1.y]);
216
217        self.path_tree
218            .locate_in_envelope_intersecting(&search_rect)
219            .map(|object| object.node_ids())
220    }
221
222    /// Get the optimized path network.
223    pub fn into_optimized(self) -> Self {
224        // TODO: optimize the path network
225        self
226    }
227
228    /// This function is only for testing
229    #[allow(dead_code)]
230    fn check_path_state_is_consistent(&self) -> bool {
231        self.path_tree.size() == self.path_connection.size()
232            && self.nodes.len() == self.node_tree.size()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_path_network() {
242        let mut network = PathNetwork::new();
243        let node0 = network.add_node(Site::new(0.0, 0.0));
244        let node1 = network.add_node(Site::new(1.0, 1.0));
245        let node2 = network.add_node(Site::new(2.0, 2.0));
246        let node3 = network.add_node(Site::new(3.0, 3.0));
247        let node4 = network.add_node(Site::new(1.0, 4.0));
248
249        network.add_path(node0, node1);
250        network.add_path(node1, node2);
251        network.add_path(node2, node3);
252        network.add_path(node3, node4);
253        network.add_path(node4, node2);
254
255        assert!(network.has_path(node0, node1));
256        assert!(network.has_path(node1, node2));
257        assert!(network.has_path(node2, node3));
258        assert!(network.has_path(node3, node4));
259        assert!(!network.has_path(node0, node2));
260
261        assert!(network.check_path_state_is_consistent());
262
263        network.remove_path(node1, node2);
264        assert!(!network.has_path(node1, node2));
265        assert!(network.has_path(node2, node3));
266
267        assert!(network.check_path_state_is_consistent());
268
269        network.remove_node(node1);
270        assert!(!network.has_path(node0, node1));
271
272        assert!(network.check_path_state_is_consistent());
273    }
274
275    #[test]
276    fn test_path_crossing_no_crosses() {
277        let mut network = PathNetwork::new();
278        let node0 = network.add_node(Site::new(0.0, 1.0));
279        let node1 = network.add_node(Site::new(2.0, 3.0));
280        let node2 = network.add_node(Site::new(4.0, 5.0));
281
282        network.add_path(node0, node1);
283        network.add_path(node1, node2);
284
285        let paths = network
286            .paths_touching_rect_iter(Site::new(0.0, 0.0), Site::new(1.0, 1.0))
287            .collect::<Vec<_>>();
288        assert_eq!(paths.len(), 1);
289
290        assert!(network.check_path_state_is_consistent());
291    }
292
293    #[test]
294    fn test_path_crossing_all_cross() {
295        let mut network = PathNetwork::new();
296
297        let sites = vec![
298            Site::new(0.0, 2.0),
299            Site::new(2.0, 2.0),
300            Site::new(2.0, 0.0),
301            Site::new(0.0, 0.0),
302        ];
303
304        let nodes = sites
305            .iter()
306            .map(|site| network.add_node(*site))
307            .collect::<Vec<_>>();
308
309        for i in 0..sites.len() {
310            // Add all paths between sites
311            // When i == j, the path is expected to be ignored
312            for j in i..sites.len() {
313                network.add_path(nodes[i], nodes[j]);
314            }
315        }
316
317        for i in 0..sites.len() {
318            for j in 0..sites.len() {
319                if i != j {
320                    assert!(network.has_path(NodeId(i), NodeId(j)));
321                }
322            }
323        }
324
325        let paths = network
326            .paths_touching_rect_iter(Site::new(0.0, 0.0), Site::new(1.0, 2.0))
327            .collect::<Vec<_>>();
328        assert_eq!(paths.len(), 5);
329
330        assert!(network.check_path_state_is_consistent());
331    }
332
333    #[test]
334    fn test_nodes_around_site() {
335        let mut network = PathNetwork::new();
336        let node0 = network.add_node(Site::new(0.0, 0.0));
337        let node1 = network.add_node(Site::new(1.0, 1.0));
338        let node2 = network.add_node(Site::new(2.0, 2.0));
339        let node3 = network.add_node(Site::new(3.0, 3.0));
340        let node4 = network.add_node(Site::new(1.0, 4.0));
341
342        network.add_path(node0, node1);
343        network.add_path(node1, node2);
344        network.add_path(node2, node3);
345        network.add_path(node3, node4);
346        network.add_path(node4, node2);
347
348        let site = Site::new(1.0, 1.0);
349        let nodes = network
350            .nodes_around_site_iter(site, 1.0)
351            .collect::<Vec<_>>();
352        assert_eq!(nodes.len(), 1);
353
354        let site = Site::new(2.0, 1.0);
355        let nodes = network
356            .nodes_around_site_iter(site, 2.0)
357            .collect::<Vec<_>>();
358        assert_eq!(nodes.len(), 2);
359
360        let site = Site::new(2.0, 3.0);
361        let nodes = network
362            .nodes_around_site_iter(site, 2.0)
363            .collect::<Vec<_>>();
364        assert_eq!(nodes.len(), 3);
365
366        let line = LineSegment::new(Site::new(1.0, 3.0), Site::new(3.0, 2.0));
367        let nodes = network
368            .nodes_around_line_iter(line, 1.0)
369            .collect::<Vec<_>>();
370        assert_eq!(nodes.len(), 3);
371
372        let line = LineSegment::new(Site::new(1.0, 0.0), Site::new(0.0, 1.0));
373        let nodes = network
374            .nodes_around_line_iter(line, 2.5)
375            .collect::<Vec<_>>();
376        assert_eq!(nodes.len(), 3);
377
378        network.remove_path(node3, node4);
379        network.remove_node(node1);
380
381        let site = Site::new(2.0, 1.0);
382        let nodes = network
383            .nodes_around_site_iter(site, 2.0)
384            .collect::<Vec<_>>();
385        assert_eq!(nodes.len(), 1);
386
387        let line = LineSegment::new(Site::new(1.0, 0.0), Site::new(0.0, 1.0));
388        let nodes = network
389            .nodes_around_line_iter(line, 2.5)
390            .collect::<Vec<_>>();
391        assert_eq!(nodes.len(), 2);
392
393        assert!(network.check_path_state_is_consistent());
394    }
395
396    #[test]
397    fn test_complex_network() {
398        let xorshift = |x: usize| -> usize {
399            let mut x = x;
400            x ^= x << 13;
401            x ^= x >> 17;
402            x ^= x << 5;
403            x
404        };
405
406        let sites = (0..100)
407            .map(|i| Site::new(xorshift(i * 2) as f64, xorshift(i * 2 + 1) as f64))
408            .collect::<Vec<_>>();
409
410        let loop_count = 10;
411
412        let mut network = PathNetwork::new();
413
414        let nodeids = sites
415            .iter()
416            .map(|site| network.add_node(*site))
417            .collect::<Vec<_>>();
418
419        for l in 0..loop_count {
420            let seed_start = l * sites.len() * sites.len();
421            (0..sites.len()).for_each(|i| {
422                (0..sites.len()).for_each(|j| {
423                    let id = i * sites.len() + j;
424                    if xorshift(id + seed_start) % 2 == 0 {
425                        network.add_path(nodeids[i], nodeids[j]);
426                    }
427                });
428            });
429
430            assert!(network.check_path_state_is_consistent());
431
432            (0..sites.len()).for_each(|i| {
433                (0..sites.len()).for_each(|j| {
434                    let id = i * sites.len() + j;
435                    if xorshift(id + seed_start) % 3 == 0 {
436                        network.remove_path(nodeids[i], nodeids[j]);
437                    }
438                });
439            });
440
441            assert!(network.check_path_state_is_consistent());
442        }
443    }
444}