sbz_switch/media/
mod.rs

1//! Provides a Rust layer over the Windows IMMDevice API.
2
3#![allow(unknown_lints)]
4
5mod event;
6
7use std::alloc;
8use std::error::Error;
9use std::ffi::{OsStr, OsString};
10use std::fmt;
11use std::isize;
12use std::mem::MaybeUninit;
13use std::os::windows::ffi::{OsStrExt, OsStringExt};
14use std::ptr;
15use std::slice;
16use std::sync::atomic::{self, AtomicUsize, Ordering};
17
18use regex::Regex;
19use slog::Logger;
20use winapi::ctypes::c_void;
21use winapi::shared::guiddef::GUID;
22use winapi::shared::guiddef::{IsEqualIID, REFIID};
23use winapi::shared::minwindef::ULONG;
24use winapi::shared::ntdef::HRESULT;
25use winapi::shared::winerror::{E_INVALIDARG, E_NOINTERFACE};
26use winapi::shared::wtypes::{PROPERTYKEY, VARTYPE};
27use winapi::um::combaseapi::CLSCTX_ALL;
28use winapi::um::combaseapi::{CoCreateInstance, CoTaskMemFree, PropVariantClear};
29use winapi::um::coml2api::STGM_READ;
30use winapi::um::endpointvolume::{
31    IAudioEndpointVolume, IAudioEndpointVolumeCallback, IAudioEndpointVolumeCallbackVtbl,
32    AUDIO_VOLUME_NOTIFICATION_DATA,
33};
34use winapi::um::mmdeviceapi::{
35    eConsole, eRender, CLSID_MMDeviceEnumerator, IMMDevice, IMMDeviceEnumerator,
36    DEVICE_STATE_ACTIVE,
37};
38use winapi::um::propidl::PROPVARIANT;
39use winapi::um::propsys::IPropertyStore;
40use winapi::um::unknwnbase::{IUnknown, IUnknownVtbl};
41use winapi::Interface;
42
43pub(crate) use self::event::VolumeEvents;
44pub use self::event::VolumeNotification;
45use crate::com::{ComObject, ComScope};
46use crate::hresult::{check, Win32Error};
47use crate::lazy::Lazy;
48use crate::soundcore::{SoundCoreError, PKEY_SOUNDCORECTL_CLSID_AE5, PKEY_SOUNDCORECTL_CLSID_Z};
49use crate::winapiext::{PKEY_DeviceInterface_FriendlyName, PKEY_Device_DeviceDesc};
50
51fn parse_guid(src: &str) -> Result<GUID, Box<dyn Error>> {
52    let re1 = Regex::new(
53        "^\\{([0-9a-fA-F]{8})-([0-9a-fA-F]{4})-\
54         ([0-9a-fA-F]{4})-([0-9a-fA-F]{2})([0-9a-fA-F]{2})-\
55         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})\
56         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})\\}$",
57    )
58    .unwrap();
59    let re2 = Regex::new(
60        "^([0-9a-fA-F]{8})-([0-9a-fA-F]{4})-\
61         ([0-9a-fA-F]{4})-([0-9a-fA-F]{2})([0-9a-fA-F]{2})-\
62         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})\
63         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$",
64    )
65    .unwrap();
66    let re3 = Regex::new(
67        "^([0-9a-fA-F]{8})([0-9a-fA-F]{4})\
68         ([0-9a-fA-F]{4})([0-9a-fA-F]{2})([0-9a-fA-F]{2})\
69         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})\
70         ([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$",
71    )
72    .unwrap();
73
74    let caps = re1
75        .captures(src)
76        .or_else(|| re2.captures(src))
77        .or_else(|| re3.captures(src))
78        .ok_or(SoundCoreError::NotSupported)?;
79
80    let mut iter = caps.iter().skip(1).map(|c| c.unwrap().as_str());
81    let l = u32::from_str_radix(iter.next().unwrap(), 16).unwrap();
82    let w1 = u16::from_str_radix(iter.next().unwrap(), 16).unwrap();
83    let w2 = u16::from_str_radix(iter.next().unwrap(), 16).unwrap();
84    let mut array = [0; 8];
85    for b in iter.enumerate() {
86        array[b.0] = u8::from_str_radix(b.1, 16).unwrap();
87    }
88
89    Ok(GUID {
90        Data1: l,
91        Data2: w1,
92        Data3: w2,
93        Data4: array,
94    })
95}
96
97/// Represents an audio device.
98pub struct Endpoint {
99    device: ComObject<IMMDevice>,
100    logger: Logger,
101    volume: Lazy<Result<ComObject<IAudioEndpointVolume>, Win32Error>>,
102    properties: Lazy<Result<PropertyStore, Win32Error>>,
103}
104
105impl Endpoint {
106    fn new(device: ComObject<IMMDevice>, logger: Logger) -> Self {
107        Self {
108            device,
109            logger,
110            volume: Lazy::new(),
111            properties: Lazy::new(),
112        }
113    }
114    /// Gets the ID of the endpoint.
115    ///
116    /// See [Endpoint ID Strings](https://docs.microsoft.com/en-us/windows/desktop/CoreAudio/endpoint-id-strings).
117    pub fn id(&self) -> Result<String, Win32Error> {
118        unsafe {
119            trace!(self.logger, "Getting device ID...");
120            let mut raw_id = MaybeUninit::uninit();
121            check(self.device.GetId(raw_id.as_mut_ptr()))?;
122            let raw_id = raw_id.assume_init();
123            let length = (0..isize::MAX)
124                .position(|i| *raw_id.offset(i) == 0)
125                .unwrap();
126            let str: OsString = OsStringExt::from_wide(slice::from_raw_parts(raw_id, length));
127            CoTaskMemFree(raw_id as *mut _);
128            Ok(str.to_string_lossy().into_owned())
129        }
130    }
131    fn property_store(&self) -> Result<&PropertyStore, Win32Error> {
132        self.properties
133            .get_or_create(|| unsafe {
134                trace!(self.logger, "Opening PropertyStore...");
135                let mut property_store = MaybeUninit::uninit();
136                check(
137                    self.device
138                        .OpenPropertyStore(STGM_READ, property_store.as_mut_ptr()),
139                )?;
140                Ok(PropertyStore(
141                    ComObject::take(property_store.assume_init()),
142                    self.logger.clone(),
143                ))
144            })
145            .as_ref()
146            .map_err(|e| e.clone())
147    }
148    /// Gets the CLSID of the class implementing Creative's APIs.
149    ///
150    /// This allows discovery of a SoundCore implementation for devices that support it.
151    pub fn clsid(&self) -> Result<GUID, SoundCoreError> {
152        let store = self.property_store()?;
153        let value = match store.get_string_value(&PKEY_SOUNDCORECTL_CLSID_AE5)? {
154            Some(value) => value,
155            None => store
156                .get_string_value(&PKEY_SOUNDCORECTL_CLSID_Z)?
157                .ok_or(SoundCoreError::NotSupported)?,
158        };
159        parse_guid(&value).or(Err(SoundCoreError::NotSupported))
160    }
161
162    /// Gets the friendly name of the audio interface (sound adapter).
163    ///
164    /// See [Core Audio Properties: Device Properties](https://docs.microsoft.com/en-us/windows/desktop/coreaudio/core-audio-properties#device-properties).
165    pub fn interface(&self) -> Result<String, GetPropertyError> {
166        self.property_store()?
167            .get_string_value(&PKEY_DeviceInterface_FriendlyName)?
168            .ok_or(GetPropertyError::NOT_FOUND)
169    }
170    /// Gets a description of the audio endpoint (speakers, headphones, etc).
171    ///
172    /// See [Core Audio Properties: Device Properties](https://docs.microsoft.com/en-us/windows/desktop/coreaudio/core-audio-properties#device-properties).
173    pub fn description(&self) -> Result<String, GetPropertyError> {
174        self.property_store()?
175            .get_string_value(&PKEY_Device_DeviceDesc)?
176            .ok_or(GetPropertyError::NOT_FOUND)
177    }
178    fn volume(&self) -> Result<ComObject<IAudioEndpointVolume>, Win32Error> {
179        self.volume
180            .get_or_create(|| unsafe {
181                let mut ctrl = MaybeUninit::<*mut IAudioEndpointVolume>::uninit();
182                check(self.device.Activate(
183                    &IAudioEndpointVolume::uuidof(),
184                    CLSCTX_ALL,
185                    ptr::null_mut(),
186                    ctrl.as_mut_ptr() as *mut _,
187                ))?;
188                Ok(ComObject::take(ctrl.assume_init()))
189            })
190            .clone()
191    }
192    /// Checks whether the device is already muted.
193    pub fn get_mute(&self) -> Result<bool, Win32Error> {
194        unsafe {
195            trace!(self.logger, "Checking if we are muted...");
196            let mut mute = 0;
197            check(self.volume()?.GetMute(&mut mute))?;
198            debug!(self.logger, "Muted = {}", mute);
199            Ok(mute != 0)
200        }
201    }
202    /// Mutes or unmutes the device.
203    pub fn set_mute(&self, mute: bool) -> Result<(), Win32Error> {
204        unsafe {
205            let mute = if mute { 1 } else { 0 };
206            info!(self.logger, "Setting muted to {}...", mute);
207            check(self.volume()?.SetMute(mute, ptr::null_mut()))?;
208            Ok(())
209        }
210    }
211    /// Sets the volume of the device.
212    ///
213    /// Volumes range from 0.0 to 1.0.
214    ///
215    /// Volume can be controlled independent of muting.
216    pub fn set_volume(&self, volume: f32) -> Result<(), Win32Error> {
217        unsafe {
218            info!(self.logger, "Setting volume to {}...", volume);
219            check(
220                self.volume()?
221                    .SetMasterVolumeLevelScalar(volume, ptr::null_mut()),
222            )?;
223            Ok(())
224        }
225    }
226    /// Gets the volume of the device.
227    pub fn get_volume(&self) -> Result<f32, Win32Error> {
228        unsafe {
229            debug!(self.logger, "Getting volume...");
230            let mut volume = MaybeUninit::uninit();
231            check(
232                self.volume()?
233                    .GetMasterVolumeLevelScalar(volume.as_mut_ptr()),
234            )?;
235            let volume = volume.assume_init();
236            debug!(self.logger, "volume = {}", volume);
237            Ok(volume)
238        }
239    }
240    pub(crate) fn event_stream(&self) -> Result<VolumeEvents, Win32Error> {
241        VolumeEvents::new(self.volume()?)
242    }
243}
244
245/// Describes an error that occurred while retrieving a property from a device.
246#[derive(Debug)]
247pub enum GetPropertyError {
248    /// A Win32 error occurred.
249    Win32(Win32Error),
250    /// The returned value was not the expected type.
251    UnexpectedType(VARTYPE),
252}
253
254impl GetPropertyError {
255    pub(crate) const NOT_FOUND: GetPropertyError = GetPropertyError::UnexpectedType(0);
256}
257
258impl fmt::Display for GetPropertyError {
259    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260        match self {
261            GetPropertyError::Win32(error) => error.fmt(f),
262            GetPropertyError::UnexpectedType(code) => {
263                write!(f, "returned property value was of unexpected type {}", code)
264            }
265        }
266    }
267}
268
269impl Error for GetPropertyError {
270    fn cause(&self) -> Option<&dyn Error> {
271        match self {
272            GetPropertyError::Win32(error) => Some(error),
273            _ => None,
274        }
275    }
276}
277
278impl From<Win32Error> for GetPropertyError {
279    fn from(error: Win32Error) -> Self {
280        GetPropertyError::Win32(error)
281    }
282}
283
284struct PropertyStore(ComObject<IPropertyStore>, Logger);
285
286impl PropertyStore {
287    unsafe fn get_value(&self, key: &PROPERTYKEY) -> Result<PROPVARIANT, Win32Error> {
288        trace!(self.1, "Getting property...");
289        let mut property_value = MaybeUninit::uninit();
290        check(self.0.GetValue(key, property_value.as_mut_ptr()))?;
291        Ok(property_value.assume_init())
292    }
293    #[allow(clippy::cast_ptr_alignment)]
294    fn get_string_value(&self, key: &PROPERTYKEY) -> Result<Option<String>, GetPropertyError> {
295        unsafe {
296            let mut property_value = self.get_value(key)?;
297            trace!(self.1, "Returned variant has type {}", property_value.vt);
298            // VT_EMPTY
299            if property_value.vt == 0 {
300                return Ok(None);
301            }
302            // VT_LPWSTR
303            if property_value.vt != 31 {
304                PropVariantClear(&mut property_value);
305                return Err(GetPropertyError::UnexpectedType(property_value.vt));
306            }
307            let chars = *property_value.data.pwszVal();
308            let length = (0..isize::MAX).position(|i| *chars.offset(i) == 0);
309            let str = length.map(|length| {
310                OsString::from_wide(slice::from_raw_parts(chars, length))
311                    .to_string_lossy()
312                    .into_owned()
313            });
314            PropVariantClear(&mut property_value);
315            let str = str.unwrap();
316            trace!(self.1, "Returned variant has value {}", &str);
317            Ok(Some(str))
318        }
319    }
320}
321
322/// Provides access to the devices available in the current Windows session.
323pub struct DeviceEnumerator(ComObject<IMMDeviceEnumerator>, Logger);
324
325impl DeviceEnumerator {
326    /// Creates a new device enumerator with the provided logger.
327    pub fn with_logger(logger: Logger) -> Result<Self, Win32Error> {
328        unsafe {
329            let _scope = ComScope::begin();
330            let mut enumerator = MaybeUninit::<*mut IMMDeviceEnumerator>::uninit();
331            trace!(logger, "Creating DeviceEnumerator...");
332            check(CoCreateInstance(
333                &CLSID_MMDeviceEnumerator,
334                ptr::null_mut(),
335                CLSCTX_ALL,
336                &IMMDeviceEnumerator::uuidof(),
337                enumerator.as_mut_ptr() as *mut _,
338            ))?;
339            trace!(logger, "Created DeviceEnumerator");
340            Ok(DeviceEnumerator(
341                ComObject::take(enumerator.assume_init()),
342                logger,
343            ))
344        }
345    }
346    /// Gets all active audio outputs.
347    #[allow(clippy::unnecessary_mut_passed)]
348    pub fn get_active_audio_endpoints(&self) -> Result<Vec<Endpoint>, Win32Error> {
349        unsafe {
350            trace!(self.1, "Getting active endpoints...");
351            let mut collection = MaybeUninit::uninit();
352            check(self.0.EnumAudioEndpoints(
353                eRender,
354                DEVICE_STATE_ACTIVE,
355                collection.as_mut_ptr(),
356            ))?;
357            let collection = collection.assume_init();
358            let mut count = 0;
359            check((*collection).GetCount(&mut count))?;
360            let mut result = Vec::with_capacity(count as usize);
361            for i in 0..count {
362                let mut device = MaybeUninit::uninit();
363                check((*collection).Item(i, device.as_mut_ptr()))?;
364                result.push(Endpoint::new(
365                    ComObject::take(device.assume_init()),
366                    self.1.clone(),
367                ))
368            }
369            Ok(result)
370        }
371    }
372    /// Gets the default audio output.
373    ///
374    /// There are multiple default audio outputs in Windows.
375    /// This function gets the device that would be used if the current application
376    /// were to play music or sound effects (as opposed to VOIP audio).
377    pub fn get_default_audio_endpoint(&self) -> Result<Endpoint, Win32Error> {
378        unsafe {
379            trace!(self.1, "Getting default endpoint...");
380            let mut device = MaybeUninit::uninit();
381            check(
382                self.0
383                    .GetDefaultAudioEndpoint(eRender, eConsole, device.as_mut_ptr()),
384            )?;
385            Ok(Endpoint::new(
386                ComObject::take(device.assume_init()),
387                self.1.clone(),
388            ))
389        }
390    }
391    /// Get a specific audio endpoint by its ID.
392    pub fn get_endpoint(&self, id: &OsStr) -> Result<Endpoint, Win32Error> {
393        trace!(self.1, "Getting endpoint...");
394        let buffer: Vec<_> = id.encode_wide().chain(Some(0)).collect();
395        unsafe {
396            let mut device = MaybeUninit::uninit();
397            check(self.0.GetDevice(buffer.as_ptr(), device.as_mut_ptr()))?;
398            Ok(Endpoint::new(
399                ComObject::take(device.assume_init()),
400                self.1.clone(),
401            ))
402        }
403    }
404}
405
406#[repr(C)]
407struct AudioEndpointVolumeCallback<C> {
408    lp_vtbl: *mut IAudioEndpointVolumeCallbackVtbl,
409    vtbl: IAudioEndpointVolumeCallbackVtbl,
410    refs: AtomicUsize,
411    callback: C,
412}
413
414impl<C> AudioEndpointVolumeCallback<C>
415where
416    C: Send + 'static + FnMut(&AUDIO_VOLUME_NOTIFICATION_DATA) -> Result<(), Win32Error>,
417{
418    /// Wraps a function in an `IAudioEndpointVolumeCallback`.
419    pub unsafe fn wrap(callback: C) -> *mut IAudioEndpointVolumeCallback {
420        let mut value = Box::new(AudioEndpointVolumeCallback::<C> {
421            lp_vtbl: ptr::null_mut(),
422            vtbl: IAudioEndpointVolumeCallbackVtbl {
423                parent: IUnknownVtbl {
424                    QueryInterface: callback_query_interface::<C>,
425                    AddRef: callback_add_ref::<C>,
426                    Release: callback_release::<C>,
427                },
428                OnNotify: callback_on_notify::<C>,
429            },
430            refs: AtomicUsize::new(1),
431            callback,
432        });
433        value.lp_vtbl = &mut value.vtbl as *mut _;
434        Box::into_raw(value) as *mut _
435    }
436}
437
438// ensures `this` is an instance of the expected type
439unsafe fn validate<I, C>(this: *mut I) -> Result<*mut AudioEndpointVolumeCallback<C>, Win32Error>
440where
441    I: Interface,
442{
443    let this = this as *mut IUnknown;
444    if this.is_null()
445        || (*this).lpVtbl.is_null()
446        || (*(*this).lpVtbl).QueryInterface as usize != callback_query_interface::<C> as usize
447    {
448        Err(Win32Error::new(E_INVALIDARG))
449    } else {
450        Ok(this as *mut AudioEndpointVolumeCallback<C>)
451    }
452}
453
454// converts a `Result` to an `HRESULT` so `?` can be used
455unsafe fn uncheck<E>(result: E) -> HRESULT
456where
457    E: FnOnce() -> Result<HRESULT, Win32Error>,
458{
459    match result() {
460        Ok(result) => result,
461        Err(Win32Error { code, .. }) => code,
462    }
463}
464
465unsafe extern "system" fn callback_query_interface<C>(
466    this: *mut IUnknown,
467    iid: REFIID,
468    object: *mut *mut c_void,
469) -> HRESULT {
470    uncheck(|| {
471        let this = validate::<_, C>(this)?;
472        let iid = iid.as_ref().unwrap();
473        if IsEqualIID(iid, &IUnknown::uuidof())
474            || IsEqualIID(iid, &IAudioEndpointVolumeCallback::uuidof())
475        {
476            (*this).refs.fetch_add(1, Ordering::Relaxed);
477            *object = this as *mut c_void;
478            Ok(0)
479        } else {
480            *object = ptr::null_mut();
481            Err(Win32Error::new(E_NOINTERFACE))
482        }
483    })
484}
485
486unsafe extern "system" fn callback_add_ref<C>(this: *mut IUnknown) -> ULONG {
487    match validate::<_, C>(this) {
488        Ok(this) => {
489            let count = (*this).refs.fetch_add(1, Ordering::Relaxed) + 1;
490            count as ULONG
491        }
492        Err(_) => 1,
493    }
494}
495
496unsafe extern "system" fn callback_release<C>(this: *mut IUnknown) -> ULONG {
497    match validate::<_, C>(this) {
498        Ok(this) => {
499            let count = (*this).refs.fetch_sub(1, Ordering::Release) - 1;
500            if count == 0 {
501                atomic::fence(Ordering::Acquire);
502                ptr::drop_in_place(this);
503                alloc::dealloc(
504                    this as *mut u8,
505                    alloc::Layout::for_value(this.as_ref().unwrap()),
506                );
507            }
508            count as ULONG
509        }
510        Err(_) => 1,
511    }
512}
513
514unsafe extern "system" fn callback_on_notify<C>(
515    this: *mut IAudioEndpointVolumeCallback,
516    notify: *mut AUDIO_VOLUME_NOTIFICATION_DATA,
517) -> HRESULT
518where
519    C: FnMut(&AUDIO_VOLUME_NOTIFICATION_DATA) -> Result<(), Win32Error>,
520{
521    uncheck(|| {
522        let this = validate::<_, C>(this)?;
523        ((*this).callback)(&*notify)?;
524        Ok(0)
525    })
526}