streamkit_plugin_sdk_native/
conversions.rs

1// SPDX-FileCopyrightText: © 2025 StreamKit Contributors
2//
3// SPDX-License-Identifier: MPL-2.0
4
5//! Type conversions between C ABI types and Rust types
6//!
7//! These functions provide safe wrappers around unsafe FFI operations.
8
9use crate::types::{
10    CAudioFormat, CAudioFrame, CCustomEncoding, CCustomPacket, CPacket, CPacketMetadata,
11    CPacketType, CPacketTypeInfo, CSampleFormat,
12};
13use std::cell::RefCell;
14use std::ffi::{c_void, CStr, CString};
15use std::os::raw::c_char;
16use std::sync::Arc;
17use streamkit_core::types::{
18    AudioFormat, AudioFrame, CustomEncoding, CustomPacketData, Packet, PacketMetadata, PacketType,
19    SampleFormat, TranscriptionData,
20};
21
22/// Convert C packet type info to Rust PacketType
23///
24/// # Errors
25///
26/// Returns an error if:
27/// - `RawAudio` is missing its `audio_format`
28/// - `Custom` is missing its `custom_type_id`
29/// - `custom_type_id` is not valid UTF-8
30pub fn packet_type_from_c(cpt_info: CPacketTypeInfo) -> Result<PacketType, String> {
31    match cpt_info.type_discriminant {
32        CPacketType::RawAudio => {
33            if cpt_info.audio_format.is_null() {
34                return Err("RawAudio packet type missing audio_format".to_string());
35            }
36            // SAFETY: caller guarantees pointer validity for the duration of this call.
37            let c_format = unsafe { &*cpt_info.audio_format };
38            Ok(PacketType::RawAudio(audio_format_from_c(c_format)))
39        },
40        CPacketType::OpusAudio => Ok(PacketType::OpusAudio),
41        CPacketType::Text => Ok(PacketType::Text),
42        CPacketType::Transcription => Ok(PacketType::Transcription),
43        CPacketType::Custom => {
44            if cpt_info.custom_type_id.is_null() {
45                return Err("Custom packet type missing custom_type_id".to_string());
46            }
47            let type_id = unsafe { c_str_to_string(cpt_info.custom_type_id) }?;
48            Ok(PacketType::Custom { type_id })
49        },
50        CPacketType::Binary => Ok(PacketType::Binary),
51        CPacketType::Any => Ok(PacketType::Any),
52        CPacketType::Passthrough => Ok(PacketType::Passthrough),
53    }
54}
55
56/// Convert Rust SampleFormat to C
57pub const fn sample_format_to_c(sf: &SampleFormat) -> CSampleFormat {
58    match sf {
59        SampleFormat::F32 => CSampleFormat::F32,
60        SampleFormat::S16Le => CSampleFormat::S16Le,
61    }
62}
63
64/// Convert C sample format to Rust
65pub const fn sample_format_from_c(csf: CSampleFormat) -> SampleFormat {
66    match csf {
67        CSampleFormat::F32 => SampleFormat::F32,
68        CSampleFormat::S16Le => SampleFormat::S16Le,
69    }
70}
71
72/// Convert Rust AudioFormat to C
73pub const fn audio_format_to_c(af: &AudioFormat) -> CAudioFormat {
74    CAudioFormat {
75        sample_rate: af.sample_rate,
76        channels: af.channels,
77        sample_format: sample_format_to_c(&af.sample_format),
78    }
79}
80
81/// Convert C AudioFormat to Rust
82pub const fn audio_format_from_c(caf: &CAudioFormat) -> AudioFormat {
83    AudioFormat {
84        sample_rate: caf.sample_rate,
85        channels: caf.channels,
86        sample_format: sample_format_from_c(caf.sample_format),
87    }
88}
89
90/// Convert Rust PacketType to C representation
91/// Returns (CPacketTypeInfo, optional CAudioFormat that must be kept alive)
92/// For RawAudio types, the returned CAudioFormat must outlive the CPacketTypeInfo
93pub const fn packet_type_to_c(pt: &PacketType) -> (CPacketTypeInfo, Option<CAudioFormat>) {
94    match pt {
95        PacketType::RawAudio(format) => {
96            let c_format = audio_format_to_c(format);
97            (
98                CPacketTypeInfo {
99                    type_discriminant: CPacketType::RawAudio,
100                    audio_format: &raw const c_format,
101                    custom_type_id: std::ptr::null(),
102                },
103                Some(c_format),
104            )
105        },
106        PacketType::OpusAudio => (
107            CPacketTypeInfo {
108                type_discriminant: CPacketType::OpusAudio,
109                audio_format: std::ptr::null(),
110                custom_type_id: std::ptr::null(),
111            },
112            None,
113        ),
114        PacketType::Text => (
115            CPacketTypeInfo {
116                type_discriminant: CPacketType::Text,
117                audio_format: std::ptr::null(),
118                custom_type_id: std::ptr::null(),
119            },
120            None,
121        ),
122        PacketType::Transcription => (
123            CPacketTypeInfo {
124                type_discriminant: CPacketType::Transcription,
125                audio_format: std::ptr::null(),
126                custom_type_id: std::ptr::null(),
127            },
128            None,
129        ),
130        PacketType::Custom { .. } => (
131            CPacketTypeInfo {
132                type_discriminant: CPacketType::Custom,
133                audio_format: std::ptr::null(),
134                custom_type_id: std::ptr::null(), // provided by the caller where stable storage exists
135            },
136            None,
137        ),
138        PacketType::Binary => (
139            CPacketTypeInfo {
140                type_discriminant: CPacketType::Binary,
141                audio_format: std::ptr::null(),
142                custom_type_id: std::ptr::null(),
143            },
144            None,
145        ),
146        PacketType::Any => (
147            CPacketTypeInfo {
148                type_discriminant: CPacketType::Any,
149                audio_format: std::ptr::null(),
150                custom_type_id: std::ptr::null(),
151            },
152            None,
153        ),
154        PacketType::Passthrough => (
155            CPacketTypeInfo {
156                type_discriminant: CPacketType::Passthrough,
157                audio_format: std::ptr::null(),
158                custom_type_id: std::ptr::null(),
159            },
160            None,
161        ),
162    }
163}
164
165pub struct CPacketRepr {
166    pub packet: CPacket,
167    _owned: CPacketOwned,
168}
169
170#[allow(dead_code)] // Owned values are kept alive to support FFI pointers during callbacks.
171enum CPacketOwned {
172    None,
173    Audio(Box<CAudioFrame>),
174    Text(CString),
175    Bytes(Vec<u8>),
176    Custom(CustomOwned),
177}
178
179#[allow(dead_code)] // Owned values are kept alive to support FFI pointers during callbacks.
180struct CustomOwned {
181    type_id: CString,
182    data_json: Vec<u8>,
183    metadata: Option<Box<CPacketMetadata>>,
184    custom: Box<CCustomPacket>,
185}
186
187fn metadata_to_c(meta: &PacketMetadata) -> CPacketMetadata {
188    CPacketMetadata {
189        timestamp_us: meta.timestamp_us.unwrap_or_default(),
190        has_timestamp_us: meta.timestamp_us.is_some(),
191        duration_us: meta.duration_us.unwrap_or_default(),
192        has_duration_us: meta.duration_us.is_some(),
193        sequence: meta.sequence.unwrap_or_default(),
194        has_sequence: meta.sequence.is_some(),
195    }
196}
197
198fn metadata_from_c(meta: &CPacketMetadata) -> PacketMetadata {
199    PacketMetadata {
200        timestamp_us: meta.has_timestamp_us.then_some(meta.timestamp_us),
201        duration_us: meta.has_duration_us.then_some(meta.duration_us),
202        sequence: meta.has_sequence.then_some(meta.sequence),
203    }
204}
205
206fn cstring_sanitize(s: &str) -> CString {
207    CString::new(s).unwrap_or_else(|_| CString::new(s.replace('\0', " ")).unwrap_or_default())
208}
209
210/// Convert Rust Packet to C representation.
211///
212/// The returned representation owns any allocations needed for the duration of the C callback.
213pub fn packet_to_c(packet: &Packet) -> CPacketRepr {
214    match packet {
215        Packet::Audio(frame) => {
216            let c_frame = Box::new(CAudioFrame {
217                sample_rate: frame.sample_rate,
218                channels: frame.channels,
219                samples: frame.samples.as_ptr(),
220                sample_count: frame.samples.len(),
221            });
222            let packet = CPacket {
223                packet_type: CPacketType::RawAudio,
224                data: std::ptr::from_ref::<CAudioFrame>(&*c_frame).cast::<c_void>(),
225                len: std::mem::size_of::<CAudioFrame>(),
226            };
227            CPacketRepr { packet, _owned: CPacketOwned::Audio(c_frame) }
228        },
229        Packet::Text(text) => {
230            let s = text.as_ref();
231            let c_str = match CString::new(s) {
232                Ok(s) => s,
233                Err(e) => {
234                    tracing::warn!(
235                        "Text packet contains null bytes (position {}), data will be truncated",
236                        e.nul_position()
237                    );
238                    let truncated = &s[..e.nul_position()];
239                    CString::new(truncated).unwrap_or_default()
240                },
241            };
242            let packet = CPacket {
243                packet_type: CPacketType::Text,
244                data: c_str.as_ptr().cast::<c_void>(),
245                len: c_str.as_bytes_with_nul().len(),
246            };
247            CPacketRepr { packet, _owned: CPacketOwned::Text(c_str) }
248        },
249        Packet::Transcription(trans_data) => {
250            let json = serde_json::to_vec(trans_data).unwrap_or_else(|e| {
251                tracing::error!("Failed to serialize transcription data to JSON: {}", e);
252                b"{}".to_vec()
253            });
254            let packet = CPacket {
255                packet_type: CPacketType::Transcription,
256                data: json.as_ptr().cast::<c_void>(),
257                len: json.len(),
258            };
259            CPacketRepr { packet, _owned: CPacketOwned::Bytes(json) }
260        },
261        Packet::Custom(custom) => {
262            let type_id = cstring_sanitize(custom.type_id.as_str());
263            let data_json = serde_json::to_vec(&custom.data).unwrap_or_else(|e| {
264                tracing::error!("Failed to serialize custom packet data to JSON: {}", e);
265                b"{}".to_vec()
266            });
267
268            let metadata = custom.metadata.as_ref().map(|m| Box::new(metadata_to_c(m)));
269            let mut custom_packet = Box::new(CCustomPacket {
270                type_id: type_id.as_ptr(),
271                encoding: match custom.encoding {
272                    CustomEncoding::Json => CCustomEncoding::Json,
273                },
274                data_json: data_json.as_ptr(),
275                data_len: data_json.len(),
276                metadata: metadata.as_deref().map_or(std::ptr::null(), std::ptr::from_ref),
277            });
278
279            let packet = CPacket {
280                packet_type: CPacketType::Custom,
281                data: std::ptr::from_mut::<CCustomPacket>(&mut *custom_packet).cast::<c_void>(),
282                len: std::mem::size_of::<CCustomPacket>(),
283            };
284
285            CPacketRepr {
286                packet,
287                _owned: CPacketOwned::Custom(CustomOwned {
288                    type_id,
289                    data_json,
290                    metadata,
291                    custom: custom_packet,
292                }),
293            }
294        },
295        Packet::Binary { data, .. } => CPacketRepr {
296            packet: CPacket {
297                packet_type: CPacketType::Binary,
298                data: data.as_ref().as_ptr().cast::<c_void>(),
299                len: data.len(),
300            },
301            _owned: CPacketOwned::None,
302        },
303    }
304}
305
306/// Convert C packet to Rust Packet
307///
308/// # Safety
309///
310/// The caller must ensure:
311/// - The CPacket pointer is valid
312/// - The data pointer is valid and points to data of the specified length
313/// - The data remains valid for the duration of this call
314///
315/// # Errors
316///
317/// Returns an error if:
318/// - The packet pointer is null
319/// - The data pointer is null
320/// - The packet type is unsupported
321/// - The packet data is invalid (e.g., invalid UTF-8, malformed JSON)
322pub unsafe fn packet_from_c(c_packet: *const CPacket) -> Result<Packet, String> {
323    if c_packet.is_null() {
324        return Err("Null packet pointer".to_string());
325    }
326
327    let c_pkt = &*c_packet;
328
329    if c_pkt.data.is_null() {
330        return Err("Null packet data pointer".to_string());
331    }
332
333    match c_pkt.packet_type {
334        CPacketType::RawAudio => {
335            let c_frame = &*c_pkt.data.cast::<CAudioFrame>();
336            if c_frame.samples.is_null() {
337                return Err("Null samples pointer in audio frame".to_string());
338            }
339
340            let samples = std::slice::from_raw_parts(c_frame.samples, c_frame.sample_count);
341
342            Ok(Packet::Audio(AudioFrame::new(
343                c_frame.sample_rate,
344                c_frame.channels,
345                samples.to_vec(),
346            )))
347        },
348        CPacketType::Text => {
349            let c_str = CStr::from_ptr(c_pkt.data.cast::<c_char>());
350            let text = c_str
351                .to_str()
352                .map_err(|e| format!("Invalid UTF-8 in text packet: {e}"))?
353                .to_string();
354            Ok(Packet::Text(text.into()))
355        },
356        CPacketType::Transcription => {
357            // Deserialize JSON transcription data
358            let data = std::slice::from_raw_parts(c_pkt.data.cast::<u8>(), c_pkt.len);
359            let trans_data: TranscriptionData = serde_json::from_slice(data)
360                .map_err(|e| format!("Invalid transcription data: {e}"))?;
361            Ok(Packet::Transcription(Arc::new(trans_data)))
362        },
363        CPacketType::Custom => {
364            let c_custom = &*c_pkt.data.cast::<CCustomPacket>();
365            if c_custom.type_id.is_null() {
366                return Err("Custom packet missing type_id".to_string());
367            }
368            if c_custom.data_json.is_null() {
369                return Err("Custom packet missing data_json".to_string());
370            }
371
372            let type_id = c_str_to_string(c_custom.type_id)?;
373            let data_bytes = std::slice::from_raw_parts(c_custom.data_json, c_custom.data_len);
374            let data: serde_json::Value = serde_json::from_slice(data_bytes)
375                .map_err(|e| format!("Invalid custom JSON: {e}"))?;
376
377            let metadata = if c_custom.metadata.is_null() {
378                None
379            } else {
380                Some(metadata_from_c(&*c_custom.metadata))
381            };
382
383            let encoding = match c_custom.encoding {
384                CCustomEncoding::Json => CustomEncoding::Json,
385            };
386
387            Ok(Packet::Custom(Arc::new(CustomPacketData { type_id, encoding, data, metadata })))
388        },
389        CPacketType::Binary => {
390            let data = std::slice::from_raw_parts(c_pkt.data.cast::<u8>(), c_pkt.len);
391            Ok(Packet::Binary {
392                data: bytes::Bytes::copy_from_slice(data),
393                content_type: None,
394                metadata: None,
395            })
396        },
397        _ => Err(format!("Unsupported packet type: {:?}", c_pkt.packet_type)),
398    }
399}
400
401/// Convert C string to Rust String
402///
403/// # Safety
404///
405/// The pointer must be a valid null-terminated C string
406///
407/// # Errors
408///
409/// Returns an error if the string contains invalid UTF-8
410pub unsafe fn c_str_to_string(ptr: *const c_char) -> Result<String, String> {
411    if ptr.is_null() {
412        return Ok(String::new());
413    }
414
415    CStr::from_ptr(ptr)
416        .to_str()
417        .map(std::string::ToString::to_string)
418        .map_err(|e| format!("Invalid UTF-8: {e}"))
419}
420
421/// Convert Rust string to C string (caller must free)
422///
423/// # Panics
424///
425/// Panics if the string contains null bytes
426#[allow(clippy::expect_used)] // expect is appropriate here - null bytes in strings are programmer errors
427pub fn string_to_c(s: &str) -> *const c_char {
428    CString::new(s).expect("String should not contain null bytes").into_raw()
429}
430
431/// Convert an error message to a C string for returning across the C ABI.
432///
433/// # Ownership and lifetime
434///
435/// The returned pointer is **borrowed** and **must not be freed** by the caller.
436/// It remains valid until the next `error_to_c()` call on the same OS thread.
437///
438/// This design:
439/// - Prevents host-side leaks when the host copies the message into an owned string.
440/// - Avoids cross-dylib allocator issues (freeing memory in a different module).
441pub fn error_to_c(msg: impl AsRef<str>) -> *const c_char {
442    thread_local! {
443        static LAST_ERROR: RefCell<CString> = RefCell::new(
444            // Empty string; always a valid null-terminated C string.
445            CString::new("").unwrap_or_else(|_| unsafe { CString::from_vec_unchecked(vec![0]) })
446        );
447    }
448
449    let msg = msg.as_ref();
450    let sanitized = if msg.contains('\0') { msg.replace('\0', " ") } else { msg.to_string() };
451
452    // CString::new can only fail if there are interior null bytes. We sanitize them above,
453    // but avoid panicking at this FFI boundary and fall back to an empty string if needed.
454    let c_str =
455        CString::new(sanitized).unwrap_or_else(|_| unsafe { CString::from_vec_unchecked(vec![0]) });
456
457    LAST_ERROR.with(|slot| {
458        *slot.borrow_mut() = c_str;
459        slot.borrow().as_ptr()
460    })
461}
462
463/// Free a C string created by [`string_to_c`].
464/// # Safety
465/// The pointer must have been created by `string_to_c` and not freed yet.
466pub unsafe fn free_c_string(ptr: *const c_char) {
467    if !ptr.is_null() {
468        drop(CString::from_raw(ptr.cast_mut()));
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_error_to_c_normal_string() {
478        let msg = "Test error message";
479        let c_msg = error_to_c(msg);
480        unsafe {
481            let result_cstr = CStr::from_ptr(c_msg);
482            assert_eq!(result_cstr.to_string_lossy(), msg);
483        }
484    }
485
486    #[test]
487    fn test_error_to_c_with_null_bytes() {
488        let msg = "Error\0with\0null\0bytes";
489        let c_msg = error_to_c(msg);
490        unsafe {
491            let result_cstr = CStr::from_ptr(c_msg);
492            let result = result_cstr.to_string_lossy();
493            // Null bytes should be replaced with spaces
494            assert_eq!(result, "Error with null bytes");
495        }
496    }
497
498    #[test]
499    fn test_error_to_c_format_string() {
500        let msg = format!("Error code: {}", 42);
501        let c_msg = error_to_c(&msg);
502        unsafe {
503            let result_cstr = CStr::from_ptr(c_msg);
504            assert_eq!(result_cstr.to_string_lossy(), "Error code: 42");
505        }
506    }
507
508    #[test]
509    fn test_string_to_c_requires_free() {
510        let c_msg = string_to_c("hello");
511        unsafe {
512            let result_cstr = CStr::from_ptr(c_msg);
513            assert_eq!(result_cstr.to_string_lossy(), "hello");
514            free_c_string(c_msg);
515        }
516    }
517}