viewpoint_core/network/auth/
mod.rs1use std::sync::Arc;
7
8use tokio::sync::RwLock;
9use viewpoint_cdp::protocol::fetch::{
10 AuthChallenge, AuthChallengeResponse, AuthRequiredEvent,
11 ContinueWithAuthParams,
12};
13use viewpoint_cdp::CdpConnection;
14
15use crate::error::NetworkError;
16
17#[derive(Debug, Clone)]
19pub struct HttpCredentials {
20 pub username: String,
22 pub password: String,
24 pub origin: Option<String>,
27}
28
29impl HttpCredentials {
30 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
32 Self {
33 username: username.into(),
34 password: password.into(),
35 origin: None,
36 }
37 }
38
39 pub fn for_origin(
41 username: impl Into<String>,
42 password: impl Into<String>,
43 origin: impl Into<String>,
44 ) -> Self {
45 Self {
46 username: username.into(),
47 password: password.into(),
48 origin: Some(origin.into()),
49 }
50 }
51
52 pub fn matches_origin(&self, challenge_origin: &str) -> bool {
54 match &self.origin {
55 Some(origin) => {
56 challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
58 }
59 None => true, }
61 }
62}
63
64#[derive(Debug)]
66pub struct AuthHandler {
67 connection: Arc<CdpConnection>,
69 session_id: String,
71 credentials: RwLock<Option<HttpCredentials>>,
73 max_retries: u32,
75 retry_counts: RwLock<std::collections::HashMap<String, u32>>,
77}
78
79impl AuthHandler {
80 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
82 Self {
83 connection,
84 session_id,
85 credentials: RwLock::new(None),
86 max_retries: 3,
87 retry_counts: RwLock::new(std::collections::HashMap::new()),
88 }
89 }
90
91 pub fn with_credentials(
93 connection: Arc<CdpConnection>,
94 session_id: String,
95 credentials: HttpCredentials,
96 ) -> Self {
97 Self {
98 connection,
99 session_id,
100 credentials: RwLock::new(Some(credentials)),
101 max_retries: 3,
102 retry_counts: RwLock::new(std::collections::HashMap::new()),
103 }
104 }
105
106 pub async fn set_credentials(&self, credentials: HttpCredentials) {
108 let mut creds = self.credentials.write().await;
109 *creds = Some(credentials);
110 }
111
112 pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
116 if let Ok(mut creds) = self.credentials.try_write() {
119 *creds = Some(credentials);
120 }
121 }
122
123 pub async fn clear_credentials(&self) {
125 let mut creds = self.credentials.write().await;
126 *creds = None;
127 }
128
129 pub async fn handle_auth_challenge(
138 &self,
139 event: &AuthRequiredEvent,
140 ) -> Result<bool, NetworkError> {
141 let creds = self.credentials.read().await;
142
143 if let Some(credentials) = &*creds {
144 if !credentials.matches_origin(&event.auth_challenge.origin) {
146 tracing::debug!(
147 origin = %event.auth_challenge.origin,
148 "No matching credentials for origin"
149 );
150 return self.cancel_auth(&event.request_id).await.map(|()| false);
151 }
152
153 {
155 let mut counts = self.retry_counts.write().await;
156 let count = counts.entry(event.auth_challenge.origin.clone()).or_insert(0);
157
158 if *count >= self.max_retries {
159 tracing::warn!(
160 origin = %event.auth_challenge.origin,
161 retries = self.max_retries,
162 "Max auth retries exceeded, canceling"
163 );
164 return self.cancel_auth(&event.request_id).await.map(|()| false);
165 }
166
167 *count += 1;
168 }
169
170 self.provide_credentials(
172 &event.request_id,
173 &event.auth_challenge,
174 &credentials.username,
175 &credentials.password,
176 )
177 .await?;
178
179 Ok(true)
180 } else {
181 tracing::debug!(
182 origin = %event.auth_challenge.origin,
183 scheme = %event.auth_challenge.scheme,
184 "No credentials available, deferring to default"
185 );
186 self.default_auth(&event.request_id).await?;
188 Ok(false)
189 }
190 }
191
192 async fn provide_credentials(
194 &self,
195 request_id: &str,
196 challenge: &AuthChallenge,
197 username: &str,
198 password: &str,
199 ) -> Result<(), NetworkError> {
200 tracing::debug!(
201 origin = %challenge.origin,
202 scheme = %challenge.scheme,
203 realm = %challenge.realm,
204 "Providing credentials for auth challenge"
205 );
206
207 self.connection
208 .send_command::<_, serde_json::Value>(
209 "Fetch.continueWithAuth",
210 Some(ContinueWithAuthParams {
211 request_id: request_id.to_string(),
212 auth_challenge_response: AuthChallengeResponse::provide_credentials(
213 username,
214 password,
215 ),
216 }),
217 Some(&self.session_id),
218 )
219 .await?;
220
221 Ok(())
222 }
223
224 async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
226 tracing::debug!("Canceling auth challenge");
227
228 self.connection
229 .send_command::<_, serde_json::Value>(
230 "Fetch.continueWithAuth",
231 Some(ContinueWithAuthParams {
232 request_id: request_id.to_string(),
233 auth_challenge_response: AuthChallengeResponse::cancel(),
234 }),
235 Some(&self.session_id),
236 )
237 .await?;
238
239 Ok(())
240 }
241
242 async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
244 self.connection
245 .send_command::<_, serde_json::Value>(
246 "Fetch.continueWithAuth",
247 Some(ContinueWithAuthParams {
248 request_id: request_id.to_string(),
249 auth_challenge_response: AuthChallengeResponse::default_response(),
250 }),
251 Some(&self.session_id),
252 )
253 .await?;
254
255 Ok(())
256 }
257
258 pub async fn reset_retries(&self, origin: &str) {
260 let mut counts = self.retry_counts.write().await;
261 counts.remove(origin);
262 }
263
264 pub async fn reset_all_retries(&self) {
266 let mut counts = self.retry_counts.write().await;
267 counts.clear();
268 }
269}
270
271#[cfg(test)]
272mod tests;