viewpoint_core/network/auth/
mod.rs1use std::sync::Arc;
8
9use tokio::sync::RwLock;
10use viewpoint_cdp::CdpConnection;
11use viewpoint_cdp::protocol::fetch::{
12 AuthChallenge, AuthChallengeResponse, AuthChallengeSource, AuthRequiredEvent,
13 ContinueWithAuthParams,
14};
15
16use crate::error::NetworkError;
17
18#[derive(Debug, Clone)]
20pub struct ProxyCredentials {
21 pub username: String,
23 pub password: String,
25}
26
27impl ProxyCredentials {
28 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
30 Self {
31 username: username.into(),
32 password: password.into(),
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct HttpCredentials {
40 pub username: String,
42 pub password: String,
44 pub origin: Option<String>,
47}
48
49impl HttpCredentials {
50 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
52 Self {
53 username: username.into(),
54 password: password.into(),
55 origin: None,
56 }
57 }
58
59 pub fn for_origin(
61 username: impl Into<String>,
62 password: impl Into<String>,
63 origin: impl Into<String>,
64 ) -> Self {
65 Self {
66 username: username.into(),
67 password: password.into(),
68 origin: Some(origin.into()),
69 }
70 }
71
72 pub fn matches_origin(&self, challenge_origin: &str) -> bool {
74 match &self.origin {
75 Some(origin) => {
76 challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
78 }
79 None => true, }
81 }
82}
83
84#[derive(Debug)]
86pub struct AuthHandler {
87 connection: Arc<CdpConnection>,
89 session_id: String,
91 credentials: RwLock<Option<HttpCredentials>>,
93 proxy_credentials: RwLock<Option<ProxyCredentials>>,
95 max_retries: u32,
97 retry_counts: RwLock<std::collections::HashMap<String, u32>>,
99}
100
101impl AuthHandler {
102 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
104 Self {
105 connection,
106 session_id,
107 credentials: RwLock::new(None),
108 proxy_credentials: RwLock::new(None),
109 max_retries: 3,
110 retry_counts: RwLock::new(std::collections::HashMap::new()),
111 }
112 }
113
114 pub fn with_credentials(
116 connection: Arc<CdpConnection>,
117 session_id: String,
118 credentials: HttpCredentials,
119 ) -> Self {
120 Self {
121 connection,
122 session_id,
123 credentials: RwLock::new(Some(credentials)),
124 proxy_credentials: RwLock::new(None),
125 max_retries: 3,
126 retry_counts: RwLock::new(std::collections::HashMap::new()),
127 }
128 }
129
130 pub fn with_proxy_credentials(
132 connection: Arc<CdpConnection>,
133 session_id: String,
134 proxy_credentials: ProxyCredentials,
135 ) -> Self {
136 Self {
137 connection,
138 session_id,
139 credentials: RwLock::new(None),
140 proxy_credentials: RwLock::new(Some(proxy_credentials)),
141 max_retries: 3,
142 retry_counts: RwLock::new(std::collections::HashMap::new()),
143 }
144 }
145
146 pub fn with_all_credentials(
148 connection: Arc<CdpConnection>,
149 session_id: String,
150 http_credentials: Option<HttpCredentials>,
151 proxy_credentials: Option<ProxyCredentials>,
152 ) -> Self {
153 Self {
154 connection,
155 session_id,
156 credentials: RwLock::new(http_credentials),
157 proxy_credentials: RwLock::new(proxy_credentials),
158 max_retries: 3,
159 retry_counts: RwLock::new(std::collections::HashMap::new()),
160 }
161 }
162
163 pub async fn set_credentials(&self, credentials: HttpCredentials) {
165 let mut creds = self.credentials.write().await;
166 *creds = Some(credentials);
167 }
168
169 pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
173 if let Ok(mut creds) = self.credentials.try_write() {
176 *creds = Some(credentials);
177 }
178 }
179
180 pub async fn clear_credentials(&self) {
182 let mut creds = self.credentials.write().await;
183 *creds = None;
184 }
185
186 pub async fn set_proxy_credentials(&self, credentials: ProxyCredentials) {
188 let mut creds = self.proxy_credentials.write().await;
189 *creds = Some(credentials);
190 }
191
192 pub fn set_proxy_credentials_sync(&self, credentials: ProxyCredentials) {
196 if let Ok(mut creds) = self.proxy_credentials.try_write() {
197 *creds = Some(credentials);
198 }
199 }
200
201 pub async fn clear_proxy_credentials(&self) {
203 let mut creds = self.proxy_credentials.write().await;
204 *creds = None;
205 }
206
207 pub async fn handle_auth_challenge(
216 &self,
217 event: &AuthRequiredEvent,
218 ) -> Result<bool, NetworkError> {
219 if event.auth_challenge.source == AuthChallengeSource::Proxy {
221 return self.handle_proxy_auth(event).await;
222 }
223
224 let creds = self.credentials.read().await;
226
227 if let Some(credentials) = &*creds {
228 if !credentials.matches_origin(&event.auth_challenge.origin) {
230 tracing::debug!(
231 origin = %event.auth_challenge.origin,
232 "No matching credentials for origin"
233 );
234 return self.cancel_auth(&event.request_id).await.map(|()| false);
235 }
236
237 {
239 let mut counts = self.retry_counts.write().await;
240 let count = counts
241 .entry(event.auth_challenge.origin.clone())
242 .or_insert(0);
243
244 if *count >= self.max_retries {
245 tracing::warn!(
246 origin = %event.auth_challenge.origin,
247 retries = self.max_retries,
248 "Max auth retries exceeded, canceling"
249 );
250 return self.cancel_auth(&event.request_id).await.map(|()| false);
251 }
252
253 *count += 1;
254 }
255
256 self.provide_credentials(
258 &event.request_id,
259 &event.auth_challenge,
260 &credentials.username,
261 &credentials.password,
262 )
263 .await?;
264
265 Ok(true)
266 } else {
267 tracing::debug!(
268 origin = %event.auth_challenge.origin,
269 scheme = %event.auth_challenge.scheme,
270 "No credentials available, deferring to default"
271 );
272 self.default_auth(&event.request_id).await?;
274 Ok(false)
275 }
276 }
277
278 async fn handle_proxy_auth(&self, event: &AuthRequiredEvent) -> Result<bool, NetworkError> {
280 let proxy_creds = self.proxy_credentials.read().await;
281
282 if let Some(credentials) = &*proxy_creds {
283 let retry_key = format!("proxy:{}", event.auth_challenge.origin);
285 {
286 let mut counts = self.retry_counts.write().await;
287 let count = counts.entry(retry_key.clone()).or_insert(0);
288
289 if *count >= self.max_retries {
290 tracing::warn!(
291 origin = %event.auth_challenge.origin,
292 retries = self.max_retries,
293 "Max proxy auth retries exceeded, canceling"
294 );
295 return self.cancel_auth(&event.request_id).await.map(|()| false);
296 }
297
298 *count += 1;
299 }
300
301 tracing::debug!(
302 origin = %event.auth_challenge.origin,
303 scheme = %event.auth_challenge.scheme,
304 "Providing proxy credentials"
305 );
306
307 self.provide_credentials(
309 &event.request_id,
310 &event.auth_challenge,
311 &credentials.username,
312 &credentials.password,
313 )
314 .await?;
315
316 Ok(true)
317 } else {
318 tracing::debug!(
319 origin = %event.auth_challenge.origin,
320 scheme = %event.auth_challenge.scheme,
321 "No proxy credentials available, deferring to default"
322 );
323 self.default_auth(&event.request_id).await?;
325 Ok(false)
326 }
327 }
328
329 async fn provide_credentials(
331 &self,
332 request_id: &str,
333 challenge: &AuthChallenge,
334 username: &str,
335 password: &str,
336 ) -> Result<(), NetworkError> {
337 tracing::debug!(
338 origin = %challenge.origin,
339 scheme = %challenge.scheme,
340 realm = %challenge.realm,
341 "Providing credentials for auth challenge"
342 );
343
344 self.connection
345 .send_command::<_, serde_json::Value>(
346 "Fetch.continueWithAuth",
347 Some(ContinueWithAuthParams {
348 request_id: request_id.to_string(),
349 auth_challenge_response: AuthChallengeResponse::provide_credentials(
350 username, password,
351 ),
352 }),
353 Some(&self.session_id),
354 )
355 .await?;
356
357 Ok(())
358 }
359
360 async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
362 tracing::debug!("Canceling auth challenge");
363
364 self.connection
365 .send_command::<_, serde_json::Value>(
366 "Fetch.continueWithAuth",
367 Some(ContinueWithAuthParams {
368 request_id: request_id.to_string(),
369 auth_challenge_response: AuthChallengeResponse::cancel(),
370 }),
371 Some(&self.session_id),
372 )
373 .await?;
374
375 Ok(())
376 }
377
378 async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
380 self.connection
381 .send_command::<_, serde_json::Value>(
382 "Fetch.continueWithAuth",
383 Some(ContinueWithAuthParams {
384 request_id: request_id.to_string(),
385 auth_challenge_response: AuthChallengeResponse::default_response(),
386 }),
387 Some(&self.session_id),
388 )
389 .await?;
390
391 Ok(())
392 }
393
394 pub async fn reset_retries(&self, origin: &str) {
396 let mut counts = self.retry_counts.write().await;
397 counts.remove(origin);
398 }
399
400 pub async fn reset_all_retries(&self) {
402 let mut counts = self.retry_counts.write().await;
403 counts.clear();
404 }
405}
406
407#[cfg(test)]
408mod tests;