spec_ai_core/tools/builtin/
audio_transcription.rs

1use crate::persistence::Persistence;
2use crate::tools::{Tool, ToolResult};
3use crate::types::MessageRole;
4use anyhow::Result;
5use async_trait::async_trait;
6use futures::stream::{self, Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::Mutex;
13use tokio::time;
14
15/// Mock transcription event types
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type", rename_all = "snake_case")]
18pub enum TranscriptionEvent {
19    /// Regular speech transcription
20    Speech {
21        text: String,
22        confidence: f32,
23        speaker: Option<String>,
24    },
25    /// Background noise or non-speech audio
26    Noise { description: String, intensity: f32 },
27    /// Emotional or tonal context
28    Tone { emotion: String, text: String },
29    /// Partial/incomplete transcription
30    Partial { text: String, is_final: bool },
31    /// System events (start/stop listening)
32    System { message: String },
33}
34
35/// Predefined mock scenarios for testing
36#[derive(Debug, Clone)]
37pub struct MockScenario {
38    pub name: String,
39    pub description: String,
40    pub events: Vec<(TranscriptionEvent, Duration)>, // Event and delay before next
41}
42
43impl MockScenario {
44    fn simple_conversation() -> Self {
45        Self {
46            name: "simple_conversation".to_string(),
47            description: "A simple back-and-forth conversation".to_string(),
48            events: vec![
49                (
50                    TranscriptionEvent::System {
51                        message: "Audio transcription started".to_string(),
52                    },
53                    Duration::from_millis(500),
54                ),
55                (
56                    TranscriptionEvent::Speech {
57                        text: "Hello, how are you today?".to_string(),
58                        confidence: 0.95,
59                        speaker: Some("User".to_string()),
60                    },
61                    Duration::from_millis(1000),
62                ),
63                (
64                    TranscriptionEvent::Speech {
65                        text: "I'm doing well, thank you for asking.".to_string(),
66                        confidence: 0.92,
67                        speaker: Some("Assistant".to_string()),
68                    },
69                    Duration::from_millis(800),
70                ),
71                (
72                    TranscriptionEvent::Speech {
73                        text: "What's the weather like outside?".to_string(),
74                        confidence: 0.88,
75                        speaker: Some("User".to_string()),
76                    },
77                    Duration::from_millis(1200),
78                ),
79            ],
80        }
81    }
82
83    fn command_sequence() -> Self {
84        Self {
85            name: "command_sequence".to_string(),
86            description: "A series of voice commands".to_string(),
87            events: vec![
88                (
89                    TranscriptionEvent::System {
90                        message: "Voice command mode activated".to_string(),
91                    },
92                    Duration::from_millis(300),
93                ),
94                (
95                    TranscriptionEvent::Speech {
96                        text: "Create a new file called test.txt".to_string(),
97                        confidence: 0.90,
98                        speaker: None,
99                    },
100                    Duration::from_millis(1500),
101                ),
102                (
103                    TranscriptionEvent::Speech {
104                        text: "Write hello world to the file".to_string(),
105                        confidence: 0.87,
106                        speaker: None,
107                    },
108                    Duration::from_millis(1000),
109                ),
110                (
111                    TranscriptionEvent::Speech {
112                        text: "Save and close the file".to_string(),
113                        confidence: 0.93,
114                        speaker: None,
115                    },
116                    Duration::from_millis(800),
117                ),
118                (
119                    TranscriptionEvent::System {
120                        message: "Commands executed successfully".to_string(),
121                    },
122                    Duration::from_millis(200),
123                ),
124            ],
125        }
126    }
127
128    fn noisy_environment() -> Self {
129        Self {
130            name: "noisy_environment".to_string(),
131            description: "Transcription with background noise".to_string(),
132            events: vec![
133                (
134                    TranscriptionEvent::Noise {
135                        description: "Background chatter".to_string(),
136                        intensity: 0.3,
137                    },
138                    Duration::from_millis(500),
139                ),
140                (
141                    TranscriptionEvent::Speech {
142                        text: "Can you hear me clearly?".to_string(),
143                        confidence: 0.75,
144                        speaker: Some("User".to_string()),
145                    },
146                    Duration::from_millis(800),
147                ),
148                (
149                    TranscriptionEvent::Noise {
150                        description: "Door closing".to_string(),
151                        intensity: 0.7,
152                    },
153                    Duration::from_millis(300),
154                ),
155                (
156                    TranscriptionEvent::Partial {
157                        text: "I need to...".to_string(),
158                        is_final: false,
159                    },
160                    Duration::from_millis(500),
161                ),
162                (
163                    TranscriptionEvent::Partial {
164                        text: "I need to schedule a meeting".to_string(),
165                        is_final: true,
166                    },
167                    Duration::from_millis(1000),
168                ),
169            ],
170        }
171    }
172
173    fn emotional_context() -> Self {
174        Self {
175            name: "emotional_context".to_string(),
176            description: "Transcription with emotional tone markers".to_string(),
177            events: vec![
178                (
179                    TranscriptionEvent::Tone {
180                        emotion: "excited".to_string(),
181                        text: "That's amazing news!".to_string(),
182                    },
183                    Duration::from_millis(1000),
184                ),
185                (
186                    TranscriptionEvent::Tone {
187                        emotion: "concerned".to_string(),
188                        text: "Are you sure that's the right approach?".to_string(),
189                    },
190                    Duration::from_millis(1200),
191                ),
192                (
193                    TranscriptionEvent::Speech {
194                        text: "Let me think about it.".to_string(),
195                        confidence: 0.85,
196                        speaker: None,
197                    },
198                    Duration::from_millis(800),
199                ),
200                (
201                    TranscriptionEvent::Tone {
202                        emotion: "confident".to_string(),
203                        text: "Yes, I'm certain this will work.".to_string(),
204                    },
205                    Duration::from_millis(1000),
206                ),
207            ],
208        }
209    }
210
211    fn multi_speaker() -> Self {
212        Self {
213            name: "multi_speaker".to_string(),
214            description: "Multiple speakers in a meeting".to_string(),
215            events: vec![
216                (
217                    TranscriptionEvent::System {
218                        message: "Meeting transcription started".to_string(),
219                    },
220                    Duration::from_millis(500),
221                ),
222                (
223                    TranscriptionEvent::Speech {
224                        text: "Welcome everyone to today's standup.".to_string(),
225                        confidence: 0.92,
226                        speaker: Some("Alice".to_string()),
227                    },
228                    Duration::from_millis(1000),
229                ),
230                (
231                    TranscriptionEvent::Speech {
232                        text: "I finished the authentication module yesterday.".to_string(),
233                        confidence: 0.88,
234                        speaker: Some("Bob".to_string()),
235                    },
236                    Duration::from_millis(1200),
237                ),
238                (
239                    TranscriptionEvent::Speech {
240                        text: "Great work Bob. Charlie, how about you?".to_string(),
241                        confidence: 0.90,
242                        speaker: Some("Alice".to_string()),
243                    },
244                    Duration::from_millis(800),
245                ),
246                (
247                    TranscriptionEvent::Speech {
248                        text: "Still working on the database migrations.".to_string(),
249                        confidence: 0.85,
250                        speaker: Some("Charlie".to_string()),
251                    },
252                    Duration::from_millis(1000),
253                ),
254            ],
255        }
256    }
257}
258
259/// Configuration for audio transcription session
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct TranscriptionConfig {
262    /// Mock scenario to use
263    pub scenario: String,
264    /// Duration to listen (in seconds), None for continuous
265    pub duration: Option<u64>,
266    /// Whether to loop the scenario
267    pub loop_scenario: bool,
268    /// Base delay multiplier (1.0 = normal speed)
269    pub speed_multiplier: f32,
270    /// Whether to persist to database
271    pub persist: bool,
272    /// Session ID for persistence
273    pub session_id: Option<String>,
274}
275
276impl Default for TranscriptionConfig {
277    fn default() -> Self {
278        Self {
279            scenario: "simple_conversation".to_string(),
280            duration: Some(30),
281            loop_scenario: false,
282            speed_multiplier: 1.0,
283            persist: true,
284            session_id: None,
285        }
286    }
287}
288
289/// Mock audio transcription tool
290pub struct AudioTranscriptionTool {
291    scenarios: Vec<MockScenario>,
292    active_sessions: Arc<Mutex<Vec<String>>>,
293    persistence: Option<Arc<Persistence>>,
294}
295
296impl AudioTranscriptionTool {
297    pub fn new() -> Self {
298        Self {
299            scenarios: vec![
300                MockScenario::simple_conversation(),
301                MockScenario::command_sequence(),
302                MockScenario::noisy_environment(),
303                MockScenario::emotional_context(),
304                MockScenario::multi_speaker(),
305            ],
306            active_sessions: Arc::new(Mutex::new(Vec::new())),
307            persistence: None,
308        }
309    }
310
311    pub fn with_persistence(persistence: Arc<Persistence>) -> Self {
312        let mut tool = Self::new();
313        tool.persistence = Some(persistence);
314        tool
315    }
316
317    /// Get a scenario by name
318    fn get_scenario(&self, name: &str) -> Option<&MockScenario> {
319        self.scenarios.iter().find(|s| s.name == name)
320    }
321
322    /// Create a stream of transcription events
323    pub fn create_event_stream(
324        &self,
325        config: TranscriptionConfig,
326    ) -> Pin<Box<dyn Stream<Item = TranscriptionEvent> + Send>> {
327        let scenario = self
328            .get_scenario(&config.scenario)
329            .cloned()
330            .unwrap_or_else(MockScenario::simple_conversation);
331
332        let speed_multiplier = config.speed_multiplier;
333        let loop_scenario = config.loop_scenario;
334        let duration = config.duration;
335
336        Box::pin(stream::unfold(
337            (scenario, 0usize, time::Instant::now(), duration),
338            move |(scenario, mut index, start_time, duration)| async move {
339                // Check duration limit
340                if let Some(max_duration) = duration {
341                    if start_time.elapsed() >= Duration::from_secs(max_duration) {
342                        return None;
343                    }
344                }
345
346                // Get current event or loop/end
347                if index >= scenario.events.len() {
348                    if loop_scenario {
349                        index = 0;
350                    } else {
351                        return None;
352                    }
353                }
354
355                let (event, delay) = scenario.events[index].clone();
356
357                // Apply speed multiplier to delay
358                let adjusted_delay =
359                    Duration::from_millis((delay.as_millis() as f32 * speed_multiplier) as u64);
360
361                // Wait before emitting the event
362                time::sleep(adjusted_delay).await;
363
364                Some((event, (scenario, index + 1, start_time, duration)))
365            },
366        ))
367    }
368
369    /// Format transcription event as string
370    fn format_event(&self, event: &TranscriptionEvent) -> String {
371        match event {
372            TranscriptionEvent::Speech {
373                text,
374                confidence,
375                speaker,
376            } => {
377                if let Some(speaker) = speaker {
378                    format!(
379                        "[{}] {} (confidence: {:.1}%)",
380                        speaker,
381                        text,
382                        confidence * 100.0
383                    )
384                } else {
385                    format!("{} (confidence: {:.1}%)", text, confidence * 100.0)
386                }
387            }
388            TranscriptionEvent::Noise {
389                description,
390                intensity,
391            } => {
392                format!("[NOISE: {} (intensity: {:.1})]", description, intensity)
393            }
394            TranscriptionEvent::Tone { emotion, text } => {
395                format!("[TONE: {}] {}", emotion, text)
396            }
397            TranscriptionEvent::Partial { text, is_final } => {
398                if *is_final {
399                    format!("[FINAL] {}", text)
400                } else {
401                    format!("[PARTIAL] {}...", text)
402                }
403            }
404            TranscriptionEvent::System { message } => {
405                format!("[SYSTEM] {}", message)
406            }
407        }
408    }
409
410    /// Store transcription event in database
411    async fn persist_event(&self, session_id: &str, event: &TranscriptionEvent) -> Result<()> {
412        if let Some(persistence) = &self.persistence {
413            let formatted = self.format_event(event);
414
415            // Store as a user message
416            persistence.insert_message(session_id, MessageRole::User, &formatted)?;
417
418            // Optionally store metadata as graph nodes
419            if let TranscriptionEvent::Speech {
420                text: _,
421                speaker: _,
422                confidence: _,
423            } = event
424            {
425                // Could create graph nodes for entities, speakers, etc.
426                // This is where we'd integrate with the knowledge graph
427                // For now, we'll just log the transcription
428            }
429        }
430        Ok(())
431    }
432}
433
434impl Default for AudioTranscriptionTool {
435    fn default() -> Self {
436        Self::new()
437    }
438}
439
440#[async_trait]
441impl Tool for AudioTranscriptionTool {
442    fn name(&self) -> &str {
443        "audio_transcribe"
444    }
445
446    fn description(&self) -> &str {
447        "Mock audio transcription tool that simulates live audio input and converts it to text. \
448         Supports multiple scenarios including conversations, commands, noisy environments, \
449         and multi-speaker sessions."
450    }
451
452    fn parameters(&self) -> Value {
453        json!({
454            "type": "object",
455            "properties": {
456                "scenario": {
457                    "type": "string",
458                    "description": "Mock scenario to use",
459                    "enum": [
460                        "simple_conversation",
461                        "command_sequence",
462                        "noisy_environment",
463                        "emotional_context",
464                        "multi_speaker"
465                    ]
466                },
467                "duration": {
468                    "type": "integer",
469                    "description": "Duration to listen in seconds (default: 30)"
470                },
471                "mode": {
472                    "type": "string",
473                    "description": "Transcription mode: 'stream' for real-time or 'batch' for all at once",
474                    "enum": ["stream", "batch"],
475                    "default": "stream"
476                },
477                "speed_multiplier": {
478                    "type": "number",
479                    "description": "Speed multiplier for event dispatch (1.0 = normal, 0.5 = half speed, 2.0 = double speed)",
480                    "default": 1.0
481                },
482                "persist": {
483                    "type": "boolean",
484                    "description": "Whether to persist transcriptions to database",
485                    "default": true
486                }
487            },
488            "required": []
489        })
490    }
491
492    async fn execute(&self, args: Value) -> Result<ToolResult> {
493        // Parse configuration
494        let scenario = args["scenario"]
495            .as_str()
496            .unwrap_or("simple_conversation")
497            .to_string();
498
499        let duration = args["duration"].as_u64().or(Some(30));
500
501        let mode = args["mode"].as_str().unwrap_or("stream").to_string();
502
503        let speed_multiplier = args["speed_multiplier"].as_f64().unwrap_or(1.0) as f32;
504
505        let persist = args["persist"].as_bool().unwrap_or(true);
506
507        // Generate session ID
508        let session_id = format!("audio_{}", chrono::Utc::now().timestamp_millis());
509
510        // Track active session
511        {
512            let mut sessions = self.active_sessions.lock().await;
513            sessions.push(session_id.clone());
514        }
515
516        let config = TranscriptionConfig {
517            scenario: scenario.clone(),
518            duration,
519            loop_scenario: false,
520            speed_multiplier,
521            persist,
522            session_id: Some(session_id.clone()),
523        };
524
525        // Create event stream
526        let mut event_stream = self.create_event_stream(config);
527        let mut transcriptions = Vec::new();
528
529        // Process events based on mode
530        if mode == "batch" {
531            // Collect all events at once
532            while let Some(event) = event_stream.next().await {
533                let formatted = self.format_event(&event);
534                transcriptions.push(formatted.clone());
535
536                if persist {
537                    let _ = self.persist_event(&session_id, &event).await;
538                }
539            }
540
541            // Remove from active sessions
542            {
543                let mut sessions = self.active_sessions.lock().await;
544                sessions.retain(|s| s != &session_id);
545            }
546
547            let result = json!({
548                "session_id": session_id,
549                "scenario": scenario,
550                "mode": "batch",
551                "transcriptions": transcriptions,
552                "count": transcriptions.len(),
553                "duration": duration,
554            });
555            Ok(ToolResult::success(result.to_string()))
556        } else {
557            // Stream mode - return immediately with session info
558            // In a real implementation, this would set up a background task
559
560            // For mock, we'll process a few events to show it's working
561            let mut sample_transcriptions = Vec::new();
562            let mut count = 0;
563
564            while let Some(event) = event_stream.next().await {
565                let formatted = self.format_event(&event);
566                sample_transcriptions.push(formatted.clone());
567
568                if persist {
569                    let _ = self.persist_event(&session_id, &event).await;
570                }
571
572                count += 1;
573                if count >= 3 {
574                    break; // Just show first 3 events as sample
575                }
576            }
577
578            let result = json!({
579                "session_id": session_id,
580                "scenario": scenario,
581                "mode": "stream",
582                "status": "listening",
583                "sample_transcriptions": sample_transcriptions,
584                "message": format!(
585                    "Audio transcription session {} started. Listening for {} seconds...",
586                    session_id,
587                    duration.unwrap_or(0)
588                ),
589            });
590            Ok(ToolResult::success(result.to_string()))
591        }
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598
599    #[tokio::test]
600    async fn test_tool_metadata() {
601        let tool = AudioTranscriptionTool::new();
602        assert_eq!(tool.name(), "audio_transcribe");
603        assert!(tool.description().contains("audio"));
604
605        let params = tool.parameters();
606        assert!(params["properties"]["scenario"].is_object());
607        assert!(params["properties"]["duration"].is_object());
608    }
609
610    #[tokio::test]
611    async fn test_scenario_loading() {
612        let tool = AudioTranscriptionTool::new();
613
614        assert!(tool.get_scenario("simple_conversation").is_some());
615        assert!(tool.get_scenario("command_sequence").is_some());
616        assert!(tool.get_scenario("noisy_environment").is_some());
617        assert!(tool.get_scenario("emotional_context").is_some());
618        assert!(tool.get_scenario("multi_speaker").is_some());
619        assert!(tool.get_scenario("non_existent").is_none());
620    }
621
622    #[tokio::test]
623    async fn test_event_stream_creation() {
624        let tool = AudioTranscriptionTool::new();
625        let config = TranscriptionConfig {
626            scenario: "simple_conversation".to_string(),
627            duration: Some(1), // 1 second for quick test
628            loop_scenario: false,
629            speed_multiplier: 10.0, // Speed up for testing
630            persist: false,
631            session_id: None,
632        };
633
634        let mut stream = tool.create_event_stream(config);
635        let mut count = 0;
636
637        while let Some(_event) = stream.next().await {
638            count += 1;
639            if count >= 3 {
640                break; // Just test a few events
641            }
642        }
643
644        assert!(count > 0);
645    }
646
647    #[tokio::test]
648    async fn test_batch_execution() {
649        let tool = AudioTranscriptionTool::new();
650        let args = json!({
651            "scenario": "simple_conversation",
652            "duration": 1,
653            "mode": "batch",
654            "speed_multiplier": 100.0, // Very fast for testing
655            "persist": false
656        });
657
658        let result = tool.execute(args).await.unwrap();
659        assert!(result.success);
660
661        let output: Value = serde_json::from_str(&result.output).unwrap();
662        assert_eq!(output["scenario"], "simple_conversation");
663        assert_eq!(output["mode"], "batch");
664        assert!(output["transcriptions"].is_array());
665    }
666
667    #[tokio::test]
668    async fn test_stream_execution() {
669        let tool = AudioTranscriptionTool::new();
670        let args = json!({
671            "scenario": "command_sequence",
672            "duration": 5,
673            "mode": "stream",
674            "speed_multiplier": 10.0,
675            "persist": false
676        });
677
678        let result = tool.execute(args).await.unwrap();
679        assert!(result.success);
680
681        let output: Value = serde_json::from_str(&result.output).unwrap();
682        assert_eq!(output["scenario"], "command_sequence");
683        assert_eq!(output["mode"], "stream");
684        assert_eq!(output["status"], "listening");
685        assert!(output["sample_transcriptions"].is_array());
686    }
687
688    #[tokio::test]
689    async fn test_event_formatting() {
690        let tool = AudioTranscriptionTool::new();
691
692        let speech_event = TranscriptionEvent::Speech {
693            text: "Hello".to_string(),
694            confidence: 0.95,
695            speaker: Some("Alice".to_string()),
696        };
697        let formatted = tool.format_event(&speech_event);
698        assert!(formatted.contains("Alice"));
699        assert!(formatted.contains("Hello"));
700        assert!(formatted.contains("95.0%"));
701
702        let noise_event = TranscriptionEvent::Noise {
703            description: "Door closing".to_string(),
704            intensity: 0.7,
705        };
706        let formatted = tool.format_event(&noise_event);
707        assert!(formatted.contains("NOISE"));
708        assert!(formatted.contains("Door closing"));
709    }
710
711    #[tokio::test]
712    async fn test_active_session_tracking() {
713        let tool = AudioTranscriptionTool::new();
714
715        let args = json!({
716            "scenario": "simple_conversation",
717            "duration": 1,
718            "mode": "stream",
719            "speed_multiplier": 100.0,
720            "persist": false
721        });
722
723        let result = tool.execute(args).await.unwrap();
724        let output: Value = serde_json::from_str(&result.output).unwrap();
725        let session_id = output["session_id"].as_str().unwrap();
726
727        {
728            let sessions = tool.active_sessions.lock().await;
729            assert!(sessions.iter().any(|s| s == session_id));
730        }
731    }
732}