1use crate::types::{
10 CAudioFormat, CAudioFrame, CCustomEncoding, CCustomPacket, CPacket, CPacketMetadata,
11 CPacketType, CPacketTypeInfo, CSampleFormat,
12};
13use std::cell::RefCell;
14use std::ffi::{c_void, CStr, CString};
15use std::os::raw::c_char;
16use std::sync::Arc;
17use streamkit_core::types::{
18 AudioFormat, AudioFrame, CustomEncoding, CustomPacketData, Packet, PacketMetadata, PacketType,
19 SampleFormat, TranscriptionData,
20};
21
22pub fn packet_type_from_c(cpt_info: CPacketTypeInfo) -> Result<PacketType, String> {
31 match cpt_info.type_discriminant {
32 CPacketType::RawAudio => {
33 if cpt_info.audio_format.is_null() {
34 return Err("RawAudio packet type missing audio_format".to_string());
35 }
36 let c_format = unsafe { &*cpt_info.audio_format };
38 Ok(PacketType::RawAudio(audio_format_from_c(c_format)))
39 },
40 CPacketType::OpusAudio => Ok(PacketType::OpusAudio),
41 CPacketType::Text => Ok(PacketType::Text),
42 CPacketType::Transcription => Ok(PacketType::Transcription),
43 CPacketType::Custom => {
44 if cpt_info.custom_type_id.is_null() {
45 return Err("Custom packet type missing custom_type_id".to_string());
46 }
47 let type_id = unsafe { c_str_to_string(cpt_info.custom_type_id) }?;
48 Ok(PacketType::Custom { type_id })
49 },
50 CPacketType::Binary => Ok(PacketType::Binary),
51 CPacketType::Any => Ok(PacketType::Any),
52 CPacketType::Passthrough => Ok(PacketType::Passthrough),
53 }
54}
55
56pub const fn sample_format_to_c(sf: &SampleFormat) -> CSampleFormat {
58 match sf {
59 SampleFormat::F32 => CSampleFormat::F32,
60 SampleFormat::S16Le => CSampleFormat::S16Le,
61 }
62}
63
64pub const fn sample_format_from_c(csf: CSampleFormat) -> SampleFormat {
66 match csf {
67 CSampleFormat::F32 => SampleFormat::F32,
68 CSampleFormat::S16Le => SampleFormat::S16Le,
69 }
70}
71
72pub const fn audio_format_to_c(af: &AudioFormat) -> CAudioFormat {
74 CAudioFormat {
75 sample_rate: af.sample_rate,
76 channels: af.channels,
77 sample_format: sample_format_to_c(&af.sample_format),
78 }
79}
80
81pub const fn audio_format_from_c(caf: &CAudioFormat) -> AudioFormat {
83 AudioFormat {
84 sample_rate: caf.sample_rate,
85 channels: caf.channels,
86 sample_format: sample_format_from_c(caf.sample_format),
87 }
88}
89
90pub const fn packet_type_to_c(pt: &PacketType) -> (CPacketTypeInfo, Option<CAudioFormat>) {
94 match pt {
95 PacketType::RawAudio(format) => {
96 let c_format = audio_format_to_c(format);
97 (
98 CPacketTypeInfo {
99 type_discriminant: CPacketType::RawAudio,
100 audio_format: &raw const c_format,
101 custom_type_id: std::ptr::null(),
102 },
103 Some(c_format),
104 )
105 },
106 PacketType::OpusAudio => (
107 CPacketTypeInfo {
108 type_discriminant: CPacketType::OpusAudio,
109 audio_format: std::ptr::null(),
110 custom_type_id: std::ptr::null(),
111 },
112 None,
113 ),
114 PacketType::Text => (
115 CPacketTypeInfo {
116 type_discriminant: CPacketType::Text,
117 audio_format: std::ptr::null(),
118 custom_type_id: std::ptr::null(),
119 },
120 None,
121 ),
122 PacketType::Transcription => (
123 CPacketTypeInfo {
124 type_discriminant: CPacketType::Transcription,
125 audio_format: std::ptr::null(),
126 custom_type_id: std::ptr::null(),
127 },
128 None,
129 ),
130 PacketType::Custom { .. } => (
131 CPacketTypeInfo {
132 type_discriminant: CPacketType::Custom,
133 audio_format: std::ptr::null(),
134 custom_type_id: std::ptr::null(), },
136 None,
137 ),
138 PacketType::Binary => (
139 CPacketTypeInfo {
140 type_discriminant: CPacketType::Binary,
141 audio_format: std::ptr::null(),
142 custom_type_id: std::ptr::null(),
143 },
144 None,
145 ),
146 PacketType::Any => (
147 CPacketTypeInfo {
148 type_discriminant: CPacketType::Any,
149 audio_format: std::ptr::null(),
150 custom_type_id: std::ptr::null(),
151 },
152 None,
153 ),
154 PacketType::Passthrough => (
155 CPacketTypeInfo {
156 type_discriminant: CPacketType::Passthrough,
157 audio_format: std::ptr::null(),
158 custom_type_id: std::ptr::null(),
159 },
160 None,
161 ),
162 }
163}
164
165pub struct CPacketRepr {
166 pub packet: CPacket,
167 _owned: CPacketOwned,
168}
169
170#[allow(dead_code)] enum CPacketOwned {
172 None,
173 Audio(Box<CAudioFrame>),
174 Text(CString),
175 Bytes(Vec<u8>),
176 Custom(CustomOwned),
177}
178
179#[allow(dead_code)] struct CustomOwned {
181 type_id: CString,
182 data_json: Vec<u8>,
183 metadata: Option<Box<CPacketMetadata>>,
184 custom: Box<CCustomPacket>,
185}
186
187fn metadata_to_c(meta: &PacketMetadata) -> CPacketMetadata {
188 CPacketMetadata {
189 timestamp_us: meta.timestamp_us.unwrap_or_default(),
190 has_timestamp_us: meta.timestamp_us.is_some(),
191 duration_us: meta.duration_us.unwrap_or_default(),
192 has_duration_us: meta.duration_us.is_some(),
193 sequence: meta.sequence.unwrap_or_default(),
194 has_sequence: meta.sequence.is_some(),
195 }
196}
197
198fn metadata_from_c(meta: &CPacketMetadata) -> PacketMetadata {
199 PacketMetadata {
200 timestamp_us: meta.has_timestamp_us.then_some(meta.timestamp_us),
201 duration_us: meta.has_duration_us.then_some(meta.duration_us),
202 sequence: meta.has_sequence.then_some(meta.sequence),
203 }
204}
205
206fn cstring_sanitize(s: &str) -> CString {
207 CString::new(s).unwrap_or_else(|_| CString::new(s.replace('\0', " ")).unwrap_or_default())
208}
209
210pub fn packet_to_c(packet: &Packet) -> CPacketRepr {
214 match packet {
215 Packet::Audio(frame) => {
216 let c_frame = Box::new(CAudioFrame {
217 sample_rate: frame.sample_rate,
218 channels: frame.channels,
219 samples: frame.samples.as_ptr(),
220 sample_count: frame.samples.len(),
221 });
222 let packet = CPacket {
223 packet_type: CPacketType::RawAudio,
224 data: std::ptr::from_ref::<CAudioFrame>(&*c_frame).cast::<c_void>(),
225 len: std::mem::size_of::<CAudioFrame>(),
226 };
227 CPacketRepr { packet, _owned: CPacketOwned::Audio(c_frame) }
228 },
229 Packet::Text(text) => {
230 let s = text.as_ref();
231 let c_str = match CString::new(s) {
232 Ok(s) => s,
233 Err(e) => {
234 tracing::warn!(
235 "Text packet contains null bytes (position {}), data will be truncated",
236 e.nul_position()
237 );
238 let truncated = &s[..e.nul_position()];
239 CString::new(truncated).unwrap_or_default()
240 },
241 };
242 let packet = CPacket {
243 packet_type: CPacketType::Text,
244 data: c_str.as_ptr().cast::<c_void>(),
245 len: c_str.as_bytes_with_nul().len(),
246 };
247 CPacketRepr { packet, _owned: CPacketOwned::Text(c_str) }
248 },
249 Packet::Transcription(trans_data) => {
250 let json = serde_json::to_vec(trans_data).unwrap_or_else(|e| {
251 tracing::error!("Failed to serialize transcription data to JSON: {}", e);
252 b"{}".to_vec()
253 });
254 let packet = CPacket {
255 packet_type: CPacketType::Transcription,
256 data: json.as_ptr().cast::<c_void>(),
257 len: json.len(),
258 };
259 CPacketRepr { packet, _owned: CPacketOwned::Bytes(json) }
260 },
261 Packet::Custom(custom) => {
262 let type_id = cstring_sanitize(custom.type_id.as_str());
263 let data_json = serde_json::to_vec(&custom.data).unwrap_or_else(|e| {
264 tracing::error!("Failed to serialize custom packet data to JSON: {}", e);
265 b"{}".to_vec()
266 });
267
268 let metadata = custom.metadata.as_ref().map(|m| Box::new(metadata_to_c(m)));
269 let mut custom_packet = Box::new(CCustomPacket {
270 type_id: type_id.as_ptr(),
271 encoding: match custom.encoding {
272 CustomEncoding::Json => CCustomEncoding::Json,
273 },
274 data_json: data_json.as_ptr(),
275 data_len: data_json.len(),
276 metadata: metadata.as_deref().map_or(std::ptr::null(), std::ptr::from_ref),
277 });
278
279 let packet = CPacket {
280 packet_type: CPacketType::Custom,
281 data: std::ptr::from_mut::<CCustomPacket>(&mut *custom_packet).cast::<c_void>(),
282 len: std::mem::size_of::<CCustomPacket>(),
283 };
284
285 CPacketRepr {
286 packet,
287 _owned: CPacketOwned::Custom(CustomOwned {
288 type_id,
289 data_json,
290 metadata,
291 custom: custom_packet,
292 }),
293 }
294 },
295 Packet::Binary { data, .. } => CPacketRepr {
296 packet: CPacket {
297 packet_type: CPacketType::Binary,
298 data: data.as_ref().as_ptr().cast::<c_void>(),
299 len: data.len(),
300 },
301 _owned: CPacketOwned::None,
302 },
303 }
304}
305
306pub unsafe fn packet_from_c(c_packet: *const CPacket) -> Result<Packet, String> {
323 if c_packet.is_null() {
324 return Err("Null packet pointer".to_string());
325 }
326
327 let c_pkt = &*c_packet;
328
329 if c_pkt.data.is_null() {
330 return Err("Null packet data pointer".to_string());
331 }
332
333 match c_pkt.packet_type {
334 CPacketType::RawAudio => {
335 let c_frame = &*c_pkt.data.cast::<CAudioFrame>();
336 if c_frame.samples.is_null() {
337 return Err("Null samples pointer in audio frame".to_string());
338 }
339
340 let samples = std::slice::from_raw_parts(c_frame.samples, c_frame.sample_count);
341
342 Ok(Packet::Audio(AudioFrame::new(
343 c_frame.sample_rate,
344 c_frame.channels,
345 samples.to_vec(),
346 )))
347 },
348 CPacketType::Text => {
349 let c_str = CStr::from_ptr(c_pkt.data.cast::<c_char>());
350 let text = c_str
351 .to_str()
352 .map_err(|e| format!("Invalid UTF-8 in text packet: {e}"))?
353 .to_string();
354 Ok(Packet::Text(text.into()))
355 },
356 CPacketType::Transcription => {
357 let data = std::slice::from_raw_parts(c_pkt.data.cast::<u8>(), c_pkt.len);
359 let trans_data: TranscriptionData = serde_json::from_slice(data)
360 .map_err(|e| format!("Invalid transcription data: {e}"))?;
361 Ok(Packet::Transcription(Arc::new(trans_data)))
362 },
363 CPacketType::Custom => {
364 let c_custom = &*c_pkt.data.cast::<CCustomPacket>();
365 if c_custom.type_id.is_null() {
366 return Err("Custom packet missing type_id".to_string());
367 }
368 if c_custom.data_json.is_null() {
369 return Err("Custom packet missing data_json".to_string());
370 }
371
372 let type_id = c_str_to_string(c_custom.type_id)?;
373 let data_bytes = std::slice::from_raw_parts(c_custom.data_json, c_custom.data_len);
374 let data: serde_json::Value = serde_json::from_slice(data_bytes)
375 .map_err(|e| format!("Invalid custom JSON: {e}"))?;
376
377 let metadata = if c_custom.metadata.is_null() {
378 None
379 } else {
380 Some(metadata_from_c(&*c_custom.metadata))
381 };
382
383 let encoding = match c_custom.encoding {
384 CCustomEncoding::Json => CustomEncoding::Json,
385 };
386
387 Ok(Packet::Custom(Arc::new(CustomPacketData { type_id, encoding, data, metadata })))
388 },
389 CPacketType::Binary => {
390 let data = std::slice::from_raw_parts(c_pkt.data.cast::<u8>(), c_pkt.len);
391 Ok(Packet::Binary {
392 data: bytes::Bytes::copy_from_slice(data),
393 content_type: None,
394 metadata: None,
395 })
396 },
397 _ => Err(format!("Unsupported packet type: {:?}", c_pkt.packet_type)),
398 }
399}
400
401pub unsafe fn c_str_to_string(ptr: *const c_char) -> Result<String, String> {
411 if ptr.is_null() {
412 return Ok(String::new());
413 }
414
415 CStr::from_ptr(ptr)
416 .to_str()
417 .map(std::string::ToString::to_string)
418 .map_err(|e| format!("Invalid UTF-8: {e}"))
419}
420
421#[allow(clippy::expect_used)] pub fn string_to_c(s: &str) -> *const c_char {
428 CString::new(s).expect("String should not contain null bytes").into_raw()
429}
430
431pub fn error_to_c(msg: impl AsRef<str>) -> *const c_char {
442 thread_local! {
443 static LAST_ERROR: RefCell<CString> = RefCell::new(
444 CString::new("").unwrap_or_else(|_| unsafe { CString::from_vec_unchecked(vec![0]) })
446 );
447 }
448
449 let msg = msg.as_ref();
450 let sanitized = if msg.contains('\0') { msg.replace('\0', " ") } else { msg.to_string() };
451
452 let c_str =
455 CString::new(sanitized).unwrap_or_else(|_| unsafe { CString::from_vec_unchecked(vec![0]) });
456
457 LAST_ERROR.with(|slot| {
458 *slot.borrow_mut() = c_str;
459 slot.borrow().as_ptr()
460 })
461}
462
463pub unsafe fn free_c_string(ptr: *const c_char) {
467 if !ptr.is_null() {
468 drop(CString::from_raw(ptr.cast_mut()));
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_error_to_c_normal_string() {
478 let msg = "Test error message";
479 let c_msg = error_to_c(msg);
480 unsafe {
481 let result_cstr = CStr::from_ptr(c_msg);
482 assert_eq!(result_cstr.to_string_lossy(), msg);
483 }
484 }
485
486 #[test]
487 fn test_error_to_c_with_null_bytes() {
488 let msg = "Error\0with\0null\0bytes";
489 let c_msg = error_to_c(msg);
490 unsafe {
491 let result_cstr = CStr::from_ptr(c_msg);
492 let result = result_cstr.to_string_lossy();
493 assert_eq!(result, "Error with null bytes");
495 }
496 }
497
498 #[test]
499 fn test_error_to_c_format_string() {
500 let msg = format!("Error code: {}", 42);
501 let c_msg = error_to_c(&msg);
502 unsafe {
503 let result_cstr = CStr::from_ptr(c_msg);
504 assert_eq!(result_cstr.to_string_lossy(), "Error code: 42");
505 }
506 }
507
508 #[test]
509 fn test_string_to_c_requires_free() {
510 let c_msg = string_to_c("hello");
511 unsafe {
512 let result_cstr = CStr::from_ptr(c_msg);
513 assert_eq!(result_cstr.to_string_lossy(), "hello");
514 free_c_string(c_msg);
515 }
516 }
517}