rwkv_tts_rs/
lib.rs

1//! RWKV TTS Core Library
2//!
3//! This library provides the core functionality for text-to-speech generation using RWKV models.
4
5// Core modules
6// pub mod batch_manager; // 已移动到备份目录
7pub mod properties_util;
8pub mod ref_audio_utilities;
9pub mod rwkv_sampler;
10pub mod tts_state_manager;
11// pub mod tts_pipeline; // 已移动到备份目录
12pub mod tts_pipeline_fixes;
13
14// New concurrent architecture modules
15pub mod global_sampler_manager;
16pub mod onnx_session_pool;
17// pub mod batch_request_scheduler; // 已移动到备份目录
18pub mod dynamic_batch_manager;
19pub mod lightweight_tts_pipeline;
20pub mod voice_feature;
21pub mod voice_feature_manager;
22
23// Refactored batch manager modules
24pub mod batch_types;
25pub mod feature_extractor;
26pub mod sampler_manager;
27pub mod shared_runtime;
28
29// Inference modules
30pub mod normal_mode_inference;
31pub mod zero_shot_inference;
32
33// Performance optimization modules
34pub mod inference_state_manager;
35pub mod streaming_inference;
36
37// 新的状态管理架构
38pub use tts_state_manager::{
39    TtsInferContext, TtsInferOptions, TtsStateId, TtsStateManager, TtsStateStats,
40};
41
42// Re-export key components
43// pub use batch_manager::{BatchManager, BatchConfig, BatchStats}; // 已移动到备份目录
44pub use properties_util::*;
45pub use ref_audio_utilities::RefAudioUtilities;
46pub use rwkv_sampler::{RwkvSampler, SamplerArgs, TtsBatchRequest};
47// pub use tts_pipeline::{TtsPipeline, TtsPipelineArgs}; // 已移动到备份目录
48
49/// TTS Generator module
50pub mod tts_generator {
51    // TTS Generator implementation
52
53    use crate::{RefAudioUtilities, RwkvSampler};
54
55    use anyhow::Result;
56
57    /// Args结构体定义
58    #[derive(Debug)]
59    pub struct Args {
60        pub text: String,
61        pub model_path: String,
62        pub vocab_path: String,
63        pub output_path: String,
64        pub temperature: f32,
65        pub top_p: f32,
66        pub top_k: usize,
67        pub max_tokens: usize,
68        pub age: String,
69        pub gender: String,
70        pub emotion: String,
71        pub pitch: String,
72        pub speed: String,
73        pub validate: bool,
74        pub zero_shot: bool,
75        pub ref_audio_path: String,
76        pub prompt_text: String,
77    }
78
79    /// TTS生成器结构体
80    pub struct TTSGenerator {
81        /// RWKV采样器
82        pub rwkv_sampler: Option<RwkvSampler>,
83        /// 参考音频处理工具
84        pub ref_audio_utilities: Option<RefAudioUtilities>,
85    }
86
87    impl TTSGenerator {
88        /// 创建新的TTS生成器
89        pub fn new() -> Self {
90            Self {
91                rwkv_sampler: None,
92                ref_audio_utilities: None,
93            }
94        }
95
96        /// 异步创建新的TTS生成器
97        pub async fn new_async(model_path: String, vocab_path: String) -> Result<Self> {
98            // 创建RWKV采样器,不使用量化配置
99            let quant_config = None;
100            let rwkv_sampler =
101                RwkvSampler::new(&model_path, &vocab_path, quant_config, 256).await?;
102
103            Ok(Self {
104                rwkv_sampler: Some(rwkv_sampler),
105                ref_audio_utilities: None,
106            })
107        }
108
109        /// 设置RWKV采样器
110        pub fn with_rwkv_sampler(mut self, sampler: RwkvSampler) -> Self {
111            self.rwkv_sampler = Some(sampler);
112            self
113        }
114
115        /// 设置参考音频处理工具
116        pub fn with_ref_audio_utilities(mut self, utilities: RefAudioUtilities) -> Self {
117            self.ref_audio_utilities = Some(utilities);
118            self
119        }
120
121        /// 生成语音 (暂时禁用,因为TtsPipeline已移动到备份目录)
122        pub async fn generate(&self, _text: &str, _args: &Args) -> Result<Vec<f32>> {
123            Err(anyhow::anyhow!(
124                "TtsPipeline已移动到备份目录,请使用lightweight_tts_pipeline"
125            ))
126        }
127
128        /// 保存音频到WAV文件
129        pub fn save_audio(
130            &self,
131            audio_samples: &[f32],
132            output_path: &str,
133            sample_rate: u32,
134        ) -> Result<()> {
135            // 保存音频到WAV文件
136            // 保存音频到指定路径
137
138            // 使用hound库保存WAV文件
139            let spec = hound::WavSpec {
140                channels: 1,
141                sample_rate,
142                bits_per_sample: 32,
143                sample_format: hound::SampleFormat::Float,
144            };
145
146            let mut writer = hound::WavWriter::create(output_path, spec)?;
147            for &sample in audio_samples {
148                writer.write_sample(sample)?;
149            }
150            writer.finalize()?;
151
152            Ok(())
153        }
154    }
155
156    impl Default for TTSGenerator {
157        fn default() -> Self {
158            Self::new()
159        }
160    }
161}
162
163// 测试模块
164#[cfg(test)]
165mod tests {
166    #[test]
167    fn test_basic_functionality() {
168        // 基本功能测试
169        assert_eq!(2 + 2, 4);
170    }
171}