1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use crate::*;

pub fn gen_async(
    def: &tables::TypeDef,
    interfaces: &[InterfaceInfo],
    gen: &Gen,
) -> (TokenStream, TokenStream) {
    let kind = async_kind(def);

    if kind != AsyncKind::None {
        return gen_async_kind(kind, def, def, gen);
    }

    for interface in interfaces {
        let kind = async_kind(&interface.def);

        if kind != AsyncKind::None {
            return gen_async_kind(kind, &interface.def, def, gen);
        }
    }

    (TokenStream::new(), TokenStream::new())
}

#[derive(PartialEq)]
pub enum AsyncKind {
    None,
    Action,
    ActionWithProgress,
    Operation,
    OperationWithProgress,
}

pub fn async_kind(def: &tables::TypeDef) -> AsyncKind {
    if def.namespace() != "Windows.Foundation" {
        return AsyncKind::None;
    }

    match def.name() {
        "IAsyncAction" => AsyncKind::Action,
        "IAsyncActionWithProgress`1" => AsyncKind::ActionWithProgress,
        "IAsyncOperation`1" => AsyncKind::Operation,
        "IAsyncOperationWithProgress`2" => AsyncKind::OperationWithProgress,
        _ => AsyncKind::None,
    }
}

fn gen_async_kind(
    kind: AsyncKind,
    name: &tables::TypeDef,
    self_name: &tables::TypeDef,
    gen: &Gen,
) -> (TokenStream, TokenStream) {
    let return_type = match kind {
        AsyncKind::Operation | AsyncKind::OperationWithProgress => name.generics[0].gen_name(gen),
        _ => quote! { () },
    };

    let handler = match kind {
        AsyncKind::Action => quote! { AsyncActionCompletedHandler },
        AsyncKind::ActionWithProgress => quote! { AsyncActionWithProgressCompletedHandler },
        AsyncKind::Operation => quote! { AsyncOperationCompletedHandler },
        AsyncKind::OperationWithProgress => quote! { AsyncOperationWithProgressCompletedHandler },
        _ => panic!("Unexpected AsyncKind"),
    };

    let constraints = self_name.gen_constraints(gen);
    let name = self_name.gen_name(gen);
    let namespace = gen.namespace("Windows.Foundation");

    (
        quote! {
            pub fn get(&self) -> ::windows::Result<#return_type> {
                if self.Status()? == #namespace AsyncStatus::Started {
                    let (waiter, signaler) = ::windows::Waiter::new();
                    self.SetCompleted(#namespace  #handler::new(move |_sender, _args| {
                        // Safe because the waiter will only be dropped after being signaled.
                        unsafe { signaler.signal(); }
                        Ok(())
                    }))?;
                }
                self.GetResults()
            }
        },
        quote! {
            impl<#constraints> ::std::future::Future for #name {
                type Output = ::windows::Result<#return_type>;

                fn poll(self: ::std::pin::Pin<&mut Self>, context: &mut ::std::task::Context) -> ::std::task::Poll<Self::Output> {
                    if self.Status()? == #namespace AsyncStatus::Started {
                        let waker = context.waker().clone();

                        let _ = self.SetCompleted(#namespace #handler::new(move |_sender, _args| {
                            waker.wake_by_ref();
                            Ok(())
                        }));

                        ::std::task::Poll::Pending
                    } else {
                        ::std::task::Poll::Ready(self.GetResults())
                    }
                }
            }
        },
    )
}