stable_diffusion_trainer/environment/
mod.rs1use std::path::{Path, PathBuf};
3
4use directories::ProjectDirs;
5use serde::{Serialize, Deserialize};
6
7#[derive(Debug, Serialize, Deserialize)]
9pub struct Environment {
10 kohya_ss: PathBuf,
11 #[serde(skip)]
12 previous_dir: PathBuf
13}
14
15impl Default for Environment {
16 fn default() -> Self {
17 Self::load()
18 }
19}
20
21impl Environment {
22 pub fn load() -> Self {
24 let kohya_ss = ProjectDirs::from("com", "sensorial-systems", "stable-diffusion-trainer")
25 .map(|dirs| dirs.config_dir().to_path_buf())
26 .map(|config_dir| config_dir.join("config.json"))
27 .and_then(|config_path| std::fs::read_to_string(config_path).ok())
28 .and_then(|config| serde_json::from_str::<Environment>(&config).ok())
29 .map(|env| env.kohya_ss().to_path_buf())
30 .unwrap_or_default();
31 let previous_dir = Default::default();
32 Environment { kohya_ss, previous_dir }
33 }
34
35 pub fn save(&self) -> Result<(), Box<dyn std::error::Error>> {
37 let config_dir = ProjectDirs::from("com", "sensorial-systems", "stable-diffusion-trainer")
38 .map(|dirs| dirs.config_dir().to_path_buf())
39 .expect("Failed to get config_dir");
40 let config_path = config_dir.join("config.json");
41 std::fs::create_dir_all(&config_dir)?;
42 let json = serde_json::to_string(self)?;
43 std::fs::write(config_path, json)?;
44 Ok(())
45 }
46
47 pub fn new() -> Self {
49 Default::default()
50 }
51
52 pub fn with_kohya_ss(mut self, kohya_ss: impl Into<PathBuf>) -> Self {
54 self.kohya_ss = kohya_ss.into();
55 self
56 }
57
58 pub fn kohya_ss(&self) -> &Path {
60 &self.kohya_ss
61 }
62
63 pub fn binary_path(&self) -> PathBuf {
65 #[cfg(target_os = "windows")]
66 let python_executable = self.kohya_ss.join("venv").join("Scripts");
67 #[cfg(not(target_os = "windows"))]
68 let python_executable = self.kohya_ss.join("venv").join("bin");
69 python_executable
70 }
71
72 pub fn python_executable_path(&self) -> PathBuf {
74 #[cfg(target_os = "windows")]
75 let python_executable = self.binary_path().join("python.exe");
76 #[cfg(not(target_os = "windows"))]
77 let python_executable = self.binary_path().join("python");
78 python_executable
79 }
80
81 pub fn activate(&mut self) {
83 std::env::set_var("PYTHONPATH", self.kohya_ss.join("venv").join("Lib").join("site-packages"));
84 #[cfg(target_os = "windows")]
85 std::env::set_var("PATH", format!("{};{}", self.binary_path().display(), std::env::var("PATH").unwrap()));
86 #[cfg(not(target_os = "windows"))]
87 std::env::set_var("PATH", format!("{}:{}", self.binary_path().display(), std::env::var("PATH").unwrap()));
88 self.previous_dir = std::env::current_dir().unwrap();
90 std::env::set_current_dir(&self.kohya_ss).unwrap();
91 }
92
93 pub fn deactivate(&mut self) {
95 std::env::set_current_dir(&self.previous_dir).unwrap();
96 }
97}