1use crate::error::{Result, VoxRtcError};
2use crate::socket::RawSocketChannel;
3use crate::types::*;
4use serde_json::Value;
5use tokio::task::JoinHandle;
6use tokio::time::{Duration, timeout};
7
8#[derive(Clone)]
9pub struct VoxRtcControlSession {
10 channel: RawSocketChannel,
11 session_id: String,
12 channel_name: String,
13 join_timeout: Duration,
14}
15
16pub struct Listener {
17 handle: JoinHandle<()>,
18}
19
20impl Drop for Listener {
21 fn drop(&mut self) {
22 self.handle.abort();
23 }
24}
25
26impl VoxRtcControlSession {
27 pub(crate) fn new(
28 channel: RawSocketChannel,
29 session_id: String,
30 join_timeout: Duration,
31 ) -> Self {
32 let channel_name = format!("/rtc/{session_id}");
33 Self {
34 channel,
35 session_id,
36 channel_name,
37 join_timeout,
38 }
39 }
40
41 pub fn session_id(&self) -> &str {
42 &self.session_id
43 }
44
45 pub fn channel_name(&self) -> &str {
46 &self.channel_name
47 }
48
49 pub async fn join(&self) -> Result<()> {
50 let mut states = self.channel.subscribe_state();
51 self.channel.join().await?;
52 let channel_name = self.channel.name().to_owned();
53 timeout(self.join_timeout, async move {
54 loop {
55 let state = *states.borrow_and_update();
56 match state {
57 ChannelState::Joined => return Ok(()),
58 ChannelState::Closed | ChannelState::Declined => {
59 return Err(VoxRtcError::JoinFailed {
60 channel: channel_name,
61 state: format!("{state:?}"),
62 });
63 }
64 _ => {}
65 }
66 if states.changed().await.is_err() {
67 return Err(VoxRtcError::Disconnected);
68 }
69 }
70 })
71 .await
72 .map_err(|_| VoxRtcError::JoinTimeout(self.channel_name.clone()))?
73 }
74
75 pub async fn close(&self) -> Result<()> {
76 self.channel.leave().await
77 }
78
79 pub fn on_event<F>(&self, handler: F) -> Listener
80 where
81 F: Fn(WireEvent) + Send + Sync + 'static,
82 {
83 let mut messages = self.channel.subscribe_messages();
84 let session_id = self.session_id.clone();
85 let channel_name = self.channel_name.clone();
86 Listener {
87 handle: tokio::spawn(async move {
88 while let Ok((event, payload)) = messages.recv().await {
89 handler(WireEvent {
90 r#type: event,
91 data: payload,
92 session_id: session_id.clone(),
93 channel_name: channel_name.clone(),
94 });
95 }
96 }),
97 }
98 }
99
100 pub fn on<F>(&self, event_name: impl Into<String>, handler: F) -> Listener
101 where
102 F: Fn(EventData) + Send + Sync + 'static,
103 {
104 let event_name = event_name.into();
105 let mut messages = self.channel.subscribe_messages();
106 Listener {
107 handle: tokio::spawn(async move {
108 while let Ok((event, payload)) = messages.recv().await {
109 if event == event_name {
110 handler(payload);
111 }
112 }
113 }),
114 }
115 }
116
117 pub fn on_session_attached<F>(&self, handler: F) -> Listener
118 where
119 F: Fn(SessionAttachedEvent) + Send + Sync + 'static,
120 {
121 let session_id = self.session_id.clone();
122 let channel_name = self.channel_name.clone();
123 self.on(EVENT_RTC_SESSION_ATTACHED, move |payload| {
124 handler(SessionAttachedEvent {
125 session_id: base_session_id(&payload, &session_id),
126 channel_name: channel_name.clone(),
127 data: payload,
128 })
129 })
130 }
131
132 pub fn on_session_created<F>(&self, handler: F) -> Listener
133 where
134 F: Fn(SessionCreatedEvent) + Send + Sync + 'static,
135 {
136 let session_id = self.session_id.clone();
137 let channel_name = self.channel_name.clone();
138 self.on(EVENT_SESSION_CREATED, move |payload| {
139 let session = payload.get("session").and_then(Value::as_object).cloned();
140 handler(SessionCreatedEvent {
141 session_id: base_session_id(&payload, &session_id),
142 channel_name: channel_name.clone(),
143 data: payload,
144 session,
145 });
146 })
147 }
148
149 pub fn on_transcript<F>(&self, handler: F) -> Listener
150 where
151 F: Fn(TranscriptEvent) + Send + Sync + 'static,
152 {
153 let session_id = self.session_id.clone();
154 let channel_name = self.channel_name.clone();
155 self.on(EVENT_TRANSCRIPT_COMPLETED, move |payload| {
156 handler(TranscriptEvent {
157 session_id: base_session_id(&payload, &session_id),
158 channel_name: channel_name.clone(),
159 transcript: required_string(&payload, "transcript", ""),
160 language: optional_string(&payload, "language"),
161 start_ms: optional_number(&payload, "start_ms"),
162 end_ms: optional_number(&payload, "end_ms"),
163 eou_probability: optional_number(&payload, "eou_probability"),
164 topics: optional_string_vec(&payload, "topics"),
165 data: payload,
166 });
167 })
168 }
169
170 pub fn on_turn_state_changed<F>(&self, handler: F) -> Listener
171 where
172 F: Fn(TurnStateEvent) + Send + Sync + 'static,
173 {
174 let session_id = self.session_id.clone();
175 let channel_name = self.channel_name.clone();
176 self.on(EVENT_TURN_STATE_CHANGED, move |payload| {
177 handler(TurnStateEvent {
178 session_id: base_session_id(&payload, &session_id),
179 channel_name: channel_name.clone(),
180 state: required_string(&payload, "state", "unknown"),
181 previous_state: optional_string(&payload, "previous_state"),
182 data: payload,
183 });
184 })
185 }
186
187 pub fn on_response_created<F>(&self, handler: F) -> Listener
188 where
189 F: Fn(ResponseEvent) + Send + Sync + 'static,
190 {
191 self.on_response_event(EVENT_RESPONSE_CREATED, handler)
192 }
193
194 pub fn on_response_committed<F>(&self, handler: F) -> Listener
195 where
196 F: Fn(ResponseEvent) + Send + Sync + 'static,
197 {
198 self.on_response_event(EVENT_RESPONSE_COMMITTED, handler)
199 }
200
201 pub fn on_response_done<F>(&self, handler: F) -> Listener
202 where
203 F: Fn(ResponseEvent) + Send + Sync + 'static,
204 {
205 self.on_response_event(EVENT_RESPONSE_DONE, handler)
206 }
207
208 pub fn on_response_cancelled<F>(&self, handler: F) -> Listener
209 where
210 F: Fn(ResponseEvent) + Send + Sync + 'static,
211 {
212 self.on_response_event(EVENT_RESPONSE_CANCELLED, handler)
213 }
214
215 pub fn on_response_audio_clear<F>(&self, handler: F) -> Listener
216 where
217 F: Fn(ResponseEvent) + Send + Sync + 'static,
218 {
219 self.on_response_event(EVENT_RESPONSE_AUDIO_CLEAR, handler)
220 }
221
222 fn on_response_event<F>(&self, event_name: &'static str, handler: F) -> Listener
223 where
224 F: Fn(ResponseEvent) + Send + Sync + 'static,
225 {
226 let session_id = self.session_id.clone();
227 let channel_name = self.channel_name.clone();
228 self.on(event_name, move |payload| {
229 handler(response_event(payload, &session_id, &channel_name));
230 })
231 }
232
233 pub fn on_interruption_detected<F>(&self, handler: F) -> Listener
234 where
235 F: Fn(InterruptionEvent) + Send + Sync + 'static,
236 {
237 self.on_interruption_event(EVENT_INTERRUPTION_DETECTED, handler)
238 }
239
240 pub fn on_interruption_false_positive<F>(&self, handler: F) -> Listener
241 where
242 F: Fn(InterruptionEvent) + Send + Sync + 'static,
243 {
244 self.on_interruption_event(EVENT_INTERRUPTION_FALSE_POSITIVE, handler)
245 }
246
247 fn on_interruption_event<F>(&self, event_name: &'static str, handler: F) -> Listener
248 where
249 F: Fn(InterruptionEvent) + Send + Sync + 'static,
250 {
251 let session_id = self.session_id.clone();
252 let channel_name = self.channel_name.clone();
253 self.on(event_name, move |payload| {
254 handler(InterruptionEvent {
255 response: response_event(payload.clone(), &session_id, &channel_name),
256 vad_active_ms: optional_number(&payload, "vad_active_ms"),
257 partial_transcript: optional_string(&payload, "partial_transcript"),
258 });
259 })
260 }
261
262 pub fn on_browser_event<F>(&self, handler: F) -> Listener
263 where
264 F: Fn(BrowserEvent) + Send + Sync + 'static,
265 {
266 let session_id = self.session_id.clone();
267 let channel_name = self.channel_name.clone();
268 self.on(EVENT_BROWSER_EVENT, move |payload| {
269 handler(BrowserEvent {
270 session_id: base_session_id(&payload, &session_id),
271 channel_name: channel_name.clone(),
272 event: required_string(&payload, "event", ""),
273 payload: payload.get("payload").cloned().unwrap_or(Value::Null),
274 data: payload,
275 });
276 })
277 }
278
279 pub fn on_close<F>(&self, handler: F) -> Listener
280 where
281 F: Fn(CloseEvent) + Send + Sync + 'static,
282 {
283 let session_id = self.session_id.clone();
284 let channel_name = self.channel_name.clone();
285 self.on(EVENT_RTC_CLIENT_DISCONNECTED, move |payload| {
286 handler(CloseEvent {
287 session_id: base_session_id(&payload, &session_id),
288 channel_name: channel_name.clone(),
289 reason: required_string(&payload, "reason", "unknown"),
290 connection_state: optional_string(&payload, "connection_state"),
291 ice_connection_state: optional_string(&payload, "ice_connection_state"),
292 data_channel_state: optional_string(&payload, "data_channel_state"),
293 data: payload,
294 });
295 })
296 }
297
298 pub fn on_error<F>(&self, handler: F) -> Listener
299 where
300 F: Fn(ErrorEvent) + Send + Sync + 'static,
301 {
302 let session_id = self.session_id.clone();
303 let channel_name = self.channel_name.clone();
304 self.on(EVENT_ERROR, move |payload| {
305 handler(ErrorEvent {
306 session_id: base_session_id(&payload, &session_id),
307 channel_name: channel_name.clone(),
308 message: optional_string(&payload, "message"),
309 code: optional_string(&payload, "code"),
310 data: payload,
311 });
312 })
313 }
314
315 pub async fn send_control(&self, event: &str, payload: EventData) -> Result<()> {
316 self.channel.send_message(event, payload).await
317 }
318
319 pub async fn configure(&self, config: SessionConfig) -> Result<()> {
320 let mut session = config.extra;
321 insert_opt(&mut session, "stt_model", config.stt_model);
322 insert_opt(&mut session, "tts_model", config.tts_model);
323 insert_opt(&mut session, "voice", config.voice);
324 insert_opt(&mut session, "turn_profile", config.turn_profile);
325 insert_opt(&mut session, "vad_backend", config.vad_backend);
326 insert_opt(&mut session, "turn_detector", config.turn_detector);
327
328 let mut payload = EventData::new();
329 payload.insert("session".to_owned(), Value::Object(session));
330 self.send_control("session.update", payload).await
331 }
332
333 pub async fn start_response(&self, options: Option<ResponseOptions>) -> Result<()> {
334 self.send_control("response.start", response_options_payload(options))
335 .await
336 }
337
338 pub async fn append_response_text(
339 &self,
340 delta: impl Into<String>,
341 options: Option<ResponseOptions>,
342 ) -> Result<()> {
343 let mut payload = response_options_payload(options);
344 payload.insert("delta".to_owned(), Value::String(delta.into()));
345 self.send_control("response.delta", payload).await
346 }
347
348 pub async fn commit_response(&self) -> Result<()> {
349 self.send_control("response.commit", EventData::new()).await
350 }
351
352 pub async fn cancel_response(&self) -> Result<()> {
353 self.send_control("response.cancel", EventData::new()).await
354 }
355
356 pub async fn replace_response_text(
357 &self,
358 text: impl Into<String>,
359 options: Option<ResponseOptions>,
360 ) -> Result<()> {
361 let mut payload = response_options_payload(options);
362 payload.insert("text".to_owned(), Value::String(text.into()));
363 self.send_control("response.replace_text", payload).await
364 }
365
366 pub async fn send_text_response(
367 &self,
368 text: impl Into<String>,
369 options: Option<ResponseOptions>,
370 cancel_first: bool,
371 ) -> Result<()> {
372 let text = text.into();
373 if cancel_first {
374 return self.replace_response_text(text, options).await;
375 }
376 self.start_response(options.clone()).await?;
377 self.append_response_text(text, options).await?;
378 self.commit_response().await
379 }
380
381 pub async fn send_client_event(&self, envelope: ClientEventEnvelope) -> Result<()> {
382 let mut payload = EventData::new();
383 payload.insert("event".to_owned(), Value::String(envelope.event));
384 payload.insert("payload".to_owned(), envelope.payload);
385 self.send_control(EVENT_CLIENT_EVENT, payload).await
386 }
387}
388
389fn insert_opt(session: &mut EventData, key: &str, value: Option<String>) {
390 if let Some(value) = value {
391 session.insert(key.to_owned(), Value::String(value));
392 }
393}
394
395fn response_options_payload(options: Option<ResponseOptions>) -> EventData {
396 let mut payload = EventData::new();
397 if let Some(options) = options
398 && let Some(allow) = options.allow_interruptions
399 {
400 payload.insert("allow_interruptions".to_owned(), Value::Bool(allow));
401 }
402 payload
403}
404
405fn base_session_id(payload: &EventData, fallback: &str) -> String {
406 required_string(payload, "session_id", fallback)
407}
408
409fn response_event(payload: EventData, session_id: &str, channel_name: &str) -> ResponseEvent {
410 ResponseEvent {
411 session_id: base_session_id(&payload, session_id),
412 channel_name: channel_name.to_owned(),
413 response_id: optional_string(&payload, "response_id"),
414 data: payload,
415 }
416}