Skip to main content

vibe_graph_layout_gpu/
layout.rs

1//! High-level GPU layout interface.
2
3use crate::gpu::{GpuContext, LayoutBuffers, LayoutPipeline};
4use crate::quadtree::QuadTree;
5#[cfg(not(target_arch = "wasm32"))]
6use crate::shaders::FORCE_SHADER;
7use crate::shaders::SIMPLE_FORCE_SHADER;
8use crate::{Edge, LayoutError, LayoutParams, Position, Result};
9
10#[cfg(target_arch = "wasm32")]
11use std::cell::RefCell;
12#[cfg(target_arch = "wasm32")]
13use std::rc::Rc;
14
15/// Configuration for the GPU layout.
16#[derive(Debug, Clone)]
17pub struct LayoutConfig {
18    /// Time step per iteration.
19    pub dt: f32,
20    /// Damping factor (0-1).
21    pub damping: f32,
22    /// Repulsion strength.
23    pub repulsion: f32,
24    /// Attraction strength.
25    pub attraction: f32,
26    /// Barnes-Hut theta (0.5-1.0).
27    pub theta: f32,
28    /// Center gravity strength.
29    pub gravity: f32,
30    /// Ideal edge length.
31    pub ideal_length: f32,
32    /// Use Barnes-Hut (true) or simple O(n²) (false).
33    pub use_barnes_hut: bool,
34    /// Maximum quadtree depth.
35    pub max_tree_depth: usize,
36}
37
38impl Default for LayoutConfig {
39    fn default() -> Self {
40        Self {
41            dt: 0.016,
42            damping: 0.9,
43            repulsion: 1000.0,
44            attraction: 0.01,
45            theta: 0.8,
46            gravity: 0.1,
47            ideal_length: 50.0,
48            use_barnes_hut: true,
49            max_tree_depth: 12,
50        }
51    }
52}
53
54/// Current state of the layout.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum LayoutState {
57    /// Layout is not initialized.
58    Uninitialized,
59    /// Layout is running.
60    Running,
61    /// Layout is paused.
62    Paused,
63    /// Layout has converged.
64    Converged,
65}
66
67/// Shared state for async position updates (native uses Mutex, WASM uses RefCell)
68#[cfg(not(target_arch = "wasm32"))]
69#[derive(Default)]
70struct _AsyncPositions {
71    data: Vec<Position>,
72    pending_update: bool,
73}
74
75/// Shared state for async position updates (WASM)
76#[cfg(target_arch = "wasm32")]
77#[derive(Default)]
78struct AsyncPositionsWasm {
79    data: Vec<Position>,
80    pending_update: bool,
81}
82
83/// GPU-accelerated force-directed graph layout.
84pub struct GpuLayout {
85    ctx: GpuContext,
86    pipeline: LayoutPipeline,
87    buffers: Option<LayoutBuffers>,
88    bind_group: Option<wgpu::BindGroup>,
89    config: LayoutConfig,
90    state: LayoutState,
91    positions: Vec<Position>,
92    edges: Vec<Edge>,
93    iteration: u32,
94    /// Frame counter for periodic readback
95    frame_counter: u32,
96    /// How often to read back positions (every N frames)
97    _readback_interval: u32,
98    /// Shared state for async updates (WASM) - uses RefCell since WASM is single-threaded
99    #[cfg(target_arch = "wasm32")]
100    async_positions: Rc<RefCell<AsyncPositionsWasm>>,
101    /// Whether a readback is currently in progress (WASM)
102    #[cfg(target_arch = "wasm32")]
103    readback_pending: bool,
104}
105
106impl GpuLayout {
107    /// Create a new GPU layout engine.
108    pub async fn new(config: LayoutConfig) -> Result<Self> {
109        let ctx = GpuContext::new().await?;
110
111        // For WASM, always use simple shader (no CPU quadtree dependency)
112        #[cfg(target_arch = "wasm32")]
113        let shader = SIMPLE_FORCE_SHADER;
114
115        #[cfg(not(target_arch = "wasm32"))]
116        let shader = if config.use_barnes_hut {
117            FORCE_SHADER
118        } else {
119            SIMPLE_FORCE_SHADER
120        };
121
122        let pipeline = LayoutPipeline::new(&ctx, shader)?;
123
124        Ok(Self {
125            ctx,
126            pipeline,
127            buffers: None,
128            bind_group: None,
129            config,
130            state: LayoutState::Uninitialized,
131            positions: Vec::new(),
132            edges: Vec::new(),
133            iteration: 0,
134            frame_counter: 0,
135            _readback_interval: 1, // Read back every 30 frames (~0.5s at 60fps)
136            #[cfg(target_arch = "wasm32")]
137            async_positions: Rc::new(RefCell::new(AsyncPositionsWasm::default())),
138            #[cfg(target_arch = "wasm32")]
139            readback_pending: false,
140        })
141    }
142
143    /// Initialize the layout with graph data.
144    pub fn init(&mut self, positions: Vec<Position>, edges: Vec<Edge>) -> Result<()> {
145        if positions.is_empty() {
146            return Err(LayoutError::InvalidGraph("No nodes".into()));
147        }
148
149        self.positions = positions;
150        self.edges = edges;
151
152        // Build initial quadtree (only used on native with Barnes-Hut)
153        #[cfg(not(target_arch = "wasm32"))]
154        let tree = if self.config.use_barnes_hut {
155            QuadTree::build(&self.positions, self.config.max_tree_depth)
156        } else {
157            QuadTree::build(&[], 1)
158        };
159
160        #[cfg(target_arch = "wasm32")]
161        let tree = QuadTree::build(&[], 1); // Empty tree for WASM
162
163        // Create GPU buffers
164        let buffers = LayoutBuffers::new(&self.ctx, &self.positions, &self.edges, tree.nodes())?;
165
166        // Update params
167        let params = self.create_params(tree.nodes().len() as u32);
168        buffers.update_params(&self.ctx, &params);
169
170        // Create bind group
171        let bind_group = self.pipeline.create_bind_group(&self.ctx, &buffers);
172
173        self.buffers = Some(buffers);
174        self.bind_group = Some(bind_group);
175        self.state = LayoutState::Paused;
176        self.iteration = 0;
177        self.frame_counter = 0;
178
179        #[cfg(target_arch = "wasm32")]
180        {
181            self.readback_pending = false;
182            let mut async_pos = self.async_positions.borrow_mut();
183            async_pos.data = self.positions.clone();
184            async_pos.pending_update = false;
185        }
186
187        tracing::info!(
188            "GPU layout initialized: {} nodes, {} edges",
189            self.positions.len(),
190            self.edges.len()
191        );
192
193        Ok(())
194    }
195
196    /// Create layout params struct.
197    fn create_params(&self, tree_size: u32) -> LayoutParams {
198        LayoutParams {
199            node_count: self.positions.len() as u32,
200            edge_count: self.edges.len() as u32,
201            tree_size,
202            dt: self.config.dt,
203            damping: self.config.damping,
204            repulsion: self.config.repulsion,
205            attraction: self.config.attraction,
206            theta: self.config.theta,
207            gravity: self.config.gravity,
208            ideal_length: self.config.ideal_length,
209        }
210    }
211
212    /// Start the layout.
213    pub fn start(&mut self) {
214        if self.state != LayoutState::Uninitialized {
215            self.state = LayoutState::Running;
216        }
217    }
218
219    /// Pause the layout.
220    pub fn pause(&mut self) {
221        if self.state == LayoutState::Running {
222            self.state = LayoutState::Paused;
223        }
224    }
225
226    /// Get current state.
227    pub fn state(&self) -> LayoutState {
228        self.state
229    }
230
231    /// Get current iteration count.
232    pub fn iteration(&self) -> u32 {
233        self.iteration
234    }
235
236    /// Run one iteration of the layout algorithm (native - blocking).
237    /// Returns the updated positions.
238    #[cfg(not(target_arch = "wasm32"))]
239    pub fn step(&mut self) -> Result<&[Position]> {
240        if self.state != LayoutState::Running {
241            return Ok(&self.positions);
242        }
243
244        if self.buffers.is_none() || self.bind_group.is_none() {
245            return Err(LayoutError::NotInitialized);
246        }
247
248        // Rebuild quadtree with current positions (CPU side)
249        if self.config.use_barnes_hut {
250            // Read back positions from GPU first
251            self.read_positions_blocking()?;
252
253            let tree = QuadTree::build(&self.positions, self.config.max_tree_depth);
254            let params = self.create_params(tree.nodes().len() as u32);
255
256            // Update buffers
257            let buffers = self.buffers.as_ref().unwrap();
258            if !buffers.update_tree(&self.ctx, tree.nodes()) {
259                return Ok(&self.positions);
260            }
261            buffers.update_params(&self.ctx, &params);
262        }
263
264        // Run compute shader
265        self.dispatch_compute();
266        self.iteration += 1;
267
268        // Read back positions
269        self.read_positions_blocking()?;
270
271        Ok(&self.positions)
272    }
273
274    /// Run one iteration of the layout algorithm (WASM - non-blocking).
275    /// Returns the updated positions (may be stale by up to readback_interval frames).
276    #[cfg(target_arch = "wasm32")]
277    pub fn step(&mut self) -> Result<&[Position]> {
278        if self.state != LayoutState::Running {
279            return Ok(&self.positions);
280        }
281
282        if self.buffers.is_none() || self.bind_group.is_none() {
283            return Err(LayoutError::NotInitialized);
284        }
285
286        // Check for async position updates (using RefCell for WASM)
287        {
288            let mut async_pos = self.async_positions.borrow_mut();
289            if async_pos.pending_update && async_pos.data.len() == self.positions.len() {
290                self.positions.copy_from_slice(&async_pos.data);
291                async_pos.pending_update = false;
292                self.readback_pending = false;
293            }
294        }
295
296        // Run compute shader (non-blocking on WASM)
297        self.dispatch_compute();
298        self.iteration += 1;
299        self.frame_counter += 1;
300
301        // Poll GPU (non-blocking) - always poll to process async callbacks
302        self.ctx.device.poll(wgpu::Maintain::Poll);
303
304        // Periodically request position readback
305        if self.frame_counter >= self._readback_interval && !self.readback_pending {
306            self.frame_counter = 0;
307            self.request_positions_async();
308        }
309
310        Ok(&self.positions)
311    }
312
313    /// Dispatch the compute shader.
314    fn dispatch_compute(&self) {
315        let buffers = self.buffers.as_ref().unwrap();
316        let bind_group = self.bind_group.as_ref().unwrap();
317
318        let mut encoder = self
319            .ctx
320            .device
321            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
322                label: Some("Layout Encoder"),
323            });
324
325        {
326            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
327                label: Some("Layout Compute Pass"),
328                timestamp_writes: None,
329            });
330            compute_pass.set_pipeline(&self.pipeline.pipeline);
331            compute_pass.set_bind_group(0, bind_group, &[]);
332
333            // Dispatch enough workgroups to cover all nodes
334            let workgroup_count = buffers.node_count.div_ceil(256);
335            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
336        }
337
338        self.ctx.queue.submit(Some(encoder.finish()));
339    }
340
341    /// Read positions back from GPU (blocking - native only).
342    #[cfg(not(target_arch = "wasm32"))]
343    fn read_positions_blocking(&mut self) -> Result<()> {
344        let buffers = self.buffers.as_ref().ok_or(LayoutError::NotInitialized)?;
345
346        let mut encoder = self
347            .ctx
348            .device
349            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
350                label: Some("Readback Encoder"),
351            });
352
353        let size = (self.positions.len() * std::mem::size_of::<Position>()) as u64;
354        encoder.copy_buffer_to_buffer(&buffers.positions, 0, &buffers.staging, 0, size);
355
356        self.ctx.queue.submit(Some(encoder.finish()));
357
358        // Map the staging buffer
359        let buffer_slice = buffers.staging.slice(..);
360        let (tx, rx) = std::sync::mpsc::channel();
361
362        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
363            tx.send(result).unwrap();
364        });
365
366        self.ctx.device.poll(wgpu::Maintain::Wait);
367
368        rx.recv()
369            .map_err(|_| LayoutError::Readback("Channel closed".into()))?
370            .map_err(|e| LayoutError::Readback(e.to_string()))?;
371
372        {
373            let data = buffer_slice.get_mapped_range();
374            let positions: &[Position] = bytemuck::cast_slice(&data);
375            self.positions.copy_from_slice(positions);
376        }
377
378        buffers.staging.unmap();
379
380        Ok(())
381    }
382
383    /// Request positions asynchronously (WASM - non-blocking).
384    /// The positions will be updated in `self.positions` when ready.
385    #[cfg(target_arch = "wasm32")]
386    fn request_positions_async(&mut self) {
387        let Some(buffers) = &self.buffers else {
388            return;
389        };
390
391        self.readback_pending = true;
392
393        let mut encoder = self
394            .ctx
395            .device
396            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
397                label: Some("Readback Encoder"),
398            });
399
400        let size = (self.positions.len() * std::mem::size_of::<Position>()) as u64;
401        encoder.copy_buffer_to_buffer(&buffers.positions, 0, &buffers.staging, 0, size);
402
403        self.ctx.queue.submit(Some(encoder.finish()));
404
405        // Clone staging buffer handle for callback
406        let staging = buffers.staging.clone();
407        let async_positions = Rc::clone(&self.async_positions);
408        let positions_len = self.positions.len();
409
410        // Map the staging buffer with async callback
411        // wgpu::Buffer is internally Arc-wrapped, so we can safely move it into the closure
412        // We use RefCell for WASM since it's single-threaded
413        buffers
414            .staging
415            .slice(..)
416            .map_async(wgpu::MapMode::Read, move |result| {
417                if result.is_ok() {
418                    // Create a new slice from the cloned buffer inside the callback
419                    let buffer_slice = staging.slice(..);
420                    let data = buffer_slice.get_mapped_range();
421                    let positions: &[Position] = bytemuck::cast_slice(&data);
422
423                    let mut async_pos = async_positions.borrow_mut();
424                    if async_pos.data.len() != positions_len {
425                        async_pos
426                            .data
427                            .resize(positions_len, Position { x: 0.0, y: 0.0 });
428                    }
429                    async_pos.data.copy_from_slice(positions);
430                    async_pos.pending_update = true;
431
432                    drop(data);
433                    staging.unmap();
434                }
435            });
436    }
437
438    /// Get current positions (without GPU readback).
439    pub fn positions(&self) -> &[Position] {
440        &self.positions
441    }
442
443    /// Update configuration.
444    pub fn set_config(&mut self, config: LayoutConfig) {
445        self.config = config;
446
447        if let Some(buffers) = &self.buffers {
448            #[cfg(not(target_arch = "wasm32"))]
449            let tree_size = if self.config.use_barnes_hut {
450                let tree = QuadTree::build(&self.positions, self.config.max_tree_depth);
451                tree.nodes().len() as u32
452            } else {
453                1
454            };
455
456            #[cfg(target_arch = "wasm32")]
457            let tree_size = 1u32;
458
459            let params = self.create_params(tree_size);
460            buffers.update_params(&self.ctx, &params);
461        }
462    }
463}
464
465/// Synchronous wrapper for environments without async runtime (native only).
466#[cfg(not(target_arch = "wasm32"))]
467pub mod sync {
468    use super::*;
469
470    /// Create a new GPU layout synchronously.
471    pub fn _new_layout(config: LayoutConfig) -> Result<GpuLayout> {
472        pollster::block_on(GpuLayout::new(config))
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_layout_config_default() {
482        let config = LayoutConfig::default();
483        assert!(config.use_barnes_hut);
484        assert!(config.theta > 0.0);
485    }
486}