1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
use std::ops::{Deref, DerefMut};
use std::fmt;
use std::ptr;
use ::{ComIid, ComInterface, RtInterface, RtClassInterface, IInspectable, Guid};

use w::shared::ntdef::VOID;
use w::shared::minwindef::LPVOID;
use w::shared::winerror::S_OK;
use w::um::unknwnbase::IUnknown;
use w::um::combaseapi::CoTaskMemFree;

/// Smart pointer for Windows Runtime objects. This pointer automatically maintains the
/// reference count of the underlying COM object.
#[repr(transparent)]
#[derive(Debug)]
pub struct ComPtr<T>(ptr::NonNull<T>);

impl<T> fmt::Pointer for ComPtr<T> {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        fmt::Pointer::fmt(&self.0, f)
    }
}

// This is a helper method that is not exposed publically by the library
#[inline]
pub fn query_interface<T, Target>(interface: &T) -> Option<ComPtr<Target>> where Target: ComIid, T: ComInterface {
    let iid: &'static Guid = Target::iid();
    let as_unknown = unsafe { &mut *(interface  as *const T as *mut T as *mut IUnknown) };
    let mut res = ptr::null_mut();
    unsafe {
        match as_unknown.QueryInterface(iid.as_ref(), &mut res as *mut _ as *mut *mut VOID) {
            S_OK => Some(ComPtr::wrap(res)),
            _ => None
        }
    }
}

// This trait is not exported in the library interface
pub trait HiddenGetRuntimeClassName {
    fn get_runtime_class_name(&self) -> ::HString;
}

impl<T> ComPtr<T> {
    /// Creates a `ComPtr` to wrap a raw pointer.
    /// It takes ownership over the pointer which means it does __not__ call `AddRef`.
    /// `T` __must__ be a COM interface that inherits from `IUnknown`.
    /// The wrapped pointer must not be null.
    #[inline]
    pub unsafe fn wrap(ptr: *mut T) -> ComPtr<T> { // TODO: Add T: ComInterface bound
        debug_assert!(!ptr.is_null());
        ComPtr(ptr::NonNull::new_unchecked(ptr))
    }

    /// Creates an optional `ComPtr` to wrap a raw pointer that may be null.
    /// It takes ownership over the pointer which means it does __not__ call `AddRef`.
    /// `T` __must__ be a COM interface that inherits from `IUnknown`.
    #[inline]
    pub unsafe fn wrap_optional(ptr: *mut T) -> Option<ComPtr<T>> { // TODO: Add T: ComInterface bound
        if ptr.is_null() {
            None
        } else {
            Some(ComPtr(ptr::NonNull::new_unchecked(ptr)))
        }
    }

    /// Returns the underlying WinRT object as a reference to an `IInspectable` object.
    #[inline]
    fn as_inspectable(&self) -> &mut IInspectable where T: RtInterface {
        unsafe { &mut *(self.0.as_ptr() as *mut IInspectable) }
    }
    
    /// Returns the underlying WinRT or COM object as a reference to an `IUnknown` object.
    #[inline]
    fn as_unknown(&self) -> &mut IUnknown {
        unsafe { &mut *(self.0.as_ptr() as *mut IUnknown) }
    }

    /// Changes the type of the underlying COM object to a different interface without doing `QueryInterface`.
    /// This is a runtime no-op, but you need to be sure that the interface is compatible.
    #[inline]
    pub unsafe fn into_unchecked<Interface>(self) -> ComPtr<Interface> where Interface: ComInterface {
        ::std::mem::transmute(self)
    }
    
    /// Gets the fully qualified name of the current Windows Runtime object.
    /// This is only available for interfaces that inherit from `IInspectable` and
    /// are not factory or statics interfaces.
    ///
    /// # Examples
    ///
    /// Basic usage:
    ///
    /// ```
    /// use winrt::*;
    /// use winrt::windows::foundation::Uri;
    ///
    /// let uri = FastHString::new("https://www.rust-lang.org");
    /// let uri = Uri::create_uri(&uri).unwrap();
    /// assert_eq!("Windows.Foundation.Uri", uri.get_runtime_class_name().to_string());
    /// ```
    #[inline]
    pub fn get_runtime_class_name(&self) -> ::HString where T: RtClassInterface {
        HiddenGetRuntimeClassName::get_runtime_class_name(self.as_inspectable())
    }
    
