1#![deny(clippy::all)]
6#![allow(dead_code)]
7#![allow(unused_variables)]
8
9use std::{borrow::Cow, mem::size_of};
10use wgpu::util::DeviceExt;
11
12#[derive(Copy, Clone, Debug, PartialEq)]
13pub enum MatchTemplateMethod {
14 SumOfAbsoluteDifferences,
15 SumOfSquaredDifferences,
16}
17
18pub fn match_template<'a>(
28 input: impl Into<Image<'a>>,
29 template: impl Into<Image<'a>>,
30 method: MatchTemplateMethod,
31) -> Image<'static> {
32 let mut matcher = TemplateMatcher::new();
33 matcher.match_template(input, template, method);
34 matcher.wait_for_result().unwrap()
35}
36
37pub fn find_extremes(input: &Image<'_>) -> Extremes {
39 let mut min_value = f32::MAX;
40 let mut min_value_location = (0, 0);
41 let mut max_value = f32::MIN;
42 let mut max_value_location = (0, 0);
43
44 for y in 0..input.height {
45 for x in 0..input.width {
46 let idx = (y * input.width) + x;
47 let value = input.data[idx as usize];
48
49 if value < min_value {
50 min_value = value;
51 min_value_location = (x, y);
52 }
53
54 if value > max_value {
55 max_value = value;
56 max_value_location = (x, y);
57 }
58 }
59 }
60
61 Extremes {
62 min_value,
63 max_value,
64 min_value_location,
65 max_value_location,
66 }
67}
68
69pub struct Image<'a> {
70 pub data: Cow<'a, [f32]>,
71 pub width: u32,
72 pub height: u32,
73}
74
75impl<'a> Image<'a> {
76 pub fn new(data: impl Into<Cow<'a, [f32]>>, width: u32, height: u32) -> Self {
77 Self {
78 data: data.into(),
79 width,
80 height,
81 }
82 }
83}
84
85#[cfg(feature = "image")]
86impl<'a> From<&'a image::ImageBuffer<image::Luma<f32>, Vec<f32>>> for Image<'a> {
87 fn from(img: &'a image::ImageBuffer<image::Luma<f32>, Vec<f32>>) -> Self {
88 Self {
89 data: Cow::Borrowed(img),
90 width: img.width(),
91 height: img.height(),
92 }
93 }
94}
95
96#[derive(Copy, Clone, Debug)]
97pub struct Extremes {
98 pub min_value: f32,
99 pub max_value: f32,
100 pub min_value_location: (u32, u32),
101 pub max_value_location: (u32, u32),
102}
103
104#[repr(C)]
105#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
106struct ShaderUniforms {
107 input_width: u32,
108 input_height: u32,
109 template_width: u32,
110 template_height: u32,
111}
112
113pub struct TemplateMatcher {
114 instance: wgpu::Instance,
115 adapter: wgpu::Adapter,
116 device: wgpu::Device,
117 queue: wgpu::Queue,
118 shader: wgpu::ShaderModule,
119 bind_group_layout: wgpu::BindGroupLayout,
120 pipeline_layout: wgpu::PipelineLayout,
121
122 last_pipeline: Option<wgpu::ComputePipeline>,
123 last_method: Option<MatchTemplateMethod>,
124
125 last_input_size: (u32, u32),
126 last_template_size: (u32, u32),
127 last_result_size: (u32, u32),
128
129 uniform_buffer: wgpu::Buffer,
130 input_buffer: Option<wgpu::Buffer>,
131 template_buffer: Option<wgpu::Buffer>,
132 result_buffer: Option<wgpu::Buffer>,
133 staging_buffer: Option<wgpu::Buffer>,
134 bind_group: Option<wgpu::BindGroup>,
135
136 matching_ongoing: bool,
137}
138
139impl Default for TemplateMatcher {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl TemplateMatcher {
146 pub fn new() -> Self {
147 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
148 backends: wgpu::Backends::all(),
149 dx12_shader_compiler: Default::default(),
150 });
151
152 let adapter = pollster::block_on(async {
153 instance
154 .request_adapter(&wgpu::RequestAdapterOptions {
155 power_preference: wgpu::PowerPreference::HighPerformance,
156 compatible_surface: None,
157 force_fallback_adapter: false,
158 })
159 .await
160 .expect("Adapter request failed")
161 });
162
163 let (device, queue) = pollster::block_on(async {
164 adapter
165 .request_device(
166 &wgpu::DeviceDescriptor {
167 label: None,
168 features: wgpu::Features::empty(),
169 limits: wgpu::Limits::default(),
170 },
171 None,
172 )
173 .await
174 .expect("Device request failed")
175 });
176
177 let shader = device.create_shader_module(wgpu::include_wgsl!("../shaders/matching.wgsl"));
178
179 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
180 label: None,
181 entries: &[
182 wgpu::BindGroupLayoutEntry {
183 binding: 0,
184 visibility: wgpu::ShaderStages::COMPUTE,
185 ty: wgpu::BindingType::Buffer {
186 ty: wgpu::BufferBindingType::Storage { read_only: true },
187 has_dynamic_offset: false,
188 min_binding_size: None,
189 },
190 count: None,
191 },
192 wgpu::BindGroupLayoutEntry {
193 binding: 1,
194 visibility: wgpu::ShaderStages::COMPUTE,
195 ty: wgpu::BindingType::Buffer {
196 ty: wgpu::BufferBindingType::Storage { read_only: true },
197 has_dynamic_offset: false,
198 min_binding_size: None,
199 },
200 count: None,
201 },
202 wgpu::BindGroupLayoutEntry {
203 binding: 2,
204 visibility: wgpu::ShaderStages::COMPUTE,
205 ty: wgpu::BindingType::Buffer {
206 ty: wgpu::BufferBindingType::Storage { read_only: false },
207 has_dynamic_offset: false,
208 min_binding_size: None,
209 },
210 count: None,
211 },
212 wgpu::BindGroupLayoutEntry {
213 binding: 3,
214 visibility: wgpu::ShaderStages::COMPUTE,
215 ty: wgpu::BindingType::Buffer {
216 ty: wgpu::BufferBindingType::Uniform,
217 has_dynamic_offset: false,
218 min_binding_size: None,
219 },
220 count: None,
221 },
222 ],
223 });
224
225 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
226 label: None,
227 bind_group_layouts: &[&bind_group_layout],
228 push_constant_ranges: &[],
229 });
230
231 let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
232 label: Some("uniform_buffer"),
233 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
234 size: size_of::<ShaderUniforms>() as _,
235 mapped_at_creation: false,
236 });
237
238 Self {
239 instance,
240 adapter,
241 device,
242 queue,
243 shader,
244 pipeline_layout,
245 bind_group_layout,
246 last_pipeline: None,
247 last_method: None,
248 last_input_size: (0, 0),
249 last_template_size: (0, 0),
250 last_result_size: (0, 0),
251 uniform_buffer,
252 input_buffer: None,
253 template_buffer: None,
254 result_buffer: None,
255 staging_buffer: None,
256 bind_group: None,
257 matching_ongoing: false,
258 }
259 }
260
261 pub fn wait_for_result(&mut self) -> Option<Image<'static>> {
264 if !self.matching_ongoing {
265 return None;
266 }
267 self.matching_ongoing = false;
268
269 let (result_width, result_height) = self.last_result_size;
270
271 let buffer_slice = self.staging_buffer.as_ref().unwrap().slice(..);
272 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
273 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
274
275 self.device.poll(wgpu::Maintain::Wait);
276
277 pollster::block_on(async {
278 let result;
279
280 if let Some(Ok(())) = receiver.receive().await {
281 let data = buffer_slice.get_mapped_range();
282 result = bytemuck::cast_slice(&data).to_vec();
283 drop(data);
284 self.staging_buffer.as_ref().unwrap().unmap();
285 } else {
286 result = vec![0.0; (result_width * result_height) as usize]
287 };
288
289 Some(Image::new(result, result_width as _, result_height as _))
290 })
291 }
292
293 pub fn match_template<'a>(
296 &mut self,
297 input: impl Into<Image<'a>>,
298 template: impl Into<Image<'a>>,
299 method: MatchTemplateMethod,
300 ) {
301 if self.matching_ongoing {
302 self.wait_for_result();
304 }
305
306 let input = input.into();
307 let template = template.into();
308
309 if self.last_pipeline.is_none() || self.last_method != Some(method) {
310 self.last_method = Some(method);
311
312 let entry_point = match method {
313 MatchTemplateMethod::SumOfAbsoluteDifferences => "main_sad",
314 MatchTemplateMethod::SumOfSquaredDifferences => "main_ssd",
315 };
316
317 self.last_pipeline = Some(self.device.create_compute_pipeline(
318 &wgpu::ComputePipelineDescriptor {
319 label: None,
320 layout: Some(&self.pipeline_layout),
321 module: &self.shader,
322 entry_point,
323 },
324 ));
325 }
326
327 let mut buffers_changed = false;
328
329 let input_size = (input.width, input.height);
330 if self.input_buffer.is_none() || self.last_input_size != input_size {
331 buffers_changed = true;
332
333 self.last_input_size = input_size;
334
335 self.input_buffer = Some(self.device.create_buffer_init(
336 &wgpu::util::BufferInitDescriptor {
337 label: Some("input_buffer"),
338 contents: bytemuck::cast_slice(&input.data),
339 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
340 },
341 ));
342 } else {
343 self.queue.write_buffer(
344 self.input_buffer.as_ref().unwrap(),
345 0,
346 bytemuck::cast_slice(&input.data),
347 );
348 }
349
350 let template_size = (template.width, template.height);
351 if self.template_buffer.is_none() || self.last_template_size != template_size {
352 self.queue.write_buffer(
353 &self.uniform_buffer,
354 0,
355 bytemuck::cast_slice(&[ShaderUniforms {
356 input_width: input.width,
357 input_height: input.height,
358 template_width: template.width,
359 template_height: template.height,
360 }]),
361 );
362 buffers_changed = true;
363
364 self.last_template_size = template_size;
365
366 self.template_buffer = Some(self.device.create_buffer_init(
367 &wgpu::util::BufferInitDescriptor {
368 label: Some("template_buffer"),
369 contents: bytemuck::cast_slice(&template.data),
370 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
371 },
372 ));
373 } else {
374 self.queue.write_buffer(
375 self.template_buffer.as_ref().unwrap(),
376 0,
377 bytemuck::cast_slice(&template.data),
378 );
379 }
380
381 let result_width = input.width - template.width + 1;
382 let result_height = input.height - template.height + 1;
383 let result_buf_size = (result_width * result_height) as u64 * size_of::<f32>() as u64;
384
385 if buffers_changed {
386 self.last_result_size = (result_width, result_height);
387
388 self.result_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
389 label: Some("result_buffer"),
390 usage: wgpu::BufferUsages::STORAGE
391 | wgpu::BufferUsages::COPY_SRC
392 | wgpu::BufferUsages::COPY_DST,
393 size: result_buf_size,
394 mapped_at_creation: false,
395 }));
396
397 self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
398 label: Some("staging_buffer"),
399 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
400 size: result_buf_size,
401 mapped_at_creation: false,
402 }));
403
404 self.bind_group = Some(self.device.create_bind_group(&wgpu::BindGroupDescriptor {
405 label: None,
406 layout: &self.bind_group_layout,
407 entries: &[
408 wgpu::BindGroupEntry {
409 binding: 0,
410 resource: self.input_buffer.as_ref().unwrap().as_entire_binding(),
411 },
412 wgpu::BindGroupEntry {
413 binding: 1,
414 resource: self.template_buffer.as_ref().unwrap().as_entire_binding(),
415 },
416 wgpu::BindGroupEntry {
417 binding: 2,
418 resource: self.result_buffer.as_ref().unwrap().as_entire_binding(),
419 },
420 wgpu::BindGroupEntry {
421 binding: 3,
422 resource: self.uniform_buffer.as_entire_binding(),
423 },
424 ],
425 }));
426 }
427
428 let mut encoder = self
429 .device
430 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
431 label: Some("encoder"),
432 });
433
434 {
435 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
436 label: Some("compute_pass"),
437 });
438 compute_pass.set_pipeline(self.last_pipeline.as_ref().unwrap());
439 compute_pass.set_bind_group(0, self.bind_group.as_ref().unwrap(), &[]);
440 compute_pass.dispatch_workgroups(
441 (result_width as f32 / 16.0).ceil() as u32,
442 (result_height as f32 / 16.0).ceil() as u32,
443 1,
444 );
445 }
446
447 encoder.copy_buffer_to_buffer(
448 self.result_buffer.as_ref().unwrap(),
449 0,
450 self.staging_buffer.as_ref().unwrap(),
451 0,
452 result_buf_size,
453 );
454
455 self.queue.submit(std::iter::once(encoder.finish()));
456 self.matching_ongoing = true;
457 }
458}