supertonic_core/
engine.rs1use async_trait::async_trait;
2
3use crate::config::Config;
4
5#[derive(Debug, Clone)]
6pub enum TensorValue {
7 F32(ndarray::Array<f32, ndarray::IxDyn>),
8 I64(ndarray::Array<i64, ndarray::IxDyn>),
9}
10
11impl TensorValue {
12 pub fn as_f32(&self) -> Option<&ndarray::Array<f32, ndarray::IxDyn>> {
13 match self {
14 TensorValue::F32(arr) => Some(arr),
15 _ => None,
16 }
17 }
18
19 pub fn into_f32(self) -> Option<ndarray::Array<f32, ndarray::IxDyn>> {
20 match self {
21 TensorValue::F32(arr) => Some(arr),
22 _ => None,
23 }
24 }
25
26 pub fn shape(&self) -> &[usize] {
27 match self {
28 TensorValue::F32(arr) => arr.shape(),
29 TensorValue::I64(arr) => arr.shape(),
30 }
31 }
32}
33
34impl From<ndarray::Array<f32, ndarray::IxDyn>> for TensorValue {
35 fn from(arr: ndarray::Array<f32, ndarray::IxDyn>) -> Self {
36 TensorValue::F32(arr)
37 }
38}
39
40impl From<ndarray::Array<i64, ndarray::IxDyn>> for TensorValue {
41 fn from(arr: ndarray::Array<i64, ndarray::IxDyn>) -> Self {
42 TensorValue::I64(arr)
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct SynthesisParams {
48 pub total_step: usize,
49 pub speed: f32,
50 pub silence_duration: f32,
51 pub rng_seed: Option<u64>,
52}
53
54impl Default for SynthesisParams {
55 fn default() -> Self {
56 SynthesisParams {
57 total_step: 8,
58 speed: 1.05,
59 silence_duration: 0.3,
60 rng_seed: None,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct SynthesisResult {
67 pub audio: Vec<f32>,
68 pub duration_secs: f32,
69 pub sample_rate: u32,
70}
71
72#[derive(Debug, Clone)]
73pub struct ChunkResult {
74 pub audio: Vec<f32>,
75 pub duration_secs: f32,
76 pub chunk_index: usize,
77 pub is_last: bool,
78}
79
80#[async_trait]
81pub trait InferenceEngine: Send + Sync {
82 async fn predict_duration(
83 &self,
84 text_ids: &TensorValue,
85 style_dp: &TensorValue,
86 text_mask: &TensorValue,
87 ) -> Result<TensorValue, anyhow::Error>;
88
89 async fn encode_text(
90 &self,
91 text_ids: &TensorValue,
92 style_ttl: &TensorValue,
93 text_mask: &TensorValue,
94 ) -> Result<TensorValue, anyhow::Error>;
95
96 async fn estimate_vector(
97 &self,
98 noisy_latent: &TensorValue,
99 text_emb: &TensorValue,
100 style_ttl: &TensorValue,
101 latent_mask: &TensorValue,
102 text_mask: &TensorValue,
103 current_step: &TensorValue,
104 total_step: &TensorValue,
105 ) -> Result<TensorValue, anyhow::Error>;
106
107 async fn vocode(
108 &self,
109 latent: &TensorValue,
110 ) -> Result<TensorValue, anyhow::Error>;
111
112 fn config(&self) -> &Config;
113}