Skip to main content

selection_capture/
windows_runtime_adapter.rs

1use crate::windows::windows_default_runtime_event_source as windows_platform_runtime_event_source;
2#[cfg(target_os = "windows")]
3use crate::windows_observer::WindowsObserverBridge;
4use crate::windows_subscriber::{
5    set_windows_native_runtime_adapter, windows_native_runtime_adapter_registered,
6};
7#[cfg(target_os = "windows")]
8use std::io::{BufRead, BufReader};
9#[cfg(target_os = "windows")]
10use std::process::{Child, ChildStdout, Command, Stdio};
11#[cfg(target_os = "windows")]
12use std::sync::{
13    atomic::{AtomicBool, Ordering},
14    Arc, Mutex as StdMutex,
15};
16use std::sync::{Mutex, OnceLock};
17#[cfg(target_os = "windows")]
18use std::thread::{self, JoinHandle};
19#[cfg(target_os = "windows")]
20use std::time::Duration;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23pub struct WindowsDefaultRuntimeAdapterState {
24    pub attached: bool,
25    pub worker_running: bool,
26    pub attach_calls: u64,
27    pub detach_calls: u64,
28    pub listener_exits: u64,
29    pub listener_restarts: u64,
30    pub listener_failures: u64,
31}
32
33pub type WindowsDefaultRuntimeEventSource = fn() -> Option<String>;
34
35#[cfg(target_os = "windows")]
36const WINDOWS_RUNTIME_EVENT_MARKER: &str = "__SC_EVENT__";
37#[cfg(target_os = "windows")]
38const WINDOWS_ATTACH_RETRY_LIMIT: u32 = 4;
39#[cfg(target_os = "windows")]
40const WINDOWS_RESTART_RETRY_LIMIT: u32 = 8;
41#[cfg(target_os = "windows")]
42const WINDOWS_RETRY_BACKOFF_BASE: Duration = Duration::from_millis(50);
43#[cfg(target_os = "windows")]
44const WINDOWS_RETRY_BACKOFF_MAX: Duration = Duration::from_millis(800);
45
46#[cfg(target_os = "windows")]
47const _: () = assert!(
48    WINDOWS_RETRY_BACKOFF_MAX.as_millis() <= u64::MAX as u128,
49    "WINDOWS_RETRY_BACKOFF_MAX exceeds u64::MAX ms"
50);
51
52#[cfg(target_os = "windows")]
53fn retry_backoff_delay(attempt: u32) -> Duration {
54    let factor = 1u64 << attempt.min(6);
55    let millis = WINDOWS_RETRY_BACKOFF_BASE
56        .as_millis()
57        .saturating_mul(u128::from(factor))
58        .min(WINDOWS_RETRY_BACKOFF_MAX.as_millis());
59    Duration::from_millis(millis as u64)
60}
61
62#[cfg(target_os = "windows")]
63const WINDOWS_RUNTIME_LISTENER_SCRIPT: &str = r#"
64Add-Type -AssemblyName UIAutomationClient
65Add-Type -AssemblyName UIAutomationTypes
66
67$handler = [System.Windows.Automation.AutomationFocusChangedEventHandler]{
68    param($sender, $eventArgs)
69    [Console]::Out.WriteLine("__SC_EVENT__")
70    [Console]::Out.Flush()
71}
72
73[System.Windows.Automation.Automation]::AddAutomationFocusChangedEventHandler($handler)
74
75try {
76    while ($true) {
77        Start-Sleep -Milliseconds 86400000
78    }
79}
80finally {
81    [System.Windows.Automation.Automation]::RemoveAutomationFocusChangedEventHandler($handler)
82}
83"#;
84
85#[cfg(target_os = "windows")]
86struct WindowsRuntimeWorker {
87    stop: Arc<AtomicBool>,
88    child: Arc<StdMutex<Option<Child>>>,
89    telemetry: Arc<WindowsWorkerTelemetry>,
90    handle: JoinHandle<()>,
91}
92
93#[cfg(target_os = "windows")]
94#[derive(Default)]
95struct WindowsWorkerTelemetry {
96    listener_exits: std::sync::atomic::AtomicU64,
97    listener_restarts: std::sync::atomic::AtomicU64,
98    listener_failures: std::sync::atomic::AtomicU64,
99}
100
101#[cfg(target_os = "windows")]
102impl WindowsWorkerTelemetry {
103    fn snapshot(&self) -> (u64, u64, u64) {
104        (
105            self.listener_exits.load(Ordering::SeqCst),
106            self.listener_restarts.load(Ordering::SeqCst),
107            self.listener_failures.load(Ordering::SeqCst),
108        )
109    }
110}
111
112#[cfg(target_os = "windows")]
113impl WindowsRuntimeWorker {
114    fn spawn() -> Option<Self> {
115        let stop = Arc::new(AtomicBool::new(false));
116        let child = Arc::new(StdMutex::new(None));
117        let telemetry = Arc::new(WindowsWorkerTelemetry::default());
118        let stdout = install_new_windows_listener(&child)?;
119        let stop_signal = Arc::clone(&stop);
120        let child_signal = Arc::clone(&child);
121        let telemetry_signal = Arc::clone(&telemetry);
122        let handle = thread::Builder::new()
123            .name("selection-capture-win-runtime".to_string())
124            .spawn(move || {
125                let mut reader = BufReader::new(stdout);
126                loop {
127                    if stop_signal.load(Ordering::SeqCst) {
128                        break;
129                    }
130
131                    let mut line = String::new();
132                    let Ok(read) = reader.read_line(&mut line) else {
133                        telemetry_signal
134                            .listener_exits
135                            .fetch_add(1, Ordering::SeqCst);
136                        if !restart_windows_listener(
137                            &child_signal,
138                            &stop_signal,
139                            &telemetry_signal,
140                            &mut reader,
141                        ) {
142                            break;
143                        }
144                        continue;
145                    };
146                    if read == 0 {
147                        telemetry_signal
148                            .listener_exits
149                            .fetch_add(1, Ordering::SeqCst);
150                        if !restart_windows_listener(
151                            &child_signal,
152                            &stop_signal,
153                            &telemetry_signal,
154                            &mut reader,
155                        ) {
156                            break;
157                        }
158                        continue;
159                    }
160
161                    if line.trim() == WINDOWS_RUNTIME_EVENT_MARKER {
162                        if let Some(source) = windows_default_runtime_event_source() {
163                            if let Some(text) = source() {
164                                let _ = WindowsObserverBridge::push_event(text);
165                            }
166                        }
167                    }
168                }
169
170                if let Ok(mut slot) = child_signal.lock() {
171                    if let Some(mut child) = slot.take() {
172                        let _ = child.kill();
173                        let _ = child.wait();
174                    }
175                }
176            })
177            .ok()?;
178        Some(Self {
179            stop,
180            child,
181            telemetry,
182            handle,
183        })
184    }
185
186    fn stop(self) -> bool {
187        self.stop.store(true, Ordering::SeqCst);
188        if let Ok(mut slot) = self.child.lock() {
189            if let Some(mut child) = slot.take() {
190                let _ = child.kill();
191                let _ = child.wait();
192            }
193        }
194        self.handle.join().is_ok()
195    }
196
197    fn telemetry_snapshot(&self) -> (u64, u64, u64) {
198        self.telemetry.snapshot()
199    }
200
201    fn is_running(&self) -> bool {
202        !self.handle.is_finished()
203    }
204}
205
206#[cfg(target_os = "windows")]
207fn spawn_windows_runtime_listener_process() -> Option<Child> {
208    Command::new("powershell")
209        .args([
210            "-NoProfile",
211            "-NoLogo",
212            "-NonInteractive",
213            "-STA",
214            "-Command",
215            WINDOWS_RUNTIME_LISTENER_SCRIPT,
216        ])
217        .stdin(Stdio::null())
218        .stdout(Stdio::piped())
219        .stderr(Stdio::null())
220        .spawn()
221        .ok()
222}
223
224#[cfg(target_os = "windows")]
225fn install_new_windows_listener(child_slot: &Arc<StdMutex<Option<Child>>>) -> Option<ChildStdout> {
226    let mut child = spawn_windows_runtime_listener_process()?;
227    let stdout = child.stdout.take()?;
228    if let Ok(mut slot) = child_slot.lock() {
229        if let Some(mut previous) = slot.replace(child) {
230            let _ = previous.kill();
231            let _ = previous.wait();
232        }
233    }
234    Some(stdout)
235}
236
237#[cfg(target_os = "windows")]
238fn restart_windows_listener(
239    child_slot: &Arc<StdMutex<Option<Child>>>,
240    stop_signal: &Arc<AtomicBool>,
241    telemetry: &Arc<WindowsWorkerTelemetry>,
242    reader: &mut BufReader<ChildStdout>,
243) -> bool {
244    for attempt in 0..WINDOWS_RESTART_RETRY_LIMIT {
245        if stop_signal.load(Ordering::SeqCst) {
246            return false;
247        }
248        telemetry.listener_restarts.fetch_add(1, Ordering::SeqCst);
249        if let Some(stdout) = install_new_windows_listener(child_slot) {
250            *reader = BufReader::new(stdout);
251            return true;
252        }
253        telemetry.listener_failures.fetch_add(1, Ordering::SeqCst);
254        thread::sleep(retry_backoff_delay(attempt));
255    }
256    false
257}
258
259#[derive(Default)]
260struct WindowsDefaultRuntimeAdapterRuntime {
261    state: WindowsDefaultRuntimeAdapterState,
262    #[cfg(target_os = "windows")]
263    worker: Option<WindowsRuntimeWorker>,
264}
265
266fn adapter_runtime() -> &'static Mutex<WindowsDefaultRuntimeAdapterRuntime> {
267    static RUNTIME: OnceLock<Mutex<WindowsDefaultRuntimeAdapterRuntime>> = OnceLock::new();
268    RUNTIME.get_or_init(|| Mutex::new(WindowsDefaultRuntimeAdapterRuntime::default()))
269}
270
271fn event_source_slot() -> &'static Mutex<Option<WindowsDefaultRuntimeEventSource>> {
272    static SOURCE: OnceLock<Mutex<Option<WindowsDefaultRuntimeEventSource>>> = OnceLock::new();
273    SOURCE.get_or_init(|| Mutex::new(None))
274}
275
276#[cfg(target_os = "windows")]
277fn windows_default_runtime_event_source() -> Option<WindowsDefaultRuntimeEventSource> {
278    event_source_slot().lock().ok().and_then(|slot| *slot)
279}
280
281fn attach_default_windows_listener(runtime: &mut WindowsDefaultRuntimeAdapterRuntime) -> bool {
282    #[cfg(target_os = "windows")]
283    {
284        if runtime.worker.is_some() {
285            return true;
286        }
287        for attempt in 0..WINDOWS_ATTACH_RETRY_LIMIT {
288            if let Some(worker) = WindowsRuntimeWorker::spawn() {
289                runtime.worker = Some(worker);
290                return true;
291            }
292            runtime.state.listener_failures += 1;
293            thread::sleep(retry_backoff_delay(attempt));
294        }
295        false
296    }
297    #[cfg(not(target_os = "windows"))]
298    {
299        let _ = runtime;
300        true
301    }
302}
303
304fn detach_default_windows_listener(runtime: &mut WindowsDefaultRuntimeAdapterRuntime) -> bool {
305    #[cfg(target_os = "windows")]
306    {
307        runtime
308            .worker
309            .take()
310            .map(|worker| worker.stop())
311            .unwrap_or(true)
312    }
313    #[cfg(not(target_os = "windows"))]
314    {
315        let _ = runtime;
316        true
317    }
318}
319
320fn default_windows_runtime_adapter(active: bool) -> bool {
321    let Ok(mut runtime) = adapter_runtime().lock() else {
322        return false;
323    };
324
325    if active {
326        if runtime.state.attached {
327            return true;
328        }
329        if !attach_default_windows_listener(&mut runtime) {
330            return false;
331        }
332        runtime.state.attached = true;
333        runtime.state.worker_running = cfg!(target_os = "windows");
334        runtime.state.attach_calls += 1;
335        return true;
336    }
337
338    if !runtime.state.attached {
339        return true;
340    }
341    if !detach_default_windows_listener(&mut runtime) {
342        return false;
343    }
344    runtime.state.attached = false;
345    runtime.state.worker_running = false;
346    runtime.state.detach_calls += 1;
347    true
348}
349
350pub fn windows_default_runtime_adapter_state() -> WindowsDefaultRuntimeAdapterState {
351    adapter_runtime()
352        .lock()
353        .map(|runtime| {
354            #[cfg(target_os = "windows")]
355            {
356                let mut state = runtime.state;
357                if let Some(worker) = runtime.worker.as_ref() {
358                    state.worker_running = state.worker_running && worker.is_running();
359                    let (listener_exits, listener_restarts, listener_failures) =
360                        worker.telemetry_snapshot();
361                    state.listener_exits = state.listener_exits.saturating_add(listener_exits);
362                    state.listener_restarts =
363                        state.listener_restarts.saturating_add(listener_restarts);
364                    state.listener_failures =
365                        state.listener_failures.saturating_add(listener_failures);
366                }
367                state
368            }
369            #[cfg(not(target_os = "windows"))]
370            {
371                runtime.state
372            }
373        })
374        .unwrap_or_default()
375}
376
377pub fn set_windows_default_runtime_event_source(source: Option<WindowsDefaultRuntimeEventSource>) {
378    if let Ok(mut slot) = event_source_slot().lock() {
379        *slot = source;
380    }
381}
382
383pub fn windows_default_runtime_event_source_registered() -> bool {
384    event_source_slot()
385        .lock()
386        .map(|slot| slot.is_some())
387        .unwrap_or(false)
388}
389
390#[cfg(test)]
391fn reset_windows_default_runtime_adapter_state() {
392    let _ = default_windows_runtime_adapter(false);
393    if let Ok(mut runtime) = adapter_runtime().lock() {
394        *runtime = WindowsDefaultRuntimeAdapterRuntime::default();
395    }
396    set_windows_default_runtime_event_source(None);
397}
398
399#[cfg(all(test, target_os = "windows"))]
400fn kill_windows_listener_for_tests() -> bool {
401    let Ok(runtime) = adapter_runtime().lock() else {
402        return false;
403    };
404    let Some(worker) = runtime.worker.as_ref() else {
405        return false;
406    };
407    let Ok(mut slot) = worker.child.lock() else {
408        return false;
409    };
410    let Some(child) = slot.as_mut() else {
411        return false;
412    };
413    child.kill().is_ok()
414}
415
416pub fn install_default_windows_runtime_adapter_if_absent() {
417    if !windows_default_runtime_event_source_registered() {
418        set_windows_default_runtime_event_source(Some(windows_platform_runtime_event_source));
419    }
420    if !windows_native_runtime_adapter_registered() {
421        set_windows_native_runtime_adapter(Some(default_windows_runtime_adapter));
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::{
429        ensure_windows_native_subscriber_hook_installed, windows_native_subscriber_stats,
430        windows_observer::windows_observer_test_lock, WindowsObserverBridge,
431    };
432
433    #[test]
434    fn installing_default_adapter_enables_lifecycle_attempt_tracking() {
435        let _guard = windows_observer_test_lock()
436            .lock()
437            .expect("test lock poisoned");
438        let _ = WindowsObserverBridge::stop();
439        WindowsObserverBridge::set_lifecycle_hook(None);
440        reset_windows_default_runtime_adapter_state();
441        set_windows_native_runtime_adapter(None);
442        set_windows_default_runtime_event_source(None);
443        ensure_windows_native_subscriber_hook_installed();
444        install_default_windows_runtime_adapter_if_absent();
445        assert!(windows_native_runtime_adapter_registered());
446        assert!(windows_default_runtime_event_source_registered());
447
448        let before = windows_native_subscriber_stats();
449        let _ = WindowsObserverBridge::start();
450        let _ = WindowsObserverBridge::stop();
451        let after = windows_native_subscriber_stats();
452
453        assert!(after.adapter_attempts >= before.adapter_attempts);
454    }
455
456    #[test]
457    fn default_adapter_state_tracks_attach_detach_idempotently() {
458        let _guard = windows_observer_test_lock()
459            .lock()
460            .expect("test lock poisoned");
461        reset_windows_default_runtime_adapter_state();
462        assert_eq!(
463            windows_default_runtime_adapter_state(),
464            WindowsDefaultRuntimeAdapterState::default()
465        );
466
467        assert!(default_windows_runtime_adapter(true));
468        assert!(default_windows_runtime_adapter(true));
469        let started = windows_default_runtime_adapter_state();
470        assert!(started.attached);
471        assert_eq!(started.worker_running, cfg!(target_os = "windows"));
472        assert_eq!(started.attach_calls, 1);
473        assert_eq!(started.detach_calls, 0);
474
475        assert!(default_windows_runtime_adapter(false));
476        assert!(default_windows_runtime_adapter(false));
477        let stopped = windows_default_runtime_adapter_state();
478        assert!(!stopped.attached);
479        assert!(!stopped.worker_running);
480        assert_eq!(stopped.attach_calls, 1);
481        assert_eq!(stopped.detach_calls, 1);
482    }
483
484    #[test]
485    #[cfg(target_os = "windows")]
486    fn retry_backoff_delay_is_bounded_exponential() {
487        assert_eq!(retry_backoff_delay(0), Duration::from_millis(50));
488        assert_eq!(retry_backoff_delay(1), Duration::from_millis(100));
489        assert_eq!(retry_backoff_delay(2), Duration::from_millis(200));
490        assert_eq!(retry_backoff_delay(4), Duration::from_millis(800));
491        assert_eq!(retry_backoff_delay(8), Duration::from_millis(800));
492    }
493
494    #[test]
495    #[cfg(target_os = "windows")]
496    fn listener_restart_updates_telemetry_after_forced_kill() {
497        let _guard = windows_observer_test_lock()
498            .lock()
499            .expect("test lock poisoned");
500        let _ = WindowsObserverBridge::stop();
501        WindowsObserverBridge::set_lifecycle_hook(None);
502        reset_windows_default_runtime_adapter_state();
503        set_windows_native_runtime_adapter(None);
504        set_windows_default_runtime_event_source(None);
505        ensure_windows_native_subscriber_hook_installed();
506        install_default_windows_runtime_adapter_if_absent();
507
508        let _ = WindowsObserverBridge::start();
509        let before = windows_default_runtime_adapter_state();
510        if !before.attached || !before.worker_running {
511            let _ = WindowsObserverBridge::stop();
512            return;
513        }
514
515        if !kill_windows_listener_for_tests() {
516            let _ = WindowsObserverBridge::stop();
517            return;
518        }
519
520        let mut after = before;
521        for _ in 0..30 {
522            std::thread::sleep(Duration::from_millis(50));
523            after = windows_default_runtime_adapter_state();
524            if after.listener_restarts > before.listener_restarts
525                || after.listener_exits > before.listener_exits
526            {
527                break;
528            }
529        }
530
531        assert!(after.listener_exits >= before.listener_exits);
532        assert!(after.listener_restarts >= before.listener_restarts);
533        let _ = WindowsObserverBridge::stop();
534    }
535}