Skip to main content

rustsim_spaces/
graph.rs

1//! Graph-based discrete space.
2//!
3//! [`GraphSpace`] represents a network where each node can hold agents.
4//! Agents move along edges, and distance is measured in hops (BFS).
5//! Supports both directed and undirected graphs with dynamic topology
6//! (vertices and edges can be added/removed at runtime).
7//!
8//! Mirrors Julia Agents.jl `GraphSpace`.
9
10use rand::Rng;
11use rustsim_core::{
12    interaction::{PositionedAgent, SpaceInteraction},
13    space::Space,
14    types::{AgentId, NodeId},
15};
16use std::collections::{HashSet, VecDeque};
17use thiserror::Error;
18
19/// Position in a graph space - a node index (`0..num_vertices`).
20pub type GraphPos = NodeId;
21
22/// Errors returned by graph space operations.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
24pub enum GraphSpaceError {
25    /// The node index is out of range.
26    #[error("invalid graph node index {0}")]
27    InvalidNode(GraphPos),
28}
29
30/// Neighbor search direction for directed graphs.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum NeighborType {
33    /// Outgoing neighbors (default; equivalent to All for undirected graphs).
34    #[default]
35    Out,
36    /// Incoming neighbors.
37    In,
38    /// Both incoming and outgoing.
39    All,
40}
41
42/// A graph-based discrete space mirroring Julia Agents.jl `GraphSpace`.
43///
44/// Each node can hold an arbitrary number of agents. Agents move between
45/// nodes along edges. Distance is measured in graph hops.
46///
47/// Supports both directed and undirected graphs. Edges and vertices can
48/// be added/removed dynamically at runtime.
49///
50/// # Example: building floor graph
51/// ```text
52///   [Floor1-RoomA] --edge-- [Floor1-Corridor] --edge-- [Floor1-RoomB]
53///                                |
54///                              (stairs)
55///                                |
56///   [Floor2-RoomA] --edge-- [Floor2-Corridor] --edge-- [Floor2-RoomB]
57/// ```
58/// Each room/corridor/stairwell is a node. Edges connect adjacent spaces.
59#[derive(Debug, Clone)]
60pub struct GraphSpace {
61    /// Adjacency list: out-neighbors for each node.
62    adj_out: Vec<Vec<GraphPos>>,
63    /// Adjacency list: in-neighbors for each node (for directed graphs).
64    adj_in: Vec<Vec<GraphPos>>,
65    /// Agent IDs stored at each node.
66    stored_ids: Vec<Vec<AgentId>>,
67    /// Whether the graph is directed.
68    directed: bool,
69}
70
71impl GraphSpace {
72    /// Create an undirected graph with `n` nodes and no edges.
73    pub fn new(n: usize) -> Self {
74        Self {
75            adj_out: vec![Vec::new(); n],
76            adj_in: vec![Vec::new(); n],
77            stored_ids: vec![Vec::new(); n],
78            directed: false,
79        }
80    }
81
82    /// Create a graph with `n` nodes. If `directed` is true, edges are one-way.
83    pub fn new_directed(n: usize, directed: bool) -> Self {
84        Self {
85            adj_out: vec![Vec::new(); n],
86            adj_in: vec![Vec::new(); n],
87            stored_ids: vec![Vec::new(); n],
88            directed,
89        }
90    }
91
92    /// Number of nodes in the graph.
93    pub fn num_vertices(&self) -> usize {
94        self.adj_out.len()
95    }
96
97    /// Number of edges in the graph.
98    ///
99    /// For undirected graphs, each edge is counted once (not twice).
100    pub fn num_edges(&self) -> usize {
101        let total: usize = self.adj_out.iter().map(|v| v.len()).sum();
102        if self.directed {
103            total
104        } else {
105            total / 2
106        }
107    }
108
109    /// Whether the graph is directed.
110    pub fn is_directed(&self) -> bool {
111        self.directed
112    }
113
114    /// Add a new vertex and return its index.
115    pub fn add_vertex(&mut self) -> GraphPos {
116        let idx = self.adj_out.len();
117        self.adj_out.push(Vec::new());
118        self.adj_in.push(Vec::new());
119        self.stored_ids.push(Vec::new());
120        idx
121    }
122
123    /// Remove a vertex by swapping with the last vertex (matches Graphs.jl behavior).
124    /// All agents at the removed vertex must be removed beforehand.
125    /// Returns true if successful.
126    pub fn rem_vertex(&mut self, n: GraphPos) -> bool {
127        let nv = self.num_vertices();
128        if n >= nv {
129            return false;
130        }
131
132        // Remove all edges involving node n
133        let out_neighbors: Vec<GraphPos> = self.adj_out[n].clone();
134        for &neighbor in &out_neighbors {
135            self.rem_edge(n, neighbor);
136        }
137        let in_neighbors: Vec<GraphPos> = self.adj_in[n].clone();
138        for &neighbor in &in_neighbors {
139            self.rem_edge(neighbor, n);
140        }
141
142        let last = nv - 1;
143        if n != last {
144            // Swap last node into position n
145            self.adj_out.swap(n, last);
146            self.adj_in.swap(n, last);
147            self.stored_ids.swap(n, last);
148
149            // Update all references from `last` to `n`
150            for neighbors in &mut self.adj_out {
151                for pos in neighbors.iter_mut() {
152                    if *pos == last {
153                        *pos = n;
154                    }
155                }
156            }
157            for neighbors in &mut self.adj_in {
158                for pos in neighbors.iter_mut() {
159                    if *pos == last {
160                        *pos = n;
161                    }
162                }
163            }
164        }
165
166        self.adj_out.pop();
167        self.adj_in.pop();
168        self.stored_ids.pop();
169        true
170    }
171
172    /// Add an edge from `a` to `b`. For undirected graphs, also adds `b` to `a`.
173    pub fn add_edge(&mut self, a: GraphPos, b: GraphPos) -> bool {
174        let nv = self.num_vertices();
175        if a >= nv || b >= nv {
176            return false;
177        }
178        if self.adj_out[a].contains(&b) {
179            return false; // already exists
180        }
181        self.adj_out[a].push(b);
182        self.adj_in[b].push(a);
183        if !self.directed {
184            self.adj_out[b].push(a);
185            self.adj_in[a].push(b);
186        }
187        true
188    }
189
190    /// Remove an edge from `a` to `b`.
191    pub fn rem_edge(&mut self, a: GraphPos, b: GraphPos) -> bool {
192        let nv = self.num_vertices();
193        if a >= nv || b >= nv {
194            return false;
195        }
196        let removed = remove_from_vec(&mut self.adj_out[a], b);
197        remove_from_vec(&mut self.adj_in[b], a);
198        if !self.directed {
199            remove_from_vec(&mut self.adj_out[b], a);
200            remove_from_vec(&mut self.adj_in[a], b);
201        }
202        removed
203    }
204
205    /// Outgoing neighbors of a node (slice reference, no allocation).
206    pub fn neighbors_out(&self, n: GraphPos) -> &[GraphPos] {
207        &self.adj_out[n]
208    }
209
210    /// Incoming neighbors of a node (slice reference, no allocation).
211    pub fn neighbors_in(&self, n: GraphPos) -> &[GraphPos] {
212        &self.adj_in[n]
213    }
214
215    /// All neighbors (union of incoming and outgoing), deduplicated.
216    pub fn neighbors_all(&self, n: GraphPos) -> Vec<GraphPos> {
217        let mut set: HashSet<GraphPos> = HashSet::new();
218        set.extend(&self.adj_out[n]);
219        set.extend(&self.adj_in[n]);
220        set.into_iter().collect()
221    }
222
223    /// Get neighbors according to a [`NeighborType`] selector.
224    pub fn neighbors(&self, n: GraphPos, kind: NeighborType) -> Vec<GraphPos> {
225        match kind {
226            NeighborType::Out => self.adj_out[n].clone(),
227            NeighborType::In => self.adj_in[n].clone(),
228            NeighborType::All => self.neighbors_all(n),
229        }
230    }
231
232    /// Agent IDs stored at a given node.
233    pub fn ids_in_position(&self, n: GraphPos) -> &[AgentId] {
234        &self.stored_ids[n]
235    }
236
237    /// All valid node indices.
238    pub fn positions(&self) -> std::ops::Range<usize> {
239        0..self.num_vertices()
240    }
241
242    /// Find all node indices reachable within `r` hops via BFS, excluding the origin.
243    pub fn nearby_positions(&self, pos: GraphPos, r: usize, kind: NeighborType) -> Vec<GraphPos> {
244        let mut visited = HashSet::new();
245        let mut queue = VecDeque::new();
246        visited.insert(pos);
247        queue.push_back((pos, 0usize));
248
249        let mut result = Vec::new();
250
251        while let Some((node, dist)) = queue.pop_front() {
252            if dist > 0 {
253                result.push(node);
254            }
255            if dist < r {
256                match kind {
257                    NeighborType::Out => {
258                        for &neighbor in &self.adj_out[node] {
259                            if visited.insert(neighbor) {
260                                queue.push_back((neighbor, dist + 1));
261                            }
262                        }
263                    }
264                    NeighborType::In => {
265                        for &neighbor in &self.adj_in[node] {
266                            if visited.insert(neighbor) {
267                                queue.push_back((neighbor, dist + 1));
268                            }
269                        }
270                    }
271                    NeighborType::All => {
272                        for &neighbor in &self.adj_out[node] {
273                            if visited.insert(neighbor) {
274                                queue.push_back((neighbor, dist + 1));
275                            }
276                        }
277                        for &neighbor in &self.adj_in[node] {
278                            if visited.insert(neighbor) {
279                                queue.push_back((neighbor, dist + 1));
280                            }
281                        }
282                    }
283                }
284            }
285        }
286        result
287    }
288
289    /// Find all agent IDs within `r` hops, including agents at the origin node.
290    pub fn nearby_agent_ids(&self, pos: GraphPos, r: usize, kind: NeighborType) -> Vec<AgentId> {
291        let mut ids = Vec::new();
292        // Include agents at origin
293        ids.extend_from_slice(&self.stored_ids[pos]);
294        // Include agents at nearby positions
295        for neighbor in self.nearby_positions(pos, r, kind) {
296            ids.extend_from_slice(&self.stored_ids[neighbor]);
297        }
298        ids
299    }
300}
301
302fn remove_from_vec(v: &mut Vec<GraphPos>, val: GraphPos) -> bool {
303    if let Some(i) = v.iter().position(|&x| x == val) {
304        v.swap_remove(i);
305        true
306    } else {
307        false
308    }
309}
310
311impl Space for GraphSpace {}
312
313impl<A> SpaceInteraction<A> for GraphSpace
314where
315    A: PositionedAgent<Position = GraphPos>,
316{
317    type Error = GraphSpaceError;
318
319    fn random_position<R: rand::RngCore>(&self, rng: &mut R) -> A::Position {
320        rng.gen_range(0..self.num_vertices())
321    }
322
323    fn add_agent(&mut self, agent: &A) -> Result<(), Self::Error> {
324        let pos = *agent.position();
325        if pos >= self.num_vertices() {
326            return Err(GraphSpaceError::InvalidNode(pos));
327        }
328        self.stored_ids[pos].push(agent.id());
329        Ok(())
330    }
331
332    fn remove_agent(&mut self, agent: &A) -> Result<(), Self::Error> {
333        let pos = *agent.position();
334        if pos >= self.num_vertices() {
335            return Err(GraphSpaceError::InvalidNode(pos));
336        }
337        if let Some(i) = self.stored_ids[pos].iter().position(|&id| id == agent.id()) {
338            self.stored_ids[pos].swap_remove(i);
339        }
340        Ok(())
341    }
342
343    fn nearby_ids(&self, position: &A::Position, radius: usize) -> Vec<AgentId> {
344        self.nearby_agent_ids(*position, radius, NeighborType::Out)
345    }
346}