viewpoint_core/network/auth/
mod.rs1use std::sync::Arc;
7
8use tokio::sync::RwLock;
9use viewpoint_cdp::CdpConnection;
10use viewpoint_cdp::protocol::fetch::{
11 AuthChallenge, AuthChallengeResponse, AuthRequiredEvent, ContinueWithAuthParams,
12};
13
14use crate::error::NetworkError;
15
16#[derive(Debug, Clone)]
18pub struct HttpCredentials {
19 pub username: String,
21 pub password: String,
23 pub origin: Option<String>,
26}
27
28impl HttpCredentials {
29 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
31 Self {
32 username: username.into(),
33 password: password.into(),
34 origin: None,
35 }
36 }
37
38 pub fn for_origin(
40 username: impl Into<String>,
41 password: impl Into<String>,
42 origin: impl Into<String>,
43 ) -> Self {
44 Self {
45 username: username.into(),
46 password: password.into(),
47 origin: Some(origin.into()),
48 }
49 }
50
51 pub fn matches_origin(&self, challenge_origin: &str) -> bool {
53 match &self.origin {
54 Some(origin) => {
55 challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
57 }
58 None => true, }
60 }
61}
62
63#[derive(Debug)]
65pub struct AuthHandler {
66 connection: Arc<CdpConnection>,
68 session_id: String,
70 credentials: RwLock<Option<HttpCredentials>>,
72 max_retries: u32,
74 retry_counts: RwLock<std::collections::HashMap<String, u32>>,
76}
77
78impl AuthHandler {
79 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
81 Self {
82 connection,
83 session_id,
84 credentials: RwLock::new(None),
85 max_retries: 3,
86 retry_counts: RwLock::new(std::collections::HashMap::new()),
87 }
88 }
89
90 pub fn with_credentials(
92 connection: Arc<CdpConnection>,
93 session_id: String,
94 credentials: HttpCredentials,
95 ) -> Self {
96 Self {
97 connection,
98 session_id,
99 credentials: RwLock::new(Some(credentials)),
100 max_retries: 3,
101 retry_counts: RwLock::new(std::collections::HashMap::new()),
102 }
103 }
104
105 pub async fn set_credentials(&self, credentials: HttpCredentials) {
107 let mut creds = self.credentials.write().await;
108 *creds = Some(credentials);
109 }
110
111 pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
115 if let Ok(mut creds) = self.credentials.try_write() {
118 *creds = Some(credentials);
119 }
120 }
121
122 pub async fn clear_credentials(&self) {
124 let mut creds = self.credentials.write().await;
125 *creds = None;
126 }
127
128 pub async fn handle_auth_challenge(
137 &self,
138 event: &AuthRequiredEvent,
139 ) -> Result<bool, NetworkError> {
140 let creds = self.credentials.read().await;
141
142 if let Some(credentials) = &*creds {
143 if !credentials.matches_origin(&event.auth_challenge.origin) {
145 tracing::debug!(
146 origin = %event.auth_challenge.origin,
147 "No matching credentials for origin"
148 );
149 return self.cancel_auth(&event.request_id).await.map(|()| false);
150 }
151
152 {
154 let mut counts = self.retry_counts.write().await;
155 let count = counts
156 .entry(event.auth_challenge.origin.clone())
157 .or_insert(0);
158
159 if *count >= self.max_retries {
160 tracing::warn!(
161 origin = %event.auth_challenge.origin,
162 retries = self.max_retries,
163 "Max auth retries exceeded, canceling"
164 );
165 return self.cancel_auth(&event.request_id).await.map(|()| false);
166 }
167
168 *count += 1;
169 }
170
171 self.provide_credentials(
173 &event.request_id,
174 &event.auth_challenge,
175 &credentials.username,
176 &credentials.password,
177 )
178 .await?;
179
180 Ok(true)
181 } else {
182 tracing::debug!(
183 origin = %event.auth_challenge.origin,
184 scheme = %event.auth_challenge.scheme,
185 "No credentials available, deferring to default"
186 );
187 self.default_auth(&event.request_id).await?;
189 Ok(false)
190 }
191 }
192
193 async fn provide_credentials(
195 &self,
196 request_id: &str,
197 challenge: &AuthChallenge,
198 username: &str,
199 password: &str,
200 ) -> Result<(), NetworkError> {
201 tracing::debug!(
202 origin = %challenge.origin,
203 scheme = %challenge.scheme,
204 realm = %challenge.realm,
205 "Providing credentials for auth challenge"
206 );
207
208 self.connection
209 .send_command::<_, serde_json::Value>(
210 "Fetch.continueWithAuth",
211 Some(ContinueWithAuthParams {
212 request_id: request_id.to_string(),
213 auth_challenge_response: AuthChallengeResponse::provide_credentials(
214 username, password,
215 ),
216 }),
217 Some(&self.session_id),
218 )
219 .await?;
220
221 Ok(())
222 }
223
224 async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
226 tracing::debug!("Canceling auth challenge");
227
228 self.connection
229 .send_command::<_, serde_json::Value>(
230 "Fetch.continueWithAuth",
231 Some(ContinueWithAuthParams {
232 request_id: request_id.to_string(),
233 auth_challenge_response: AuthChallengeResponse::cancel(),
234 }),
235 Some(&self.session_id),
236 )
237 .await?;
238
239 Ok(())
240 }
241
242 async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
244 self.connection
245 .send_command::<_, serde_json::Value>(
246 "Fetch.continueWithAuth",
247 Some(ContinueWithAuthParams {
248 request_id: request_id.to_string(),
249 auth_challenge_response: AuthChallengeResponse::default_response(),
250 }),
251 Some(&self.session_id),
252 )
253 .await?;
254
255 Ok(())
256 }
257
258 pub async fn reset_retries(&self, origin: &str) {
260 let mut counts = self.retry_counts.write().await;
261 counts.remove(origin);
262 }
263
264 pub async fn reset_all_retries(&self) {
266 let mut counts = self.retry_counts.write().await;
267 counts.clear();
268 }
269}
270
271#[cfg(test)]
272mod tests;