1use crate::gpu::{GpuContext, LayoutBuffers, LayoutPipeline};
4use crate::quadtree::QuadTree;
5#[cfg(not(target_arch = "wasm32"))]
6use crate::shaders::FORCE_SHADER;
7use crate::shaders::SIMPLE_FORCE_SHADER;
8use crate::{Edge, LayoutError, LayoutParams, Position, Result};
9
10#[cfg(target_arch = "wasm32")]
11use std::cell::RefCell;
12#[cfg(target_arch = "wasm32")]
13use std::rc::Rc;
14
15#[derive(Debug, Clone)]
17pub struct LayoutConfig {
18 pub dt: f32,
20 pub damping: f32,
22 pub repulsion: f32,
24 pub attraction: f32,
26 pub theta: f32,
28 pub gravity: f32,
30 pub ideal_length: f32,
32 pub use_barnes_hut: bool,
34 pub max_tree_depth: usize,
36}
37
38impl Default for LayoutConfig {
39 fn default() -> Self {
40 Self {
41 dt: 0.016,
42 damping: 0.9,
43 repulsion: 1000.0,
44 attraction: 0.01,
45 theta: 0.8,
46 gravity: 0.1,
47 ideal_length: 50.0,
48 use_barnes_hut: true,
49 max_tree_depth: 12,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum LayoutState {
57 Uninitialized,
59 Running,
61 Paused,
63 Converged,
65}
66
67#[cfg(not(target_arch = "wasm32"))]
69#[derive(Default)]
70struct _AsyncPositions {
71 data: Vec<Position>,
72 pending_update: bool,
73}
74
75#[cfg(target_arch = "wasm32")]
77#[derive(Default)]
78struct AsyncPositionsWasm {
79 data: Vec<Position>,
80 pending_update: bool,
81}
82
83pub struct GpuLayout {
85 ctx: GpuContext,
86 pipeline: LayoutPipeline,
87 buffers: Option<LayoutBuffers>,
88 bind_group: Option<wgpu::BindGroup>,
89 config: LayoutConfig,
90 state: LayoutState,
91 positions: Vec<Position>,
92 edges: Vec<Edge>,
93 iteration: u32,
94 frame_counter: u32,
96 _readback_interval: u32,
98 #[cfg(target_arch = "wasm32")]
100 async_positions: Rc<RefCell<AsyncPositionsWasm>>,
101 #[cfg(target_arch = "wasm32")]
103 readback_pending: bool,
104}
105
106impl GpuLayout {
107 pub async fn new(config: LayoutConfig) -> Result<Self> {
109 let ctx = GpuContext::new().await?;
110
111 #[cfg(target_arch = "wasm32")]
113 let shader = SIMPLE_FORCE_SHADER;
114
115 #[cfg(not(target_arch = "wasm32"))]
116 let shader = if config.use_barnes_hut {
117 FORCE_SHADER
118 } else {
119 SIMPLE_FORCE_SHADER
120 };
121
122 let pipeline = LayoutPipeline::new(&ctx, shader)?;
123
124 Ok(Self {
125 ctx,
126 pipeline,
127 buffers: None,
128 bind_group: None,
129 config,
130 state: LayoutState::Uninitialized,
131 positions: Vec::new(),
132 edges: Vec::new(),
133 iteration: 0,
134 frame_counter: 0,
135 _readback_interval: 1, #[cfg(target_arch = "wasm32")]
137 async_positions: Rc::new(RefCell::new(AsyncPositionsWasm::default())),
138 #[cfg(target_arch = "wasm32")]
139 readback_pending: false,
140 })
141 }
142
143 pub fn init(&mut self, positions: Vec<Position>, edges: Vec<Edge>) -> Result<()> {
145 if positions.is_empty() {
146 return Err(LayoutError::InvalidGraph("No nodes".into()));
147 }
148
149 self.positions = positions;
150 self.edges = edges;
151
152 #[cfg(not(target_arch = "wasm32"))]
154 let tree = if self.config.use_barnes_hut {
155 QuadTree::build(&self.positions, self.config.max_tree_depth)
156 } else {
157 QuadTree::build(&[], 1)
158 };
159
160 #[cfg(target_arch = "wasm32")]
161 let tree = QuadTree::build(&[], 1); let buffers = LayoutBuffers::new(&self.ctx, &self.positions, &self.edges, tree.nodes())?;
165
166 let params = self.create_params(tree.nodes().len() as u32);
168 buffers.update_params(&self.ctx, ¶ms);
169
170 let bind_group = self.pipeline.create_bind_group(&self.ctx, &buffers);
172
173 self.buffers = Some(buffers);
174 self.bind_group = Some(bind_group);
175 self.state = LayoutState::Paused;
176 self.iteration = 0;
177 self.frame_counter = 0;
178
179 #[cfg(target_arch = "wasm32")]
180 {
181 self.readback_pending = false;
182 let mut async_pos = self.async_positions.borrow_mut();
183 async_pos.data = self.positions.clone();
184 async_pos.pending_update = false;
185 }
186
187 tracing::info!(
188 "GPU layout initialized: {} nodes, {} edges",
189 self.positions.len(),
190 self.edges.len()
191 );
192
193 Ok(())
194 }
195
196 fn create_params(&self, tree_size: u32) -> LayoutParams {
198 LayoutParams {
199 node_count: self.positions.len() as u32,
200 edge_count: self.edges.len() as u32,
201 tree_size,
202 dt: self.config.dt,
203 damping: self.config.damping,
204 repulsion: self.config.repulsion,
205 attraction: self.config.attraction,
206 theta: self.config.theta,
207 gravity: self.config.gravity,
208 ideal_length: self.config.ideal_length,
209 }
210 }
211
212 pub fn start(&mut self) {
214 if self.state != LayoutState::Uninitialized {
215 self.state = LayoutState::Running;
216 }
217 }
218
219 pub fn pause(&mut self) {
221 if self.state == LayoutState::Running {
222 self.state = LayoutState::Paused;
223 }
224 }
225
226 pub fn state(&self) -> LayoutState {
228 self.state
229 }
230
231 pub fn iteration(&self) -> u32 {
233 self.iteration
234 }
235
236 #[cfg(not(target_arch = "wasm32"))]
239 pub fn step(&mut self) -> Result<&[Position]> {
240 if self.state != LayoutState::Running {
241 return Ok(&self.positions);
242 }
243
244 if self.buffers.is_none() || self.bind_group.is_none() {
245 return Err(LayoutError::NotInitialized);
246 }
247
248 if self.config.use_barnes_hut {
250 self.read_positions_blocking()?;
252
253 let tree = QuadTree::build(&self.positions, self.config.max_tree_depth);
254 let params = self.create_params(tree.nodes().len() as u32);
255
256 let buffers = self.buffers.as_ref().unwrap();
258 if !buffers.update_tree(&self.ctx, tree.nodes()) {
259 return Ok(&self.positions);
260 }
261 buffers.update_params(&self.ctx, ¶ms);
262 }
263
264 self.dispatch_compute();
266 self.iteration += 1;
267
268 self.read_positions_blocking()?;
270
271 Ok(&self.positions)
272 }
273
274 #[cfg(target_arch = "wasm32")]
277 pub fn step(&mut self) -> Result<&[Position]> {
278 if self.state != LayoutState::Running {
279 return Ok(&self.positions);
280 }
281
282 if self.buffers.is_none() || self.bind_group.is_none() {
283 return Err(LayoutError::NotInitialized);
284 }
285
286 {
288 let mut async_pos = self.async_positions.borrow_mut();
289 if async_pos.pending_update && async_pos.data.len() == self.positions.len() {
290 self.positions.copy_from_slice(&async_pos.data);
291 async_pos.pending_update = false;
292 self.readback_pending = false;
293 }
294 }
295
296 self.dispatch_compute();
298 self.iteration += 1;
299 self.frame_counter += 1;
300
301 self.ctx.device.poll(wgpu::Maintain::Poll);
303
304 if self.frame_counter >= self._readback_interval && !self.readback_pending {
306 self.frame_counter = 0;
307 self.request_positions_async();
308 }
309
310 Ok(&self.positions)
311 }
312
313 fn dispatch_compute(&self) {
315 let buffers = self.buffers.as_ref().unwrap();
316 let bind_group = self.bind_group.as_ref().unwrap();
317
318 let mut encoder = self
319 .ctx
320 .device
321 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
322 label: Some("Layout Encoder"),
323 });
324
325 {
326 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
327 label: Some("Layout Compute Pass"),
328 timestamp_writes: None,
329 });
330 compute_pass.set_pipeline(&self.pipeline.pipeline);
331 compute_pass.set_bind_group(0, bind_group, &[]);
332
333 let workgroup_count = buffers.node_count.div_ceil(256);
335 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
336 }
337
338 self.ctx.queue.submit(Some(encoder.finish()));
339 }
340
341 #[cfg(not(target_arch = "wasm32"))]
343 fn read_positions_blocking(&mut self) -> Result<()> {
344 let buffers = self.buffers.as_ref().ok_or(LayoutError::NotInitialized)?;
345
346 let mut encoder = self
347 .ctx
348 .device
349 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
350 label: Some("Readback Encoder"),
351 });
352
353 let size = (self.positions.len() * std::mem::size_of::<Position>()) as u64;
354 encoder.copy_buffer_to_buffer(&buffers.positions, 0, &buffers.staging, 0, size);
355
356 self.ctx.queue.submit(Some(encoder.finish()));
357
358 let buffer_slice = buffers.staging.slice(..);
360 let (tx, rx) = std::sync::mpsc::channel();
361
362 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
363 tx.send(result).unwrap();
364 });
365
366 self.ctx.device.poll(wgpu::Maintain::Wait);
367
368 rx.recv()
369 .map_err(|_| LayoutError::Readback("Channel closed".into()))?
370 .map_err(|e| LayoutError::Readback(e.to_string()))?;
371
372 {
373 let data = buffer_slice.get_mapped_range();
374 let positions: &[Position] = bytemuck::cast_slice(&data);
375 self.positions.copy_from_slice(positions);
376 }
377
378 buffers.staging.unmap();
379
380 Ok(())
381 }
382
383 #[cfg(target_arch = "wasm32")]
386 fn request_positions_async(&mut self) {
387 let Some(buffers) = &self.buffers else {
388 return;
389 };
390
391 self.readback_pending = true;
392
393 let mut encoder = self
394 .ctx
395 .device
396 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
397 label: Some("Readback Encoder"),
398 });
399
400 let size = (self.positions.len() * std::mem::size_of::<Position>()) as u64;
401 encoder.copy_buffer_to_buffer(&buffers.positions, 0, &buffers.staging, 0, size);
402
403 self.ctx.queue.submit(Some(encoder.finish()));
404
405 let staging = buffers.staging.clone();
407 let async_positions = Rc::clone(&self.async_positions);
408 let positions_len = self.positions.len();
409
410 buffers
414 .staging
415 .slice(..)
416 .map_async(wgpu::MapMode::Read, move |result| {
417 if result.is_ok() {
418 let buffer_slice = staging.slice(..);
420 let data = buffer_slice.get_mapped_range();
421 let positions: &[Position] = bytemuck::cast_slice(&data);
422
423 let mut async_pos = async_positions.borrow_mut();
424 if async_pos.data.len() != positions_len {
425 async_pos
426 .data
427 .resize(positions_len, Position { x: 0.0, y: 0.0 });
428 }
429 async_pos.data.copy_from_slice(positions);
430 async_pos.pending_update = true;
431
432 drop(data);
433 staging.unmap();
434 }
435 });
436 }
437
438 pub fn positions(&self) -> &[Position] {
440 &self.positions
441 }
442
443 pub fn set_config(&mut self, config: LayoutConfig) {
445 self.config = config;
446
447 if let Some(buffers) = &self.buffers {
448 #[cfg(not(target_arch = "wasm32"))]
449 let tree_size = if self.config.use_barnes_hut {
450 let tree = QuadTree::build(&self.positions, self.config.max_tree_depth);
451 tree.nodes().len() as u32
452 } else {
453 1
454 };
455
456 #[cfg(target_arch = "wasm32")]
457 let tree_size = 1u32;
458
459 let params = self.create_params(tree_size);
460 buffers.update_params(&self.ctx, ¶ms);
461 }
462 }
463}
464
465#[cfg(not(target_arch = "wasm32"))]
467pub mod sync {
468 use super::*;
469
470 pub fn _new_layout(config: LayoutConfig) -> Result<GpuLayout> {
472 pollster::block_on(GpuLayout::new(config))
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_layout_config_default() {
482 let config = LayoutConfig::default();
483 assert!(config.use_barnes_hut);
484 assert!(config.theta > 0.0);
485 }
486}