windows_capture/
capture.rs

1use std::mem;
2use std::os::windows::prelude::AsRawHandle;
3use std::sync::atomic::{self, AtomicBool};
4use std::sync::{Arc, OnceLock, mpsc};
5use std::thread::{self, JoinHandle};
6
7use parking_lot::Mutex;
8use windows::Win32::Foundation::{HANDLE, LPARAM, S_FALSE, WPARAM};
9use windows::Win32::Graphics::Direct3D11::{ID3D11Device, ID3D11DeviceContext};
10use windows::Win32::System::Com::CoIncrementMTAUsage;
11use windows::Win32::System::Threading::{GetCurrentThreadId, GetThreadId};
12use windows::Win32::System::WinRT::{
13    CreateDispatcherQueueController, DQTAT_COM_NONE, DQTYPE_THREAD_CURRENT, DispatcherQueueOptions,
14    RO_INIT_MULTITHREADED, RoInitialize,
15};
16use windows::Win32::UI::WindowsAndMessaging::{
17    DispatchMessageW, GetMessageW, MSG, PostQuitMessage, PostThreadMessageW, TranslateMessage,
18    WM_QUIT,
19};
20use windows::core::Result as WindowsResult;
21use windows_future::AsyncActionCompletedHandler;
22
23use crate::d3d11::{self, create_d3d_device};
24use crate::frame::Frame;
25use crate::graphics_capture_api::{self, GraphicsCaptureApi, InternalCaptureControl};
26use crate::settings::{Settings, TryIntoCaptureItemWithType};
27
28#[derive(thiserror::Error, Debug)]
29pub enum CaptureControlError<E> {
30    #[error("Failed to join thread")]
31    FailedToJoinThread,
32    #[error("Thread handle is taken out of the struct")]
33    ThreadHandleIsTaken,
34    #[error("Failed to post thread message")]
35    FailedToPostThreadMessage,
36    #[error("Stopped handler error: {0}")]
37    StoppedHandlerError(E),
38    #[error("Windows capture error: {0}")]
39    GraphicsCaptureApiError(#[from] GraphicsCaptureApiError<E>),
40}
41
42/// Used to control the capture session
43pub struct CaptureControl<T: GraphicsCaptureApiHandler + Send + 'static, E> {
44    thread_handle: Option<JoinHandle<Result<(), GraphicsCaptureApiError<E>>>>,
45    halt_handle: Arc<AtomicBool>,
46    callback: Arc<Mutex<T>>,
47}
48
49impl<T: GraphicsCaptureApiHandler + Send + 'static, E> CaptureControl<T, E> {
50    /// Creates a new Capture Control struct.
51    ///
52    /// # Arguments
53    ///
54    /// * `thread_handle` - The join handle for the capture thread.
55    /// * `halt_handle` - The atomic boolean used to pause the capture thread.
56    /// * `callback` - The mutex-protected callback struct used to call struct
57    ///   methods directly.
58    ///
59    /// # Returns
60    ///
61    /// The newly created `CaptureControl` struct.
62    #[must_use]
63    #[inline]
64    pub const fn new(
65        thread_handle: JoinHandle<Result<(), GraphicsCaptureApiError<E>>>,
66        halt_handle: Arc<AtomicBool>,
67        callback: Arc<Mutex<T>>,
68    ) -> Self {
69        Self { thread_handle: Some(thread_handle), halt_handle, callback }
70    }
71
72    /// Checks to see if the capture thread is finished.
73    ///
74    /// # Returns
75    ///
76    /// `true` if the capture thread is finished, `false` otherwise.
77    #[must_use]
78    #[inline]
79    pub fn is_finished(&self) -> bool {
80        self.thread_handle.as_ref().is_none_or(std::thread::JoinHandle::is_finished)
81    }
82
83    /// Gets the join handle for the capture thread.
84    ///
85    /// # Returns
86    ///
87    /// The join handle for the capture thread.
88    #[must_use]
89    #[inline]
90    pub fn into_thread_handle(self) -> JoinHandle<Result<(), GraphicsCaptureApiError<E>>> {
91        self.thread_handle.unwrap()
92    }
93
94    /// Gets the halt handle used to pause the capture thread.
95    ///
96    /// # Returns
97    ///
98    /// The halt handle used to pause the capture thread.
99    #[must_use]
100    #[inline]
101    pub fn halt_handle(&self) -> Arc<AtomicBool> {
102        self.halt_handle.clone()
103    }
104
105    /// Gets the callback struct used to call struct methods directly.
106    ///
107    /// # Returns
108    ///
109    /// The callback struct used to call struct methods directly.
110    #[must_use]
111    #[inline]
112    pub fn callback(&self) -> Arc<Mutex<T>> {
113        self.callback.clone()
114    }
115
116    /// Waits for the capturing thread to stop.
117    ///
118    /// # Returns
119    ///
120    /// `Ok(())` if the capturing thread stops successfully, an error otherwise.
121    #[inline]
122    pub fn wait(mut self) -> Result<(), CaptureControlError<E>> {
123        if let Some(thread_handle) = self.thread_handle.take() {
124            match thread_handle.join() {
125                Ok(result) => result?,
126                Err(_) => {
127                    return Err(CaptureControlError::FailedToJoinThread);
128                }
129            }
130        } else {
131            return Err(CaptureControlError::ThreadHandleIsTaken);
132        }
133
134        Ok(())
135    }
136
137    /// Gracefully stops the capture thread.
138    ///
139    /// # Returns
140    ///
141    /// `Ok(())` if the capture thread stops successfully, an error otherwise.
142    #[inline]
143    pub fn stop(mut self) -> Result<(), CaptureControlError<E>> {
144        self.halt_handle.store(true, atomic::Ordering::Relaxed);
145
146        if let Some(thread_handle) = self.thread_handle.take() {
147            let handle = thread_handle.as_raw_handle();
148            let handle = HANDLE(handle);
149            let thread_id = unsafe { GetThreadId(handle) };
150
151            loop {
152                match unsafe {
153                    PostThreadMessageW(thread_id, WM_QUIT, WPARAM::default(), LPARAM::default())
154                } {
155                    Ok(()) => break,
156                    Err(e) => {
157                        if thread_handle.is_finished() {
158                            break;
159                        }
160
161                        if e.code().0 != -2_147_023_452 {
162                            Err(e).map_err(|_| CaptureControlError::FailedToPostThreadMessage)?;
163                        }
164                    }
165                }
166            }
167
168            match thread_handle.join() {
169                Ok(result) => result?,
170                Err(_) => {
171                    return Err(CaptureControlError::FailedToJoinThread);
172                }
173            }
174        } else {
175            return Err(CaptureControlError::ThreadHandleIsTaken);
176        }
177
178        Ok(())
179    }
180}
181
182#[derive(thiserror::Error, Eq, PartialEq, Clone, Debug)]
183pub enum GraphicsCaptureApiError<E> {
184    #[error("Failed to join thread")]
185    FailedToJoinThread,
186    #[error("Failed to initialize WinRT")]
187    FailedToInitWinRT,
188    #[error("Failed to create dispatcher queue controller")]
189    FailedToCreateDispatcherQueueController,
190    #[error("Failed to shut down dispatcher queue")]
191    FailedToShutdownDispatcherQueue,
192    #[error("Failed to set dispatcher queue completed handler")]
193    FailedToSetDispatcherQueueCompletedHandler,
194    #[error("Failed to convert item to `GraphicsCaptureItem`")]
195    ItemConvertFailed,
196    #[error("DirectX error: {0}")]
197    DirectXError(#[from] d3d11::Error),
198    #[error("Graphics capture error: {0}")]
199    GraphicsCaptureApiError(graphics_capture_api::Error),
200    #[error("New handler error: {0}")]
201    NewHandlerError(E),
202    #[error("Frame handler error: {0}")]
203    FrameHandlerError(E),
204}
205
206/// A struct representing the context of the capture handler.
207pub struct Context<Flags> {
208    /// The flags that are retrieved from the settings.
209    pub flags: Flags,
210    /// The Direct3D device.
211    pub device: ID3D11Device,
212    /// The Direct3D device context.
213    pub device_context: ID3D11DeviceContext,
214}
215
216/// A trait representing a graphics capture handler.
217pub trait GraphicsCaptureApiHandler: Sized {
218    /// The type of flags used to get the values from the settings.
219    type Flags;
220
221    /// The type of error that can occur during capture. The error will be returned from the `CaptureControl` and `start` functions.
222    type Error: Send + Sync;
223
224    /// Starts the capture and takes control of the current thread.
225    ///
226    /// # Arguments
227    ///
228    /// * `settings` - The capture settings.
229    ///
230    /// # Returns
231    ///
232    /// Returns `Ok(())` if the capture was successful; otherwise, it returns an error of type `GraphicsCaptureApiError`.
233    #[inline]
234    fn start<T: TryIntoCaptureItemWithType>(
235        settings: Settings<Self::Flags, T>,
236    ) -> Result<(), GraphicsCaptureApiError<Self::Error>>
237    where
238        Self: Send + 'static,
239        <Self as GraphicsCaptureApiHandler>::Flags: Send,
240    {
241        // Initialize WinRT
242        static INIT_MTA: OnceLock<()> = OnceLock::new();
243        INIT_MTA.get_or_init(|| {
244            unsafe {
245                CoIncrementMTAUsage().expect("Failed to increment MTA usage");
246            };
247        });
248
249        match unsafe { RoInitialize(RO_INIT_MULTITHREADED) } {
250            Ok(_) => (),
251            Err(e) => {
252                if e.code() == S_FALSE {
253                    // Already initialized
254                } else {
255                    return Err(GraphicsCaptureApiError::FailedToInitWinRT);
256                }
257            }
258        }
259
260        // Create a dispatcher queue for the current thread
261        let options = DispatcherQueueOptions {
262            dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
263            threadType: DQTYPE_THREAD_CURRENT,
264            apartmentType: DQTAT_COM_NONE,
265        };
266        let controller = unsafe {
267            CreateDispatcherQueueController(options)
268                .map_err(|_| GraphicsCaptureApiError::FailedToCreateDispatcherQueueController)?
269        };
270
271        // Get current thread ID
272        let thread_id = unsafe { GetCurrentThreadId() };
273
274        // Create direct3d device and context
275        let (d3d_device, d3d_device_context) = create_d3d_device()?;
276
277        // Start capture
278        let result = Arc::new(Mutex::new(None));
279
280        let ctx = Context {
281            flags: settings.flags,
282            device: d3d_device.clone(),
283            device_context: d3d_device_context.clone(),
284        };
285
286        let callback =
287            Arc::new(Mutex::new(Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?));
288
289        // Convert the item into a GraphicsCaptureItem and its type
290        let (item, item_type) = settings
291            .item
292            .try_into_capture_item()
293            .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
294
295        let mut capture = GraphicsCaptureApi::new(
296            d3d_device,
297            d3d_device_context,
298            item,
299            item_type,
300            callback,
301            settings.cursor_capture_settings,
302            settings.draw_border_settings,
303            settings.secondary_window_settings,
304            settings.minimum_update_interval_settings,
305            settings.dirty_region_settings,
306            settings.color_format,
307            thread_id,
308            result.clone(),
309        )
310        .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
311        capture.start_capture().map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
312
313        // Message loop
314        let mut message = MSG::default();
315        unsafe {
316            while GetMessageW(&mut message, None, 0, 0).as_bool() {
317                let _ = TranslateMessage(&message);
318                DispatchMessageW(&message);
319            }
320        }
321
322        // Shutdown dispatcher queue
323        let async_action = controller
324            .ShutdownQueueAsync()
325            .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
326
327        async_action
328            .SetCompleted(&AsyncActionCompletedHandler::new(move |_, _| -> WindowsResult<()> {
329                unsafe { PostQuitMessage(0) };
330                Ok(())
331            }))
332            .map_err(|_| GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler)?;
333
334        // Final message loop
335        let mut message = MSG::default();
336        unsafe {
337            while GetMessageW(&mut message, None, 0, 0).as_bool() {
338                let _ = TranslateMessage(&message);
339                DispatchMessageW(&message);
340            }
341        }
342
343        // Stop capture
344        capture.stop_capture();
345
346        // Uninitialize WinRT
347        // unsafe { RoUninitialize() }; // Not sure if this is needed here
348
349        // Check handler result
350        let result = result.lock().take();
351        if let Some(e) = result {
352            return Err(GraphicsCaptureApiError::FrameHandlerError(e));
353        }
354
355        Ok(())
356    }
357
358    /// Starts the capture without taking control of the current thread.
359    ///
360    /// # Arguments
361    ///
362    /// * `settings` - The capture settings.
363    ///
364    /// # Returns
365    ///
366    /// Returns a `Result` containing the `CaptureControl` if the capture was successful; otherwise, it returns an error of type `GraphicsCaptureApiError`.
367    #[inline]
368    fn start_free_threaded<T: TryIntoCaptureItemWithType + Send + 'static>(
369        settings: Settings<Self::Flags, T>,
370    ) -> Result<CaptureControl<Self, Self::Error>, GraphicsCaptureApiError<Self::Error>>
371    where
372        Self: Send + 'static,
373        <Self as GraphicsCaptureApiHandler>::Flags: Send,
374    {
375        let (halt_sender, halt_receiver) = mpsc::channel::<Arc<AtomicBool>>();
376        let (callback_sender, callback_receiver) = mpsc::channel::<Arc<Mutex<Self>>>();
377
378        let thread_handle =
379            thread::spawn(move || -> Result<(), GraphicsCaptureApiError<Self::Error>> {
380                // Initialize WinRT
381                static INIT_MTA: OnceLock<()> = OnceLock::new();
382                INIT_MTA.get_or_init(|| {
383                    unsafe {
384                        CoIncrementMTAUsage().expect("Failed to increment MTA usage");
385                    };
386                });
387
388                match unsafe { RoInitialize(RO_INIT_MULTITHREADED) } {
389                    Ok(_) => (),
390                    Err(e) => {
391                        if e.code() == S_FALSE {
392                            // Already initialized
393                        } else {
394                            return Err(GraphicsCaptureApiError::FailedToInitWinRT);
395                        }
396                    }
397                }
398
399                // Create a dispatcher queue for the current thread
400                let options = DispatcherQueueOptions {
401                    dwSize: u32::try_from(mem::size_of::<DispatcherQueueOptions>()).unwrap(),
402                    threadType: DQTYPE_THREAD_CURRENT,
403                    apartmentType: DQTAT_COM_NONE,
404                };
405                let controller = unsafe {
406                    CreateDispatcherQueueController(options).map_err(|_| {
407                        GraphicsCaptureApiError::FailedToCreateDispatcherQueueController
408                    })?
409                };
410
411                // Get current thread ID
412                let thread_id = unsafe { GetCurrentThreadId() };
413
414                // Create direct3d device and context
415                let (d3d_device, d3d_device_context) = create_d3d_device()?;
416
417                // Start capture
418                let result = Arc::new(Mutex::new(None));
419
420                let ctx = Context {
421                    flags: settings.flags,
422                    device: d3d_device.clone(),
423                    device_context: d3d_device_context.clone(),
424                };
425
426                let callback = Arc::new(Mutex::new(
427                    Self::new(ctx).map_err(GraphicsCaptureApiError::NewHandlerError)?,
428                ));
429
430                // Convert the item into a GraphicsCaptureItem and its type
431                let (item, item_type) = settings
432                    .item
433                    .try_into_capture_item()
434                    .map_err(|_| GraphicsCaptureApiError::ItemConvertFailed)?;
435
436                let mut capture = GraphicsCaptureApi::new(
437                    d3d_device,
438                    d3d_device_context,
439                    item,
440                    item_type,
441                    callback.clone(),
442                    settings.cursor_capture_settings,
443                    settings.draw_border_settings,
444                    settings.secondary_window_settings,
445                    settings.minimum_update_interval_settings,
446                    settings.dirty_region_settings,
447                    settings.color_format,
448                    thread_id,
449                    result.clone(),
450                )
451                .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
452
453                capture
454                    .start_capture()
455                    .map_err(GraphicsCaptureApiError::GraphicsCaptureApiError)?;
456
457                // Send halt handle
458                let halt_handle = capture.halt_handle();
459                halt_sender.send(halt_handle).unwrap();
460
461                // Send callback
462                callback_sender.send(callback).unwrap();
463
464                // Message loop
465                let mut message = MSG::default();
466                unsafe {
467                    while GetMessageW(&mut message, None, 0, 0).as_bool() {
468                        let _ = TranslateMessage(&message);
469                        DispatchMessageW(&message);
470                    }
471                }
472
473                // Shutdown dispatcher queue
474                let async_action = controller
475                    .ShutdownQueueAsync()
476                    .map_err(|_| GraphicsCaptureApiError::FailedToShutdownDispatcherQueue)?;
477
478                async_action
479                    .SetCompleted(&AsyncActionCompletedHandler::new(
480                        move |_, _| -> Result<(), windows::core::Error> {
481                            unsafe { PostQuitMessage(0) };
482                            Ok(())
483                        },
484                    ))
485                    .map_err(|_| {
486                        GraphicsCaptureApiError::FailedToSetDispatcherQueueCompletedHandler
487                    })?;
488
489                // Final message loop
490                let mut message = MSG::default();
491                unsafe {
492                    while GetMessageW(&mut message, None, 0, 0).as_bool() {
493                        let _ = TranslateMessage(&message);
494                        DispatchMessageW(&message);
495                    }
496                }
497
498                // Stop capture
499                capture.stop_capture();
500
501                // Uninitialize WinRT
502                // unsafe { RoUninitialize() }; // Not sure if this is needed here
503
504                // Check handler result
505                let result = result.lock().take();
506                if let Some(e) = result {
507                    return Err(GraphicsCaptureApiError::FrameHandlerError(e));
508                }
509
510                Ok(())
511            });
512
513        let Ok(halt_handle) = halt_receiver.recv() else {
514            match thread_handle.join() {
515                Ok(result) => return Err(result.err().unwrap()),
516                Err(_) => {
517                    return Err(GraphicsCaptureApiError::FailedToJoinThread);
518                }
519            }
520        };
521
522        let Ok(callback) = callback_receiver.recv() else {
523            match thread_handle.join() {
524                Ok(result) => return Err(result.err().unwrap()),
525                Err(_) => {
526                    return Err(GraphicsCaptureApiError::FailedToJoinThread);
527                }
528            }
529        };
530
531        Ok(CaptureControl::new(thread_handle, halt_handle, callback))
532    }
533
534    /// Function that will be called to create the struct. The flags can be
535    /// passed from settings.
536    ///
537    /// # Arguments
538    ///
539    /// * `flags` - The flags used to create the struct.
540    ///
541    /// # Returns
542    ///
543    /// Returns `Ok(Self)` if the struct creation was successful; otherwise, it returns an error of type `Self::Error`.
544    fn new(ctx: Context<Self::Flags>) -> Result<Self, Self::Error>;
545
546    /// Called every time a new frame is available.
547    ///
548    /// # Arguments
549    ///
550    /// * `frame` - A mutable reference to the captured frame.
551    /// * `capture_control` - The internal capture control.
552    ///
553    /// # Returns
554    ///
555    /// Returns `Ok(())` if the frame processing was successful; otherwise, it returns an error of type `Self::Error`.
556    fn on_frame_arrived(
557        &mut self,
558        frame: &mut Frame,
559        capture_control: InternalCaptureControl,
560    ) -> Result<(), Self::Error>;
561
562    /// Optional handler called when the capture item (usually a window) closes.
563    ///
564    /// # Returns
565    ///
566    /// Returns `Ok(())` if the handler executed successfully; otherwise, it returns an error of type `Self::Error`.
567    #[inline]
568    fn on_closed(&mut self) -> Result<(), Self::Error> {
569        Ok(())
570    }
571}