Skip to main content

vox_rtc_server/
session.rs

1use crate::error::{Result, VoxRtcError};
2use crate::socket::RawSocketChannel;
3use crate::types::*;
4use serde_json::Value;
5use tokio::task::JoinHandle;
6use tokio::time::{Duration, timeout};
7
8#[derive(Clone)]
9pub struct VoxRtcControlSession {
10    channel: RawSocketChannel,
11    session_id: String,
12    channel_name: String,
13    join_timeout: Duration,
14}
15
16pub struct Listener {
17    handle: JoinHandle<()>,
18}
19
20impl Drop for Listener {
21    fn drop(&mut self) {
22        self.handle.abort();
23    }
24}
25
26impl VoxRtcControlSession {
27    pub(crate) fn new(
28        channel: RawSocketChannel,
29        session_id: String,
30        join_timeout: Duration,
31    ) -> Self {
32        let channel_name = format!("/rtc/{session_id}");
33        Self {
34            channel,
35            session_id,
36            channel_name,
37            join_timeout,
38        }
39    }
40
41    pub fn session_id(&self) -> &str {
42        &self.session_id
43    }
44
45    pub fn channel_name(&self) -> &str {
46        &self.channel_name
47    }
48
49    pub async fn join(&self) -> Result<()> {
50        let mut states = self.channel.subscribe_state();
51        self.channel.join().await?;
52        let channel_name = self.channel.name().to_owned();
53        timeout(self.join_timeout, async move {
54            loop {
55                let state = *states.borrow_and_update();
56                match state {
57                    ChannelState::Joined => return Ok(()),
58                    ChannelState::Closed | ChannelState::Declined => {
59                        return Err(VoxRtcError::JoinFailed {
60                            channel: channel_name,
61                            state: format!("{state:?}"),
62                        });
63                    }
64                    _ => {}
65                }
66                if states.changed().await.is_err() {
67                    return Err(VoxRtcError::Disconnected);
68                }
69            }
70        })
71        .await
72        .map_err(|_| VoxRtcError::JoinTimeout(self.channel_name.clone()))?
73    }
74
75    pub async fn close(&self) -> Result<()> {
76        self.channel.leave().await
77    }
78
79    pub fn on_event<F>(&self, handler: F) -> Listener
80    where
81        F: Fn(WireEvent) + Send + Sync + 'static,
82    {
83        let mut messages = self.channel.subscribe_messages();
84        let session_id = self.session_id.clone();
85        let channel_name = self.channel_name.clone();
86        Listener {
87            handle: tokio::spawn(async move {
88                while let Ok((event, payload)) = messages.recv().await {
89                    handler(WireEvent {
90                        r#type: event,
91                        data: payload,
92                        session_id: session_id.clone(),
93                        channel_name: channel_name.clone(),
94                    });
95                }
96            }),
97        }
98    }
99
100    pub fn on<F>(&self, event_name: impl Into<String>, handler: F) -> Listener
101    where
102        F: Fn(EventData) + Send + Sync + 'static,
103    {
104        let event_name = event_name.into();
105        let mut messages = self.channel.subscribe_messages();
106        Listener {
107            handle: tokio::spawn(async move {
108                while let Ok((event, payload)) = messages.recv().await {
109                    if event == event_name {
110                        handler(payload);
111                    }
112                }
113            }),
114        }
115    }
116
117    pub fn on_session_attached<F>(&self, handler: F) -> Listener
118    where
119        F: Fn(SessionAttachedEvent) + Send + Sync + 'static,
120    {
121        let session_id = self.session_id.clone();
122        let channel_name = self.channel_name.clone();
123        self.on(EVENT_RTC_SESSION_ATTACHED, move |payload| {
124            handler(SessionAttachedEvent {
125                session_id: base_session_id(&payload, &session_id),
126                channel_name: channel_name.clone(),
127                data: payload,
128            })
129        })
130    }
131
132    pub fn on_session_created<F>(&self, handler: F) -> Listener
133    where
134        F: Fn(SessionCreatedEvent) + Send + Sync + 'static,
135    {
136        let session_id = self.session_id.clone();
137        let channel_name = self.channel_name.clone();
138        self.on(EVENT_SESSION_CREATED, move |payload| {
139            let session = payload.get("session").and_then(Value::as_object).cloned();
140            handler(SessionCreatedEvent {
141                session_id: base_session_id(&payload, &session_id),
142                channel_name: channel_name.clone(),
143                data: payload,
144                session,
145            });
146        })
147    }
148
149    pub fn on_transcript<F>(&self, handler: F) -> Listener
150    where
151        F: Fn(TranscriptEvent) + Send + Sync + 'static,
152    {
153        let session_id = self.session_id.clone();
154        let channel_name = self.channel_name.clone();
155        self.on(EVENT_TRANSCRIPT_COMPLETED, move |payload| {
156            handler(TranscriptEvent {
157                session_id: base_session_id(&payload, &session_id),
158                channel_name: channel_name.clone(),
159                transcript: required_string(&payload, "transcript", ""),
160                language: optional_string(&payload, "language"),
161                start_ms: optional_number(&payload, "start_ms"),
162                end_ms: optional_number(&payload, "end_ms"),
163                eou_probability: optional_number(&payload, "eou_probability"),
164                topics: optional_string_vec(&payload, "topics"),
165                data: payload,
166            });
167        })
168    }
169
170    pub fn on_turn_state_changed<F>(&self, handler: F) -> Listener
171    where
172        F: Fn(TurnStateEvent) + Send + Sync + 'static,
173    {
174        let session_id = self.session_id.clone();
175        let channel_name = self.channel_name.clone();
176        self.on(EVENT_TURN_STATE_CHANGED, move |payload| {
177            handler(TurnStateEvent {
178                session_id: base_session_id(&payload, &session_id),
179                channel_name: channel_name.clone(),
180                state: required_string(&payload, "state", "unknown"),
181                previous_state: optional_string(&payload, "previous_state"),
182                data: payload,
183            });
184        })
185    }
186
187    pub fn on_response_created<F>(&self, handler: F) -> Listener
188    where
189        F: Fn(ResponseEvent) + Send + Sync + 'static,
190    {
191        self.on_response_event(EVENT_RESPONSE_CREATED, handler)
192    }
193
194    pub fn on_response_committed<F>(&self, handler: F) -> Listener
195    where
196        F: Fn(ResponseEvent) + Send + Sync + 'static,
197    {
198        self.on_response_event(EVENT_RESPONSE_COMMITTED, handler)
199    }
200
201    pub fn on_response_done<F>(&self, handler: F) -> Listener
202    where
203        F: Fn(ResponseEvent) + Send + Sync + 'static,
204    {
205        self.on_response_event(EVENT_RESPONSE_DONE, handler)
206    }
207
208    pub fn on_response_cancelled<F>(&self, handler: F) -> Listener
209    where
210        F: Fn(ResponseEvent) + Send + Sync + 'static,
211    {
212        self.on_response_event(EVENT_RESPONSE_CANCELLED, handler)
213    }
214
215    pub fn on_response_audio_clear<F>(&self, handler: F) -> Listener
216    where
217        F: Fn(ResponseEvent) + Send + Sync + 'static,
218    {
219        self.on_response_event(EVENT_RESPONSE_AUDIO_CLEAR, handler)
220    }
221
222    fn on_response_event<F>(&self, event_name: &'static str, handler: F) -> Listener
223    where
224        F: Fn(ResponseEvent) + Send + Sync + 'static,
225    {
226        let session_id = self.session_id.clone();
227        let channel_name = self.channel_name.clone();
228        self.on(event_name, move |payload| {
229            handler(response_event(payload, &session_id, &channel_name));
230        })
231    }
232
233    pub fn on_interruption_detected<F>(&self, handler: F) -> Listener
234    where
235        F: Fn(InterruptionEvent) + Send + Sync + 'static,
236    {
237        self.on_interruption_event(EVENT_INTERRUPTION_DETECTED, handler)
238    }
239
240    pub fn on_interruption_false_positive<F>(&self, handler: F) -> Listener
241    where
242        F: Fn(InterruptionEvent) + Send + Sync + 'static,
243    {
244        self.on_interruption_event(EVENT_INTERRUPTION_FALSE_POSITIVE, handler)
245    }
246
247    fn on_interruption_event<F>(&self, event_name: &'static str, handler: F) -> Listener
248    where
249        F: Fn(InterruptionEvent) + Send + Sync + 'static,
250    {
251        let session_id = self.session_id.clone();
252        let channel_name = self.channel_name.clone();
253        self.on(event_name, move |payload| {
254            handler(InterruptionEvent {
255                response: response_event(payload.clone(), &session_id, &channel_name),
256                vad_active_ms: optional_number(&payload, "vad_active_ms"),
257                partial_transcript: optional_string(&payload, "partial_transcript"),
258            });
259        })
260    }
261
262    pub fn on_browser_event<F>(&self, handler: F) -> Listener
263    where
264        F: Fn(BrowserEvent) + Send + Sync + 'static,
265    {
266        let session_id = self.session_id.clone();
267        let channel_name = self.channel_name.clone();
268        self.on(EVENT_BROWSER_EVENT, move |payload| {
269            handler(BrowserEvent {
270                session_id: base_session_id(&payload, &session_id),
271                channel_name: channel_name.clone(),
272                event: required_string(&payload, "event", ""),
273                payload: payload.get("payload").cloned().unwrap_or(Value::Null),
274                data: payload,
275            });
276        })
277    }
278
279    pub fn on_close<F>(&self, handler: F) -> Listener
280    where
281        F: Fn(CloseEvent) + Send + Sync + 'static,
282    {
283        let session_id = self.session_id.clone();
284        let channel_name = self.channel_name.clone();
285        self.on(EVENT_RTC_CLIENT_DISCONNECTED, move |payload| {
286            handler(CloseEvent {
287                session_id: base_session_id(&payload, &session_id),
288                channel_name: channel_name.clone(),
289                reason: required_string(&payload, "reason", "unknown"),
290                connection_state: optional_string(&payload, "connection_state"),
291                ice_connection_state: optional_string(&payload, "ice_connection_state"),
292                data_channel_state: optional_string(&payload, "data_channel_state"),
293                data: payload,
294            });
295        })
296    }
297
298    pub fn on_error<F>(&self, handler: F) -> Listener
299    where
300        F: Fn(ErrorEvent) + Send + Sync + 'static,
301    {
302        let session_id = self.session_id.clone();
303        let channel_name = self.channel_name.clone();
304        self.on(EVENT_ERROR, move |payload| {
305            handler(ErrorEvent {
306                session_id: base_session_id(&payload, &session_id),
307                channel_name: channel_name.clone(),
308                message: optional_string(&payload, "message"),
309                code: optional_string(&payload, "code"),
310                data: payload,
311            });
312        })
313    }
314
315    pub async fn send_control(&self, event: &str, payload: EventData) -> Result<()> {
316        self.channel.send_message(event, payload).await
317    }
318
319    pub async fn configure(&self, config: SessionConfig) -> Result<()> {
320        let mut session = config.extra;
321        insert_opt(&mut session, "stt_model", config.stt_model);
322        insert_opt(&mut session, "tts_model", config.tts_model);
323        insert_opt(&mut session, "voice", config.voice);
324        insert_opt(&mut session, "turn_profile", config.turn_profile);
325        insert_opt(&mut session, "vad_backend", config.vad_backend);
326        insert_opt(&mut session, "turn_detector", config.turn_detector);
327
328        let mut payload = EventData::new();
329        payload.insert("session".to_owned(), Value::Object(session));
330        self.send_control("session.update", payload).await
331    }
332
333    pub async fn start_response(&self, options: Option<ResponseOptions>) -> Result<()> {
334        self.send_control("response.start", response_options_payload(options))
335            .await
336    }
337
338    pub async fn append_response_text(
339        &self,
340        delta: impl Into<String>,
341        options: Option<ResponseOptions>,
342    ) -> Result<()> {
343        let mut payload = response_options_payload(options);
344        payload.insert("delta".to_owned(), Value::String(delta.into()));
345        self.send_control("response.delta", payload).await
346    }
347
348    pub async fn commit_response(&self) -> Result<()> {
349        self.send_control("response.commit", EventData::new()).await
350    }
351
352    pub async fn cancel_response(&self) -> Result<()> {
353        self.send_control("response.cancel", EventData::new()).await
354    }
355
356    pub async fn replace_response_text(
357        &self,
358        text: impl Into<String>,
359        options: Option<ResponseOptions>,
360    ) -> Result<()> {
361        let mut payload = response_options_payload(options);
362        payload.insert("text".to_owned(), Value::String(text.into()));
363        self.send_control("response.replace_text", payload).await
364    }
365
366    pub async fn send_text_response(
367        &self,
368        text: impl Into<String>,
369        options: Option<ResponseOptions>,
370        cancel_first: bool,
371    ) -> Result<()> {
372        let text = text.into();
373        if cancel_first {
374            return self.replace_response_text(text, options).await;
375        }
376        self.start_response(options.clone()).await?;
377        self.append_response_text(text, options).await?;
378        self.commit_response().await
379    }
380
381    pub async fn send_client_event(&self, envelope: ClientEventEnvelope) -> Result<()> {
382        let mut payload = EventData::new();
383        payload.insert("event".to_owned(), Value::String(envelope.event));
384        payload.insert("payload".to_owned(), envelope.payload);
385        self.send_control(EVENT_CLIENT_EVENT, payload).await
386    }
387}
388
389fn insert_opt(session: &mut EventData, key: &str, value: Option<String>) {
390    if let Some(value) = value {
391        session.insert(key.to_owned(), Value::String(value));
392    }
393}
394
395fn response_options_payload(options: Option<ResponseOptions>) -> EventData {
396    let mut payload = EventData::new();
397    if let Some(options) = options
398        && let Some(allow) = options.allow_interruptions
399    {
400        payload.insert("allow_interruptions".to_owned(), Value::Bool(allow));
401    }
402    payload
403}
404
405fn base_session_id(payload: &EventData, fallback: &str) -> String {
406    required_string(payload, "session_id", fallback)
407}
408
409fn response_event(payload: EventData, session_id: &str, channel_name: &str) -> ResponseEvent {
410    ResponseEvent {
411        session_id: base_session_id(&payload, session_id),
412        channel_name: channel_name.to_owned(),
413        response_id: optional_string(&payload, "response_id"),
414        data: payload,
415    }
416}