solvr/graph/traits/types.rs
1//! Graph data types and result structures.
2
3use numr::error::Result;
4use numr::runtime::Runtime;
5use numr::sparse::SparseTensor;
6use numr::tensor::Tensor;
7
8/// Graph representation using sparse adjacency matrix.
9///
10/// Wraps a CSR sparse adjacency matrix with metadata for type safety.
11/// Weights are stored as values in the sparse matrix.
12///
13/// # Construction
14///
15/// ```no_run
16/// # use numr::runtime::cpu::{CpuDevice, CpuRuntime};
17/// # use numr::sparse::SparseTensor;
18/// use solvr::graph::GraphData;
19/// # let device = CpuDevice::new();
20/// # let adjacency: SparseTensor<CpuRuntime> = unimplemented!();
21/// let graph = GraphData::new(adjacency, false);
22/// ```
23#[derive(Debug, Clone)]
24pub struct GraphData<R: Runtime> {
25 /// CSR sparse adjacency matrix `[n, n]`, weights as values
26 pub adjacency: SparseTensor<R>,
27 /// Number of nodes in the graph
28 pub num_nodes: usize,
29 /// Whether the graph is directed
30 pub directed: bool,
31}
32
33impl<R: Runtime> GraphData<R> {
34 /// Create a graph from a sparse adjacency matrix.
35 pub fn new(adjacency: SparseTensor<R>, directed: bool) -> Self {
36 let num_nodes = adjacency.nrows();
37 Self {
38 adjacency,
39 num_nodes,
40 directed,
41 }
42 }
43
44 /// Create a graph from an edge list.
45 ///
46 /// # Arguments
47 ///
48 /// * `sources` - Source node indices (i64 slice)
49 /// * `targets` - Target node indices (i64 slice)
50 /// * `weights` - Optional edge weights (f64 slice). If None, uses 1.0 for all edges.
51 /// * `num_nodes` - Number of nodes in the graph
52 /// * `directed` - Whether the graph is directed
53 /// * `device` - Device to create tensors on
54 pub fn from_edge_list<T: numr::dtype::Element>(
55 sources: &[i64],
56 targets: &[i64],
57 weights: Option<&[T]>,
58 num_nodes: usize,
59 directed: bool,
60 device: &R::Device,
61 ) -> Result<Self> {
62 let num_edges = sources.len();
63
64 if directed {
65 let vals: Vec<T> = if let Some(w) = weights {
66 w.to_vec()
67 } else {
68 vec![T::one(); num_edges]
69 };
70 let adjacency = SparseTensor::<R>::from_coo_slices(
71 sources,
72 targets,
73 &vals,
74 [num_nodes, num_nodes],
75 device,
76 )?;
77 let adjacency = adjacency.to_csr()?;
78 Ok(Self::new(adjacency, directed))
79 } else {
80 // Undirected: add both directions
81 let mut all_sources = Vec::with_capacity(num_edges * 2);
82 let mut all_targets = Vec::with_capacity(num_edges * 2);
83 let mut all_weights = Vec::with_capacity(num_edges * 2);
84
85 for i in 0..num_edges {
86 let w = if let Some(ws) = weights {
87 ws[i]
88 } else {
89 T::one()
90 };
91 all_sources.push(sources[i]);
92 all_targets.push(targets[i]);
93 all_weights.push(w);
94 // Reverse edge
95 all_sources.push(targets[i]);
96 all_targets.push(sources[i]);
97 all_weights.push(w);
98 }
99
100 let adjacency = SparseTensor::<R>::from_coo_slices(
101 &all_sources,
102 &all_targets,
103 &all_weights,
104 [num_nodes, num_nodes],
105 device,
106 )?;
107 let adjacency = adjacency.to_csr()?;
108 Ok(Self::new(adjacency, directed))
109 }
110 }
111}
112
113/// Result of single-source shortest path algorithms.
114#[derive(Debug, Clone)]
115pub struct ShortestPathResult<R: Runtime> {
116 /// Distance from source to each node `[n]`. Infinity for unreachable nodes.
117 pub distances: Tensor<R>,
118 /// Predecessor of each node on shortest path `[n]`. -1 for source/unreachable.
119 pub predecessors: Tensor<R>,
120}
121
122/// Result of all-pairs shortest path algorithms.
123#[derive(Debug, Clone)]
124pub struct AllPairsResult<R: Runtime> {
125 /// Distance matrix `[n, n]`. `distances[i][j]` = shortest path from i to j.
126 pub distances: Tensor<R>,
127 /// Predecessor matrix `[n, n]`. `predecessors[i][j]` = previous node on path from i to j.
128 pub predecessors: Tensor<R>,
129}
130
131/// Result of a specific path query (source to target).
132#[derive(Debug, Clone)]
133pub struct PathResult<R: Runtime> {
134 /// Total distance from source to target. Infinity if unreachable.
135 pub distance: f64,
136 /// Node indices along the path `[path_len]`. Empty if unreachable.
137 pub path: Tensor<R>,
138}
139
140/// Result of minimum spanning tree algorithms.
141#[derive(Debug, Clone)]
142pub struct MSTResult<R: Runtime> {
143 /// Edge sources in the MST `[num_mst_edges]`.
144 pub sources: Tensor<R>,
145 /// Edge targets in the MST `[num_mst_edges]`.
146 pub targets: Tensor<R>,
147 /// Edge weights in the MST `[num_mst_edges]`.
148 pub weights: Tensor<R>,
149 /// Total weight of the MST.
150 pub total_weight: f64,
151}
152
153/// Result of connected component algorithms.
154#[derive(Debug, Clone)]
155pub struct ComponentResult<R: Runtime> {
156 /// Component label for each node `[n]`.
157 pub labels: Tensor<R>,
158 /// Number of connected components.
159 pub num_components: usize,
160}
161
162/// Result of max-flow algorithms.
163#[derive(Debug, Clone)]
164pub struct FlowResult<R: Runtime> {
165 /// Maximum flow value.
166 pub max_flow: f64,
167 /// Flow on each edge as a sparse matrix [n, n].
168 pub flow: Tensor<R>,
169}
170
171/// Options for eigenvector centrality.
172#[derive(Debug, Clone)]
173pub struct EigCentralityOptions {
174 /// Maximum number of iterations for power iteration.
175 pub max_iter: usize,
176 /// Convergence tolerance.
177 pub tol: f64,
178}
179
180impl Default for EigCentralityOptions {
181 fn default() -> Self {
182 Self {
183 max_iter: 100,
184 tol: 1e-6,
185 }
186 }
187}
188
189/// Options for PageRank.
190#[derive(Debug, Clone)]
191pub struct PageRankOptions {
192 /// Damping factor (typically 0.85).
193 pub damping: f64,
194 /// Maximum number of iterations.
195 pub max_iter: usize,
196 /// Convergence tolerance.
197 pub tol: f64,
198}
199
200impl Default for PageRankOptions {
201 fn default() -> Self {
202 Self {
203 damping: 0.85,
204 max_iter: 100,
205 tol: 1e-6,
206 }
207 }
208}
209
210/// Options for min-cost flow.
211#[derive(Debug, Clone, Default)]
212pub struct MinCostFlowOptions {
213 /// Cost per unit flow on each edge (sparse matrix [n, n]).
214 /// If None, all costs are 1.
215 pub costs: Option<Vec<f64>>,
216 /// Maximum flow to push. If None, finds min-cost max flow.
217 pub max_flow: Option<f64>,
218}