Skip to main content

supertonic_core/
engine.rs

1use async_trait::async_trait;
2
3use crate::config::Config;
4
5#[derive(Debug, Clone)]
6pub enum TensorValue {
7    F32(ndarray::Array<f32, ndarray::IxDyn>),
8    I64(ndarray::Array<i64, ndarray::IxDyn>),
9}
10
11impl TensorValue {
12    pub fn as_f32(&self) -> Option<&ndarray::Array<f32, ndarray::IxDyn>> {
13        match self {
14            TensorValue::F32(arr) => Some(arr),
15            _ => None,
16        }
17    }
18
19    pub fn into_f32(self) -> Option<ndarray::Array<f32, ndarray::IxDyn>> {
20        match self {
21            TensorValue::F32(arr) => Some(arr),
22            _ => None,
23        }
24    }
25
26    pub fn shape(&self) -> &[usize] {
27        match self {
28            TensorValue::F32(arr) => arr.shape(),
29            TensorValue::I64(arr) => arr.shape(),
30        }
31    }
32}
33
34impl From<ndarray::Array<f32, ndarray::IxDyn>> for TensorValue {
35    fn from(arr: ndarray::Array<f32, ndarray::IxDyn>) -> Self {
36        TensorValue::F32(arr)
37    }
38}
39
40impl From<ndarray::Array<i64, ndarray::IxDyn>> for TensorValue {
41    fn from(arr: ndarray::Array<i64, ndarray::IxDyn>) -> Self {
42        TensorValue::I64(arr)
43    }
44}
45
46#[derive(Debug, Clone)]
47pub struct SynthesisParams {
48    pub total_step: usize,
49    pub speed: f32,
50    pub silence_duration: f32,
51    pub rng_seed: Option<u64>,
52}
53
54impl Default for SynthesisParams {
55    fn default() -> Self {
56        SynthesisParams {
57            total_step: 8,
58            speed: 1.05,
59            silence_duration: 0.3,
60            rng_seed: None,
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct SynthesisResult {
67    pub audio: Vec<f32>,
68    pub duration_secs: f32,
69    pub sample_rate: u32,
70}
71
72#[derive(Debug, Clone)]
73pub struct ChunkResult {
74    pub audio: Vec<f32>,
75    pub duration_secs: f32,
76    pub chunk_index: usize,
77    pub is_last: bool,
78}
79
80#[async_trait]
81pub trait InferenceEngine: Send + Sync {
82    async fn predict_duration(
83        &self,
84        text_ids: &TensorValue,
85        style_dp: &TensorValue,
86        text_mask: &TensorValue,
87    ) -> Result<TensorValue, anyhow::Error>;
88
89    async fn encode_text(
90        &self,
91        text_ids: &TensorValue,
92        style_ttl: &TensorValue,
93        text_mask: &TensorValue,
94    ) -> Result<TensorValue, anyhow::Error>;
95
96    async fn estimate_vector(
97        &self,
98        noisy_latent: &TensorValue,
99        text_emb: &TensorValue,
100        style_ttl: &TensorValue,
101        latent_mask: &TensorValue,
102        text_mask: &TensorValue,
103        current_step: &TensorValue,
104        total_step: &TensorValue,
105    ) -> Result<TensorValue, anyhow::Error>;
106
107    async fn vocode(
108        &self,
109        latent: &TensorValue,
110    ) -> Result<TensorValue, anyhow::Error>;
111
112    fn config(&self) -> &Config;
113}