1#![deny(unsafe_code)]
5#![warn(
6 clippy::all,
7 clippy::await_holding_lock,
8 clippy::char_lit_as_u8,
9 clippy::checked_conversions,
10 clippy::dbg_macro,
11 clippy::debug_assert_with_mut_call,
12 clippy::doc_markdown,
13 clippy::empty_enum,
14 clippy::enum_glob_use,
15 clippy::exit,
16 clippy::expl_impl_clone_on_copy,
17 clippy::explicit_deref_methods,
18 clippy::explicit_into_iter_loop,
19 clippy::fallible_impl_from,
20 clippy::filter_map_next,
21 clippy::float_cmp_const,
22 clippy::fn_params_excessive_bools,
23 clippy::if_let_mutex,
24 clippy::implicit_clone,
25 clippy::imprecise_flops,
26 clippy::inefficient_to_string,
27 clippy::invalid_upcast_comparisons,
28 clippy::large_types_passed_by_value,
29 clippy::let_unit_value,
30 clippy::linkedlist,
31 clippy::lossy_float_literal,
32 clippy::macro_use_imports,
33 clippy::manual_ok_or,
34 clippy::map_err_ignore,
35 clippy::map_flatten,
36 clippy::map_unwrap_or,
37 clippy::match_on_vec_items,
38 clippy::match_same_arms,
39 clippy::match_wildcard_for_single_variants,
40 clippy::mem_forget,
41 clippy::mismatched_target_os,
42 clippy::mut_mut,
43 clippy::mutex_integer,
44 clippy::needless_borrow,
45 clippy::needless_continue,
46 clippy::option_option,
47 clippy::path_buf_push_overwrite,
48 clippy::ptr_as_ptr,
49 clippy::ref_option_ref,
50 clippy::rest_pat_in_fully_bound_structs,
51 clippy::same_functions_in_if_condition,
52 clippy::semicolon_if_nothing_returned,
53 clippy::string_add_assign,
54 clippy::string_add,
55 clippy::string_lit_as_bytes,
56 clippy::string_to_string,
57 clippy::todo,
58 clippy::trait_duplication_in_bounds,
59 clippy::unimplemented,
60 clippy::unnested_or_patterns,
61 clippy::unused_self,
62 clippy::useless_transmute,
63 clippy::verbose_file_reads,
64 clippy::zero_sized_map_values,
65 future_incompatible,
66 nonstandard_style,
67 rust_2018_idioms
68)]
69#![allow(unsafe_code)]
71
72mod errors;
115mod img_pyramid;
116use img_pyramid::*;
117mod utils;
118use utils::*;
119mod ms;
120use ms::*;
121pub mod session;
122mod unsync;
123
124pub use image;
125use std::path::Path;
126
127pub use errors::Error;
128pub use session::{Session, SessionBuilder};
129pub use utils::{load_dynamic_image, ChannelMask, ImageSource};
130
131#[derive(Copy, Clone)]
133#[cfg_attr(test, derive(Debug, PartialEq))]
134pub struct Dims {
135 pub width: u32,
136 pub height: u32,
137}
138
139impl Dims {
140 pub fn square(size: u32) -> Self {
141 Self {
142 width: size,
143 height: size,
144 }
145 }
146 pub fn new(width: u32, height: u32) -> Self {
147 Self { width, height }
148 }
149}
150
151pub struct CoordinateTransform {
155 buffer: Vec<u32>,
156 pub output_size: Dims,
157 original_maps: Vec<Dims>,
158}
159
160const TRANSFORM_MAGIC: u32 = 0x1234_0001;
161
162impl<'a> CoordinateTransform {
163 pub fn apply<E, I>(&self, source: I) -> Result<image::RgbaImage, Error>
171 where
172 I: IntoIterator<Item = E>,
173 E: Into<ImageSource<'a>>,
174 {
175 let ref_maps: Vec<image::RgbaImage> = source
176 .into_iter()
177 .zip(self.original_maps.iter())
178 .map(|(is, dims)| load_image(is.into(), Some(*dims)))
179 .collect::<Result<Vec<_>, Error>>()?;
180
181 if ref_maps.len() != self.original_maps.len() {
184 return Err(Error::MapsCountMismatch(
185 ref_maps.len() as u32,
186 self.original_maps.len() as u32,
187 ));
188 }
189
190 let mut img = image::RgbaImage::new(self.output_size.width, self.output_size.height);
191
192 for (i, pix) in img.pixels_mut().enumerate() {
194 let x = self.buffer[i * 3];
195 let y = self.buffer[i * 3 + 1];
196 let map = self.buffer[i * 3 + 2];
197
198 *pix = *ref_maps[map as usize].get_pixel(x, y);
199 }
200
201 Ok(img)
202 }
203
204 pub fn write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<usize> {
205 use std::mem;
206 let mut written = 0;
207
208 if self.buffer.len()
211 != self.output_size.width as usize * self.output_size.height as usize * 3
212 {
213 return Err(std::io::Error::new(
214 std::io::ErrorKind::InvalidInput,
215 "buffer length doesn't match dimensions",
216 ));
217 }
218
219 let header = [
220 TRANSFORM_MAGIC,
221 self.output_size.width,
222 self.output_size.height,
223 self.original_maps.len() as u32,
224 ];
225
226 fn cast(ina: &[u32]) -> &[u8] {
227 unsafe {
228 let p = ina.as_ptr();
229 let len = ina.len();
230
231 std::slice::from_raw_parts(p.cast::<u8>(), len * mem::size_of::<u32>())
232 }
233 }
234
235 w.write_all(cast(&header))?;
236 written += mem::size_of_val(&header);
237
238 for om in &self.original_maps {
239 let dims = [om.width, om.height];
240 w.write_all(cast(&dims))?;
241 written += mem::size_of_val(&dims);
242 }
243
244 w.write_all(cast(&self.buffer))?;
245 written += 4 * self.buffer.len();
246
247 Ok(written)
248 }
249
250 pub fn read<R: std::io::Read>(r: &mut R) -> std::io::Result<Self> {
251 use std::{
252 io::{Error, ErrorKind, Read},
253 mem,
254 };
255
256 fn do_read<R: Read>(r: &mut R, buf: &mut [u32]) -> std::io::Result<()> {
257 unsafe {
258 let p = buf.as_mut_ptr();
259 let len = buf.len();
260
261 let mut slice =
262 std::slice::from_raw_parts_mut(p.cast::<u8>(), len * mem::size_of::<u32>());
263
264 r.read(&mut slice).map(|_| ())
265 }
266 }
267
268 let mut magic = [0u32];
269 do_read(r, &mut magic)?;
270
271 if magic[0] >> 16 != 0x1234 {
272 return Err(Error::new(ErrorKind::InvalidData, "invalid magic"));
273 }
274
275 let (output_size, original_maps) = match magic[0] & 0x0000_ffff {
276 0x1 => {
277 let mut header = [0u32; 3];
278 do_read(r, &mut header)?;
279
280 let mut omaps = Vec::with_capacity(header[2] as usize);
281 for _ in 0..header[2] {
282 let mut dims = [0u32; 2];
283 do_read(r, &mut dims)?;
284 omaps.push(Dims {
285 width: dims[0],
286 height: dims[1],
287 });
288 }
289
290 (
291 Dims {
292 width: header[0],
293 height: header[1],
294 },
295 omaps,
296 )
297 }
298 _ => return Err(Error::new(ErrorKind::InvalidData, "invalid version")),
299 };
300
301 let buffer = unsafe {
302 let len = output_size.width as usize * output_size.height as usize * 3;
303 let mut buffer = Vec::with_capacity(len);
304 buffer.set_len(len);
305
306 do_read(r, &mut buffer)?;
307 buffer
308 };
309
310 Ok(Self {
311 buffer,
312 output_size,
313 original_maps,
314 })
315 }
316}
317
318struct Parameters {
319 tiling_mode: bool,
320 nearest_neighbors: u32,
321 random_sample_locations: u64,
322 cauchy_dispersion: f32,
323 backtrack_percent: f32,
324 backtrack_stages: u32,
325 resize_input: Option<Dims>,
326 output_size: Dims,
327 guide_alpha: f32,
328 random_resolve: Option<u64>,
329 max_thread_count: Option<usize>,
330 seed: u64,
331}
332
333impl Default for Parameters {
334 fn default() -> Self {
335 Self {
336 tiling_mode: false,
337 nearest_neighbors: 50,
338 random_sample_locations: 50,
339 cauchy_dispersion: 1.0,
340 backtrack_percent: 0.5,
341 backtrack_stages: 5,
342 resize_input: None,
343 output_size: Dims::square(500),
344 guide_alpha: 0.8,
345 random_resolve: None,
346 max_thread_count: None,
347 seed: 0,
348 }
349 }
350}
351
352impl Parameters {
353 fn to_generator_params(&self) -> GeneratorParams {
354 GeneratorParams {
355 nearest_neighbors: self.nearest_neighbors,
356 random_sample_locations: self.random_sample_locations,
357 cauchy_dispersion: self.cauchy_dispersion,
358 p: self.backtrack_percent,
359 p_stages: self.backtrack_stages as i32,
360 seed: self.seed,
361 alpha: self.guide_alpha,
362 max_thread_count: self.max_thread_count.unwrap_or_else(num_cpus::get),
363 tiling_mode: self.tiling_mode,
364 }
365 }
366}
367
368pub struct GeneratedImage {
370 inner: ms::Generator,
371}
372
373impl GeneratedImage {
374 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
376 let path = path.as_ref();
377 if let Some(parent_path) = path.parent() {
378 std::fs::create_dir_all(&parent_path)?;
379 }
380
381 self.inner.color_map.as_ref().save(&path)?;
382 Ok(())
383 }
384
385 pub fn write<W: std::io::Write>(
387 self,
388 writer: &mut W,
389 fmt: image::ImageOutputFormat,
390 ) -> Result<(), Error> {
391 let dyn_img = self.into_image();
392 Ok(dyn_img.write_to(writer, fmt)?)
393 }
394
395 pub fn save_debug<P: AsRef<Path>>(&self, dir: P) -> Result<(), Error> {
399 let dir = dir.as_ref();
400 std::fs::create_dir_all(&dir)?;
401
402 self.inner
403 .get_uncertainty_map()
404 .save(&dir.join("uncertainty.png"))?;
405 let id_maps = self.inner.get_id_maps();
406 id_maps[0].save(&dir.join("patch_id.png"))?;
407 id_maps[1].save(&dir.join("map_id.png"))?;
408
409 Ok(())
410 }
411
412 pub fn get_coordinate_transform(&self) -> CoordinateTransform {
433 self.inner.get_coord_transform()
434 }
435
436 pub fn into_image(self) -> image::DynamicImage {
438 image::DynamicImage::ImageRgba8(self.inner.color_map.into_inner())
439 }
440}
441
442impl AsRef<image::RgbaImage> for GeneratedImage {
443 fn as_ref(&self) -> &image::RgbaImage {
444 self.inner.color_map.as_ref()
445 }
446}
447
448pub enum GenericSampleMethod<Img> {
450 All,
452 Ignore,
454 Image(Img),
456}
457
458pub type SampleMethod<'a> = GenericSampleMethod<ImageSource<'a>>;
459pub type SamplingMethod = GenericSampleMethod<image::RgbaImage>;
460
461impl<Img> GenericSampleMethod<Img> {
462 #[inline]
463 fn is_ignore(&self) -> bool {
464 matches!(self, Self::Ignore)
465 }
466}
467
468impl<'a, IS> From<IS> for SampleMethod<'a>
469where
470 IS: Into<ImageSource<'a>>,
471{
472 fn from(is: IS) -> Self {
473 SampleMethod::Image(is.into())
474 }
475}
476
477pub struct ExampleBuilder<'a> {
479 img: ImageSource<'a>,
480 guide: Option<ImageSource<'a>>,
481 sample_method: SampleMethod<'a>,
482}
483
484impl<'a> ExampleBuilder<'a> {
485 pub fn new<I: Into<ImageSource<'a>>>(img: I) -> Self {
487 Self {
488 img: img.into(),
489 guide: None,
490 sample_method: SampleMethod::All,
491 }
492 }
493
494 pub fn with_guide<G: Into<ImageSource<'a>>>(mut self, guide: G) -> Self {
499 self.guide = Some(guide.into());
500 self
501 }
502
503 pub fn set_sample_method<M: Into<SampleMethod<'a>>>(mut self, method: M) -> Self {
507 self.sample_method = method.into();
508 self
509 }
510}
511
512pub struct Example<'a> {
514 img: ImageSource<'a>,
515 guide: Option<ImageSource<'a>>,
516 sample_method: SampleMethod<'a>,
517}
518
519impl<'a> Example<'a> {
520 pub fn builder<I: Into<ImageSource<'a>>>(img: I) -> ExampleBuilder<'a> {
522 ExampleBuilder::new(img)
523 }
524
525 pub fn image_source(&self) -> &ImageSource<'a> {
526 &self.img
527 }
528
529 pub fn new<I: Into<ImageSource<'a>>>(img: I) -> Self {
531 Self {
532 img: img.into(),
533 guide: None,
534 sample_method: SampleMethod::All,
535 }
536 }
537
538 pub fn with_guide<G: Into<ImageSource<'a>>>(&mut self, guide: G) -> &mut Self {
543 self.guide = Some(guide.into());
544 self
545 }
546
547 pub fn set_sample_method<M: Into<SampleMethod<'a>>>(&mut self, method: M) -> &mut Self {
551 self.sample_method = method.into();
552 self
553 }
554
555 fn resolve(
556 self,
557 backtracks: u32,
558 resize: Option<Dims>,
559 target_guide: &Option<ImagePyramid>,
560 ) -> Result<ResolvedExample, Error> {
561 let image = ImagePyramid::new(load_image(self.img, resize)?, Some(backtracks));
562
563 let guide = match target_guide {
564 Some(tg) => {
565 Some(match self.guide {
566 Some(exguide) => {
567 let exguide = load_image(exguide, resize)?;
568 ImagePyramid::new(exguide, Some(backtracks))
569 }
570 None => {
571 let mut gm = transform_to_guide_map(image.bottom().clone(), resize, 2.0);
573 match_histograms(&mut gm, tg.bottom());
574
575 ImagePyramid::new(gm, Some(backtracks))
576 }
577 })
578 }
579 None => None,
580 };
581
582 let method = match self.sample_method {
583 SampleMethod::All => SamplingMethod::All,
584 SampleMethod::Ignore => SamplingMethod::Ignore,
585 SampleMethod::Image(src) => {
586 let img = load_image(src, resize)?;
587 SamplingMethod::Image(img)
588 }
589 };
590
591 Ok(ResolvedExample {
592 image,
593 guide,
594 method,
595 })
596 }
597}
598
599impl<'a> From<ExampleBuilder<'a>> for Example<'a> {
600 fn from(eb: ExampleBuilder<'a>) -> Self {
601 Self {
602 img: eb.img,
603 guide: eb.guide,
604 sample_method: eb.sample_method,
605 }
606 }
607}
608
609impl<'a, IS> From<IS> for Example<'a>
610where
611 IS: Into<ImageSource<'a>>,
612{
613 fn from(is: IS) -> Self {
614 Example::new(is)
615 }
616}
617
618enum MaskOrImg<'a> {
619 Mask(utils::ChannelMask),
620 ImageSource(ImageSource<'a>),
621}
622
623struct InpaintMask<'a> {
624 src: MaskOrImg<'a>,
625 example_index: usize,
626 dims: Dims,
627}
628
629struct ResolvedExample {
630 image: ImagePyramid,
631 guide: Option<ImagePyramid>,
632 method: SamplingMethod,
633}
634
635#[cfg(test)]
636mod test {
637 #[test]
638 fn coord_tx_serde() {
639 use super::CoordinateTransform as CT;
640
641 let fake_buffer = vec![1, 2, 3, 4, 5, 6];
642
643 let input = CT {
644 buffer: fake_buffer.clone(),
645 output_size: super::Dims {
646 width: 2,
647 height: 1,
648 },
649 original_maps: vec![
650 super::Dims {
651 width: 9001,
652 height: 9002,
653 },
654 super::Dims {
655 width: 20,
656 height: 5,
657 },
658 ],
659 };
660
661 let mut buffer = Vec::new();
662 input.write(&mut buffer).unwrap();
663
664 let mut cursor = std::io::Cursor::new(&buffer);
665 let deserialized = CT::read(&mut cursor).unwrap();
666
667 assert_eq!(deserialized.buffer, fake_buffer);
668 assert_eq!(deserialized.output_size.width, 2);
669 assert_eq!(deserialized.output_size.height, 1);
670
671 assert_eq!(
672 super::Dims {
673 width: 9001,
674 height: 9002,
675 },
676 deserialized.original_maps[0]
677 );
678 assert_eq!(
679 super::Dims {
680 width: 20,
681 height: 5,
682 },
683 deserialized.original_maps[1]
684 );
685 }
686}