stable_diffusion_trainer/environment/
mod.rs

1//! Environment module.
2use std::path::{Path, PathBuf};
3
4use directories::ProjectDirs;
5use serde::{Serialize, Deserialize};
6
7/// The environment structure.
8#[derive(Debug, Serialize, Deserialize)]
9pub struct Environment {
10    kohya_ss: PathBuf,
11    #[serde(skip)]
12    previous_dir: PathBuf
13}
14
15impl Default for Environment {
16    fn default() -> Self {
17        Self::load()
18    }
19}
20
21impl Environment {
22    /// Load the environment from the configuration file.
23    pub fn load() -> Self {
24        let kohya_ss = ProjectDirs::from("com", "sensorial-systems", "stable-diffusion-trainer")
25            .map(|dirs| dirs.config_dir().to_path_buf())
26            .map(|config_dir| config_dir.join("config.json"))
27            .and_then(|config_path| std::fs::read_to_string(config_path).ok())
28            .and_then(|config| serde_json::from_str::<Environment>(&config).ok())
29            .map(|env| env.kohya_ss().to_path_buf())
30            .unwrap_or_default();
31        let previous_dir = Default::default();
32        Environment { kohya_ss, previous_dir }
33    }
34
35    /// Save the environment to the configuration file.
36    pub fn save(&self) -> Result<(), Box<dyn std::error::Error>> {
37        let config_dir = ProjectDirs::from("com", "sensorial-systems", "stable-diffusion-trainer")
38            .map(|dirs| dirs.config_dir().to_path_buf())
39            .expect("Failed to get config_dir");
40        let config_path = config_dir.join("config.json");
41        std::fs::create_dir_all(&config_dir)?;
42        let json = serde_json::to_string(self)?;
43        std::fs::write(config_path, json)?;
44        Ok(())
45    }
46
47    /// Create a new environment structure.
48    pub fn new() -> Self {
49        Default::default()
50    }
51
52    /// Set the kohya_ss path.
53    pub fn with_kohya_ss(mut self, kohya_ss: impl Into<PathBuf>) -> Self {
54        self.kohya_ss = kohya_ss.into();
55        self
56    }
57
58    /// Get the kohya_ss path.
59    pub fn kohya_ss(&self) -> &Path {
60        &self.kohya_ss
61    }
62
63    /// Get the kohya_ss path.
64    pub fn binary_path(&self) -> PathBuf {
65        #[cfg(target_os = "windows")]
66        let python_executable = self.kohya_ss.join("venv").join("Scripts");
67        #[cfg(not(target_os = "windows"))]
68        let python_executable = self.kohya_ss.join("venv").join("bin");
69        python_executable
70    }
71
72    /// Get the kohya_ss path.
73    pub fn python_executable_path(&self) -> PathBuf {
74        #[cfg(target_os = "windows")]
75        let python_executable = self.binary_path().join("python.exe");
76        #[cfg(not(target_os = "windows"))]
77        let python_executable = self.binary_path().join("python");
78        python_executable
79    }
80
81    /// Activate the environment.
82    pub fn activate(&mut self) {
83        std::env::set_var("PYTHONPATH", self.kohya_ss.join("venv").join("Lib").join("site-packages"));
84        #[cfg(target_os = "windows")]
85        std::env::set_var("PATH", format!("{};{}", self.binary_path().display(), std::env::var("PATH").unwrap()));
86        #[cfg(not(target_os = "windows"))]
87        std::env::set_var("PATH", format!("{}:{}", self.binary_path().display(), std::env::var("PATH").unwrap()));
88        // FIXME: This is too invasive. It should be done in a more controlled way.
89        self.previous_dir = std::env::current_dir().unwrap();
90        std::env::set_current_dir(&self.kohya_ss).unwrap();
91    }
92
93    /// Deactivate the environment.
94    pub fn deactivate(&mut self) {
95        std::env::set_current_dir(&self.previous_dir).unwrap();
96    }
97}