template_matching/
lib.rs

1//! GPU-accelerated template matching.
2//!
3//! Faster alternative to [imageproc::template_matching](https://docs.rs/imageproc/latest/imageproc/template_matching/index.html).
4
5#![deny(clippy::all)]
6#![allow(dead_code)]
7#![allow(unused_variables)]
8
9use std::{borrow::Cow, mem::size_of};
10use wgpu::util::DeviceExt;
11
12#[derive(Copy, Clone, Debug, PartialEq)]
13pub enum MatchTemplateMethod {
14    SumOfAbsoluteDifferences,
15    SumOfSquaredDifferences,
16}
17
18/// Slides a template over the input and scores the match at each point using the requested method.
19///
20/// This is a shorthand for:
21/// ```ignore
22/// let mut matcher = TemplateMatcher::new();
23/// matcher.match_template(input, template, method);
24/// matcher.wait_for_result().unwrap()
25/// ```
26/// You can use  [find_extremes] to find minimum and maximum values, and their locations in the result image.
27pub fn match_template<'a>(
28    input: impl Into<Image<'a>>,
29    template: impl Into<Image<'a>>,
30    method: MatchTemplateMethod,
31) -> Image<'static> {
32    let mut matcher = TemplateMatcher::new();
33    matcher.match_template(input, template, method);
34    matcher.wait_for_result().unwrap()
35}
36
37/// Finds the smallest and largest values and their locations in an image.
38pub fn find_extremes(input: &Image<'_>) -> Extremes {
39    let mut min_value = f32::MAX;
40    let mut min_value_location = (0, 0);
41    let mut max_value = f32::MIN;
42    let mut max_value_location = (0, 0);
43
44    for y in 0..input.height {
45        for x in 0..input.width {
46            let idx = (y * input.width) + x;
47            let value = input.data[idx as usize];
48
49            if value < min_value {
50                min_value = value;
51                min_value_location = (x, y);
52            }
53
54            if value > max_value {
55                max_value = value;
56                max_value_location = (x, y);
57            }
58        }
59    }
60
61    Extremes {
62        min_value,
63        max_value,
64        min_value_location,
65        max_value_location,
66    }
67}
68
69pub struct Image<'a> {
70    pub data: Cow<'a, [f32]>,
71    pub width: u32,
72    pub height: u32,
73}
74
75impl<'a> Image<'a> {
76    pub fn new(data: impl Into<Cow<'a, [f32]>>, width: u32, height: u32) -> Self {
77        Self {
78            data: data.into(),
79            width,
80            height,
81        }
82    }
83}
84
85#[cfg(feature = "image")]
86impl<'a> From<&'a image::ImageBuffer<image::Luma<f32>, Vec<f32>>> for Image<'a> {
87    fn from(img: &'a image::ImageBuffer<image::Luma<f32>, Vec<f32>>) -> Self {
88        Self {
89            data: Cow::Borrowed(img),
90            width: img.width(),
91            height: img.height(),
92        }
93    }
94}
95
96#[derive(Copy, Clone, Debug)]
97pub struct Extremes {
98    pub min_value: f32,
99    pub max_value: f32,
100    pub min_value_location: (u32, u32),
101    pub max_value_location: (u32, u32),
102}
103
104#[repr(C)]
105#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
106struct ShaderUniforms {
107    input_width: u32,
108    input_height: u32,
109    template_width: u32,
110    template_height: u32,
111}
112
113pub struct TemplateMatcher {
114    instance: wgpu::Instance,
115    adapter: wgpu::Adapter,
116    device: wgpu::Device,
117    queue: wgpu::Queue,
118    shader: wgpu::ShaderModule,
119    bind_group_layout: wgpu::BindGroupLayout,
120    pipeline_layout: wgpu::PipelineLayout,
121
122    last_pipeline: Option<wgpu::ComputePipeline>,
123    last_method: Option<MatchTemplateMethod>,
124
125    last_input_size: (u32, u32),
126    last_template_size: (u32, u32),
127    last_result_size: (u32, u32),
128
129    uniform_buffer: wgpu::Buffer,
130    input_buffer: Option<wgpu::Buffer>,
131    template_buffer: Option<wgpu::Buffer>,
132    result_buffer: Option<wgpu::Buffer>,
133    staging_buffer: Option<wgpu::Buffer>,
134    bind_group: Option<wgpu::BindGroup>,
135
136    matching_ongoing: bool,
137}
138
139impl Default for TemplateMatcher {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl TemplateMatcher {
146    pub fn new() -> Self {
147        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
148            backends: wgpu::Backends::all(),
149            dx12_shader_compiler: Default::default(),
150        });
151
152        let adapter = pollster::block_on(async {
153            instance
154                .request_adapter(&wgpu::RequestAdapterOptions {
155                    power_preference: wgpu::PowerPreference::HighPerformance,
156                    compatible_surface: None,
157                    force_fallback_adapter: false,
158                })
159                .await
160                .expect("Adapter request failed")
161        });
162
163        let (device, queue) = pollster::block_on(async {
164            adapter
165                .request_device(
166                    &wgpu::DeviceDescriptor {
167                        label: None,
168                        features: wgpu::Features::empty(),
169                        limits: wgpu::Limits::default(),
170                    },
171                    None,
172                )
173                .await
174                .expect("Device request failed")
175        });
176
177        let shader = device.create_shader_module(wgpu::include_wgsl!("../shaders/matching.wgsl"));
178
179        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
180            label: None,
181            entries: &[
182                wgpu::BindGroupLayoutEntry {
183                    binding: 0,
184                    visibility: wgpu::ShaderStages::COMPUTE,
185                    ty: wgpu::BindingType::Buffer {
186                        ty: wgpu::BufferBindingType::Storage { read_only: true },
187                        has_dynamic_offset: false,
188                        min_binding_size: None,
189                    },
190                    count: None,
191                },
192                wgpu::BindGroupLayoutEntry {
193                    binding: 1,
194                    visibility: wgpu::ShaderStages::COMPUTE,
195                    ty: wgpu::BindingType::Buffer {
196                        ty: wgpu::BufferBindingType::Storage { read_only: true },
197                        has_dynamic_offset: false,
198                        min_binding_size: None,
199                    },
200                    count: None,
201                },
202                wgpu::BindGroupLayoutEntry {
203                    binding: 2,
204                    visibility: wgpu::ShaderStages::COMPUTE,
205                    ty: wgpu::BindingType::Buffer {
206                        ty: wgpu::BufferBindingType::Storage { read_only: false },
207                        has_dynamic_offset: false,
208                        min_binding_size: None,
209                    },
210                    count: None,
211                },
212                wgpu::BindGroupLayoutEntry {
213                    binding: 3,
214                    visibility: wgpu::ShaderStages::COMPUTE,
215                    ty: wgpu::BindingType::Buffer {
216                        ty: wgpu::BufferBindingType::Uniform,
217                        has_dynamic_offset: false,
218                        min_binding_size: None,
219                    },
220                    count: None,
221                },
222            ],
223        });
224
225        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
226            label: None,
227            bind_group_layouts: &[&bind_group_layout],
228            push_constant_ranges: &[],
229        });
230
231        let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
232            label: Some("uniform_buffer"),
233            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
234            size: size_of::<ShaderUniforms>() as _,
235            mapped_at_creation: false,
236        });
237
238        Self {
239            instance,
240            adapter,
241            device,
242            queue,
243            shader,
244            pipeline_layout,
245            bind_group_layout,
246            last_pipeline: None,
247            last_method: None,
248            last_input_size: (0, 0),
249            last_template_size: (0, 0),
250            last_result_size: (0, 0),
251            uniform_buffer,
252            input_buffer: None,
253            template_buffer: None,
254            result_buffer: None,
255            staging_buffer: None,
256            bind_group: None,
257            matching_ongoing: false,
258        }
259    }
260
261    /// Waits for the latest [match_template] execution and returns the result.
262    /// Returns [None] if no matching was started.
263    pub fn wait_for_result(&mut self) -> Option<Image<'static>> {
264        if !self.matching_ongoing {
265            return None;
266        }
267        self.matching_ongoing = false;
268
269        let (result_width, result_height) = self.last_result_size;
270
271        let buffer_slice = self.staging_buffer.as_ref().unwrap().slice(..);
272        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
273        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
274
275        self.device.poll(wgpu::Maintain::Wait);
276
277        pollster::block_on(async {
278            let result;
279
280            if let Some(Ok(())) = receiver.receive().await {
281                let data = buffer_slice.get_mapped_range();
282                result = bytemuck::cast_slice(&data).to_vec();
283                drop(data);
284                self.staging_buffer.as_ref().unwrap().unmap();
285            } else {
286                result = vec![0.0; (result_width * result_height) as usize]
287            };
288
289            Some(Image::new(result, result_width as _, result_height as _))
290        })
291    }
292
293    /// Slides a template over the input and scores the match at each point using the requested method.
294    /// To get the result of the matching, call [wait_for_result].
295    pub fn match_template<'a>(
296        &mut self,
297        input: impl Into<Image<'a>>,
298        template: impl Into<Image<'a>>,
299        method: MatchTemplateMethod,
300    ) {
301        if self.matching_ongoing {
302            // Discard previous result if not collected.
303            self.wait_for_result();
304        }
305
306        let input = input.into();
307        let template = template.into();
308
309        if self.last_pipeline.is_none() || self.last_method != Some(method) {
310            self.last_method = Some(method);
311
312            let entry_point = match method {
313                MatchTemplateMethod::SumOfAbsoluteDifferences => "main_sad",
314                MatchTemplateMethod::SumOfSquaredDifferences => "main_ssd",
315            };
316
317            self.last_pipeline = Some(self.device.create_compute_pipeline(
318                &wgpu::ComputePipelineDescriptor {
319                    label: None,
320                    layout: Some(&self.pipeline_layout),
321                    module: &self.shader,
322                    entry_point,
323                },
324            ));
325        }
326
327        let mut buffers_changed = false;
328
329        let input_size = (input.width, input.height);
330        if self.input_buffer.is_none() || self.last_input_size != input_size {
331            buffers_changed = true;
332
333            self.last_input_size = input_size;
334
335            self.input_buffer = Some(self.device.create_buffer_init(
336                &wgpu::util::BufferInitDescriptor {
337                    label: Some("input_buffer"),
338                    contents: bytemuck::cast_slice(&input.data),
339                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
340                },
341            ));
342        } else {
343            self.queue.write_buffer(
344                self.input_buffer.as_ref().unwrap(),
345                0,
346                bytemuck::cast_slice(&input.data),
347            );
348        }
349
350        let template_size = (template.width, template.height);
351        if self.template_buffer.is_none() || self.last_template_size != template_size {
352            self.queue.write_buffer(
353                &self.uniform_buffer,
354                0,
355                bytemuck::cast_slice(&[ShaderUniforms {
356                    input_width: input.width,
357                    input_height: input.height,
358                    template_width: template.width,
359                    template_height: template.height,
360                }]),
361            );
362            buffers_changed = true;
363
364            self.last_template_size = template_size;
365
366            self.template_buffer = Some(self.device.create_buffer_init(
367                &wgpu::util::BufferInitDescriptor {
368                    label: Some("template_buffer"),
369                    contents: bytemuck::cast_slice(&template.data),
370                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
371                },
372            ));
373        } else {
374            self.queue.write_buffer(
375                self.template_buffer.as_ref().unwrap(),
376                0,
377                bytemuck::cast_slice(&template.data),
378            );
379        }
380
381        let result_width = input.width - template.width + 1;
382        let result_height = input.height - template.height + 1;
383        let result_buf_size = (result_width * result_height) as u64 * size_of::<f32>() as u64;
384
385        if buffers_changed {
386            self.last_result_size = (result_width, result_height);
387
388            self.result_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
389                label: Some("result_buffer"),
390                usage: wgpu::BufferUsages::STORAGE
391                    | wgpu::BufferUsages::COPY_SRC
392                    | wgpu::BufferUsages::COPY_DST,
393                size: result_buf_size,
394                mapped_at_creation: false,
395            }));
396
397            self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
398                label: Some("staging_buffer"),
399                usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
400                size: result_buf_size,
401                mapped_at_creation: false,
402            }));
403
404            self.bind_group = Some(self.device.create_bind_group(&wgpu::BindGroupDescriptor {
405                label: None,
406                layout: &self.bind_group_layout,
407                entries: &[
408                    wgpu::BindGroupEntry {
409                        binding: 0,
410                        resource: self.input_buffer.as_ref().unwrap().as_entire_binding(),
411                    },
412                    wgpu::BindGroupEntry {
413                        binding: 1,
414                        resource: self.template_buffer.as_ref().unwrap().as_entire_binding(),
415                    },
416                    wgpu::BindGroupEntry {
417                        binding: 2,
418                        resource: self.result_buffer.as_ref().unwrap().as_entire_binding(),
419                    },
420                    wgpu::BindGroupEntry {
421                        binding: 3,
422                        resource: self.uniform_buffer.as_entire_binding(),
423                    },
424                ],
425            }));
426        }
427
428        let mut encoder = self
429            .device
430            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
431                label: Some("encoder"),
432            });
433
434        {
435            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
436                label: Some("compute_pass"),
437            });
438            compute_pass.set_pipeline(self.last_pipeline.as_ref().unwrap());
439            compute_pass.set_bind_group(0, self.bind_group.as_ref().unwrap(), &[]);
440            compute_pass.dispatch_workgroups(
441                (result_width as f32 / 16.0).ceil() as u32,
442                (result_height as f32 / 16.0).ceil() as u32,
443                1,
444            );
445        }
446
447        encoder.copy_buffer_to_buffer(
448            self.result_buffer.as_ref().unwrap(),
449            0,
450            self.staging_buffer.as_ref().unwrap(),
451            0,
452            result_buf_size,
453        );
454
455        self.queue.submit(std::iter::once(encoder.finish()));
456        self.matching_ongoing = true;
457    }
458}