1use crate::persistence::Persistence;
2use crate::tools::{Tool, ToolResult};
3use crate::types::MessageRole;
4use anyhow::Result;
5use async_trait::async_trait;
6use futures::stream::{self, Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::Mutex;
13use tokio::time;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum TranscriptionEvent {
19 Speech {
21 text: String,
22 confidence: f32,
23 speaker: Option<String>,
24 },
25 Noise { description: String, intensity: f32 },
27 Tone { emotion: String, text: String },
29 Partial { text: String, is_final: bool },
31 System { message: String },
33}
34
35#[derive(Debug, Clone)]
37pub struct MockScenario {
38 pub name: String,
39 pub description: String,
40 pub events: Vec<(TranscriptionEvent, Duration)>, }
42
43impl MockScenario {
44 fn simple_conversation() -> Self {
45 Self {
46 name: "simple_conversation".to_string(),
47 description: "A simple back-and-forth conversation".to_string(),
48 events: vec![
49 (
50 TranscriptionEvent::System {
51 message: "Audio transcription started".to_string(),
52 },
53 Duration::from_millis(500),
54 ),
55 (
56 TranscriptionEvent::Speech {
57 text: "Hello, how are you today?".to_string(),
58 confidence: 0.95,
59 speaker: Some("User".to_string()),
60 },
61 Duration::from_millis(1000),
62 ),
63 (
64 TranscriptionEvent::Speech {
65 text: "I'm doing well, thank you for asking.".to_string(),
66 confidence: 0.92,
67 speaker: Some("Assistant".to_string()),
68 },
69 Duration::from_millis(800),
70 ),
71 (
72 TranscriptionEvent::Speech {
73 text: "What's the weather like outside?".to_string(),
74 confidence: 0.88,
75 speaker: Some("User".to_string()),
76 },
77 Duration::from_millis(1200),
78 ),
79 ],
80 }
81 }
82
83 fn command_sequence() -> Self {
84 Self {
85 name: "command_sequence".to_string(),
86 description: "A series of voice commands".to_string(),
87 events: vec![
88 (
89 TranscriptionEvent::System {
90 message: "Voice command mode activated".to_string(),
91 },
92 Duration::from_millis(300),
93 ),
94 (
95 TranscriptionEvent::Speech {
96 text: "Create a new file called test.txt".to_string(),
97 confidence: 0.90,
98 speaker: None,
99 },
100 Duration::from_millis(1500),
101 ),
102 (
103 TranscriptionEvent::Speech {
104 text: "Write hello world to the file".to_string(),
105 confidence: 0.87,
106 speaker: None,
107 },
108 Duration::from_millis(1000),
109 ),
110 (
111 TranscriptionEvent::Speech {
112 text: "Save and close the file".to_string(),
113 confidence: 0.93,
114 speaker: None,
115 },
116 Duration::from_millis(800),
117 ),
118 (
119 TranscriptionEvent::System {
120 message: "Commands executed successfully".to_string(),
121 },
122 Duration::from_millis(200),
123 ),
124 ],
125 }
126 }
127
128 fn noisy_environment() -> Self {
129 Self {
130 name: "noisy_environment".to_string(),
131 description: "Transcription with background noise".to_string(),
132 events: vec![
133 (
134 TranscriptionEvent::Noise {
135 description: "Background chatter".to_string(),
136 intensity: 0.3,
137 },
138 Duration::from_millis(500),
139 ),
140 (
141 TranscriptionEvent::Speech {
142 text: "Can you hear me clearly?".to_string(),
143 confidence: 0.75,
144 speaker: Some("User".to_string()),
145 },
146 Duration::from_millis(800),
147 ),
148 (
149 TranscriptionEvent::Noise {
150 description: "Door closing".to_string(),
151 intensity: 0.7,
152 },
153 Duration::from_millis(300),
154 ),
155 (
156 TranscriptionEvent::Partial {
157 text: "I need to...".to_string(),
158 is_final: false,
159 },
160 Duration::from_millis(500),
161 ),
162 (
163 TranscriptionEvent::Partial {
164 text: "I need to schedule a meeting".to_string(),
165 is_final: true,
166 },
167 Duration::from_millis(1000),
168 ),
169 ],
170 }
171 }
172
173 fn emotional_context() -> Self {
174 Self {
175 name: "emotional_context".to_string(),
176 description: "Transcription with emotional tone markers".to_string(),
177 events: vec![
178 (
179 TranscriptionEvent::Tone {
180 emotion: "excited".to_string(),
181 text: "That's amazing news!".to_string(),
182 },
183 Duration::from_millis(1000),
184 ),
185 (
186 TranscriptionEvent::Tone {
187 emotion: "concerned".to_string(),
188 text: "Are you sure that's the right approach?".to_string(),
189 },
190 Duration::from_millis(1200),
191 ),
192 (
193 TranscriptionEvent::Speech {
194 text: "Let me think about it.".to_string(),
195 confidence: 0.85,
196 speaker: None,
197 },
198 Duration::from_millis(800),
199 ),
200 (
201 TranscriptionEvent::Tone {
202 emotion: "confident".to_string(),
203 text: "Yes, I'm certain this will work.".to_string(),
204 },
205 Duration::from_millis(1000),
206 ),
207 ],
208 }
209 }
210
211 fn multi_speaker() -> Self {
212 Self {
213 name: "multi_speaker".to_string(),
214 description: "Multiple speakers in a meeting".to_string(),
215 events: vec![
216 (
217 TranscriptionEvent::System {
218 message: "Meeting transcription started".to_string(),
219 },
220 Duration::from_millis(500),
221 ),
222 (
223 TranscriptionEvent::Speech {
224 text: "Welcome everyone to today's standup.".to_string(),
225 confidence: 0.92,
226 speaker: Some("Alice".to_string()),
227 },
228 Duration::from_millis(1000),
229 ),
230 (
231 TranscriptionEvent::Speech {
232 text: "I finished the authentication module yesterday.".to_string(),
233 confidence: 0.88,
234 speaker: Some("Bob".to_string()),
235 },
236 Duration::from_millis(1200),
237 ),
238 (
239 TranscriptionEvent::Speech {
240 text: "Great work Bob. Charlie, how about you?".to_string(),
241 confidence: 0.90,
242 speaker: Some("Alice".to_string()),
243 },
244 Duration::from_millis(800),
245 ),
246 (
247 TranscriptionEvent::Speech {
248 text: "Still working on the database migrations.".to_string(),
249 confidence: 0.85,
250 speaker: Some("Charlie".to_string()),
251 },
252 Duration::from_millis(1000),
253 ),
254 ],
255 }
256 }
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct TranscriptionConfig {
262 pub scenario: String,
264 pub duration: Option<u64>,
266 pub loop_scenario: bool,
268 pub speed_multiplier: f32,
270 pub persist: bool,
272 pub session_id: Option<String>,
274}
275
276impl Default for TranscriptionConfig {
277 fn default() -> Self {
278 Self {
279 scenario: "simple_conversation".to_string(),
280 duration: Some(30),
281 loop_scenario: false,
282 speed_multiplier: 1.0,
283 persist: true,
284 session_id: None,
285 }
286 }
287}
288
289pub struct AudioTranscriptionTool {
291 scenarios: Vec<MockScenario>,
292 active_sessions: Arc<Mutex<Vec<String>>>,
293 persistence: Option<Arc<Persistence>>,
294}
295
296impl AudioTranscriptionTool {
297 pub fn new() -> Self {
298 Self {
299 scenarios: vec![
300 MockScenario::simple_conversation(),
301 MockScenario::command_sequence(),
302 MockScenario::noisy_environment(),
303 MockScenario::emotional_context(),
304 MockScenario::multi_speaker(),
305 ],
306 active_sessions: Arc::new(Mutex::new(Vec::new())),
307 persistence: None,
308 }
309 }
310
311 pub fn with_persistence(persistence: Arc<Persistence>) -> Self {
312 let mut tool = Self::new();
313 tool.persistence = Some(persistence);
314 tool
315 }
316
317 fn get_scenario(&self, name: &str) -> Option<&MockScenario> {
319 self.scenarios.iter().find(|s| s.name == name)
320 }
321
322 pub fn create_event_stream(
324 &self,
325 config: TranscriptionConfig,
326 ) -> Pin<Box<dyn Stream<Item = TranscriptionEvent> + Send>> {
327 let scenario = self
328 .get_scenario(&config.scenario)
329 .cloned()
330 .unwrap_or_else(MockScenario::simple_conversation);
331
332 let speed_multiplier = config.speed_multiplier;
333 let loop_scenario = config.loop_scenario;
334 let duration = config.duration;
335
336 Box::pin(stream::unfold(
337 (scenario, 0usize, time::Instant::now(), duration),
338 move |(scenario, mut index, start_time, duration)| async move {
339 if let Some(max_duration) = duration {
341 if start_time.elapsed() >= Duration::from_secs(max_duration) {
342 return None;
343 }
344 }
345
346 if index >= scenario.events.len() {
348 if loop_scenario {
349 index = 0;
350 } else {
351 return None;
352 }
353 }
354
355 let (event, delay) = scenario.events[index].clone();
356
357 let adjusted_delay =
359 Duration::from_millis((delay.as_millis() as f32 * speed_multiplier) as u64);
360
361 time::sleep(adjusted_delay).await;
363
364 Some((event, (scenario, index + 1, start_time, duration)))
365 },
366 ))
367 }
368
369 fn format_event(&self, event: &TranscriptionEvent) -> String {
371 match event {
372 TranscriptionEvent::Speech {
373 text,
374 confidence,
375 speaker,
376 } => {
377 if let Some(speaker) = speaker {
378 format!(
379 "[{}] {} (confidence: {:.1}%)",
380 speaker,
381 text,
382 confidence * 100.0
383 )
384 } else {
385 format!("{} (confidence: {:.1}%)", text, confidence * 100.0)
386 }
387 }
388 TranscriptionEvent::Noise {
389 description,
390 intensity,
391 } => {
392 format!("[NOISE: {} (intensity: {:.1})]", description, intensity)
393 }
394 TranscriptionEvent::Tone { emotion, text } => {
395 format!("[TONE: {}] {}", emotion, text)
396 }
397 TranscriptionEvent::Partial { text, is_final } => {
398 if *is_final {
399 format!("[FINAL] {}", text)
400 } else {
401 format!("[PARTIAL] {}...", text)
402 }
403 }
404 TranscriptionEvent::System { message } => {
405 format!("[SYSTEM] {}", message)
406 }
407 }
408 }
409
410 async fn persist_event(&self, session_id: &str, event: &TranscriptionEvent) -> Result<()> {
412 if let Some(persistence) = &self.persistence {
413 let formatted = self.format_event(event);
414
415 persistence.insert_message(session_id, MessageRole::User, &formatted)?;
417
418 if let TranscriptionEvent::Speech {
420 text: _,
421 speaker: _,
422 confidence: _,
423 } = event
424 {
425 }
429 }
430 Ok(())
431 }
432}
433
434impl Default for AudioTranscriptionTool {
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440#[async_trait]
441impl Tool for AudioTranscriptionTool {
442 fn name(&self) -> &str {
443 "audio_transcribe"
444 }
445
446 fn description(&self) -> &str {
447 "Mock audio transcription tool that simulates live audio input and converts it to text. \
448 Supports multiple scenarios including conversations, commands, noisy environments, \
449 and multi-speaker sessions."
450 }
451
452 fn parameters(&self) -> Value {
453 json!({
454 "type": "object",
455 "properties": {
456 "scenario": {
457 "type": "string",
458 "description": "Mock scenario to use",
459 "enum": [
460 "simple_conversation",
461 "command_sequence",
462 "noisy_environment",
463 "emotional_context",
464 "multi_speaker"
465 ]
466 },
467 "duration": {
468 "type": "integer",
469 "description": "Duration to listen in seconds (default: 30)"
470 },
471 "mode": {
472 "type": "string",
473 "description": "Transcription mode: 'stream' for real-time or 'batch' for all at once",
474 "enum": ["stream", "batch"],
475 "default": "stream"
476 },
477 "speed_multiplier": {
478 "type": "number",
479 "description": "Speed multiplier for event dispatch (1.0 = normal, 0.5 = half speed, 2.0 = double speed)",
480 "default": 1.0
481 },
482 "persist": {
483 "type": "boolean",
484 "description": "Whether to persist transcriptions to database",
485 "default": true
486 }
487 },
488 "required": []
489 })
490 }
491
492 async fn execute(&self, args: Value) -> Result<ToolResult> {
493 let scenario = args["scenario"]
495 .as_str()
496 .unwrap_or("simple_conversation")
497 .to_string();
498
499 let duration = args["duration"].as_u64().or(Some(30));
500
501 let mode = args["mode"].as_str().unwrap_or("stream").to_string();
502
503 let speed_multiplier = args["speed_multiplier"].as_f64().unwrap_or(1.0) as f32;
504
505 let persist = args["persist"].as_bool().unwrap_or(true);
506
507 let session_id = format!("audio_{}", chrono::Utc::now().timestamp_millis());
509
510 {
512 let mut sessions = self.active_sessions.lock().await;
513 sessions.push(session_id.clone());
514 }
515
516 let config = TranscriptionConfig {
517 scenario: scenario.clone(),
518 duration,
519 loop_scenario: false,
520 speed_multiplier,
521 persist,
522 session_id: Some(session_id.clone()),
523 };
524
525 let mut event_stream = self.create_event_stream(config);
527 let mut transcriptions = Vec::new();
528
529 if mode == "batch" {
531 while let Some(event) = event_stream.next().await {
533 let formatted = self.format_event(&event);
534 transcriptions.push(formatted.clone());
535
536 if persist {
537 let _ = self.persist_event(&session_id, &event).await;
538 }
539 }
540
541 {
543 let mut sessions = self.active_sessions.lock().await;
544 sessions.retain(|s| s != &session_id);
545 }
546
547 let result = json!({
548 "session_id": session_id,
549 "scenario": scenario,
550 "mode": "batch",
551 "transcriptions": transcriptions,
552 "count": transcriptions.len(),
553 "duration": duration,
554 });
555 Ok(ToolResult::success(result.to_string()))
556 } else {
557 let mut sample_transcriptions = Vec::new();
562 let mut count = 0;
563
564 while let Some(event) = event_stream.next().await {
565 let formatted = self.format_event(&event);
566 sample_transcriptions.push(formatted.clone());
567
568 if persist {
569 let _ = self.persist_event(&session_id, &event).await;
570 }
571
572 count += 1;
573 if count >= 3 {
574 break; }
576 }
577
578 let result = json!({
579 "session_id": session_id,
580 "scenario": scenario,
581 "mode": "stream",
582 "status": "listening",
583 "sample_transcriptions": sample_transcriptions,
584 "message": format!(
585 "Audio transcription session {} started. Listening for {} seconds...",
586 session_id,
587 duration.unwrap_or(0)
588 ),
589 });
590 Ok(ToolResult::success(result.to_string()))
591 }
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[tokio::test]
600 async fn test_tool_metadata() {
601 let tool = AudioTranscriptionTool::new();
602 assert_eq!(tool.name(), "audio_transcribe");
603 assert!(tool.description().contains("audio"));
604
605 let params = tool.parameters();
606 assert!(params["properties"]["scenario"].is_object());
607 assert!(params["properties"]["duration"].is_object());
608 }
609
610 #[tokio::test]
611 async fn test_scenario_loading() {
612 let tool = AudioTranscriptionTool::new();
613
614 assert!(tool.get_scenario("simple_conversation").is_some());
615 assert!(tool.get_scenario("command_sequence").is_some());
616 assert!(tool.get_scenario("noisy_environment").is_some());
617 assert!(tool.get_scenario("emotional_context").is_some());
618 assert!(tool.get_scenario("multi_speaker").is_some());
619 assert!(tool.get_scenario("non_existent").is_none());
620 }
621
622 #[tokio::test]
623 async fn test_event_stream_creation() {
624 let tool = AudioTranscriptionTool::new();
625 let config = TranscriptionConfig {
626 scenario: "simple_conversation".to_string(),
627 duration: Some(1), loop_scenario: false,
629 speed_multiplier: 10.0, persist: false,
631 session_id: None,
632 };
633
634 let mut stream = tool.create_event_stream(config);
635 let mut count = 0;
636
637 while let Some(_event) = stream.next().await {
638 count += 1;
639 if count >= 3 {
640 break; }
642 }
643
644 assert!(count > 0);
645 }
646
647 #[tokio::test]
648 async fn test_batch_execution() {
649 let tool = AudioTranscriptionTool::new();
650 let args = json!({
651 "scenario": "simple_conversation",
652 "duration": 1,
653 "mode": "batch",
654 "speed_multiplier": 100.0, "persist": false
656 });
657
658 let result = tool.execute(args).await.unwrap();
659 assert!(result.success);
660
661 let output: Value = serde_json::from_str(&result.output).unwrap();
662 assert_eq!(output["scenario"], "simple_conversation");
663 assert_eq!(output["mode"], "batch");
664 assert!(output["transcriptions"].is_array());
665 }
666
667 #[tokio::test]
668 async fn test_stream_execution() {
669 let tool = AudioTranscriptionTool::new();
670 let args = json!({
671 "scenario": "command_sequence",
672 "duration": 5,
673 "mode": "stream",
674 "speed_multiplier": 10.0,
675 "persist": false
676 });
677
678 let result = tool.execute(args).await.unwrap();
679 assert!(result.success);
680
681 let output: Value = serde_json::from_str(&result.output).unwrap();
682 assert_eq!(output["scenario"], "command_sequence");
683 assert_eq!(output["mode"], "stream");
684 assert_eq!(output["status"], "listening");
685 assert!(output["sample_transcriptions"].is_array());
686 }
687
688 #[tokio::test]
689 async fn test_event_formatting() {
690 let tool = AudioTranscriptionTool::new();
691
692 let speech_event = TranscriptionEvent::Speech {
693 text: "Hello".to_string(),
694 confidence: 0.95,
695 speaker: Some("Alice".to_string()),
696 };
697 let formatted = tool.format_event(&speech_event);
698 assert!(formatted.contains("Alice"));
699 assert!(formatted.contains("Hello"));
700 assert!(formatted.contains("95.0%"));
701
702 let noise_event = TranscriptionEvent::Noise {
703 description: "Door closing".to_string(),
704 intensity: 0.7,
705 };
706 let formatted = tool.format_event(&noise_event);
707 assert!(formatted.contains("NOISE"));
708 assert!(formatted.contains("Door closing"));
709 }
710
711 #[tokio::test]
712 async fn test_active_session_tracking() {
713 let tool = AudioTranscriptionTool::new();
714
715 let args = json!({
716 "scenario": "simple_conversation",
717 "duration": 1,
718 "mode": "stream",
719 "speed_multiplier": 100.0,
720 "persist": false
721 });
722
723 let result = tool.execute(args).await.unwrap();
724 let output: Value = serde_json::from_str(&result.output).unwrap();
725 let session_id = output["session_id"].as_str().unwrap();
726
727 {
728 let sessions = tool.active_sessions.lock().await;
729 assert!(sessions.iter().any(|s| s == session_id));
730 }
731 }
732}