Skip to main content

speakers_core/
profile.rs

1use std::collections::{BTreeSet, HashMap};
2use std::path::PathBuf;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use anyhow::{Context, Result};
6use candle_core::Device;
7use qwen3_tts::VoiceClonePrompt;
8use serde::{Deserialize, Serialize};
9
10use crate::paths;
11
12#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "snake_case")]
14pub enum ProfileMode {
15    Icl,
16    Xvector,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ProfileMeta {
21    pub version: u32,
22    pub name: String,
23    pub mode: ProfileMode,
24    pub created_at_unix: u64,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub ref_text_ids: Option<Vec<u32>>,
27}
28
29fn validate_profile_name(name: &str) -> Result<()> {
30    if name.is_empty() {
31        anyhow::bail!("profile name cannot be empty");
32    }
33
34    let ok = name
35        .chars()
36        .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_');
37    if !ok {
38        anyhow::bail!("profile name can only contain [A-Za-z0-9_-]");
39    }
40
41    Ok(())
42}
43
44fn profile_dir(name: &str) -> PathBuf {
45    paths::profiles_dir().join(name)
46}
47
48fn legacy_profile_dir(name: &str) -> PathBuf {
49    paths::legacy_profiles_dir().join(name)
50}
51
52fn existing_profile_dir(name: &str) -> Option<PathBuf> {
53    let current = profile_dir(name);
54    if current.is_dir() {
55        return Some(current);
56    }
57
58    let legacy = legacy_profile_dir(name);
59    legacy.is_dir().then_some(legacy)
60}
61
62pub fn save_profile(name: &str, prompt: &VoiceClonePrompt) -> Result<()> {
63    validate_profile_name(name)?;
64
65    let dir = profile_dir(name);
66    std::fs::create_dir_all(&dir)
67        .with_context(|| format!("failed to create profile dir: {}", dir.display()))?;
68
69    let mut tensors = HashMap::new();
70    tensors.insert(
71        "speaker_embedding".to_string(),
72        prompt.speaker_embedding.clone(),
73    );
74
75    let (mode, ref_text_ids) =
76        if let (Some(codes), Some(ids)) = (&prompt.ref_codes, &prompt.ref_text_ids) {
77            tensors.insert("ref_codes".to_string(), codes.clone());
78            (ProfileMode::Icl, Some(ids.clone()))
79        } else {
80            (ProfileMode::Xvector, None)
81        };
82
83    let tensor_path = dir.join("tensors.safetensors");
84    candle_core::safetensors::save(&tensors, &tensor_path)
85        .with_context(|| format!("failed to write tensors: {}", tensor_path.display()))?;
86
87    let created_at_unix = SystemTime::now()
88        .duration_since(UNIX_EPOCH)
89        .unwrap_or_default()
90        .as_secs();
91
92    let meta = ProfileMeta {
93        version: 1,
94        name: name.to_string(),
95        mode,
96        created_at_unix,
97        ref_text_ids,
98    };
99
100    let meta_path = dir.join("profile.json");
101    let meta_body =
102        serde_json::to_string_pretty(&meta).context("failed to serialize profile meta")?;
103    std::fs::write(&meta_path, meta_body)
104        .with_context(|| format!("failed to write profile meta: {}", meta_path.display()))?;
105
106    Ok(())
107}
108
109pub fn load_profile(name: &str, device: &Device) -> Result<VoiceClonePrompt> {
110    validate_profile_name(name)?;
111
112    let dir =
113        existing_profile_dir(name).with_context(|| format!("profile '{name}' does not exist"))?;
114
115    let meta = read_profile_meta(name)?;
116
117    let tensor_path = dir.join("tensors.safetensors");
118    let tensors = candle_core::safetensors::load(&tensor_path, device)
119        .with_context(|| format!("failed to read tensors: {}", tensor_path.display()))?;
120
121    let speaker_embedding = tensors
122        .get("speaker_embedding")
123        .context("missing speaker_embedding in profile tensors")?
124        .clone();
125
126    let (ref_codes, ref_text_ids) = if meta.mode == ProfileMode::Icl {
127        let codes = tensors
128            .get("ref_codes")
129            .context("missing ref_codes in ICL profile tensors")?
130            .clone();
131        (Some(codes), meta.ref_text_ids)
132    } else {
133        (None, None)
134    };
135
136    Ok(VoiceClonePrompt {
137        speaker_embedding,
138        ref_codes,
139        ref_text_ids,
140    })
141}
142
143pub fn read_profile_meta(name: &str) -> Result<ProfileMeta> {
144    validate_profile_name(name)?;
145
146    let meta_path = existing_profile_dir(name)
147        .with_context(|| format!("profile '{name}' does not exist"))?
148        .join("profile.json");
149    let body = std::fs::read_to_string(&meta_path)
150        .with_context(|| format!("failed to read profile meta: {}", meta_path.display()))?;
151    let meta: ProfileMeta = serde_json::from_str(&body)
152        .with_context(|| format!("failed to parse profile meta: {}", meta_path.display()))?;
153    Ok(meta)
154}
155
156pub fn list_profiles() -> Result<Vec<String>> {
157    paths::ensure_profiles_dir()?;
158    let mut names = BTreeSet::new();
159
160    for dir in [paths::profiles_dir(), paths::legacy_profiles_dir()] {
161        if !dir.exists() {
162            continue;
163        }
164
165        for entry in std::fs::read_dir(&dir)
166            .with_context(|| format!("failed to read profiles dir: {}", dir.display()))?
167        {
168            let entry = entry?;
169            let path = entry.path();
170            if !path.is_dir() || !path.join("profile.json").exists() {
171                continue;
172            }
173
174            if let Some(name) = entry.file_name().to_str() {
175                names.insert(name.to_string());
176            }
177        }
178    }
179
180    Ok(names.into_iter().collect())
181}