1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
use std::error::Error;

use burn::tensor::backend::Backend;

use super::*;
use crate::model::{
    autoencoder::load::load_autoencoder, clip::load::load_clip, load::*, unet::load::load_unet,
};

pub fn load_stable_diffusion<B: Backend>(
    path: &str,
    device: &B::Device,
) -> Result<StableDiffusion<B>, Box<dyn Error>> {
    let n_steps = load_usize::<B>("n_steps", path, device)?;
    let alpha_cumulative_products = load_tensor::<B, 1>("alphas_cumprod", path, device)?.into();
    let autoencoder = load_autoencoder(&format!("{}/{}", path, "autoencoder"), device)?;
    let diffusion = load_unet(&format!("{}/{}", path, "unet"), device)?;
    let clip = load_clip(&format!("{}/{}", path, "clip"), device)?;

    Ok(StableDiffusion {
        n_steps,
        alpha_cumulative_products,
        autoencoder,
        diffusion,
        clip,
    })
}