Skip to main content

sift/
gpu_sift.rs

1// gpu_sift/mod.rs
2// WebGPU-based SIFT implementation
3
4use crate::keypoints::KeyPoint;
5use crate::utils::*;
6use std::sync::{Arc, Mutex};
7use wgpu;
8
9// ===== Configuration =====
10pub struct GpuSiftConfig {
11    pub octaves: u32,
12    pub scales: u32,             // includes +2 extra scales (e.g., 5 total)
13    pub base_sigma: f32,         // initial blur (e.g., 1.6)
14    pub contrast_threshold: f32, // e.g., 0.03
15    pub edge_threshold: f32,     // e.g., 10.0
16}
17
18impl Default for GpuSiftConfig {
19    fn default() -> Self {
20        Self {
21            octaves: 4,
22            scales: 5, // 5 scales → 4 DoG layers
23            base_sigma: 1.6,
24            contrast_threshold: 0.04, // Match CPU default (0.04 / num_intervals)
25            edge_threshold: 10.0,
26        }
27    }
28}
29
30// ===== GPU Resources =====
31pub struct GpuSiftContext {
32    device: Arc<wgpu::Device>,
33    queue: Arc<wgpu::Queue>,
34    pipelines: GpuPipelines,
35    #[allow(dead_code)]
36    kernels: GpuKernels,
37    buffers: Mutex<GpuSiftBuffers>,
38    #[allow(dead_code)]
39    config: GpuSiftConfig,
40}
41
42#[allow(dead_code)]
43struct GpuPipelines {
44    upload: wgpu::ComputePipeline,
45    blur_h: wgpu::ComputePipeline,
46    blur_v: wgpu::ComputePipeline,
47    downsample: wgpu::ComputePipeline,
48    dog: wgpu::ComputePipeline,
49    extrema: wgpu::ComputePipeline,
50    orientation: wgpu::ComputePipeline,
51    descriptor: wgpu::ComputePipeline,
52}
53
54#[allow(dead_code)]
55struct GpuKernels {
56    // Precomputed Gaussian kernels for each scale
57    kernels: Vec<Vec<f32>>, // kernels[scale_idx] = weights
58}
59
60struct GpuSiftBuffers {
61    // Pyramid heap
62    heap: wgpu::Buffer,
63    heap_capacity: u64,
64
65    // Metadata
66    meta_buffer: wgpu::Buffer,
67    level_offsets: wgpu::Buffer,
68    level_widths: wgpu::Buffer,
69    level_heights: wgpu::Buffer,
70
71    // Gaussian kernel weights (one buffer per scale)
72    #[allow(dead_code)]
73    kernel_buffers: Vec<wgpu::Buffer>,
74
75    // Keypoint buffers
76    extrema_counter: wgpu::Buffer,
77    keypoints_staging: wgpu::Buffer,
78    orientation_counter: wgpu::Buffer,
79    keypoints_final: wgpu::Buffer,
80    descriptors: wgpu::Buffer,
81
82    // Readback buffers
83    readback_counters: wgpu::Buffer,
84    readback_keypoints: wgpu::Buffer,
85    readback_descriptors: wgpu::Buffer,
86
87    // Current image dimensions
88    current_width: u32,
89    current_height: u32,
90}
91
92// Lightweight clone for async operations (no mutex held)
93struct GpuRunContext {
94    heap: wgpu::Buffer,
95    meta_buffer: wgpu::Buffer,
96    level_offsets: wgpu::Buffer,
97    level_widths: wgpu::Buffer,
98    level_heights: wgpu::Buffer,
99    #[allow(dead_code)]
100    kernel_buffers: Vec<wgpu::Buffer>,
101    extrema_counter: wgpu::Buffer,
102    keypoints_staging: wgpu::Buffer,
103    orientation_counter: wgpu::Buffer,
104    keypoints_final: wgpu::Buffer,
105    descriptors: wgpu::Buffer,
106}
107
108// ===== Public API =====
109impl GpuSiftContext {
110    pub async fn new(config: GpuSiftConfig) -> Result<Self, Box<dyn std::error::Error>> {
111        // Request WebGPU device/queue
112        let instance = wgpu::Instance::default();
113        let adapter = instance
114            .request_adapter(&wgpu::RequestAdapterOptions::default())
115            .await;
116
117        let adapter = match adapter {
118            Ok(a) => a,
119            Err(_) => return Err("No suitable GPU adapter found".into()),
120        };
121
122        let (device, queue) = adapter
123            .request_device(&wgpu::DeviceDescriptor {
124                label: Some("SIFT GPU Device"),
125                required_features: wgpu::Features::empty(),
126                required_limits: wgpu::Limits::default(),
127                memory_hints: Default::default(),
128                trace: Default::default(),
129            })
130            .await?;
131
132        let device = Arc::new(device);
133        let queue = Arc::new(queue);
134
135        // Precompute Gaussian kernels
136        let kernels = Self::compute_kernels(&config);
137
138        // Create compute pipelines
139        let pipelines = Self::create_pipelines(&device)?;
140
141        // Initialize empty buffers
142        let mut buffers = GpuSiftBuffers::new(&device, 0, 0);
143
144        // Initialize kernel weight buffers
145        buffers.initialize_kernel_buffers(&device, &queue, &kernels);
146
147        let buffers = Mutex::new(buffers);
148
149        Ok(Self {
150            device,
151            queue,
152            pipelines,
153            kernels,
154            buffers,
155            config,
156        })
157    }
158
159    pub async fn detect(
160        &self,
161        image: &[u8],
162        width: u32,
163        height: u32,
164    ) -> Result<(Vec<KeyPoint>, Vec<[u8; 128]>), Box<dyn std::error::Error>> {
165        let profile = std::env::var("SIFT_PROFILE").is_ok();
166        let total_start = web_time::Instant::now();
167
168        // 1. Ensure buffers are sized correctly
169        let t0 = web_time::Instant::now();
170        {
171            let mut buffers = self.buffers.lock().unwrap();
172            buffers.ensure_capacity(&self.device, width, height, &self.config);
173        }
174        if profile {
175            eprintln!("  [GPU] Buffer setup: {:?}", t0.elapsed());
176        }
177
178        // 2. Clone buffer handles (release lock before async)
179        let run_ctx = {
180            let buffers = self.buffers.lock().unwrap();
181            GpuRunContext {
182                heap: buffers.heap.clone(),
183                meta_buffer: buffers.meta_buffer.clone(),
184                level_offsets: buffers.level_offsets.clone(),
185                level_widths: buffers.level_widths.clone(),
186                level_heights: buffers.level_heights.clone(),
187                kernel_buffers: buffers.kernel_buffers.clone(),
188                extrema_counter: buffers.extrema_counter.clone(),
189                keypoints_staging: buffers.keypoints_staging.clone(),
190                orientation_counter: buffers.orientation_counter.clone(),
191                keypoints_final: buffers.keypoints_final.clone(),
192                descriptors: buffers.descriptors.clone(),
193            }
194        };
195
196        // 3. Build DoG pyramid on CPU (hybrid approach for now)
197        let t1 = web_time::Instant::now();
198        let gaussian_pyramid = self.build_pyramid_cpu(image, width, height);
199        if profile {
200            eprintln!("  [GPU] Gaussian pyramid (CPU): {:?}", t1.elapsed());
201        }
202
203        let t2 = web_time::Instant::now();
204        let dog_pyramid = self.compute_dog_cpu(&gaussian_pyramid, width, height);
205        if profile {
206            eprintln!("  [GPU] DoG computation (CPU): {:?}", t2.elapsed());
207        }
208
209        // 4. Upload DoG pyramid to GPU
210        let t3 = web_time::Instant::now();
211        self.upload_dog_pyramid(&dog_pyramid, &run_ctx);
212        if profile {
213            eprintln!("  [GPU] Upload to GPU: {:?}", t3.elapsed());
214        }
215
216        // 5. Execute GPU pipeline (extrema detection, orientation, descriptors)
217        let t4 = web_time::Instant::now();
218        self.execute_pipeline(width, height, &run_ctx).await?;
219        if profile {
220            eprintln!("  [GPU] GPU pipeline: {:?}", t4.elapsed());
221        }
222
223        // 6. Readback results
224        let t5 = web_time::Instant::now();
225        let (keypoints, descriptors) = self.readback_results(&run_ctx).await?;
226        if profile {
227            eprintln!("  [GPU] Readback: {:?}", t5.elapsed());
228            eprintln!("  [GPU] Total: {:?}", total_start.elapsed());
229        }
230
231        Ok((keypoints, descriptors))
232    }
233
234    fn compute_kernels(config: &GpuSiftConfig) -> GpuKernels {
235        let mut kernels = Vec::new();
236        let k = 2.0_f32.powf(1.0 / (config.scales as f32 - 2.0));
237
238        for s in 0..config.scales {
239            let sigma = config.base_sigma * k.powi(s as i32);
240            let radius = (4.0 * sigma).ceil() as usize;
241            let size = 2 * radius + 1;
242
243            let mut weights = vec![0.0; size];
244            let two_sigma_sq = 2.0 * sigma * sigma;
245            let mut sum = 0.0;
246
247            for (i, weight) in weights.iter_mut().enumerate() {
248                let x = (i as f32) - (radius as f32);
249                *weight = (-x * x / two_sigma_sq).exp();
250                sum += *weight;
251            }
252
253            // Normalize
254            for weight in weights.iter_mut() {
255                *weight /= sum;
256            }
257
258            kernels.push(weights);
259        }
260
261        GpuKernels { kernels }
262    }
263
264    fn create_pipelines(device: &wgpu::Device) -> Result<GpuPipelines, Box<dyn std::error::Error>> {
265        // Load shaders
266        let upload_src = include_str!("shaders/upload.wgsl");
267        let blur_src = include_str!("shaders/gaussian_blur.wgsl");
268        let downsample_src = include_str!("shaders/downsample.wgsl");
269        let dog_src = include_str!("shaders/dog.wgsl");
270        let extrema_src = include_str!("shaders/extrema_detect.wgsl");
271        let orientation_src = include_str!("shaders/orientation.wgsl");
272        let descriptor_src = include_str!("shaders/descriptor.wgsl");
273
274        let upload_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
275            label: Some("Upload Shader"),
276            source: wgpu::ShaderSource::Wgsl(upload_src.into()),
277        });
278
279        let blur_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
280            label: Some("Blur Shader"),
281            source: wgpu::ShaderSource::Wgsl(blur_src.into()),
282        });
283
284        let downsample_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
285            label: Some("Downsample Shader"),
286            source: wgpu::ShaderSource::Wgsl(downsample_src.into()),
287        });
288
289        let dog_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
290            label: Some("DoG Shader"),
291            source: wgpu::ShaderSource::Wgsl(dog_src.into()),
292        });
293
294        let extrema_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
295            label: Some("Extrema Shader"),
296            source: wgpu::ShaderSource::Wgsl(extrema_src.into()),
297        });
298
299        let orientation_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
300            label: Some("Orientation Shader"),
301            source: wgpu::ShaderSource::Wgsl(orientation_src.into()),
302        });
303
304        let descriptor_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
305            label: Some("Descriptor Shader"),
306            source: wgpu::ShaderSource::Wgsl(descriptor_src.into()),
307        });
308
309        // Create bind group layouts
310        // Upload pipeline: @group(0) params (uniform), input_u8, heap
311        let upload_bgl0 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
312            label: Some("Upload BGL 0"),
313            entries: &[
314                wgpu::BindGroupLayoutEntry {
315                    binding: 0,
316                    visibility: wgpu::ShaderStages::COMPUTE,
317                    ty: wgpu::BindingType::Buffer {
318                        ty: wgpu::BufferBindingType::Uniform,
319                        has_dynamic_offset: false,
320                        min_binding_size: None,
321                    },
322                    count: None,
323                },
324                wgpu::BindGroupLayoutEntry {
325                    binding: 1,
326                    visibility: wgpu::ShaderStages::COMPUTE,
327                    ty: wgpu::BindingType::Buffer {
328                        ty: wgpu::BufferBindingType::Storage { read_only: true },
329                        has_dynamic_offset: false,
330                        min_binding_size: None,
331                    },
332                    count: None,
333                },
334                wgpu::BindGroupLayoutEntry {
335                    binding: 2,
336                    visibility: wgpu::ShaderStages::COMPUTE,
337                    ty: wgpu::BindingType::Buffer {
338                        ty: wgpu::BufferBindingType::Storage { read_only: false },
339                        has_dynamic_offset: false,
340                        min_binding_size: None,
341                    },
342                    count: None,
343                },
344            ],
345        });
346
347        // Blur/downsample: @group(0) heap_in/out, @group(1) params + weights
348        let blur_bgl0 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
349            label: Some("Blur BGL 0"),
350            entries: &[
351                wgpu::BindGroupLayoutEntry {
352                    binding: 0,
353                    visibility: wgpu::ShaderStages::COMPUTE,
354                    ty: wgpu::BindingType::Buffer {
355                        ty: wgpu::BufferBindingType::Storage { read_only: true },
356                        has_dynamic_offset: false,
357                        min_binding_size: None,
358                    },
359                    count: None,
360                },
361                wgpu::BindGroupLayoutEntry {
362                    binding: 1,
363                    visibility: wgpu::ShaderStages::COMPUTE,
364                    ty: wgpu::BindingType::Buffer {
365                        ty: wgpu::BufferBindingType::Storage { read_only: false },
366                        has_dynamic_offset: false,
367                        min_binding_size: None,
368                    },
369                    count: None,
370                },
371            ],
372        });
373
374        let blur_bgl1 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
375            label: Some("Blur BGL 1"),
376            entries: &[
377                wgpu::BindGroupLayoutEntry {
378                    binding: 0,
379                    visibility: wgpu::ShaderStages::COMPUTE,
380                    ty: wgpu::BindingType::Buffer {
381                        ty: wgpu::BufferBindingType::Uniform,
382                        has_dynamic_offset: false,
383                        min_binding_size: None,
384                    },
385                    count: None,
386                },
387                wgpu::BindGroupLayoutEntry {
388                    binding: 1,
389                    visibility: wgpu::ShaderStages::COMPUTE,
390                    ty: wgpu::BindingType::Buffer {
391                        ty: wgpu::BufferBindingType::Storage { read_only: true },
392                        has_dynamic_offset: false,
393                        min_binding_size: None,
394                    },
395                    count: None,
396                },
397            ],
398        });
399
400        // DoG: @group(0) header, params, level_offsets, level_widths, level_heights, heap_in
401        //      @group(1) heap_out
402        let dog_bgl0 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
403            label: Some("DoG BGL 0"),
404            entries: &[
405                wgpu::BindGroupLayoutEntry {
406                    binding: 0,
407                    visibility: wgpu::ShaderStages::COMPUTE,
408                    ty: wgpu::BindingType::Buffer {
409                        ty: wgpu::BufferBindingType::Uniform,
410                        has_dynamic_offset: false,
411                        min_binding_size: None,
412                    },
413                    count: None,
414                },
415                wgpu::BindGroupLayoutEntry {
416                    binding: 1,
417                    visibility: wgpu::ShaderStages::COMPUTE,
418                    ty: wgpu::BindingType::Buffer {
419                        ty: wgpu::BufferBindingType::Uniform,
420                        has_dynamic_offset: false,
421                        min_binding_size: None,
422                    },
423                    count: None,
424                },
425                wgpu::BindGroupLayoutEntry {
426                    binding: 2,
427                    visibility: wgpu::ShaderStages::COMPUTE,
428                    ty: wgpu::BindingType::Buffer {
429                        ty: wgpu::BufferBindingType::Storage { read_only: true },
430                        has_dynamic_offset: false,
431                        min_binding_size: None,
432                    },
433                    count: None,
434                },
435                wgpu::BindGroupLayoutEntry {
436                    binding: 3,
437                    visibility: wgpu::ShaderStages::COMPUTE,
438                    ty: wgpu::BindingType::Buffer {
439                        ty: wgpu::BufferBindingType::Storage { read_only: true },
440                        has_dynamic_offset: false,
441                        min_binding_size: None,
442                    },
443                    count: None,
444                },
445                wgpu::BindGroupLayoutEntry {
446                    binding: 4,
447                    visibility: wgpu::ShaderStages::COMPUTE,
448                    ty: wgpu::BindingType::Buffer {
449                        ty: wgpu::BufferBindingType::Storage { read_only: true },
450                        has_dynamic_offset: false,
451                        min_binding_size: None,
452                    },
453                    count: None,
454                },
455                wgpu::BindGroupLayoutEntry {
456                    binding: 5,
457                    visibility: wgpu::ShaderStages::COMPUTE,
458                    ty: wgpu::BindingType::Buffer {
459                        ty: wgpu::BufferBindingType::Storage { read_only: true },
460                        has_dynamic_offset: false,
461                        min_binding_size: None,
462                    },
463                    count: None,
464                },
465            ],
466        });
467
468        let dog_bgl1 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
469            label: Some("DoG BGL 1"),
470            entries: &[wgpu::BindGroupLayoutEntry {
471                binding: 0,
472                visibility: wgpu::ShaderStages::COMPUTE,
473                ty: wgpu::BindingType::Buffer {
474                    ty: wgpu::BufferBindingType::Storage { read_only: false },
475                    has_dynamic_offset: false,
476                    min_binding_size: None,
477                },
478                count: None,
479            }],
480        });
481
482        // Extrema: @group(0) meta, @group(1) heap, @group(2) output
483        let extrema_bgl0 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
484            label: Some("Extrema BGL 0"),
485            entries: &[
486                wgpu::BindGroupLayoutEntry {
487                    binding: 0,
488                    visibility: wgpu::ShaderStages::COMPUTE,
489                    ty: wgpu::BindingType::Buffer {
490                        ty: wgpu::BufferBindingType::Storage { read_only: true },
491                        has_dynamic_offset: false,
492                        min_binding_size: None,
493                    },
494                    count: None,
495                },
496                wgpu::BindGroupLayoutEntry {
497                    binding: 1,
498                    visibility: wgpu::ShaderStages::COMPUTE,
499                    ty: wgpu::BindingType::Buffer {
500                        ty: wgpu::BufferBindingType::Storage { read_only: true },
501                        has_dynamic_offset: false,
502                        min_binding_size: None,
503                    },
504                    count: None,
505                },
506                wgpu::BindGroupLayoutEntry {
507                    binding: 2,
508                    visibility: wgpu::ShaderStages::COMPUTE,
509                    ty: wgpu::BindingType::Buffer {
510                        ty: wgpu::BufferBindingType::Storage { read_only: true },
511                        has_dynamic_offset: false,
512                        min_binding_size: None,
513                    },
514                    count: None,
515                },
516                wgpu::BindGroupLayoutEntry {
517                    binding: 3,
518                    visibility: wgpu::ShaderStages::COMPUTE,
519                    ty: wgpu::BindingType::Buffer {
520                        ty: wgpu::BufferBindingType::Storage { read_only: true },
521                        has_dynamic_offset: false,
522                        min_binding_size: None,
523                    },
524                    count: None,
525                },
526            ],
527        });
528
529        let extrema_bgl1 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
530            label: Some("Extrema BGL 1"),
531            entries: &[wgpu::BindGroupLayoutEntry {
532                binding: 0,
533                visibility: wgpu::ShaderStages::COMPUTE,
534                ty: wgpu::BindingType::Buffer {
535                    ty: wgpu::BufferBindingType::Storage { read_only: true },
536                    has_dynamic_offset: false,
537                    min_binding_size: None,
538                },
539                count: None,
540            }],
541        });
542
543        let extrema_bgl2 = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
544            label: Some("Extrema BGL 2"),
545            entries: &[
546                wgpu::BindGroupLayoutEntry {
547                    binding: 0,
548                    visibility: wgpu::ShaderStages::COMPUTE,
549                    ty: wgpu::BindingType::Buffer {
550                        ty: wgpu::BufferBindingType::Storage { read_only: false },
551                        has_dynamic_offset: false,
552                        min_binding_size: None,
553                    },
554                    count: None,
555                },
556                wgpu::BindGroupLayoutEntry {
557                    binding: 1,
558                    visibility: wgpu::ShaderStages::COMPUTE,
559                    ty: wgpu::BindingType::Buffer {
560                        ty: wgpu::BufferBindingType::Storage { read_only: false },
561                        has_dynamic_offset: false,
562                        min_binding_size: None,
563                    },
564                    count: None,
565                },
566            ],
567        });
568
569        // Create compute pipelines
570        let upload_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
571            label: Some("Upload Layout"),
572            bind_group_layouts: &[&upload_bgl0],
573            push_constant_ranges: &[],
574        });
575
576        let blur_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
577            label: Some("Blur Layout"),
578            bind_group_layouts: &[&blur_bgl0, &blur_bgl1],
579            push_constant_ranges: &[],
580        });
581
582        let dog_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
583            label: Some("DoG Layout"),
584            bind_group_layouts: &[&dog_bgl0, &dog_bgl1],
585            push_constant_ranges: &[],
586        });
587
588        let extrema_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
589            label: Some("Extrema Layout"),
590            bind_group_layouts: &[&extrema_bgl0, &extrema_bgl1, &extrema_bgl2],
591            push_constant_ranges: &[],
592        });
593
594        let upload = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
595            label: Some("Upload Pipeline"),
596            layout: Some(&upload_layout),
597            module: &upload_module,
598            entry_point: Some("upload_grayscale"),
599            compilation_options: Default::default(),
600            cache: None,
601        });
602
603        let blur_h = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
604            label: Some("Blur H Pipeline"),
605            layout: Some(&blur_layout),
606            module: &blur_module,
607            entry_point: Some("gaussian_blur"),
608            compilation_options: Default::default(),
609            cache: None,
610        });
611
612        let blur_v = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
613            label: Some("Blur V Pipeline"),
614            layout: Some(&blur_layout),
615            module: &blur_module,
616            entry_point: Some("gaussian_blur"),
617            compilation_options: Default::default(),
618            cache: None,
619        });
620
621        let downsample = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
622            label: Some("Downsample Pipeline"),
623            layout: Some(&blur_layout),
624            module: &downsample_module,
625            entry_point: Some("downsample"),
626            compilation_options: Default::default(),
627            cache: None,
628        });
629
630        let dog = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
631            label: Some("DoG Pipeline"),
632            layout: Some(&dog_layout),
633            module: &dog_module,
634            entry_point: Some("compute_dog"),
635            compilation_options: Default::default(),
636            cache: None,
637        });
638
639        let extrema = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
640            label: Some("Extrema Pipeline"),
641            layout: Some(&extrema_layout),
642            module: &extrema_module,
643            entry_point: Some("detect_extrema"),
644            compilation_options: Default::default(),
645            cache: None,
646        });
647
648        // Orientation and descriptor pipelines (simplified, need proper layouts)
649        let orientation = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
650            label: Some("Orientation Pipeline"),
651            layout: None, // auto-layout for now
652            module: &orientation_module,
653            entry_point: Some("compute_orientation"),
654            compilation_options: Default::default(),
655            cache: None,
656        });
657
658        let descriptor = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
659            label: Some("Descriptor Pipeline"),
660            layout: None, // auto-layout for now
661            module: &descriptor_module,
662            entry_point: Some("compute_descriptor"),
663            compilation_options: Default::default(),
664            cache: None,
665        });
666
667        Ok(GpuPipelines {
668            upload,
669            blur_h,
670            blur_v,
671            downsample,
672            dog,
673            extrema,
674            orientation,
675            descriptor,
676        })
677    }
678
679    #[allow(dead_code)]
680    fn upload_image(
681        &self,
682        image: &[u8],
683        width: u32,
684        height: u32,
685        ctx: &GpuRunContext,
686    ) -> Result<(), Box<dyn std::error::Error>> {
687        // Write image data to a temporary staging area at the END of heap
688        // The upload shader will convert u8->f16 and write to the beginning
689        let image_size = (width * height) as usize;
690        let staging_offset = ctx.heap.size() as usize - ((image_size + 3) / 4) * 4; // Aligned
691
692        // Create a padded buffer for u8 data (4-byte aligned)
693        let mut padded_image = vec![0u8; ((image_size + 3) / 4) * 4];
694        padded_image[..image_size].copy_from_slice(image);
695
696        self.queue
697            .write_buffer(&ctx.heap, staging_offset as u64, &padded_image);
698        Ok(())
699    }
700
701    /// Build Gaussian scale space on CPU and upload to GPU
702    /// This is a hybrid approach: CPU builds pyramid, GPU does extrema detection
703    /// Uses INCREMENTAL blurring for efficiency (each scale from previous)
704    fn build_pyramid_cpu(&self, image: &[u8], width: u32, height: u32) -> Vec<f32> {
705        // k is the scale multiplier between adjacent scales
706        // Standard SIFT uses scales-3 intervals
707        let intervals = (self.config.scales as f32 - 3.0).max(1.0);
708        let k = 2.0_f32.powf(1.0 / intervals);
709
710        // Precompute differential sigmas for incremental blurring
711        // sigma_total[s] = base_sigma * k^s
712        // sigma_diff[s] = sqrt(sigma_total[s]^2 - sigma_total[s-1]^2)
713        let mut diff_sigmas = vec![0.0f32; self.config.scales as usize];
714        let assumed_blur = 0.5f32; // assumed initial blur of input image
715
716        for s in 0..self.config.scales as usize {
717            if s == 0 {
718                // First scale: blur from assumed_blur to base_sigma
719                let sigma_target = self.config.base_sigma;
720                if sigma_target > assumed_blur {
721                    diff_sigmas[s] =
722                        (sigma_target * sigma_target - assumed_blur * assumed_blur).sqrt();
723                } else {
724                    diff_sigmas[s] = 0.0;
725                }
726            } else {
727                // Incremental blur from previous scale
728                let sigma_prev = self.config.base_sigma * k.powi((s - 1) as i32);
729                let sigma_curr = self.config.base_sigma * k.powi(s as i32);
730                diff_sigmas[s] = (sigma_curr * sigma_curr - sigma_prev * sigma_prev).sqrt();
731            }
732        }
733
734        let mut pyramid_data = Vec::new();
735        let mut current_img: Vec<f32> = image.iter().map(|&p| p as f32 / 255.0).collect();
736        let mut w = width as usize;
737        let mut h = height as usize;
738
739        for octave in 0..self.config.octaves {
740            if w < 8 || h < 8 {
741                break;
742            }
743
744            // Build scales for this octave using incremental blur
745            let mut octave_images: Vec<Vec<f32>> = Vec::with_capacity(self.config.scales as usize);
746
747            for s in 0..self.config.scales as usize {
748                let blurred = if s == 0 {
749                    if octave == 0 && diff_sigmas[0] > 0.01 {
750                        // First octave, first scale: blur from input
751                        self.gaussian_blur_cpu(&current_img, w, h, diff_sigmas[0])
752                    } else {
753                        // Other octaves: first scale comes from downsampling (already at correct blur)
754                        current_img.clone()
755                    }
756                } else {
757                    // Incremental blur from previous scale in this octave
758                    let prev_scale = &octave_images[s - 1];
759                    if diff_sigmas[s] > 0.01 {
760                        self.gaussian_blur_cpu(prev_scale, w, h, diff_sigmas[s])
761                    } else {
762                        prev_scale.clone()
763                    }
764                };
765
766                pyramid_data.extend_from_slice(&blurred);
767                octave_images.push(blurred);
768            }
769
770            // Downsample from scale (scales-3) for next octave
771            // This is the scale with 2x the base blur
772            let downsample_idx = (self.config.scales as usize).saturating_sub(3);
773            current_img = self.downsample_2x(&octave_images[downsample_idx], w, h);
774            w /= 2;
775            h /= 2;
776        }
777
778        pyramid_data
779    }
780
781    /// Fast 2x downsample by taking every other pixel
782    fn downsample_2x(&self, img: &[f32], width: usize, height: usize) -> Vec<f32> {
783        let new_w = width / 2;
784        let new_h = height / 2;
785        let mut result = vec![0.0f32; new_w * new_h];
786
787        result
788            .par_chunks_mut(new_w)
789            .enumerate()
790            .for_each(|(y, row)| {
791                for x in 0..new_w {
792                    row[x] = img[(y * 2) * width + (x * 2)];
793                }
794            });
795
796        result
797    }
798
799    fn gaussian_blur_cpu(&self, img: &[f32], width: usize, height: usize, sigma: f32) -> Vec<f32> {
800        // Handle edge case of very small sigma
801        if sigma < 0.1 {
802            return img.to_vec();
803        }
804
805        // Use smaller radius (sigma * 2.5 is sufficient for SIFT, saves computation)
806        let radius = (sigma * 2.5).ceil() as i32;
807        let size = (2 * radius + 1).max(1) as usize;
808
809        // Build symmetric kernel (only store half + center)
810        let mut kernel = vec![0.0f32; size];
811        let mut sum = 0.0f32;
812        let two_sigma_sq = 2.0 * sigma * sigma;
813        for i in 0..size {
814            let x = (i as i32 - radius) as f32;
815            kernel[i] = (-x * x / two_sigma_sq).exp();
816            sum += kernel[i];
817        }
818        let norm = 1.0 / sum;
819        for k in kernel.iter_mut() {
820            *k *= norm;
821        }
822
823        // For small kernels, use simple approach
824        // For larger kernels, use optimized symmetric approach
825        if size <= 5 {
826            return self.gaussian_blur_simple(img, width, height, &kernel, radius);
827        }
828
829        // Optimized: exploit symmetry - kernel[i] == kernel[size-1-i]
830        // Horizontal pass - parallel over rows
831        let mut temp = vec![0.0f32; width * height];
832        temp.par_chunks_mut(width).enumerate().for_each(|(y, row)| {
833            let row_start = y * width;
834            for x in 0..width {
835                // Center weight
836                let mut val = img[row_start + x] * kernel[radius as usize];
837
838                // Symmetric pairs
839                for i in 1..=radius as usize {
840                    let left = if x >= i { x - i } else { 0 };
841                    let right = (x + i).min(width - 1);
842                    val += (img[row_start + left] + img[row_start + right])
843                        * kernel[radius as usize + i];
844                }
845                row[x] = val;
846            }
847        });
848
849        // Vertical pass - parallel over rows (better cache locality)
850        let mut result = vec![0.0f32; width * height];
851
852        // Process in chunks for better cache utilization
853        let chunk_height = 64.min(height);
854        result
855            .par_chunks_mut(chunk_height * width)
856            .enumerate()
857            .for_each(|(chunk_idx, chunk)| {
858                let y_start = chunk_idx * chunk_height;
859                let y_end: usize = (y_start + chunk_height).min(height); // Fix E0282
860
861                for local_y in 0..(y_end - y_start) {
862                    let y = y_start + local_y;
863                    let row_offset = local_y * width;
864
865                    for x in 0..width {
866                        // Center weight
867                        let mut val = temp[y * width + x] * kernel[radius as usize];
868
869                        // Symmetric pairs
870                        for i in 1..=radius as usize {
871                            let top = if y >= i { y - i } else { 0 };
872                            let bottom: usize = (y + i).min(height - 1); // Fix E0282
873                            val += (temp[top * width + x] + temp[bottom * width + x])
874                                * kernel[radius as usize + i];
875                        }
876                        chunk[row_offset + x] = val;
877                    }
878                }
879            });
880
881        result
882    }
883
884    /// Simple blur for small kernels (avoids overhead of symmetric optimization)
885    fn gaussian_blur_simple(
886        &self,
887        img: &[f32],
888        width: usize,
889        height: usize,
890        kernel: &[f32],
891        radius: i32,
892    ) -> Vec<f32> {
893        let size = kernel.len();
894
895        // Horizontal pass
896        let mut temp = vec![0.0f32; width * height];
897        temp.par_chunks_mut(width).enumerate().for_each(|(y, row)| {
898            for x in 0..width {
899                let mut val = 0.0f32;
900                for i in 0..size {
901                    let sx = (x as i32 + i as i32 - radius).clamp(0, width as i32 - 1) as usize;
902                    val += img[y * width + sx] * kernel[i];
903                }
904                row[x] = val;
905            }
906        });
907
908        // Vertical pass
909        let mut result = vec![0.0f32; width * height];
910        result
911            .par_chunks_mut(width)
912            .enumerate()
913            .for_each(|(y, row)| {
914                for x in 0..width {
915                    let mut val = 0.0f32;
916                    for i in 0..size {
917                        let sy =
918                            (y as i32 + i as i32 - radius).clamp(0, height as i32 - 1) as usize;
919                        val += temp[sy * width + x] * kernel[i];
920                    }
921                    row[x] = val;
922                }
923            });
924
925        result
926    }
927
928    /// Compute DoG from Gaussian pyramid (parallelized)
929    fn compute_dog_cpu(&self, gaussian_pyramid: &[f32], width: u32, height: u32) -> Vec<f32> {
930        let scales = self.config.scales as usize;
931        let dog_scales = scales - 1;
932
933        // First pass: collect octave info
934        let mut octave_info = Vec::new();
935        let mut w = width as usize;
936        let mut h = height as usize;
937        let mut offset = 0usize;
938
939        for _ in 0..self.config.octaves {
940            if w < 8 || h < 8 {
941                break;
942            }
943            let level_size = w * h;
944            octave_info.push((offset, level_size, w, h));
945            offset += scales * level_size;
946            w /= 2;
947            h /= 2;
948        }
949
950        // Compute total DoG size
951        let total_dog_size: usize = octave_info
952            .iter()
953            .map(|(_, level_size, _, _)| level_size * dog_scales)
954            .sum();
955
956        let mut dog_data = vec![0.0f32; total_dog_size];
957
958        // Parallel computation of DoG for each octave
959        let mut dog_offset = 0usize;
960        for (gauss_offset, level_size, _, _) in &octave_info {
961            for d in 0..dog_scales {
962                let scale1_start = gauss_offset + d * level_size;
963                let scale2_start = gauss_offset + (d + 1) * level_size;
964                let dog_start = dog_offset + d * level_size;
965
966                dog_data[dog_start..dog_start + level_size]
967                    .par_iter_mut()
968                    .enumerate()
969                    .for_each(|(i, dog_val)| {
970                        *dog_val =
971                            gaussian_pyramid[scale2_start + i] - gaussian_pyramid[scale1_start + i];
972                    });
973            }
974            dog_offset += dog_scales * level_size;
975        }
976
977        dog_data
978    }
979
980    /// Upload DoG pyramid to GPU in f16 format (parallelized conversion)
981    fn upload_dog_pyramid(&self, dog_data: &[f32], ctx: &GpuRunContext) {
982        // Convert f32 to f16 packed as u32 (parallel)
983        let packed_data: Vec<u32> = dog_data
984            .par_chunks(2)
985            .map(|chunk| {
986                let v0 = chunk[0];
987                let v1 = if chunk.len() > 1 { chunk[1] } else { 0.0 };
988                half::f16::from_f32(v0).to_bits() as u32
989                    | ((half::f16::from_f32(v1).to_bits() as u32) << 16)
990            })
991            .collect();
992
993        let bytes: Vec<u8> = packed_data
994            .iter()
995            .flat_map(|v: &u32| v.to_le_bytes())
996            .collect(); // Fix E0282
997        self.queue.write_buffer(&ctx.heap, 0, &bytes);
998    }
999
1000    async fn execute_pipeline(
1001        &self,
1002        width: u32,
1003        height: u32,
1004        ctx: &GpuRunContext,
1005    ) -> Result<(), Box<dyn std::error::Error>> {
1006        // DoG pyramid has (scales-1) layers per octave
1007        let dog_scales = self.config.scales - 1;
1008
1009        // Compute metadata: level offsets, widths, heights for DoG pyramid
1010        let mut level_offsets_data = Vec::new();
1011        let mut level_widths_data = Vec::new();
1012        let mut level_heights_data = Vec::new();
1013
1014        let mut offset = 0u32;
1015        let mut w = width;
1016        let mut h = height;
1017        let mut actual_octaves = 0u32;
1018
1019        for octave in 0..self.config.octaves {
1020            if w < 8 || h < 8 {
1021                break;
1022            }
1023            actual_octaves = octave + 1;
1024
1025            // DoG has (scales-1) layers per octave
1026            for _scale in 0..dog_scales {
1027                level_offsets_data.push(offset);
1028                level_widths_data.push(w);
1029                level_heights_data.push(h);
1030
1031                let pixels = w * h;
1032                offset += (pixels + 1) / 2; // f16 packed as u32
1033            }
1034
1035            w /= 2;
1036            h /= 2;
1037        }
1038
1039        // Write metadata to GPU
1040        let offsets_bytes: Vec<u8> = level_offsets_data
1041            .iter()
1042            .flat_map(|v| v.to_le_bytes())
1043            .collect();
1044        let widths_bytes: Vec<u8> = level_widths_data
1045            .iter()
1046            .flat_map(|v| v.to_le_bytes())
1047            .collect();
1048        let heights_bytes: Vec<u8> = level_heights_data
1049            .iter()
1050            .flat_map(|v| v.to_le_bytes())
1051            .collect();
1052
1053        self.queue
1054            .write_buffer(&ctx.level_offsets, 0, &offsets_bytes);
1055        self.queue.write_buffer(&ctx.level_widths, 0, &widths_bytes);
1056        self.queue
1057            .write_buffer(&ctx.level_heights, 0, &heights_bytes);
1058
1059        // Write pyramid metadata
1060        // Note: extrema shader expects dog_scales (scales-1), not scales
1061        let meta_data = [
1062            actual_octaves,
1063            dog_scales,     // Number of DoG scales per octave
1064            dog_scales - 2, // Usable DoG layers for extrema (need 3 adjacent)
1065            width,
1066            height,
1067            self.config.base_sigma.to_bits(),
1068            self.config.contrast_threshold.to_bits(),
1069            self.config.edge_threshold.to_bits(),
1070        ];
1071        let meta_bytes: Vec<u8> = meta_data.iter().flat_map(|v| v.to_le_bytes()).collect();
1072        self.queue.write_buffer(&ctx.meta_buffer, 0, &meta_bytes);
1073
1074        // Clear counters
1075        self.queue.write_buffer(&ctx.extrema_counter, 0, &[0u8; 4]);
1076        self.queue
1077            .write_buffer(&ctx.orientation_counter, 0, &[0u8; 4]);
1078
1079        // Create all bind groups upfront
1080        // ===== Extrema Detection Bind Groups =====
1081        let extrema_bg0 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1082            label: Some("Extrema BG0"),
1083            layout: &self.pipelines.extrema.get_bind_group_layout(0),
1084            entries: &[
1085                wgpu::BindGroupEntry {
1086                    binding: 0,
1087                    resource: ctx.meta_buffer.as_entire_binding(),
1088                },
1089                wgpu::BindGroupEntry {
1090                    binding: 1,
1091                    resource: ctx.level_offsets.as_entire_binding(),
1092                },
1093                wgpu::BindGroupEntry {
1094                    binding: 2,
1095                    resource: ctx.level_widths.as_entire_binding(),
1096                },
1097                wgpu::BindGroupEntry {
1098                    binding: 3,
1099                    resource: ctx.level_heights.as_entire_binding(),
1100                },
1101            ],
1102        });
1103
1104        let extrema_bg1 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1105            label: Some("Extrema BG1"),
1106            layout: &self.pipelines.extrema.get_bind_group_layout(1),
1107            entries: &[wgpu::BindGroupEntry {
1108                binding: 0,
1109                resource: ctx.heap.as_entire_binding(),
1110            }],
1111        });
1112
1113        let extrema_bg2 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1114            label: Some("Extrema BG2"),
1115            layout: &self.pipelines.extrema.get_bind_group_layout(2),
1116            entries: &[
1117                wgpu::BindGroupEntry {
1118                    binding: 0,
1119                    resource: ctx.extrema_counter.as_entire_binding(),
1120                },
1121                wgpu::BindGroupEntry {
1122                    binding: 1,
1123                    resource: ctx.keypoints_staging.as_entire_binding(),
1124                },
1125            ],
1126        });
1127
1128        // ===== Orientation Bind Groups =====
1129        let orient_bg0 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1130            label: Some("Orientation BG0"),
1131            layout: &self.pipelines.orientation.get_bind_group_layout(0),
1132            entries: &[
1133                wgpu::BindGroupEntry {
1134                    binding: 0,
1135                    resource: ctx.meta_buffer.as_entire_binding(),
1136                },
1137                wgpu::BindGroupEntry {
1138                    binding: 1,
1139                    resource: ctx.level_offsets.as_entire_binding(),
1140                },
1141                wgpu::BindGroupEntry {
1142                    binding: 2,
1143                    resource: ctx.level_widths.as_entire_binding(),
1144                },
1145                wgpu::BindGroupEntry {
1146                    binding: 3,
1147                    resource: ctx.level_heights.as_entire_binding(),
1148                },
1149            ],
1150        });
1151
1152        let orient_bg1 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1153            label: Some("Orientation BG1"),
1154            layout: &self.pipelines.orientation.get_bind_group_layout(1),
1155            entries: &[wgpu::BindGroupEntry {
1156                binding: 0,
1157                resource: ctx.heap.as_entire_binding(),
1158            }],
1159        });
1160
1161        let orient_bg2 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1162            label: Some("Orientation BG2"),
1163            layout: &self.pipelines.orientation.get_bind_group_layout(2),
1164            entries: &[
1165                wgpu::BindGroupEntry {
1166                    binding: 0,
1167                    resource: ctx.keypoints_staging.as_entire_binding(),
1168                },
1169                wgpu::BindGroupEntry {
1170                    binding: 1,
1171                    resource: ctx.extrema_counter.as_entire_binding(),
1172                },
1173            ],
1174        });
1175
1176        let orient_bg3 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1177            label: Some("Orientation BG3"),
1178            layout: &self.pipelines.orientation.get_bind_group_layout(3),
1179            entries: &[
1180                wgpu::BindGroupEntry {
1181                    binding: 0,
1182                    resource: ctx.orientation_counter.as_entire_binding(),
1183                },
1184                wgpu::BindGroupEntry {
1185                    binding: 1,
1186                    resource: ctx.keypoints_final.as_entire_binding(),
1187                },
1188            ],
1189        });
1190
1191        // ===== Descriptor Bind Groups =====
1192        let desc_bg0 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1193            label: Some("Descriptor BG0"),
1194            layout: &self.pipelines.descriptor.get_bind_group_layout(0),
1195            entries: &[
1196                wgpu::BindGroupEntry {
1197                    binding: 0,
1198                    resource: ctx.meta_buffer.as_entire_binding(),
1199                },
1200                wgpu::BindGroupEntry {
1201                    binding: 1,
1202                    resource: ctx.level_offsets.as_entire_binding(),
1203                },
1204                wgpu::BindGroupEntry {
1205                    binding: 2,
1206                    resource: ctx.level_widths.as_entire_binding(),
1207                },
1208                wgpu::BindGroupEntry {
1209                    binding: 3,
1210                    resource: ctx.level_heights.as_entire_binding(),
1211                },
1212            ],
1213        });
1214
1215        let desc_bg1 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1216            label: Some("Descriptor BG1"),
1217            layout: &self.pipelines.descriptor.get_bind_group_layout(1),
1218            entries: &[wgpu::BindGroupEntry {
1219                binding: 0,
1220                resource: ctx.heap.as_entire_binding(),
1221            }],
1222        });
1223
1224        let desc_bg2 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1225            label: Some("Descriptor BG2"),
1226            layout: &self.pipelines.descriptor.get_bind_group_layout(2),
1227            entries: &[
1228                wgpu::BindGroupEntry {
1229                    binding: 0,
1230                    resource: ctx.keypoints_final.as_entire_binding(),
1231                },
1232                wgpu::BindGroupEntry {
1233                    binding: 1,
1234                    resource: ctx.orientation_counter.as_entire_binding(),
1235                },
1236            ],
1237        });
1238
1239        let desc_bg3 = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1240            label: Some("Descriptor BG3"),
1241            layout: &self.pipelines.descriptor.get_bind_group_layout(3),
1242            entries: &[wgpu::BindGroupEntry {
1243                binding: 0,
1244                resource: ctx.descriptors.as_entire_binding(),
1245            }],
1246        });
1247
1248        // ===== SINGLE ENCODER - All stages in one command buffer =====
1249        let mut encoder = self
1250            .device
1251            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1252                label: Some("SIFT Full Pipeline Encoder"),
1253            });
1254
1255        // ===== STAGE 1: Extrema Detection =====
1256        {
1257            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1258                label: Some("Extrema Pass"),
1259                timestamp_writes: None,
1260            });
1261
1262            compute_pass.set_pipeline(&self.pipelines.extrema);
1263            compute_pass.set_bind_group(0, &extrema_bg0, &[]);
1264            compute_pass.set_bind_group(1, &extrema_bg1, &[]);
1265            compute_pass.set_bind_group(2, &extrema_bg2, &[]);
1266
1267            // Calculate total z workgroups: octaves * (dog_scales - 2)
1268            let usable_dog_scales = dog_scales.saturating_sub(2).max(1);
1269            let total_z = actual_octaves * usable_dog_scales;
1270
1271            let workgroups_x = (width + 15) / 16;
1272            let workgroups_y = (height + 15) / 16;
1273
1274            compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, total_z);
1275        }
1276        // Pass ends here - implicit barrier between compute passes
1277
1278        // ===== STAGE 2: Orientation Assignment =====
1279        {
1280            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1281                label: Some("Orientation Pass"),
1282                timestamp_writes: None,
1283            });
1284
1285            compute_pass.set_pipeline(&self.pipelines.orientation);
1286            compute_pass.set_bind_group(0, &orient_bg0, &[]);
1287            compute_pass.set_bind_group(1, &orient_bg1, &[]);
1288            compute_pass.set_bind_group(2, &orient_bg2, &[]);
1289            compute_pass.set_bind_group(3, &orient_bg3, &[]);
1290
1291            // Dispatch orientation computation
1292            let max_keypoints = 1024;
1293            let workgroups = (max_keypoints * 36 + 35) / 36;
1294            compute_pass.dispatch_workgroups(workgroups, 1, 1);
1295        }
1296        // Pass ends here - implicit barrier
1297
1298        // ===== STAGE 3: Descriptor Computation =====
1299        {
1300            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1301                label: Some("Descriptor Pass"),
1302                timestamp_writes: None,
1303            });
1304
1305            compute_pass.set_pipeline(&self.pipelines.descriptor);
1306            compute_pass.set_bind_group(0, &desc_bg0, &[]);
1307            compute_pass.set_bind_group(1, &desc_bg1, &[]);
1308            compute_pass.set_bind_group(2, &desc_bg2, &[]);
1309            compute_pass.set_bind_group(3, &desc_bg3, &[]);
1310
1311            // Each descriptor needs 4 threads
1312            let max_final_keypoints = 2048;
1313            let workgroups = (max_final_keypoints * 4 + 3) / 4;
1314            compute_pass.dispatch_workgroups(workgroups, 1, 1);
1315        }
1316
1317        // SINGLE SUBMIT - all stages in one command buffer
1318        self.queue.submit(Some(encoder.finish()));
1319
1320        // Wait for all GPU work to complete (only ONE wait)
1321        let _ = self.device.poll(wgpu::MaintainBase::Wait);
1322
1323        Ok(())
1324    }
1325
1326    async fn readback_results(
1327        &self,
1328        ctx: &GpuRunContext,
1329    ) -> Result<(Vec<KeyPoint>, Vec<[u8; 128]>), Box<dyn std::error::Error>> {
1330        let mut encoder = self
1331            .device
1332            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1333                label: Some("Readback Encoder"),
1334            });
1335
1336        let buffers = self.buffers.lock().unwrap();
1337
1338        // Copy counters
1339        encoder.copy_buffer_to_buffer(&ctx.extrema_counter, 0, &buffers.readback_counters, 0, 4);
1340        encoder.copy_buffer_to_buffer(
1341            &ctx.orientation_counter,
1342            0,
1343            &buffers.readback_counters,
1344            4,
1345            4,
1346        );
1347
1348        self.queue.submit(Some(encoder.finish()));
1349
1350        // Map and read counters
1351        let counters_slice = buffers.readback_counters.slice(..);
1352        let (tx, rx) = std::sync::mpsc::channel();
1353        counters_slice.map_async(wgpu::MapMode::Read, move |result| {
1354            tx.send(result).unwrap();
1355        });
1356
1357        let _ = self.device.poll(wgpu::MaintainBase::Wait);
1358        rx.recv()??;
1359
1360        let counters_data = counters_slice.get_mapped_range();
1361        let orientation_count = u32::from_le_bytes([
1362            counters_data[4],
1363            counters_data[5],
1364            counters_data[6],
1365            counters_data[7],
1366        ]);
1367        drop(counters_data);
1368        buffers.readback_counters.unmap();
1369
1370        let num_keypoints = orientation_count.min(65536);
1371
1372        // Early return if no keypoints found
1373        if num_keypoints == 0 {
1374            return Ok((Vec::new(), Vec::new()));
1375        }
1376
1377        // Copy keypoints and descriptors
1378        let mut encoder = self
1379            .device
1380            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1381                label: Some("Readback KP Encoder"),
1382            });
1383
1384        encoder.copy_buffer_to_buffer(
1385            &ctx.keypoints_final,
1386            0,
1387            &buffers.readback_keypoints,
1388            0,
1389            (num_keypoints as u64) * 16,
1390        );
1391
1392        encoder.copy_buffer_to_buffer(
1393            &ctx.descriptors,
1394            0,
1395            &buffers.readback_descriptors,
1396            0,
1397            (num_keypoints as u64) * 128,
1398        );
1399
1400        self.queue.submit(Some(encoder.finish()));
1401
1402        // Map keypoints
1403        let kp_slice = buffers
1404            .readback_keypoints
1405            .slice(..(num_keypoints as u64 * 16));
1406        let (tx, rx) = std::sync::mpsc::channel();
1407        kp_slice.map_async(wgpu::MapMode::Read, move |result| {
1408            tx.send(result).unwrap();
1409        });
1410
1411        let _ = self.device.poll(wgpu::MaintainBase::Wait);
1412        rx.recv()??;
1413
1414        let kp_data = kp_slice.get_mapped_range();
1415        let mut keypoints = Vec::with_capacity(num_keypoints as usize);
1416
1417        for i in 0..num_keypoints as usize {
1418            let offset = i * 16;
1419            let x = f32::from_le_bytes([
1420                kp_data[offset],
1421                kp_data[offset + 1],
1422                kp_data[offset + 2],
1423                kp_data[offset + 3],
1424            ]);
1425            let y = f32::from_le_bytes([
1426                kp_data[offset + 4],
1427                kp_data[offset + 5],
1428                kp_data[offset + 6],
1429                kp_data[offset + 7],
1430            ]);
1431            let size = f32::from_le_bytes([
1432                kp_data[offset + 8],
1433                kp_data[offset + 9],
1434                kp_data[offset + 10],
1435                kp_data[offset + 11],
1436            ]);
1437            let angle = f32::from_le_bytes([
1438                kp_data[offset + 12],
1439                kp_data[offset + 13],
1440                kp_data[offset + 14],
1441                kp_data[offset + 15],
1442            ]);
1443
1444            keypoints.push(KeyPoint {
1445                x,
1446                y,
1447                size,
1448                angle,
1449                response: 0.0,
1450                octave: 0,
1451                layer: 0,
1452            });
1453        }
1454        drop(kp_data);
1455        buffers.readback_keypoints.unmap();
1456
1457        // Map descriptors
1458        let desc_slice = buffers
1459            .readback_descriptors
1460            .slice(..(num_keypoints as u64 * 128));
1461        let (tx, rx) = std::sync::mpsc::channel();
1462        desc_slice.map_async(wgpu::MapMode::Read, move |result| {
1463            tx.send(result).unwrap();
1464        });
1465
1466        let _ = self.device.poll(wgpu::MaintainBase::Wait);
1467        rx.recv()??;
1468
1469        let desc_data = desc_slice.get_mapped_range();
1470        let mut descriptors = Vec::with_capacity(num_keypoints as usize);
1471
1472        for i in 0..num_keypoints as usize {
1473            let offset = i * 128;
1474            let mut desc = [0u8; 128];
1475            desc.copy_from_slice(&desc_data[offset..offset + 128]);
1476            descriptors.push(desc);
1477        }
1478        drop(desc_data);
1479        buffers.readback_descriptors.unmap();
1480
1481        Ok((keypoints, descriptors))
1482    }
1483}
1484
1485impl GpuSiftBuffers {
1486    fn new(device: &wgpu::Device, _width: u32, _height: u32) -> Self {
1487        // Initialize with minimal buffers
1488        let heap = device.create_buffer(&wgpu::BufferDescriptor {
1489            label: Some("Pyramid Heap"),
1490            size: 1024,
1491            usage: wgpu::BufferUsages::STORAGE
1492                | wgpu::BufferUsages::COPY_DST
1493                | wgpu::BufferUsages::COPY_SRC,
1494            mapped_at_creation: false,
1495        });
1496
1497        let meta_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1498            label: Some("Metadata"),
1499            size: 64,
1500            usage: wgpu::BufferUsages::STORAGE
1501                | wgpu::BufferUsages::UNIFORM
1502                | wgpu::BufferUsages::COPY_DST,
1503            mapped_at_creation: false,
1504        });
1505
1506        let level_offsets = device.create_buffer(&wgpu::BufferDescriptor {
1507            label: Some("Level Offsets"),
1508            size: 256,
1509            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1510            mapped_at_creation: false,
1511        });
1512
1513        let level_widths = device.create_buffer(&wgpu::BufferDescriptor {
1514            label: Some("Level Widths"),
1515            size: 256,
1516            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1517            mapped_at_creation: false,
1518        });
1519
1520        let level_heights = device.create_buffer(&wgpu::BufferDescriptor {
1521            label: Some("Level Heights"),
1522            size: 256,
1523            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1524            mapped_at_creation: false,
1525        });
1526
1527        let extrema_counter = device.create_buffer(&wgpu::BufferDescriptor {
1528            label: Some("Extrema Counter"),
1529            size: 4,
1530            usage: wgpu::BufferUsages::STORAGE
1531                | wgpu::BufferUsages::COPY_DST
1532                | wgpu::BufferUsages::COPY_SRC,
1533            mapped_at_creation: false,
1534        });
1535
1536        let keypoints_staging = device.create_buffer(&wgpu::BufferDescriptor {
1537            label: Some("Keypoints Staging"),
1538            size: 32768 * 16, // 32768 keypoints × 4 f32
1539            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1540            mapped_at_creation: false,
1541        });
1542
1543        let orientation_counter = device.create_buffer(&wgpu::BufferDescriptor {
1544            label: Some("Orientation Counter"),
1545            size: 4,
1546            usage: wgpu::BufferUsages::STORAGE
1547                | wgpu::BufferUsages::COPY_DST
1548                | wgpu::BufferUsages::COPY_SRC,
1549            mapped_at_creation: false,
1550        });
1551
1552        let keypoints_final = device.create_buffer(&wgpu::BufferDescriptor {
1553            label: Some("Keypoints Final"),
1554            size: 65536 * 16, // 65536 keypoints × 4 f32
1555            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1556            mapped_at_creation: false,
1557        });
1558
1559        let descriptors = device.create_buffer(&wgpu::BufferDescriptor {
1560            label: Some("Descriptors"),
1561            size: 65536 * 128, // 65536 descriptors × 128 bytes
1562            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1563            mapped_at_creation: false,
1564        });
1565
1566        let readback_counters = device.create_buffer(&wgpu::BufferDescriptor {
1567            label: Some("Readback Counters"),
1568            size: 8, // 2 u32 counters
1569            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1570            mapped_at_creation: false,
1571        });
1572
1573        let readback_keypoints = device.create_buffer(&wgpu::BufferDescriptor {
1574            label: Some("Readback Keypoints"),
1575            size: 65536 * 16,
1576            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1577            mapped_at_creation: false,
1578        });
1579
1580        let readback_descriptors = device.create_buffer(&wgpu::BufferDescriptor {
1581            label: Some("Readback Descriptors"),
1582            size: 65536 * 128,
1583            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1584            mapped_at_creation: false,
1585        });
1586
1587        Self {
1588            heap,
1589            heap_capacity: 1024,
1590            meta_buffer,
1591            level_offsets,
1592            level_widths,
1593            level_heights,
1594            kernel_buffers: Vec::new(), // Will be initialized later with actual kernels
1595            extrema_counter,
1596            keypoints_staging,
1597            orientation_counter,
1598            keypoints_final,
1599            descriptors,
1600            readback_counters,
1601            readback_keypoints,
1602            readback_descriptors,
1603            current_width: 0,
1604            current_height: 0,
1605        }
1606    }
1607
1608    fn ensure_capacity(
1609        &mut self,
1610        device: &wgpu::Device,
1611        width: u32,
1612        height: u32,
1613        config: &GpuSiftConfig,
1614    ) {
1615        // Check if we need to reallocate
1616        if width == self.current_width && height == self.current_height {
1617            return;
1618        }
1619
1620        // Compute required heap size
1621        let mut total_pixels = 0u64;
1622        let mut w = width;
1623        let mut h = height;
1624
1625        for _ in 0..config.octaves {
1626            for _ in 0..config.scales {
1627                total_pixels += (w * h) as u64;
1628            }
1629            w /= 2;
1630            h /= 2;
1631            if w < 8 || h < 8 {
1632                break;
1633            }
1634        }
1635
1636        let heap_size = total_pixels * 2; // 2 bytes per f16 pixel
1637
1638        if heap_size > self.heap_capacity {
1639            // Reallocate heap
1640            self.heap = device.create_buffer(&wgpu::BufferDescriptor {
1641                label: Some("Pyramid Heap"),
1642                size: heap_size,
1643                usage: wgpu::BufferUsages::STORAGE
1644                    | wgpu::BufferUsages::COPY_DST
1645                    | wgpu::BufferUsages::COPY_SRC,
1646                mapped_at_creation: false,
1647            });
1648            self.heap_capacity = heap_size;
1649        }
1650
1651        self.current_width = width;
1652        self.current_height = height;
1653    }
1654
1655    fn initialize_kernel_buffers(
1656        &mut self,
1657        device: &wgpu::Device,
1658        queue: &wgpu::Queue,
1659        kernels: &GpuKernels,
1660    ) {
1661        self.kernel_buffers = kernels
1662            .kernels
1663            .iter()
1664            .enumerate()
1665            .map(|(i, weights)| {
1666                let weights_bytes: Vec<u8> = weights.iter().flat_map(|w| w.to_le_bytes()).collect();
1667
1668                let buffer = device.create_buffer(&wgpu::BufferDescriptor {
1669                    label: Some(&format!("Kernel Weights {}", i)),
1670                    size: weights_bytes.len() as u64,
1671                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1672                    mapped_at_creation: false,
1673                });
1674
1675                queue.write_buffer(&buffer, 0, &weights_bytes);
1676                buffer
1677            })
1678            .collect();
1679    }
1680}