sem_reg/reg/
monitor.rs

1use futures::{
2    channel::oneshot,
3    select,
4    stream::{FusedStream, StreamExt},
5    FutureExt,
6};
7use map_self::MapSelf;
8use serde::Deserialize;
9use std::{collections::HashMap, pin::Pin};
10use thiserror::Error;
11use windows::{
12    core::PWSTR,
13    Win32::{
14        Foundation::ERROR_INSUFFICIENT_BUFFER,
15        Security::{
16            Authorization::ConvertSidToStringSidW, GetTokenInformation, TokenUser,
17            SID_AND_ATTRIBUTES, TOKEN_QUERY,
18        },
19        System::Threading::{GetCurrentProcess, OpenProcessToken},
20    },
21};
22use windows_helpers::{dual_call, FirstCallExpectation, ResGuard};
23use winreg::enums::{HKEY_CURRENT_USER, HKEY_USERS};
24use wmi::{query::quote_and_escape_wql_str, COMLibrary, WMIConnection, WMIError, WMIResult};
25
26use super::{hkey_to_str, RegValuePath};
27
28// Alternatively, a similar implementation could use `RegNotifyChangeKeyValue()`, which may be faster that WMI.
29/// Note that, on changes in very quick succession, reading a registry value after receiving a change event for it may yield newer data than from the write that triggered the event.
30pub struct RegValueMonitor<T: Copy> {
31    _wmi_con: WMIConnection,
32    ids_of_reg_value_changes: HashMap<RegValueChange, T>,
33    event_stream: Pin<Box<dyn FusedStream<Item = WMIResult<RegValueChange>>>>,
34}
35
36impl<T: Copy> RegValueMonitor<T> {
37    pub fn new<'a, I>(reg_value_paths: I) -> Result<Self, WMIError>
38    where
39        I: IntoIterator<Item = (T, &'a RegValuePath<'a>)>,
40    {
41        let wmi_con = WMIConnection::new(COMLibrary::new()?)?;
42
43        let mut ids_of_reg_value_changes = HashMap::new();
44        let mut sid = None;
45
46        let mut query = String::from(r"SELECT * FROM RegistryValueChangeEvent WHERE");
47
48        let mut first = true;
49        for (id, reg_value_path) in reg_value_paths {
50            //TODO: See <https://github.com/ohadravid/wmi-rs/issues/86> ("Helper to resolve registry links"). Otherwise, offer `current_user_sid()` to `whoami` crate.
51            // Resolve links.
52            let (corrected_hkey, subkey_path_prefix) = match reg_value_path.hkey {
53                HKEY_CURRENT_USER => {
54                    if sid.is_none() {
55                        sid = Some(current_user_sid().map_err(|error| WMIError::HResultError {
56                            hres: error.code().0,
57                        })?);
58                    }
59                    (HKEY_USERS, sid.as_ref())
60                }
61                // (`HKEY_CLASSES_ROOT` links to `HKEY_LOCAL_MACHINE\SOFTWARE\Classes` as well as `HKEY_CURRENT_USER\SOFTWARE\Classes` in a merging way, which is why it can't be resolved here.)
62                hkey => (hkey, None),
63            };
64
65            // Make proper path.
66            let expected_reg_value_change = RegValueChange {
67                hive: hkey_to_str(corrected_hkey).to_string(),
68                key_path: if let Some(prefix) = subkey_path_prefix {
69                    prefix.to_string() + r"\" + reg_value_path.subkey_path
70                } else {
71                    reg_value_path.subkey_path.to_string()
72                },
73                value_name: reg_value_path.value_name.to_string(),
74            };
75
76            // Build query.
77            // (Parentheses aren't necessary: "When more than one logical operator is used in a statement, the OR operators are evaluated after the AND operators." [https://learn.microsoft.com/en-us/windows/win32/wmisdk/wql-sql-for-wmi])
78            if !first {
79                query.push_str(r" OR");
80            }
81
82            query.push_str(r" Hive=");
83            query.push_str(&quote_and_escape_wql_str(&expected_reg_value_change.hive));
84
85            query.push_str(r" AND KeyPath=");
86            query.push_str(&quote_and_escape_wql_str(
87                &expected_reg_value_change.key_path,
88            ));
89
90            query.push_str(r" AND ValueName=");
91            query.push_str(&quote_and_escape_wql_str(
92                &expected_reg_value_change.value_name,
93            ));
94
95            // Build `HashMap` to associate events with registry value IDs from user.
96            ids_of_reg_value_changes.insert(expected_reg_value_change, id);
97
98            first = false;
99        }
100
101        let event_stream = Box::pin(
102            wmi_con
103                .async_raw_notification::<RegValueChange>(query)?
104                .fuse(),
105        );
106
107        Ok(Self {
108            _wmi_con: wmi_con,
109            ids_of_reg_value_changes,
110            event_stream,
111        })
112    }
113
114    pub async fn next_change(&mut self) -> Option<Result<T, WMIError>> {
115        loop {
116            break match self.event_stream.next().await {
117                Some(result) => Some(match result {
118                    Ok(changed_value) => {
119                        Ok(match self.ids_of_reg_value_changes.get(&changed_value) {
120                            Some(id) => *id,
121                            // Skip unrelated nonsense, which shouldn't actually happen.
122                            None => continue,
123                        })
124                    }
125                    Err(error) => Err(error),
126                }),
127                None => None,
128            };
129        }
130    }
131
132    pub fn r#loop<F, U, E>(
133        &mut self,
134        stop_receiver: Option<oneshot::Receiver<U>>,
135        mut callback: F,
136    ) -> Result<U, MonitorLoopError<E>>
137    where
138        F: FnMut(T) -> Option<Result<U, E>>,
139        U: Default,
140    {
141        //! Send a signal to the `stop_receiver` or return `Some(...)` from the callback to stop the loop.
142        //!
143        //! # Examples
144        //! ```ignore
145        //! fn main() {
146        //!     let (stop_sender, stop_receiver) = oneshot::channel();
147        //!
148        //!     let join_handle = thread::spawn(move || {
149        //!         let mut monitor = RegValueMonitor::new([
150        //!             (NightLightRegValueId::State, &NIGHT_LIGHT_STATE_REG_VALUE_PATH),
151        //!             (NightLightRegValueId::Settings, &NIGHT_LIGHT_SETTINGS_REG_VALUE_PATH),
152        //!         ])
153        //!         .unwrap();
154        //!
155        //!         monitor.r#loop(Some(stop_receiver), |changed_value_id| {
156        //!             println!("{:?}", changed_value_id);
157        //!             None
158        //!         })
159        //!         .unwrap();
160        //!     });
161        //!
162        //!     thread::sleep(Duration::from_secs(10));
163        //!     stop_sender.send(()).unwrap();
164        //!     join_handle.join().unwrap();
165        //! }
166        //!
167        //! #[derive(Clone, Copy, Debug)]
168        //! enum NightLightRegValueId {
169        //!     State,
170        //!     Settings,
171        //! }
172        //! ```
173
174        //. With no receiver, make one, so the loop works.
175        let (_stop_sender, mut stop_receiver) = if let Some(orig_receiver) = stop_receiver {
176            (None, orig_receiver)
177        } else {
178            oneshot::channel().map_self(|(sender, receiver)| (Some(sender), receiver))
179        };
180
181        futures::executor::block_on(async {
182            loop {
183                select! {
184                    change_event = self.next_change().fuse() => {
185                        match change_event {
186                            // New change.
187                            Some(Ok(id)) => if let Some(result) = callback(id) {
188                                result.map_err(|err_value| MonitorLoopError::Other(err_value))?;
189                            },
190                            // Stream error.
191                            Some(Err(error)) => break Err(MonitorLoopError::WmiError(error)),
192                            // Stream should never be exhausted: "The `notification` method returns an iterator that waits for any incoming events resulting from the provided query. Loops reading from this iterator will not end until they are broken." (https://docs.rs/wmi/latest/wmi/#subscribing-to-event-notifications)
193                            None => unreachable!(),
194                        }
195                    },
196                    // User desires to stop loop.
197                    value = stop_receiver => break Ok(value.unwrap_or_default()),
198                }
199            }
200        })
201    }
202}
203
204#[derive(Deserialize, PartialEq, Eq, Hash, Debug)]
205#[serde(rename = "RegistryValueChangeEvent")]
206#[serde(rename_all = "PascalCase")]
207struct RegValueChange {
208    hive: String,
209    key_path: String,
210    value_name: String,
211}
212
213#[derive(Error, Debug)]
214pub enum MonitorLoopError<T> {
215    #[error("WMI error: {0}")]
216    WmiError(#[from] WMIError),
217    #[error("monitor loop error: {0}")]
218    Other(T),
219}
220
221fn current_user_sid() -> Result<String, windows::core::Error> {
222    let process_token_handle = ResGuard::with_mut_acq_and_close_handle(|handle| unsafe {
223        OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, handle)
224    })?;
225
226    let mut sid_and_attrs_buffer = Vec::<u8>::new();
227    let mut sid_and_attrs_buffer_size = 0;
228
229    dual_call(
230        FirstCallExpectation::Win32Error(ERROR_INSUFFICIENT_BUFFER),
231        |getting_buffer_size| unsafe {
232            GetTokenInformation(
233                *process_token_handle,
234                TokenUser,
235                (!getting_buffer_size).then(|| {
236                    sid_and_attrs_buffer.resize(sid_and_attrs_buffer_size as _, 0);
237                    sid_and_attrs_buffer.as_mut_ptr().cast()
238                }),
239                sid_and_attrs_buffer_size,
240                &mut sid_and_attrs_buffer_size,
241            )
242        },
243    )?;
244
245    let string_sid = unsafe {
246        ResGuard::<PWSTR>::with_mut_acq_and_local_free(|pwstr| {
247            ConvertSidToStringSidW(
248                (&*sid_and_attrs_buffer.as_ptr().cast::<SID_AND_ATTRIBUTES>()).Sid,
249                pwstr,
250            )
251        })?
252        .to_string()?
253    };
254
255    Ok(string_sid)
256}