1use rand::{Rng, SeedableRng, rngs::StdRng};
2use serde_json::json;
3
4use crate::{GraphEdge, GraphEntity};
5
6#[derive(Clone, Debug)]
7pub struct GraphDataset {
8 pub entities: Vec<GraphEntity>,
9 pub edges: Vec<GraphEdge>,
10}
11
12impl GraphDataset {
13 pub fn nodes(&self) -> usize {
14 self.entities.len()
15 }
16
17 pub fn edges(&self) -> usize {
18 self.edges.len()
19 }
20
21 pub fn degrees(&self) -> Vec<usize> {
22 let mut counts = vec![0usize; self.entities.len()];
23 for edge in &self.edges {
24 let from = edge.from_id as usize;
25 let to = edge.to_id as usize;
26 counts[from] += 1;
27 counts[to] += 1;
28 }
29 counts
30 }
31
32 pub fn hub_index(&self) -> usize {
33 let mut best = (0usize, 0usize);
34 for (idx, deg) in self.degrees().into_iter().enumerate() {
35 if deg > best.0 {
36 best = (deg, idx);
37 }
38 }
39 best.1
40 }
41
42 pub fn mapped_edge(edge: &GraphEdge, id_map: &[i64]) -> GraphEdge {
43 GraphEdge {
44 id: 0,
45 from_id: id_map[edge.from_id as usize],
46 to_id: id_map[edge.to_id as usize],
47 edge_type: edge.edge_type.clone(),
48 data: edge.data.clone(),
49 }
50 }
51}
52
53#[derive(Clone, Debug)]
54pub enum GraphShape {
55 Line,
56 Star,
57 Grid2D { width: usize, height: usize },
58 RandomErdosRenyi { edges: usize },
59 ScaleFree { m: usize },
60}
61
62pub fn generate_graph(shape: GraphShape, node_count: usize, seed: u64) -> GraphDataset {
63 assert!(node_count > 1, "node_count must exceed 1");
64 let entities = build_entities(node_count);
65 let mut edges = match shape {
66 GraphShape::Line => generate_line_edges(node_count),
67 GraphShape::Star => generate_star_edges(node_count),
68 GraphShape::Grid2D { width, height } => generate_grid_edges(width, height, node_count),
69 GraphShape::RandomErdosRenyi { edges } => generate_random_edges(node_count, edges, seed),
70 GraphShape::ScaleFree { m } => generate_scale_free_edges(node_count, m, seed),
71 };
72 edges.sort_by(|a, b| {
73 a.from_id
74 .cmp(&b.from_id)
75 .then_with(|| a.to_id.cmp(&b.to_id))
76 .then_with(|| a.edge_type.cmp(&b.edge_type))
77 });
78 GraphDataset { entities, edges }
79}
80
81fn build_entities(count: usize) -> Vec<GraphEntity> {
82 (0..count)
83 .map(|idx| GraphEntity {
84 id: idx as i64,
85 kind: "Node".to_string(),
86 name: format!("Node{idx}"),
87 file_path: None,
88 data: json!({ "idx": idx }),
89 })
90 .collect()
91}
92
93fn generate_line_edges(count: usize) -> Vec<GraphEdge> {
94 (0..count - 1)
95 .map(|idx| new_edge(idx, idx + 1, "LINE"))
96 .collect()
97}
98
99fn generate_star_edges(count: usize) -> Vec<GraphEdge> {
100 (1..count).map(|leaf| new_edge(0, leaf, "STAR")).collect()
101}
102
103fn generate_grid_edges(width: usize, height: usize, node_count: usize) -> Vec<GraphEdge> {
104 assert_eq!(
105 width * height,
106 node_count,
107 "grid dimensions must match node count"
108 );
109 let mut edges = Vec::with_capacity(width * height * 2);
110 for y in 0..height {
111 for x in 0..width {
112 let base = grid_index(x, y, width);
113 if x + 1 < width {
114 edges.push(new_edge(base, grid_index(x + 1, y, width), "GRID"));
115 }
116 if y + 1 < height {
117 edges.push(new_edge(base, grid_index(x, y + 1, width), "GRID"));
118 }
119 }
120 }
121 edges
122}
123
124fn generate_random_edges(node_count: usize, edge_count: usize, seed: u64) -> Vec<GraphEdge> {
125 let total_pairs = pair_count(node_count);
126 assert!(
127 edge_count as u128 <= total_pairs,
128 "edge_count exceeds possible pairs"
129 );
130 let mut rng = StdRng::seed_from_u64(seed);
131 let mut edges = Vec::with_capacity(edge_count);
132 let mut idx = 0u64;
133 let mut remaining_edges = edge_count as u64;
134 while remaining_edges > 0 && idx < total_pairs as u64 {
135 let remaining_pairs = total_pairs as u64 - idx;
136 let p = remaining_edges as f64 / remaining_pairs as f64;
137 let skip = sample_geometric(&mut rng, p);
138 idx += skip;
139 if idx >= total_pairs as u64 {
140 break;
141 }
142 let (from, to) = pair_from_index(idx, node_count as u64);
143 edges.push(new_edge(from as usize, to as usize, "ER"));
144 idx += 1;
145 remaining_edges -= 1;
146 }
147 edges
148}
149
150fn generate_scale_free_edges(node_count: usize, m: usize, seed: u64) -> Vec<GraphEdge> {
151 assert!(m > 0, "m must be positive");
152 assert!(node_count > m + 1, "node_count must exceed m + 1");
153 let mut rng = StdRng::seed_from_u64(seed);
154 let mut degrees = vec![0usize; node_count];
155 let mut edges = Vec::new();
156 let seed_nodes = m + 1;
157 for u in 0..seed_nodes {
158 for v in (u + 1)..seed_nodes {
159 edges.push(new_edge(u, v, "SF"));
160 degrees[u] += 1;
161 degrees[v] += 1;
162 }
163 }
164 let mut total_degree: usize = degrees.iter().sum();
165 for new_node in seed_nodes..node_count {
166 let mut targets = Vec::new();
167 while targets.len() < m {
168 let pick = rng.gen_range(0..total_degree);
169 let mut cumulative = 0usize;
170 for (candidate, degree) in degrees.iter().take(new_node).enumerate() {
171 cumulative += *degree;
172 if pick < cumulative {
173 if !targets.contains(&candidate) {
174 targets.push(candidate);
175 }
176 break;
177 }
178 }
179 }
180 targets.sort_unstable();
181 targets.dedup();
182 while targets.len() < m {
183 targets.push(targets.len() % new_node);
184 targets.sort_unstable();
185 targets.dedup();
186 }
187 for target in targets {
188 edges.push(new_edge(target, new_node, "SF"));
189 degrees[target] += 1;
190 degrees[new_node] += 1;
191 total_degree += 2;
192 }
193 }
194 edges
195}
196
197fn new_edge(from: usize, to: usize, label: &str) -> GraphEdge {
198 GraphEdge {
199 id: 0,
200 from_id: from as i64,
201 to_id: to as i64,
202 edge_type: label.to_string(),
203 data: json!({ "label": label }),
204 }
205}
206
207fn grid_index(x: usize, y: usize, width: usize) -> usize {
208 y * width + x
209}
210
211fn pair_count(nodes: usize) -> u128 {
212 let n = nodes as u128;
213 n * (n - 1) / 2
214}
215
216fn sample_geometric(rng: &mut StdRng, p: f64) -> u64 {
217 let u = rng.r#gen::<f64>().max(f64::MIN_POSITIVE);
218 ((u.ln() / (1.0 - p).ln()).floor().max(0.0)) as u64
219}
220
221fn pair_from_index(idx: u64, nodes: u64) -> (u64, u64) {
222 let mut left = 0;
223 let mut start = 0u64;
224 while left < nodes - 1 {
225 let remaining = nodes - left - 1;
226 if idx < start + remaining {
227 return (left, left + 1 + (idx - start));
228 }
229 start += remaining;
230 left += 1;
231 }
232 (nodes - 2, nodes - 1)
233}