1use futures::{
2 channel::oneshot,
3 select,
4 stream::{FusedStream, StreamExt},
5 FutureExt,
6};
7use map_self::MapSelf;
8use serde::Deserialize;
9use std::{collections::HashMap, pin::Pin};
10use thiserror::Error;
11use windows::{
12 core::PWSTR,
13 Win32::{
14 Foundation::ERROR_INSUFFICIENT_BUFFER,
15 Security::{
16 Authorization::ConvertSidToStringSidW, GetTokenInformation, TokenUser,
17 SID_AND_ATTRIBUTES, TOKEN_QUERY,
18 },
19 System::Threading::{GetCurrentProcess, OpenProcessToken},
20 },
21};
22use windows_helpers::{dual_call, FirstCallExpectation, ResGuard};
23use winreg::enums::{HKEY_CURRENT_USER, HKEY_USERS};
24use wmi::{query::quote_and_escape_wql_str, COMLibrary, WMIConnection, WMIError, WMIResult};
25
26use super::{hkey_to_str, RegValuePath};
27
28pub struct RegValueMonitor<T: Copy> {
31 _wmi_con: WMIConnection,
32 ids_of_reg_value_changes: HashMap<RegValueChange, T>,
33 event_stream: Pin<Box<dyn FusedStream<Item = WMIResult<RegValueChange>>>>,
34}
35
36impl<T: Copy> RegValueMonitor<T> {
37 pub fn new<'a, I>(reg_value_paths: I) -> Result<Self, WMIError>
38 where
39 I: IntoIterator<Item = (T, &'a RegValuePath<'a>)>,
40 {
41 let wmi_con = WMIConnection::new(COMLibrary::new()?)?;
42
43 let mut ids_of_reg_value_changes = HashMap::new();
44 let mut sid = None;
45
46 let mut query = String::from(r"SELECT * FROM RegistryValueChangeEvent WHERE");
47
48 let mut first = true;
49 for (id, reg_value_path) in reg_value_paths {
50 let (corrected_hkey, subkey_path_prefix) = match reg_value_path.hkey {
53 HKEY_CURRENT_USER => {
54 if sid.is_none() {
55 sid = Some(current_user_sid().map_err(|error| WMIError::HResultError {
56 hres: error.code().0,
57 })?);
58 }
59 (HKEY_USERS, sid.as_ref())
60 }
61 hkey => (hkey, None),
63 };
64
65 let expected_reg_value_change = RegValueChange {
67 hive: hkey_to_str(corrected_hkey).to_string(),
68 key_path: if let Some(prefix) = subkey_path_prefix {
69 prefix.to_string() + r"\" + reg_value_path.subkey_path
70 } else {
71 reg_value_path.subkey_path.to_string()
72 },
73 value_name: reg_value_path.value_name.to_string(),
74 };
75
76 if !first {
79 query.push_str(r" OR");
80 }
81
82 query.push_str(r" Hive=");
83 query.push_str("e_and_escape_wql_str(&expected_reg_value_change.hive));
84
85 query.push_str(r" AND KeyPath=");
86 query.push_str("e_and_escape_wql_str(
87 &expected_reg_value_change.key_path,
88 ));
89
90 query.push_str(r" AND ValueName=");
91 query.push_str("e_and_escape_wql_str(
92 &expected_reg_value_change.value_name,
93 ));
94
95 ids_of_reg_value_changes.insert(expected_reg_value_change, id);
97
98 first = false;
99 }
100
101 let event_stream = Box::pin(
102 wmi_con
103 .async_raw_notification::<RegValueChange>(query)?
104 .fuse(),
105 );
106
107 Ok(Self {
108 _wmi_con: wmi_con,
109 ids_of_reg_value_changes,
110 event_stream,
111 })
112 }
113
114 pub async fn next_change(&mut self) -> Option<Result<T, WMIError>> {
115 loop {
116 break match self.event_stream.next().await {
117 Some(result) => Some(match result {
118 Ok(changed_value) => {
119 Ok(match self.ids_of_reg_value_changes.get(&changed_value) {
120 Some(id) => *id,
121 None => continue,
123 })
124 }
125 Err(error) => Err(error),
126 }),
127 None => None,
128 };
129 }
130 }
131
132 pub fn r#loop<F, U, E>(
133 &mut self,
134 stop_receiver: Option<oneshot::Receiver<U>>,
135 mut callback: F,
136 ) -> Result<U, MonitorLoopError<E>>
137 where
138 F: FnMut(T) -> Option<Result<U, E>>,
139 U: Default,
140 {
141 let (_stop_sender, mut stop_receiver) = if let Some(orig_receiver) = stop_receiver {
176 (None, orig_receiver)
177 } else {
178 oneshot::channel().map_self(|(sender, receiver)| (Some(sender), receiver))
179 };
180
181 futures::executor::block_on(async {
182 loop {
183 select! {
184 change_event = self.next_change().fuse() => {
185 match change_event {
186 Some(Ok(id)) => if let Some(result) = callback(id) {
188 result.map_err(|err_value| MonitorLoopError::Other(err_value))?;
189 },
190 Some(Err(error)) => break Err(MonitorLoopError::WmiError(error)),
192 None => unreachable!(),
194 }
195 },
196 value = stop_receiver => break Ok(value.unwrap_or_default()),
198 }
199 }
200 })
201 }
202}
203
204#[derive(Deserialize, PartialEq, Eq, Hash, Debug)]
205#[serde(rename = "RegistryValueChangeEvent")]
206#[serde(rename_all = "PascalCase")]
207struct RegValueChange {
208 hive: String,
209 key_path: String,
210 value_name: String,
211}
212
213#[derive(Error, Debug)]
214pub enum MonitorLoopError<T> {
215 #[error("WMI error: {0}")]
216 WmiError(#[from] WMIError),
217 #[error("monitor loop error: {0}")]
218 Other(T),
219}
220
221fn current_user_sid() -> Result<String, windows::core::Error> {
222 let process_token_handle = ResGuard::with_mut_acq_and_close_handle(|handle| unsafe {
223 OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, handle)
224 })?;
225
226 let mut sid_and_attrs_buffer = Vec::<u8>::new();
227 let mut sid_and_attrs_buffer_size = 0;
228
229 dual_call(
230 FirstCallExpectation::Win32Error(ERROR_INSUFFICIENT_BUFFER),
231 |getting_buffer_size| unsafe {
232 GetTokenInformation(
233 *process_token_handle,
234 TokenUser,
235 (!getting_buffer_size).then(|| {
236 sid_and_attrs_buffer.resize(sid_and_attrs_buffer_size as _, 0);
237 sid_and_attrs_buffer.as_mut_ptr().cast()
238 }),
239 sid_and_attrs_buffer_size,
240 &mut sid_and_attrs_buffer_size,
241 )
242 },
243 )?;
244
245 let string_sid = unsafe {
246 ResGuard::<PWSTR>::with_mut_acq_and_local_free(|pwstr| {
247 ConvertSidToStringSidW(
248 (&*sid_and_attrs_buffer.as_ptr().cast::<SID_AND_ATTRIBUTES>()).Sid,
249 pwstr,
250 )
251 })?
252 .to_string()?
253 };
254
255 Ok(string_sid)
256}