    /// Retrieves a `ComPtr` to the specified interface, if it is supported by the underlying object.
    /// If the requested interface is not supported, `None` is returned.
    #[inline]
    pub fn query_interface<Target>(&self) -> Option<ComPtr<Target>> where Target: ComIid, T: ComInterface {
        query_interface::<_, Target>(&**self)
    }
}
impl<T> Deref for ComPtr<T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &T {
        unsafe { self.0.as_ref() }
    }
}
impl<T> DerefMut for ComPtr<T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut T {
        unsafe { self.0.as_mut() }
    }
}
impl<T> Clone for ComPtr<T> {
    #[inline]
    fn clone(&self) -> Self {
        unsafe { 
            self.as_unknown().AddRef();
            ComPtr::wrap(self.0.as_ptr())
        }
    }
}
impl<T> Drop for ComPtr<T> {
    #[inline]
    fn drop(&mut self) {
        unsafe { self.as_unknown().Release() };
    }
}
impl<T> PartialEq<ComPtr<T>> for ComPtr<T> {
    #[inline]
    fn eq(&self, other: &ComPtr<T>) -> bool {
        self.0 == other.0
    }
}

/// Owned array type that is used as return type when WinRT methods return arrays.
/// It wraps a block of memory that has been allocated by WinRT and will be deallocated
/// using `CoTaskMemFree` on drop.
pub struct ComArray<T> where T: ::RtType {
    size: u32,
    first: ptr::NonNull<T::Abi>
}

impl<T> ComArray<T> where T: ::RtType {
    #[inline]
    pub unsafe fn from_raw(size: u32, first: *mut T::Abi) -> ComArray<T> {
        assert!(!first.is_null());
        ComArray {
            size: size,
            first: ptr::NonNull::new_unchecked(first)
        }
    }

    /// Returns the length of the array.
    #[inline]
    pub fn len(&self) -> usize {
        self.size as usize
    }
}

impl<T> Deref for ComArray<T> where T: ::RtType {
    type Target = [T::OutNonNull];
    #[inline]
    fn deref(&self) -> &[T::OutNonNull] {
        unsafe { ::std::slice::from_raw_parts(self.first.as_ptr() as *mut T::OutNonNull, self.size as usize) }
    }
}
impl<T> DerefMut for ComArray<T> where T: ::RtType {
    #[inline]
    fn deref_mut(&mut self) -> &mut [T::OutNonNull] {
        unsafe { ::std::slice::from_raw_parts_mut(self.first.as_ptr() as *mut T::OutNonNull, self.size as usize) }
    }
}

impl<T> Drop for ComArray<T> where T: ::RtType {
    #[inline]
    fn drop(&mut self) {
        unsafe {
            ::std::ptr::drop_in_place(&mut self[..]);
            CoTaskMemFree(self.first.as_ptr() as LPVOID)
        };
    }
}

mod extra {
    // makes sure that compile fails when ComPtr is not pointer-sized
    // i.e. when a compiler version is used that still has dropflags
    #[inline]
    fn assert_no_dropflags() {
        let p: *mut ::IInspectable = ::std::ptr::null_mut();
        let _: ::ComPtr<::IInspectable> = unsafe { ::std::mem::transmute(p) };
    }
}

#[cfg(test)]
mod tests {
    extern crate test;

    #[test]
    fn check_sizes() {
        use ::std::mem::size_of;

        // make sure that ComPtr is pointer-sized
        assert_eq!(size_of::<::ComPtr<::IInspectable>>(), size_of::<*mut ::IInspectable>());
        assert_eq!(size_of::<Option<::ComPtr<::IInspectable>>>(), size_of::<*mut ::IInspectable>());
    }
}