texture_synthesis/
lib.rs

1// BEGIN - Embark standard lints v0.4
2// do not change or add/remove here, but one can add exceptions after this section
3// for more info see: <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
4#![deny(unsafe_code)]
5#![warn(
6    clippy::all,
7    clippy::await_holding_lock,
8    clippy::char_lit_as_u8,
9    clippy::checked_conversions,
10    clippy::dbg_macro,
11    clippy::debug_assert_with_mut_call,
12    clippy::doc_markdown,
13    clippy::empty_enum,
14    clippy::enum_glob_use,
15    clippy::exit,
16    clippy::expl_impl_clone_on_copy,
17    clippy::explicit_deref_methods,
18    clippy::explicit_into_iter_loop,
19    clippy::fallible_impl_from,
20    clippy::filter_map_next,
21    clippy::float_cmp_const,
22    clippy::fn_params_excessive_bools,
23    clippy::if_let_mutex,
24    clippy::implicit_clone,
25    clippy::imprecise_flops,
26    clippy::inefficient_to_string,
27    clippy::invalid_upcast_comparisons,
28    clippy::large_types_passed_by_value,
29    clippy::let_unit_value,
30    clippy::linkedlist,
31    clippy::lossy_float_literal,
32    clippy::macro_use_imports,
33    clippy::manual_ok_or,
34    clippy::map_err_ignore,
35    clippy::map_flatten,
36    clippy::map_unwrap_or,
37    clippy::match_on_vec_items,
38    clippy::match_same_arms,
39    clippy::match_wildcard_for_single_variants,
40    clippy::mem_forget,
41    clippy::mismatched_target_os,
42    clippy::mut_mut,
43    clippy::mutex_integer,
44    clippy::needless_borrow,
45    clippy::needless_continue,
46    clippy::option_option,
47    clippy::path_buf_push_overwrite,
48    clippy::ptr_as_ptr,
49    clippy::ref_option_ref,
50    clippy::rest_pat_in_fully_bound_structs,
51    clippy::same_functions_in_if_condition,
52    clippy::semicolon_if_nothing_returned,
53    clippy::string_add_assign,
54    clippy::string_add,
55    clippy::string_lit_as_bytes,
56    clippy::string_to_string,
57    clippy::todo,
58    clippy::trait_duplication_in_bounds,
59    clippy::unimplemented,
60    clippy::unnested_or_patterns,
61    clippy::unused_self,
62    clippy::useless_transmute,
63    clippy::verbose_file_reads,
64    clippy::zero_sized_map_values,
65    future_incompatible,
66    nonstandard_style,
67    rust_2018_idioms
68)]
69// END - Embark standard lints v0.4
70#![allow(unsafe_code)]
71
72//! `texture-synthesis` is a light API for Multiresolution Stochastic Texture Synthesis,
73//! a non-parametric example-based algorithm for image generation.
74//!
75//! First, you build a `Session` via a `SessionBuilder`, which follows the builder pattern. Calling
76//! `build` on the `SessionBuilder` loads all of the input images and checks for various errors.
77//!
78//! `Session` has a `run()` method that takes all of the parameters and inputs added in the session
79//! builder to generated an image, which is returned as a `GeneratedImage`.
80//!
81//! You can save, stream, or inspect the image from `GeneratedImage`.
82//!
83//! ## Features
84//!
85//! 1. Single example generation
86//! 2. Multi example generation
87//! 3. Guided synthesis
88//! 4. Style transfer
89//! 5. Inpainting
90//! 6. Tiling textures
91//!
92//! Please, refer to the examples folder in the [repository](https://github.com/EmbarkStudios/texture-synthesis) for the features usage examples.
93//!
94//! ## Usage
95//! Session follows a "builder pattern" for defining parameters, meaning you chain functions together.
96//!
97//! ```no_run
98//! // Create a new session with default parameters
99//! let session = texture_synthesis::Session::builder()
100//!     // Set some parameters
101//!     .seed(10)
102//!     .nearest_neighbors(20)
103//!     // Specify example images
104//!     .add_example(&"imgs/1.jpg")
105//!     // Build the session
106//!     .build().expect("failed to build session");
107//!
108//! // Generate a new image
109//! let generated_img = session.run(None);
110//!
111//! // Save the generated image to disk
112//! generated_img.save("my_generated_img.jpg").expect("failed to save generated image");
113//! ```
114mod errors;
115mod img_pyramid;
116use img_pyramid::*;
117mod utils;
118use utils::*;
119mod ms;
120use ms::*;
121pub mod session;
122mod unsync;
123
124pub use image;
125use std::path::Path;
126
127pub use errors::Error;
128pub use session::{Session, SessionBuilder};
129pub use utils::{load_dynamic_image, ChannelMask, ImageSource};
130
131/// Simple dimensions struct
132#[derive(Copy, Clone)]
133#[cfg_attr(test, derive(Debug, PartialEq))]
134pub struct Dims {
135    pub width: u32,
136    pub height: u32,
137}
138
139impl Dims {
140    pub fn square(size: u32) -> Self {
141        Self {
142            width: size,
143            height: size,
144        }
145    }
146    pub fn new(width: u32, height: u32) -> Self {
147        Self { width, height }
148    }
149}
150
151/// A buffer of transforms that were used to generate an image from a set of
152/// examples, which can be applied to a different set of input images to get
153/// a different output image.
154pub struct CoordinateTransform {
155    buffer: Vec<u32>,
156    pub output_size: Dims,
157    original_maps: Vec<Dims>,
158}
159
160const TRANSFORM_MAGIC: u32 = 0x1234_0001;
161
162impl<'a> CoordinateTransform {
163    /// Applies the coordinate transformation from new source images. This
164    /// method will fail if the the provided source images aren't the same
165    /// number of example images that generated the transform.
166    ///
167    /// The input images are automatically resized to the dimensions of the
168    /// original example images used in the generation of this coordinate
169    /// transform
170    pub fn apply<E, I>(&self, source: I) -> Result<image::RgbaImage, Error>
171    where
172        I: IntoIterator<Item = E>,
173        E: Into<ImageSource<'a>>,
174    {
175        let ref_maps: Vec<image::RgbaImage> = source
176            .into_iter()
177            .zip(self.original_maps.iter())
178            .map(|(is, dims)| load_image(is.into(), Some(*dims)))
179            .collect::<Result<Vec<_>, Error>>()?;
180
181        // Ensure the number of inputs match the number in that generated this
182        // transform, otherwise we would get weird results
183        if ref_maps.len() != self.original_maps.len() {
184            return Err(Error::MapsCountMismatch(
185                ref_maps.len() as u32,
186                self.original_maps.len() as u32,
187            ));
188        }
189
190        let mut img = image::RgbaImage::new(self.output_size.width, self.output_size.height);
191
192        // Populate with pixels from ref maps
193        for (i, pix) in img.pixels_mut().enumerate() {
194            let x = self.buffer[i * 3];
195            let y = self.buffer[i * 3 + 1];
196            let map = self.buffer[i * 3 + 2];
197
198            *pix = *ref_maps[map as usize].get_pixel(x, y);
199        }
200
201        Ok(img)
202    }
203
204    pub fn write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<usize> {
205        use std::mem;
206        let mut written = 0;
207
208        // Sanity check that that buffer length corresponds correctly with the
209        // supposed dimensions
210        if self.buffer.len()
211            != self.output_size.width as usize * self.output_size.height as usize * 3
212        {
213            return Err(std::io::Error::new(
214                std::io::ErrorKind::InvalidInput,
215                "buffer length doesn't match dimensions",
216            ));
217        }
218
219        let header = [
220            TRANSFORM_MAGIC,
221            self.output_size.width,
222            self.output_size.height,
223            self.original_maps.len() as u32,
224        ];
225
226        fn cast(ina: &[u32]) -> &[u8] {
227            unsafe {
228                let p = ina.as_ptr();
229                let len = ina.len();
230
231                std::slice::from_raw_parts(p.cast::<u8>(), len * mem::size_of::<u32>())
232            }
233        }
234
235        w.write_all(cast(&header))?;
236        written += mem::size_of_val(&header);
237
238        for om in &self.original_maps {
239            let dims = [om.width, om.height];
240            w.write_all(cast(&dims))?;
241            written += mem::size_of_val(&dims);
242        }
243
244        w.write_all(cast(&self.buffer))?;
245        written += 4 * self.buffer.len();
246
247        Ok(written)
248    }
249
250    pub fn read<R: std::io::Read>(r: &mut R) -> std::io::Result<Self> {
251        use std::{
252            io::{Error, ErrorKind, Read},
253            mem,
254        };
255
256        fn do_read<R: Read>(r: &mut R, buf: &mut [u32]) -> std::io::Result<()> {
257            unsafe {
258                let p = buf.as_mut_ptr();
259                let len = buf.len();
260
261                let mut slice =
262                    std::slice::from_raw_parts_mut(p.cast::<u8>(), len * mem::size_of::<u32>());
263
264                r.read(&mut slice).map(|_| ())
265            }
266        }
267
268        let mut magic = [0u32];
269        do_read(r, &mut magic)?;
270
271        if magic[0] >> 16 != 0x1234 {
272            return Err(Error::new(ErrorKind::InvalidData, "invalid magic"));
273        }
274
275        let (output_size, original_maps) = match magic[0] & 0x0000_ffff {
276            0x1 => {
277                let mut header = [0u32; 3];
278                do_read(r, &mut header)?;
279
280                let mut omaps = Vec::with_capacity(header[2] as usize);
281                for _ in 0..header[2] {
282                    let mut dims = [0u32; 2];
283                    do_read(r, &mut dims)?;
284                    omaps.push(Dims {
285                        width: dims[0],
286                        height: dims[1],
287                    });
288                }
289
290                (
291                    Dims {
292                        width: header[0],
293                        height: header[1],
294                    },
295                    omaps,
296                )
297            }
298            _ => return Err(Error::new(ErrorKind::InvalidData, "invalid version")),
299        };
300
301        let buffer = unsafe {
302            let len = output_size.width as usize * output_size.height as usize * 3;
303            let mut buffer = Vec::with_capacity(len);
304            buffer.set_len(len);
305
306            do_read(r, &mut buffer)?;
307            buffer
308        };
309
310        Ok(Self {
311            buffer,
312            output_size,
313            original_maps,
314        })
315    }
316}
317
318struct Parameters {
319    tiling_mode: bool,
320    nearest_neighbors: u32,
321    random_sample_locations: u64,
322    cauchy_dispersion: f32,
323    backtrack_percent: f32,
324    backtrack_stages: u32,
325    resize_input: Option<Dims>,
326    output_size: Dims,
327    guide_alpha: f32,
328    random_resolve: Option<u64>,
329    max_thread_count: Option<usize>,
330    seed: u64,
331}
332
333impl Default for Parameters {
334    fn default() -> Self {
335        Self {
336            tiling_mode: false,
337            nearest_neighbors: 50,
338            random_sample_locations: 50,
339            cauchy_dispersion: 1.0,
340            backtrack_percent: 0.5,
341            backtrack_stages: 5,
342            resize_input: None,
343            output_size: Dims::square(500),
344            guide_alpha: 0.8,
345            random_resolve: None,
346            max_thread_count: None,
347            seed: 0,
348        }
349    }
350}
351
352impl Parameters {
353    fn to_generator_params(&self) -> GeneratorParams {
354        GeneratorParams {
355            nearest_neighbors: self.nearest_neighbors,
356            random_sample_locations: self.random_sample_locations,
357            cauchy_dispersion: self.cauchy_dispersion,
358            p: self.backtrack_percent,
359            p_stages: self.backtrack_stages as i32,
360            seed: self.seed,
361            alpha: self.guide_alpha,
362            max_thread_count: self.max_thread_count.unwrap_or_else(num_cpus::get),
363            tiling_mode: self.tiling_mode,
364        }
365    }
366}
367
368/// An image generated by a `Session::run()`
369pub struct GeneratedImage {
370    inner: ms::Generator,
371}
372
373impl GeneratedImage {
374    /// Saves the generated image to the specified path
375    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
376        let path = path.as_ref();
377        if let Some(parent_path) = path.parent() {
378            std::fs::create_dir_all(&parent_path)?;
379        }
380
381        self.inner.color_map.as_ref().save(&path)?;
382        Ok(())
383    }
384
385    /// Writes the generated image to the specified stream
386    pub fn write<W: std::io::Write>(
387        self,
388        writer: &mut W,
389        fmt: image::ImageOutputFormat,
390    ) -> Result<(), Error> {
391        let dyn_img = self.into_image();
392        Ok(dyn_img.write_to(writer, fmt)?)
393    }
394
395    /// Saves debug information such as copied patches ids, map ids (if you have
396    /// multi example generation) and a map indicating generated pixels the
397    /// generator was "uncertain" of.
398    pub fn save_debug<P: AsRef<Path>>(&self, dir: P) -> Result<(), Error> {
399        let dir = dir.as_ref();
400        std::fs::create_dir_all(&dir)?;
401
402        self.inner
403            .get_uncertainty_map()
404            .save(&dir.join("uncertainty.png"))?;
405        let id_maps = self.inner.get_id_maps();
406        id_maps[0].save(&dir.join("patch_id.png"))?;
407        id_maps[1].save(&dir.join("map_id.png"))?;
408
409        Ok(())
410    }
411
412    /// Get the coordinate transform of this generated image, which can be
413    /// applied to new example images to get a different output image.
414    ///
415    /// ```no_run
416    /// use texture_synthesis as ts;
417    ///
418    /// // create a new session
419    /// let texsynth = ts::Session::builder()
420    ///     //load a single example image
421    ///     .add_example(&"imgs/1.jpg")
422    ///     .build().unwrap();
423    ///
424    /// // generate an image
425    /// let generated = texsynth.run(None);
426    ///
427    /// // now we can repeat the same transformation on a different image
428    /// let repeated_transform_image = generated
429    ///     .get_coordinate_transform()
430    ///     .apply(&["imgs/2.jpg"]);
431    /// ```
432    pub fn get_coordinate_transform(&self) -> CoordinateTransform {
433        self.inner.get_coord_transform()
434    }
435
436    /// Returns the generated output image
437    pub fn into_image(self) -> image::DynamicImage {
438        image::DynamicImage::ImageRgba8(self.inner.color_map.into_inner())
439    }
440}
441
442impl AsRef<image::RgbaImage> for GeneratedImage {
443    fn as_ref(&self) -> &image::RgbaImage {
444        self.inner.color_map.as_ref()
445    }
446}
447
448/// Method used for sampling an example image.
449pub enum GenericSampleMethod<Img> {
450    /// All pixels in the example image can be sampled.
451    All,
452    /// No pixels in the example image will be sampled.
453    Ignore,
454    /// Pixels are selectively sampled based on an image.
455    Image(Img),
456}
457
458pub type SampleMethod<'a> = GenericSampleMethod<ImageSource<'a>>;
459pub type SamplingMethod = GenericSampleMethod<image::RgbaImage>;
460
461impl<Img> GenericSampleMethod<Img> {
462    #[inline]
463    fn is_ignore(&self) -> bool {
464        matches!(self, Self::Ignore)
465    }
466}
467
468impl<'a, IS> From<IS> for SampleMethod<'a>
469where
470    IS: Into<ImageSource<'a>>,
471{
472    fn from(is: IS) -> Self {
473        SampleMethod::Image(is.into())
474    }
475}
476
477/// A builder for an `Example`
478pub struct ExampleBuilder<'a> {
479    img: ImageSource<'a>,
480    guide: Option<ImageSource<'a>>,
481    sample_method: SampleMethod<'a>,
482}
483
484impl<'a> ExampleBuilder<'a> {
485    /// Creates a new example builder from the specified image source
486    pub fn new<I: Into<ImageSource<'a>>>(img: I) -> Self {
487        Self {
488            img: img.into(),
489            guide: None,
490            sample_method: SampleMethod::All,
491        }
492    }
493
494    /// Use a guide map that describe a 'FROM' transformation.
495    ///
496    /// Note: If any one example has a guide, then they **all** must have
497    /// a guide, otherwise a session will not be created.
498    pub fn with_guide<G: Into<ImageSource<'a>>>(mut self, guide: G) -> Self {
499        self.guide = Some(guide.into());
500        self
501    }
502
503    /// Specify how the example image is sampled during texture generation.
504    ///
505    /// By default, all pixels in the example can be sampled.
506    pub fn set_sample_method<M: Into<SampleMethod<'a>>>(mut self, method: M) -> Self {
507        self.sample_method = method.into();
508        self
509    }
510}
511
512/// An example to be used in texture generation
513pub struct Example<'a> {
514    img: ImageSource<'a>,
515    guide: Option<ImageSource<'a>>,
516    sample_method: SampleMethod<'a>,
517}
518
519impl<'a> Example<'a> {
520    /// Creates a new example builder from the specified image source
521    pub fn builder<I: Into<ImageSource<'a>>>(img: I) -> ExampleBuilder<'a> {
522        ExampleBuilder::new(img)
523    }
524
525    pub fn image_source(&self) -> &ImageSource<'a> {
526        &self.img
527    }
528
529    /// Creates a new example input from the specified image source
530    pub fn new<I: Into<ImageSource<'a>>>(img: I) -> Self {
531        Self {
532            img: img.into(),
533            guide: None,
534            sample_method: SampleMethod::All,
535        }
536    }
537
538    /// Use a guide map that describe a 'FROM' transformation.
539    ///
540    /// Note: If any one example has a guide, then they **all** must have
541    /// a guide, otherwise a session will not be created.
542    pub fn with_guide<G: Into<ImageSource<'a>>>(&mut self, guide: G) -> &mut Self {
543        self.guide = Some(guide.into());
544        self
545    }
546
547    /// Specify how the example image is sampled during texture generation.
548    ///
549    /// By default, all pixels in the example can be sampled.
550    pub fn set_sample_method<M: Into<SampleMethod<'a>>>(&mut self, method: M) -> &mut Self {
551        self.sample_method = method.into();
552        self
553    }
554
555    fn resolve(
556        self,
557        backtracks: u32,
558        resize: Option<Dims>,
559        target_guide: &Option<ImagePyramid>,
560    ) -> Result<ResolvedExample, Error> {
561        let image = ImagePyramid::new(load_image(self.img, resize)?, Some(backtracks));
562
563        let guide = match target_guide {
564            Some(tg) => {
565                Some(match self.guide {
566                    Some(exguide) => {
567                        let exguide = load_image(exguide, resize)?;
568                        ImagePyramid::new(exguide, Some(backtracks))
569                    }
570                    None => {
571                        // if we do not have an example guide, create it as a b/w maps of the example
572                        let mut gm = transform_to_guide_map(image.bottom().clone(), resize, 2.0);
573                        match_histograms(&mut gm, tg.bottom());
574
575                        ImagePyramid::new(gm, Some(backtracks))
576                    }
577                })
578            }
579            None => None,
580        };
581
582        let method = match self.sample_method {
583            SampleMethod::All => SamplingMethod::All,
584            SampleMethod::Ignore => SamplingMethod::Ignore,
585            SampleMethod::Image(src) => {
586                let img = load_image(src, resize)?;
587                SamplingMethod::Image(img)
588            }
589        };
590
591        Ok(ResolvedExample {
592            image,
593            guide,
594            method,
595        })
596    }
597}
598
599impl<'a> From<ExampleBuilder<'a>> for Example<'a> {
600    fn from(eb: ExampleBuilder<'a>) -> Self {
601        Self {
602            img: eb.img,
603            guide: eb.guide,
604            sample_method: eb.sample_method,
605        }
606    }
607}
608
609impl<'a, IS> From<IS> for Example<'a>
610where
611    IS: Into<ImageSource<'a>>,
612{
613    fn from(is: IS) -> Self {
614        Example::new(is)
615    }
616}
617
618enum MaskOrImg<'a> {
619    Mask(utils::ChannelMask),
620    ImageSource(ImageSource<'a>),
621}
622
623struct InpaintMask<'a> {
624    src: MaskOrImg<'a>,
625    example_index: usize,
626    dims: Dims,
627}
628
629struct ResolvedExample {
630    image: ImagePyramid,
631    guide: Option<ImagePyramid>,
632    method: SamplingMethod,
633}
634
635#[cfg(test)]
636mod test {
637    #[test]
638    fn coord_tx_serde() {
639        use super::CoordinateTransform as CT;
640
641        let fake_buffer = vec![1, 2, 3, 4, 5, 6];
642
643        let input = CT {
644            buffer: fake_buffer.clone(),
645            output_size: super::Dims {
646                width: 2,
647                height: 1,
648            },
649            original_maps: vec![
650                super::Dims {
651                    width: 9001,
652                    height: 9002,
653                },
654                super::Dims {
655                    width: 20,
656                    height: 5,
657                },
658            ],
659        };
660
661        let mut buffer = Vec::new();
662        input.write(&mut buffer).unwrap();
663
664        let mut cursor = std::io::Cursor::new(&buffer);
665        let deserialized = CT::read(&mut cursor).unwrap();
666
667        assert_eq!(deserialized.buffer, fake_buffer);
668        assert_eq!(deserialized.output_size.width, 2);
669        assert_eq!(deserialized.output_size.height, 1);
670
671        assert_eq!(
672            super::Dims {
673                width: 9001,
674                height: 9002,
675            },
676            deserialized.original_maps[0]
677        );
678        assert_eq!(
679            super::Dims {
680                width: 20,
681                height: 5,
682            },
683            deserialized.original_maps[1]
684        );
685    }
686}