1use std::collections::{BTreeSet, HashMap};
2use std::path::PathBuf;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use qwen3_tts::VoiceClonePrompt;
8use serde::{Deserialize, Serialize};
9
10use crate::paths;
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "snake_case")]
14pub enum ProfileMode {
15 Icl,
16 Xvector,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ProfileMeta {
21 pub version: u32,
22 pub name: String,
23 pub mode: ProfileMode,
24 pub created_at_unix: u64,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub ref_text_ids: Option<Vec<u32>>,
27}
28
29fn validate_profile_name(name: &str) -> Result<()> {
30 if name.is_empty() {
31 anyhow::bail!("profile name cannot be empty");
32 }
33
34 let ok = name
35 .chars()
36 .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_');
37 if !ok {
38 anyhow::bail!("profile name can only contain [A-Za-z0-9_-]");
39 }
40
41 Ok(())
42}
43
44fn profile_dir(name: &str) -> PathBuf {
45 paths::profiles_dir().join(name)
46}
47
48fn legacy_profile_dir(name: &str) -> PathBuf {
49 paths::legacy_profiles_dir().join(name)
50}
51
52fn existing_profile_dir(name: &str) -> Option<PathBuf> {
53 let current = profile_dir(name);
54 if current.is_dir() {
55 return Some(current);
56 }
57
58 let legacy = legacy_profile_dir(name);
59 legacy.is_dir().then_some(legacy)
60}
61
62pub fn save_profile(name: &str, prompt: &VoiceClonePrompt) -> Result<()> {
63 validate_profile_name(name)?;
64
65 let dir = profile_dir(name);
66 std::fs::create_dir_all(&dir)
67 .with_context(|| format!("failed to create profile dir: {}", dir.display()))?;
68
69 let mut tensors = HashMap::new();
70 tensors.insert(
71 "speaker_embedding".to_string(),
72 prompt.speaker_embedding.clone(),
73 );
74
75 let (mode, ref_text_ids) =
76 if let (Some(codes), Some(ids)) = (&prompt.ref_codes, &prompt.ref_text_ids) {
77 tensors.insert("ref_codes".to_string(), codes.clone());
78 (ProfileMode::Icl, Some(ids.clone()))
79 } else {
80 (ProfileMode::Xvector, None)
81 };
82
83 let tensor_path = dir.join("tensors.safetensors");
84 candle_core::safetensors::save(&tensors, &tensor_path)
85 .with_context(|| format!("failed to write tensors: {}", tensor_path.display()))?;
86
87 let created_at_unix = SystemTime::now()
88 .duration_since(UNIX_EPOCH)
89 .unwrap_or_default()
90 .as_secs();
91
92 let meta = ProfileMeta {
93 version: 1,
94 name: name.to_string(),
95 mode,
96 created_at_unix,
97 ref_text_ids,
98 };
99
100 let meta_path = dir.join("profile.json");
101 let meta_body =
102 serde_json::to_string_pretty(&meta).context("failed to serialize profile meta")?;
103 std::fs::write(&meta_path, meta_body)
104 .with_context(|| format!("failed to write profile meta: {}", meta_path.display()))?;
105
106 Ok(())
107}
108
109pub fn load_profile(name: &str, device: &Device) -> Result<VoiceClonePrompt> {
110 validate_profile_name(name)?;
111
112 let dir =
113 existing_profile_dir(name).with_context(|| format!("profile '{name}' does not exist"))?;
114
115 let meta = read_profile_meta(name)?;
116
117 let tensor_path = dir.join("tensors.safetensors");
118 let tensors = candle_core::safetensors::load(&tensor_path, device)
119 .with_context(|| format!("failed to read tensors: {}", tensor_path.display()))?;
120
121 let speaker_embedding = tensors
122 .get("speaker_embedding")
123 .context("missing speaker_embedding in profile tensors")?
124 .clone();
125
126 let (ref_codes, ref_text_ids) = if meta.mode == ProfileMode::Icl {
127 let codes = tensors
128 .get("ref_codes")
129 .context("missing ref_codes in ICL profile tensors")?
130 .clone();
131 (Some(codes), meta.ref_text_ids)
132 } else {
133 (None, None)
134 };
135
136 Ok(VoiceClonePrompt {
137 speaker_embedding,
138 ref_codes,
139 ref_text_ids,
140 })
141}
142
143pub fn read_profile_meta(name: &str) -> Result<ProfileMeta> {
144 validate_profile_name(name)?;
145
146 let meta_path = existing_profile_dir(name)
147 .with_context(|| format!("profile '{name}' does not exist"))?
148 .join("profile.json");
149 let body = std::fs::read_to_string(&meta_path)
150 .with_context(|| format!("failed to read profile meta: {}", meta_path.display()))?;
151 let meta: ProfileMeta = serde_json::from_str(&body)
152 .with_context(|| format!("failed to parse profile meta: {}", meta_path.display()))?;
153 Ok(meta)
154}
155
156pub fn list_profiles() -> Result<Vec<String>> {
157 paths::ensure_profiles_dir()?;
158 let mut names = BTreeSet::new();
159
160 for dir in [paths::profiles_dir(), paths::legacy_profiles_dir()] {
161 if !dir.exists() {
162 continue;
163 }
164
165 for entry in std::fs::read_dir(&dir)
166 .with_context(|| format!("failed to read profiles dir: {}", dir.display()))?
167 {
168 let entry = entry?;
169 let path = entry.path();
170 if !path.is_dir() || !path.join("profile.json").exists() {
171 continue;
172 }
173
174 if let Some(name) = entry.file_name().to_str() {
175 names.insert(name.to_string());
176 }
177 }
178 }
179
180 Ok(names.into_iter().collect())
181}