texture_synthesis/
session.rs

1use crate::*;
2
3/// Texture synthesis session.
4///
5/// Calling `run()` will generate a new image and return it, consuming the
6/// session in the process. You can provide a `GeneratorProgress` implementation
7/// to periodically get updates with the currently generated image and the
8/// number of pixels that have been resolved both in the current stage and
9/// globally.
10///
11/// # Example
12/// ```no_run
13/// let tex_synth = texture_synthesis::Session::builder()
14///     .seed(10)
15///     .tiling_mode(true)
16///     .add_example(&"imgs/1.jpg")
17///     .build().expect("failed to build session");
18///
19/// let generated_img = tex_synth.run(None);
20/// generated_img.save("my_generated_img.jpg").expect("failed to save image");
21/// ```
22pub struct Session {
23    examples: Vec<ImagePyramid>,
24    guides: Option<GuidesPyramidStruct>,
25    sampling_methods: Vec<SamplingMethod>,
26    generator: Generator,
27    params: Parameters,
28}
29
30impl Session {
31    /// Creates a new session with default parameters.
32    pub fn builder<'a>() -> SessionBuilder<'a> {
33        SessionBuilder::default()
34    }
35
36    /// Runs the generator and outputs a generated image.
37    pub fn run(mut self, progress: Option<Box<dyn GeneratorProgress>>) -> GeneratedImage {
38        // random resolve
39        // TODO: Instead of consuming the generator, we could instead make the
40        // seed and random_resolve parameters, so that you could rerun the
41        // generator with the same inputs
42        if let Some(count) = self.params.random_resolve {
43            let lvl = self.examples[0].pyramid.len();
44            let imgs: Vec<_> = self
45                .examples
46                .iter()
47                .map(|a| ImageBuffer::from(&a.pyramid[lvl - 1])) //take the blurriest image
48                .collect();
49
50            self.generator
51                .resolve_random_batch(count as usize, &imgs, self.params.seed);
52        }
53
54        // run generator
55        self.generator.resolve(
56            &self.params.to_generator_params(),
57            &self.examples,
58            progress,
59            &self.guides,
60            &self.sampling_methods,
61        );
62
63        GeneratedImage {
64            inner: self.generator,
65        }
66    }
67}
68
69/// Builds a session by setting parameters and adding input images, calling
70/// `build` will check all of the provided inputs to verify that texture
71/// synthesis will provide valid output
72#[derive(Default)]
73pub struct SessionBuilder<'a> {
74    examples: Vec<Example<'a>>,
75    target_guide: Option<ImageSource<'a>>,
76    inpaint_mask: Option<InpaintMask<'a>>,
77    params: Parameters,
78}
79
80impl<'a> SessionBuilder<'a> {
81    /// Creates a new `SessionBuilder`, can also be created via
82    /// `Session::builder()`
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Adds an `Example` from which a generator will synthesize a new image.
88    ///
89    /// See [`examples/01_single_example_synthesis`](https://github.com/EmbarkStudios/texture-synthesis/tree/main/lib/examples/01_single_example_synthesis.rs)
90    ///
91    /// # Examples
92    ///
93    /// ```no_run
94    /// let tex_synth = texture_synthesis::Session::builder()
95    ///     .add_example(&"imgs/1.jpg")
96    ///     .build().expect("failed to build session");
97    /// ```
98    pub fn add_example<E: Into<Example<'a>>>(mut self, example: E) -> Self {
99        self.examples.push(example.into());
100        self
101    }
102
103    /// Adds Examples from which a generator will synthesize a new image.
104    ///
105    /// See [`examples/02_multi_example_synthesis`](https://github.com/EmbarkStudios/texture-synthesis/tree/main/lib/examples/02_multi_example_synthesis.rs)
106    ///
107    /// # Examples
108    ///
109    /// ```no_run
110    /// let tex_synth = texture_synthesis::Session::builder()
111    ///     .add_examples(&[&"imgs/1.jpg", &"imgs/2.jpg"])
112    ///     .build().expect("failed to build session");
113    /// ```
114    pub fn add_examples<E: Into<Example<'a>>, I: IntoIterator<Item = E>>(
115        mut self,
116        examples: I,
117    ) -> Self {
118        self.examples.extend(examples.into_iter().map(|e| e.into()));
119        self
120    }
121
122    /// Inpaints an example. Due to how inpainting works, a size must also be
123    /// provided, as all examples, as well as the inpaint mask, must be the same
124    /// size as each other, as well as the final output image. Using
125    /// `resize_input` or `output_size` is ignored if this method is called.
126    ///
127    /// To prevent sampling from the example, you can specify
128    /// `SamplingMethod::Ignore` with `Example::set_sample_method`.
129    ///
130    /// See [`examples/05_inpaint`](https://github.com/EmbarkStudios/texture-synthesis/tree/main/lib/examples/05_inpaint.rs)
131    ///
132    /// # Examples
133    ///
134    /// ```no_run
135    /// let tex_synth = texture_synthesis::Session::builder()
136    ///     .add_examples(&[&"imgs/1.jpg", &"imgs/3.jpg"])
137    ///     .inpaint_example(
138    ///         &"masks/inpaint.jpg",
139    ///         // This will prevent sampling from the imgs/2.jpg, note that
140    ///         // we *MUST* provide at least one example to source from!
141    ///         texture_synthesis::Example::builder(&"imgs/2.jpg")
142    ///             .set_sample_method(texture_synthesis::SampleMethod::Ignore),
143    ///         texture_synthesis::Dims::square(400)
144    ///     )
145    ///     .build().expect("failed to build session");
146    /// ```
147    pub fn inpaint_example<I: Into<ImageSource<'a>>, E: Into<Example<'a>>>(
148        mut self,
149        inpaint_mask: I,
150        example: E,
151        size: Dims,
152    ) -> Self {
153        self.inpaint_mask = Some(InpaintMask {
154            src: MaskOrImg::ImageSource(inpaint_mask.into()),
155            example_index: self.examples.len(),
156            dims: size,
157        });
158        self.examples.push(example.into());
159        self
160    }
161
162    /// Inpaints an example, using a specific channel in the example image as
163    /// the inpaint mask
164    ///
165    /// # Examples
166    ///
167    /// ```no_run
168    /// let tex_synth = texture_synthesis::Session::builder()
169    ///     .inpaint_example_channel(
170    ///         // Let's use inpaint the alpha channel
171    ///         texture_synthesis::ChannelMask::A,
172    ///         &"imgs/bricks.png",
173    ///         texture_synthesis::Dims::square(400)
174    ///     )
175    ///     .build().expect("failed to build session");
176    /// ```
177    pub fn inpaint_example_channel<E: Into<Example<'a>>>(
178        mut self,
179        mask: utils::ChannelMask,
180        example: E,
181        size: Dims,
182    ) -> Self {
183        self.inpaint_mask = Some(InpaintMask {
184            src: MaskOrImg::Mask(mask),
185            example_index: self.examples.len(),
186            dims: size,
187        });
188        self.examples.push(example.into());
189        self
190    }
191
192    /// Loads a target guide map.
193    ///
194    /// If no `Example` guide maps are provided, this will produce a style
195    /// transfer effect, where the Examples are styles and the target guide is
196    /// content.
197    ///
198    /// See [`examples/03_guided_synthesis`](https://github.com/EmbarkStudios/texture-synthesis/tree/main/lib/examples/03_guided_synthesis.rs),
199    /// or [`examples/04_style_transfer`](https://github.com/EmbarkStudios/texture-synthesis/tree/main/lib/examples/04_style_transfer.rs),
200    pub fn load_target_guide<I: Into<ImageSource<'a>>>(mut self, guide: I) -> Self {
201        self.target_guide = Some(guide.into());
202        self
203    }
204
205    /// Overwrite incoming images sizes
206    pub fn resize_input(mut self, dims: Dims) -> Self {
207        self.params.resize_input = Some(dims);
208        self
209    }
210
211    /// Changes pseudo-deterministic seed.
212    ///
213    /// Global structures will stay same, if the same seed is provided, but
214    /// smaller details may change due to undeterministic nature of
215    /// multithreading.
216    pub fn seed(mut self, value: u64) -> Self {
217        self.params.seed = value;
218        self
219    }
220
221    /// Makes the generator output tiling image.
222    ///
223    /// Default: false.
224    pub fn tiling_mode(mut self, is_tiling: bool) -> Self {
225        self.params.tiling_mode = is_tiling;
226        self
227    }
228
229    /// How many neighboring pixels each pixel is aware of during generation.
230    ///
231    /// A larger number means more global structures are captured.
232    ///
233    /// Default: 50
234    pub fn nearest_neighbors(mut self, count: u32) -> Self {
235        self.params.nearest_neighbors = count;
236        self
237    }
238
239    /// The number of random locations that will be considered during a pixel
240    /// resolution apart from its immediate neighbors.
241    ///
242    /// If unsure, keep same as nearest neighbors.
243    ///
244    /// Default: 50
245    pub fn random_sample_locations(mut self, count: u64) -> Self {
246        self.params.random_sample_locations = count;
247        self
248    }
249
250    /// Forces the first `n` pixels to be randomly resolved, and prevents them
251    /// from being overwritten.
252    ///
253    /// Can be an enforcing factor of remixing multiple images together.
254    pub fn random_init(mut self, count: u64) -> Self {
255        self.params.random_resolve = Some(count);
256        self
257    }
258
259    /// The distribution dispersion used for picking best candidate (controls
260    /// the distribution 'tail flatness').
261    ///
262    /// Values close to 0.0 will produce 'harsh' borders between generated
263    /// 'chunks'. Values closer to 1.0 will produce a smoother gradient on those
264    /// borders.
265    ///
266    /// For futher reading, check out P.Harrison's "Image Texture Tools".
267    ///
268    /// Default: 1.0
269    pub fn cauchy_dispersion(mut self, value: f32) -> Self {
270        self.params.cauchy_dispersion = value;
271        self
272    }
273
274    /// Controls the trade-off between guide and example maps.
275    ///
276    /// If doing style transfer, set to about 0.8-0.6 to allow for more global
277    /// structures of the style.
278    ///
279    /// If you'd like the guide maps to be considered through all generation
280    /// stages, set to 1.0, which will prevent guide maps weight "decay" during
281    /// the score calculation.
282    ///
283    /// Default: 0.8
284    pub fn guide_alpha(mut self, value: f32) -> Self {
285        self.params.guide_alpha = value;
286        self
287    }
288
289    /// The percentage of pixels to be backtracked during each `p_stage`.
290    /// Range (0,1).
291    ///
292    /// Default: 0.5
293    pub fn backtrack_percent(mut self, value: f32) -> Self {
294        self.params.backtrack_percent = value;
295        self
296    }
297
298    /// Controls the number of backtracking stages.
299    ///
300    /// Backtracking prevents 'garbage' generation. Right now, the depth of the
301    /// image pyramid for multiresolution synthesis depends on this parameter as
302    /// well.
303    ///
304    /// Default: 5
305    pub fn backtrack_stages(mut self, stages: u32) -> Self {
306        self.params.backtrack_stages = stages;
307        self
308    }
309
310    /// Specify size of the generated image.
311    ///
312    /// Default: 500x500
313    pub fn output_size(mut self, dims: Dims) -> Self {
314        self.params.output_size = dims;
315        self
316    }
317
318    /// Controls the maximum number of threads that will be spawned at any one
319    /// time in parallel.
320    ///
321    /// This number is allowed to exceed the number of logical cores on the
322    /// system, but it should generally be kept at or below that number.
323    ///
324    /// Setting this number to `1` will result in completely deterministic
325    /// image generation, meaning that redoing generation with the same inputs
326    /// will always give you the same outputs.
327    ///
328    /// Default: The number of logical cores on this system.
329    pub fn max_thread_count(mut self, count: usize) -> Self {
330        self.params.max_thread_count = Some(count);
331        self
332    }
333
334    /// Creates a `Session`, or returns an error if invalid parameters or input
335    /// images were specified.
336    pub fn build(mut self) -> Result<Session, Error> {
337        self.check_parameters_validity()?;
338        self.check_images_validity()?;
339
340        struct InpaintExample {
341            inpaint_mask: image::RgbaImage,
342            color_map: image::RgbaImage,
343            example_index: usize,
344        }
345
346        let (inpaint, out_size, in_size) = match self.inpaint_mask {
347            Some(inpaint_mask) => {
348                let dims = inpaint_mask.dims;
349                let inpaint_img = match inpaint_mask.src {
350                    MaskOrImg::ImageSource(img) => load_image(img, Some(dims))?,
351                    MaskOrImg::Mask(mask) => {
352                        let example_img = &mut self.examples[inpaint_mask.example_index].img;
353
354                        let dynamic_img = utils::load_dynamic_image(example_img.clone())?;
355                        let inpaint_src = ImageSource::Image(dynamic_img.clone());
356
357                        // Replace the example image source so we don't load it twice
358                        *example_img = ImageSource::Image(dynamic_img);
359
360                        let inpaint_mask = load_image(inpaint_src, Some(dims))?;
361
362                        utils::apply_mask(inpaint_mask, mask)
363                    }
364                };
365
366                let color_map = load_image(
367                    self.examples[inpaint_mask.example_index].img.clone(),
368                    Some(dims),
369                )?;
370
371                (
372                    Some(InpaintExample {
373                        inpaint_mask: inpaint_img,
374                        color_map,
375                        example_index: inpaint_mask.example_index,
376                    }),
377                    dims,
378                    Some(dims),
379                )
380            }
381            None => (None, self.params.output_size, self.params.resize_input),
382        };
383
384        let target_guide = match self.target_guide {
385            Some(tg) => {
386                let tg_img = load_image(tg, Some(out_size))?;
387
388                let num_guides = self.examples.iter().filter(|ex| ex.guide.is_some()).count();
389                let tg_img = if num_guides == 0 {
390                    transform_to_guide_map(tg_img, None, 2.0)
391                } else {
392                    tg_img
393                };
394
395                Some(ImagePyramid::new(
396                    tg_img,
397                    Some(self.params.backtrack_stages as u32),
398                ))
399            }
400            None => None,
401        };
402
403        let example_len = self.examples.len();
404
405        let mut examples = Vec::with_capacity(example_len);
406        let mut guides = if target_guide.is_some() {
407            Vec::with_capacity(example_len)
408        } else {
409            Vec::new()
410        };
411        let mut methods = Vec::with_capacity(example_len);
412
413        for example in self.examples {
414            let resolved = example.resolve(self.params.backtrack_stages, in_size, &target_guide)?;
415
416            examples.push(resolved.image);
417
418            if let Some(guide) = resolved.guide {
419                guides.push(guide);
420            }
421
422            methods.push(resolved.method);
423        }
424
425        // Initialize generator based on availability of an inpaint_mask.
426        let generator = match inpaint {
427            None => Generator::new(out_size),
428            Some(inpaint) => Generator::new_from_inpaint(
429                out_size,
430                inpaint.inpaint_mask,
431                inpaint.color_map,
432                inpaint.example_index,
433            ),
434        };
435
436        let session = Session {
437            examples,
438            guides: target_guide.map(|tg| GuidesPyramidStruct {
439                target_guide: tg,
440                example_guides: guides,
441            }),
442            sampling_methods: methods,
443            params: self.params,
444            generator,
445        };
446
447        Ok(session)
448    }
449
450    fn check_parameters_validity(&self) -> Result<(), Error> {
451        if self.params.cauchy_dispersion < 0.0 || self.params.cauchy_dispersion > 1.0 {
452            return Err(Error::InvalidRange(errors::InvalidRange {
453                min: 0.0,
454                max: 1.0,
455                value: self.params.cauchy_dispersion,
456                name: "cauchy-dispersion",
457            }));
458        }
459
460        if self.params.backtrack_percent < 0.0 || self.params.backtrack_percent > 1.0 {
461            return Err(Error::InvalidRange(errors::InvalidRange {
462                min: 0.0,
463                max: 1.0,
464                value: self.params.backtrack_percent,
465                name: "backtrack-percent",
466            }));
467        }
468
469        if self.params.guide_alpha < 0.0 || self.params.guide_alpha > 1.0 {
470            return Err(Error::InvalidRange(errors::InvalidRange {
471                min: 0.0,
472                max: 1.0,
473                value: self.params.guide_alpha,
474                name: "guide-alpha",
475            }));
476        }
477
478        if let Some(max_count) = self.params.max_thread_count {
479            if max_count == 0 {
480                return Err(Error::InvalidRange(errors::InvalidRange {
481                    min: 1.0,
482                    max: 1024.0,
483                    value: max_count as f32,
484                    name: "max-thread-count",
485                }));
486            }
487        }
488
489        if self.params.random_sample_locations == 0 {
490            return Err(Error::InvalidRange(errors::InvalidRange {
491                min: 1.0,
492                max: 1024.0,
493                value: self.params.random_sample_locations as f32,
494                name: "m-rand",
495            }));
496        }
497
498        Ok(())
499    }
500
501    fn check_images_validity(&self) -> Result<(), Error> {
502        // We must have at least one example image to source pixels from
503        let input_count = self
504            .examples
505            .iter()
506            .filter(|ex| !ex.sample_method.is_ignore())
507            .count();
508
509        if input_count == 0 {
510            return Err(Error::NoExamples);
511        }
512
513        // If we have more than one example guide, then *every* example
514        // needs a guide
515        let num_guides = self.examples.iter().filter(|ex| ex.guide.is_some()).count();
516        if num_guides != 0 && self.examples.len() != num_guides {
517            return Err(Error::ExampleGuideMismatch(
518                self.examples.len() as u32,
519                num_guides as u32,
520            ));
521        }
522
523        Ok(())
524    }
525}
526
527/// Helper struct for passing progress information to external callers
528pub struct ProgressStat {
529    /// The current amount of work that has been done
530    pub current: usize,
531    /// The total amount of work to do
532    pub total: usize,
533}
534
535/// The current state of the image generator
536pub struct ProgressUpdate<'a> {
537    /// The currenty resolved image
538    pub image: &'a image::RgbaImage,
539    /// The total progress for the final image
540    pub total: ProgressStat,
541    /// The progress for the current stage
542    pub stage: ProgressStat,
543}
544
545/// Allows the generator to update external callers with the current
546/// progress of the image synthesis
547pub trait GeneratorProgress {
548    fn update(&mut self, info: ProgressUpdate<'_>);
549}
550
551impl<G> GeneratorProgress for G
552where
553    G: FnMut(ProgressUpdate<'_>) + Send,
554{
555    fn update(&mut self, info: ProgressUpdate<'_>) {
556        self(info);
557    }
558}