windows_capture/
capture.rs

1use std::{
2    mem,
3    os::windows::prelude::AsRawHandle,
4    sync::{
5        atomic::{self, AtomicBool},
6        mpsc, Arc,
7    },
8    thread::{self, JoinHandle},
9};
10
11use parking_lot::Mutex;
12use windows::{
13    Foundation::AsyncActionCompletedHandler,
14    Graphics::Capture::GraphicsCaptureItem,
15    Win32::{
16        Foundation::{HANDLE, LPARAM, WPARAM},
17        Graphics::Direct3D11::{ID3D11Device, ID3D11DeviceContext},
18        System::{
19            Threading::{GetCurrentThreadId, GetThreadId},
20            WinRT::{
21                CreateDispatcherQueueController, DispatcherQueueOptions, RoInitialize,
22                RoUninitialize, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, RO_INIT_MULTITHREADED,
23            },
24        },
25        UI::WindowsAndMessaging::{
26            DispatchMessageW, GetMessageW, PostQuitMessage, PostThreadMessageW, TranslateMessage,
27            MSG, WM_QUIT,
28        },
29    },
30};
31
32use crate::{
33    d3d11::{self, create_d3d_device},
34    frame::Frame,
35    graphics_capture_api::{self, GraphicsCaptureApi, InternalCaptureControl},
36    settings::Settings,
37};
38
39#[derive(thiserror::Error, Debug)]
40pub enum CaptureControlError<E> {
41    #[error("Failed to join thread")]
42    FailedToJoinThread,
43    #[error("Thread handle is taken out of struct")]
44    ThreadHandleIsTaken,
45    #[error("Failed to post thread message")]
46    FailedToPostThreadMessage,
47    #[error("Stopped handler error: {0}")]
48    StoppedHandlerError(E),
49    #[error("Windows capture error: {0}")]
50    GraphicsCaptureApiError(#[from] GraphicsCaptureApiError<E>),
51}
52
53/// Used to control the capture session
54pub struct CaptureControl<T: GraphicsCaptureApiHandler + Send + 'static, E> {
55    thread_handle: Option<JoinHandle<Result<(), GraphicsCaptureApiError<E>>>>,
56    halt_handle: Arc<AtomicBool>,
57    callback: Arc<Mutex<T>>,
58}
59
60impl<T: GraphicsCaptureApiHandler + Send + 'static, E> CaptureControl<T, E> {
61    /// Creates a new Capture Control struct.
62    ///
63    /// # Arguments
64    ///
65    /// * `thread_handle` - The join handle for the capture thread.
66    /// * `halt_handle` - The atomic boolean used to pause the capture thread.
67    /// * `callback` - The mutex-protected callback struct used to call struct methods directly.
68    ///
69    /// # Returns
70    ///
71    /// The newly created CaptureControl struct.
72    #[must_use]
73    #[inline]
74    pub const fn new(
75        thread_handle: JoinHandle<Result<(), GraphicsCaptureApiError<E>>>,
76        halt_handle: Arc<AtomicBool>,
77        callback: Arc<Mutex<T>>,
78    ) -> Self {
79        Self {
80            thread_handle: Some(thread_handle),
81            halt_handle,
82            callback,
83        }
84    }
85
86    /// Checks to see if the capture thread is finished.
87    ///
88    /// # Returns
89    ///
90    /// `true` if the capture thread is finished, `false` otherwise.
91    #[must_use]
92    #[inline]
93    pub fn is_finished(&self) -> bool {
94        self.thread_handle
95            .as_ref()
96            .map_or(true, std::thread::JoinHandle::is_finished)
97    }
98
99    /// Gets the join handle for the capture thread.
100    ///
101    /// # Returns
102    ///
103    /// The join handle for the capture thread.
104    #[must_use]
105    #[inline]
106    pub fn into_thread_handle(self) -> JoinHandle<Result<(), GraphicsCaptureApiError<E>>> {
107        self.thread_handle.unwrap()
108    }
109
110    /// Gets the halt handle used to pause the capture thread.
111    ///
112    /// # Returns
113    ///
114    /// The halt handle used to pause the capture thread.
115    #[must_use]
116    #[inline]
117    pub fn halt_handle(&self) -> Arc<AtomicBool> {
118        self.halt_handle.clone()
119    }
120
121    /// Gets the callback struct used to call struct methods directly.
122    ///
123    /// # Returns
124    ///
125    /// The callback struct used to call struct methods directly.
126    #[must_use]
127    #[inline]
128    pub fn callback(&self) -> Arc<Mutex<T>> {
129        self.callback.clone()
130    }
131
132    /// Waits until the capturing thread stops.
133    ///
134    /// # Returns
135    ///
136    /// `Ok(())` if the capturing thread stops successfully, an error otherwise.
137    #[inline]
138    pub fn wait(mut self) -> Result<(), CaptureControlError<E>> {
139        if let Some(thread_handle) = self.thread_handle.take() {
140            match thread_handle.join() {
141                Ok(result) => result?,
142                Err(_) => {
143                    return Err(CaptureControlError::FailedToJoinThread);
144                }
145            }
146        } else {
147            return Err(CaptureControlError::ThreadHandleIsTaken);
148        }
149
150        Ok(())
151    }
152
153    /// Gracefully stops the capture thread.
154    ///
155    /// # Returns
156    ///
157    /// `Ok(())` if the capture thread stops successfully, an error otherwise.
158    #[inline]
159    pub fn stop(mut self) -> Result<(), CaptureControlError<E>> {
160        self.halt_handle.store(true, atomic::Ordering::Relaxed);
161
162        if let Some(thread_handle) = self.thread_handle.take() {
163            let handle = thread_handle.as_raw_handle();
164            let handle = HANDLE(handle);
165            let therad_id = unsafe { GetThreadId(handle) };
166
167            loop {
168                match unsafe {
169                    PostThreadMessageW(therad_id, WM_QUIT, WPARAM::default(), LPARAM::default())
170                } {
171                    Ok(()) => break,
172                    Err(e) => {
173                        if thread_handle.is_finished() {
174                            break;
175                        }
176
177                        if e.code().0 != -2_147_023_452 {
178                            Err(e).map_err(|_| CaptureControlError::FailedToPostThreadMessage)?;
179                        }
180                    }
181                }
182            }
183
184            match thread_handle.join() {
185                Ok(result) => result?,
186                Err(_) => {
187                    return Err(CaptureControlError::FailedToJoinThread);
188                }
189            }
190        } else {
191            return Err(CaptureControlError::ThreadHandleIsTaken);
192        }
193
194        Ok(())
195    }
196}
197
198#[derive(thiserror::Error, Eq, PartialEq, Clone, Debug)]
199pub enum GraphicsCaptureApiError<E> {
200    #[error("Failed to join thread")]
201    FailedToJoinThread,
202    #[error("Failed to initialize WinRT")]
203    FailedToInitWinRT,
204    #[error("Failed to create dispatcher queue controller")]
205    FailedToCreateDispatcherQueueController,
206    #[error("Failed to shutdown dispatcher queue")]
207    FailedToShutdownDispatcherQueue,
208    #[error("Failed to set dispatcher queue completed handler")]
209    FailedToSetDispatcherQueueCompletedHandler,
210    #[error("Failed to convert item to GraphicsCaptureItem")]
211    ItemConvertFailed,
212    #[error("DirectX error: {0}")]
213    DirectXError(#[from] d3d11::Error),
214    #[error("Graphics capture error: {0}")]
215    GraphicsCaptureApiError(graphics_capture_api::Error),
216    #[error("New handler error: {0}")]
217    NewHandlerError(E),
218    #[error("Frame handler error: {0}")]
219    FrameHandlerError(E),
220}
221
222/// A struct representing the context of the capture handler.
223pub struct Context<Flags> {
224    /// The flags that are gotten from the settings.
225    pub flags: Flags,
226    /// The direct3d device and context.
227    pub device: ID3D11Device,
228    /// The direct3d device context.
229    pub device_context: ID3D11DeviceContext,
230}
231
232/// A trait representing a graphics capture handler.
233pub trait GraphicsCaptureApiHandler: Sized {
234    /// The type of flags used to get the values from the settings.
235    type Flags;
236
237    /// The type of error that can occur during capture, the error will be returned from `CaptureControl` and `start` functions.
238    type Error: Send + Sync;
239
240    /// Starts the capture and takes control of the current thread.
241    ///
242    /// # Arguments
243    ///
244    /// * `settings` - The capture settings.
245    ///
246    /// # Returns
247    ///
248    /// Returns `Ok(())` if the capture was successful, otherwise returns an error of type `GraphicsCaptureApiError`.
249    #[inline]
250    fn start<T: TryInto<GraphicsCaptureItem>>(
251        settings: Settings<Self::Flags, T>,
252    ) -> Result<(), GraphicsCaptureApiError<Self::Error>>
253    where
254        Self: Send + 'static,
255        <Self as GraphicsCaptureApiHandler>::Flags: Send,
256    {
257        // Initialize WinRT
258        unsafe {
259            RoInitialize(RO_INIT_MULTITHREADED)
260                .map_err(|_| GraphicsCaptureApiError::FailedToInitWinRT)?;
261        };
262
263        // Create a dispatcher queue for the current thread
264        let options = DispatcherQueueOptions {
265            dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
266            threadType: DQTYPE_THREAD_CURRENT,
267            apartmentType: DQTAT_COM_NONE,
268        };
269        let controller = unsafe {
270            CreateDispatcherQueueController(options)
271                .map_err(|_| GraphicsCaptureApiError::FailedToCreateDispatcherQueueController)?
272        };
273
274        // Get current thread ID
275        let thread_id = unsafe { GetCurrentThreadId() };
276
277        // Create direct3d device and context
278        let (d3d_device, d3d_device_context) = create_d3d_device()?;
279
280        // Start capture
281        let result = Arc::new(Mutex::new(None));
282
283        let ctx = Context {
284            flags: settings.flags,
285            device: d3d_device.clone(),
286            device_context: d3d_device_context.clone(),
287        };
288
289        let callback = Arc::new(Mutex::new(
290            Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
291        ));
292
293        let item = settings
294            .item
295            .try_into()
296            .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
297
298        let mut capture = GraphicsCaptureApi::new(
299            d3d_device,
300            d3d_device_context,
301            item,
302            callback,
303            settings.cursor_capture,
304            settings.draw_border,
305            settings.color_format,
306            thread_id,
307            result.clone(),
308        )
309        .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
310        capture
311            .start_capture()
312            .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
313
314        // Message loop
315        let mut message = MSG::default();
316        unsafe {
317            while GetMessageW(&mut message, None, 0, 0).as_bool() {
318                let _ = TranslateMessage(&message);
319                DispatchMessageW(&message);
320            }
321        }
322
323        // Shutdown dispatcher queue
324        let async_action = controller
325            .ShutdownQueueAsync()
326            .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
327        async_action
328            .SetCompleted(&AsyncActionCompletedHandler::new(
329                move |_, _| -> Result<(), windows::core::Error> {
330                    unsafe { PostQuitMessage(0) };
331                    Ok(())
332                },
333            ))
334            .map_err(|_| GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler)?;
335
336        // Final message loop
337        let mut message = MSG::default();
338        unsafe {
339            while GetMessageW(&mut message, None, 0, 0).as_bool() {
340                let _ = TranslateMessage(&message);
341                DispatchMessageW(&message);
342            }
343        }
344
345        // Stop capture
346        capture.stop_capture();
347
348        // Uninitialize WinRT
349        unsafe { RoUninitialize() };
350
351        // Check handler result
352        let result = result.lock().take();
353        if let Some(e) = result {
354            return Err(GraphicsCaptureApiError::FrameHandlerError(e));
355        }
356
357        Ok(())
358    }
359
360    /// Starts the capture without taking control of the current thread.
361    ///
362    /// # Arguments
363    ///
364    /// * `settings` - The capture settings.
365    ///
366    /// # Returns
367    ///
368    /// Returns `Ok(CaptureControl)` if the capture was successful, otherwise returns an error of type `GraphicsCaptureApiError`.
369    #[inline]
370    fn start_free_threaded<T: TryInto<GraphicsCaptureItem> + Send + 'static>(
371        settings: Settings<Self::Flags, T>,
372    ) -> Result<CaptureControl<Self, Self::Error>, GraphicsCaptureApiError<Self::Error>>
373    where
374        Self: Send + 'static,
375        <Self as GraphicsCaptureApiHandler>::Flags: Send,
376    {
377        let (halt_sender, halt_receiver) = mpsc::channel::<Arc<AtomicBool>>();
378        let (callback_sender, callback_receiver) = mpsc::channel::<Arc<Mutex<Self>>>();
379
380        let thread_handle = thread::spawn(
381            move || -> Result<(), GraphicsCaptureApiError<Self::Error>> {
382                // Initialize WinRT
383                unsafe {
384                    RoInitialize(RO_INIT_MULTITHREADED)
385                        .map_err(|_| GraphicsCaptureApiError::FailedToInitWinRT)?;
386                };
387
388                // Create a dispatcher queue for the current thread
389                let options = DispatcherQueueOptions {
390                    dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
391                    threadType: DQTYPE_THREAD_CURRENT,
392                    apartmentType: DQTAT_COM_NONE,
393                };
394                let controller = unsafe {
395                    CreateDispatcherQueueController(options).map_err(|_| {
396                        GraphicsCaptureApiError::FailedToCreateDispatcherQueueController
397                    })?
398                };
399
400                // Get current thread ID
401                let thread_id = unsafe { GetCurrentThreadId() };
402
403                // Create direct3d device and context
404                let (d3d_device, d3d_device_context) = create_d3d_device()?;
405
406                // Start capture
407                let result = Arc::new(Mutex::new(None));
408
409                let ctx = Context {
410                    flags: settings.flags,
411                    device: d3d_device.clone(),
412                    device_context: d3d_device_context.clone(),
413                };
414
415                let callback = Arc::new(Mutex::new(
416                    Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
417                ));
418
419                let item = settings
420                    .item
421                    .try_into()
422                    .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
423
424                let mut capture = GraphicsCaptureApi::new(
425                    d3d_device,
426                    d3d_device_context,
427                    item,
428                    callback.clone(),
429                    settings.cursor_capture,
430                    settings.draw_border,
431                    settings.color_format,
432                    thread_id,
433                    result.clone(),
434                )
435                .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
436                capture
437                    .start_capture()
438                    .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
439
440                // Send halt handle
441                let halt_handle = capture.halt_handle();
442                halt_sender.send(halt_handle).unwrap();
443
444                // Send callback
445                callback_sender.send(callback).unwrap();
446
447                // Message loop
448                let mut message = MSG::default();
449                unsafe {
450                    while GetMessageW(&mut message, None, 0, 0).as_bool() {
451                        let _ = TranslateMessage(&message);
452                        DispatchMessageW(&message);
453                    }
454                }
455
456                // Shutdown dispatcher queue
457                let async_action = controller
458                    .ShutdownQueueAsync()
459                    .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
460
461                async_action
462                    .SetCompleted(&AsyncActionCompletedHandler::new(
463                        move |_, _| -> Result<(), windows::core::Error> {
464                            unsafe { PostQuitMessage(0) };
465                            Ok(())
466                        },
467                    ))
468                    .map_err(|_| {
469                        GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler
470                    })?;
471
472                // Final message loop
473                let mut message = MSG::default();
474                unsafe {
475                    while GetMessageW(&mut message, None, 0, 0).as_bool() {
476                        let _ = TranslateMessage(&message);
477                        DispatchMessageW(&message);
478                    }
479                }
480
481                // Stop capture
482                capture.stop_capture();
483
484                // Uninitialize WinRT
485                unsafe { RoUninitialize() };
486
487                // Check handler result
488                let result = result.lock().take();
489                if let Some(e) = result {
490                    return Err(GraphicsCaptureApiError::FrameHandlerError(e));
491                }
492
493                Ok(())
494            },
495        );
496
497        let Ok(halt_handle) = halt_receiver.recv() else {
498            match thread_handle.join() {
499                Ok(result) => return Err(result.err().unwrap()),
500                Err(_) => {
501                    return Err(GraphicsCaptureApiError::FailedToJoinThread);
502                }
503            }
504        };
505
506        let Ok(callback) = callback_receiver.recv() else {
507            match thread_handle.join() {
508                Ok(result) => return Err(result.err().unwrap()),
509                Err(_) => {
510                    return Err(GraphicsCaptureApiError::FailedToJoinThread);
511                }
512            }
513        };
514
515        Ok(CaptureControl::new(thread_handle, halt_handle, callback))
516    }
517
518    /// Function that will be called to create the struct. The flags can be passed from settings.
519    ///
520    /// # Arguments
521    ///
522    /// * `flags` - The flags used to create the struct.
523    ///
524    /// # Returns
525    ///
526    /// Returns `Ok(Self)` if the struct creation was successful, otherwise returns an error of type `Self::Error`.
527    fn new(ctx: Context<Self::Flags>) -> Result<Self, Self::Error>;
528
529    /// Called every time a new frame is available.
530    ///
531    /// # Arguments
532    ///
533    /// * `frame` - A mutable reference to the captured frame.
534    /// * `capture_control` - The internal capture control.
535    ///
536    /// # Returns
537    ///
538    /// Returns `Ok(())` if the frame processing was successful, otherwise returns an error of type `Self::Error`.
539    fn on_frame_arrived(
540        &mut self,
541        frame: &mut Frame,
542        capture_control: InternalCaptureControl,
543    ) -> Result<(), Self::Error>;
544
545    /// Optional handler called when the capture item (usually a window) closes.
546    ///
547    /// # Returns
548    ///
549    /// Returns `Ok(())` if the handler execution was successful, otherwise returns an error of type `Self::Error`.
550    #[inline]
551    fn on_closed(&mut self) -> Result<(), Self::Error> {
552        Ok(())
553    }
554}