viewpoint_core/network/auth/
mod.rs

1//! HTTP authentication handling.
2//!
3//! This module provides support for handling HTTP Basic and Digest authentication
4//! challenges via the Fetch.authRequired CDP event.
5
6use std::sync::Arc;
7
8use tokio::sync::RwLock;
9use viewpoint_cdp::CdpConnection;
10use viewpoint_cdp::protocol::fetch::{
11    AuthChallenge, AuthChallengeResponse, AuthRequiredEvent, ContinueWithAuthParams,
12};
13
14use crate::error::NetworkError;
15
16/// HTTP credentials for authentication.
17#[derive(Debug, Clone)]
18pub struct HttpCredentials {
19    /// Username for authentication.
20    pub username: String,
21    /// Password for authentication.
22    pub password: String,
23    /// Optional origin to restrict credentials to.
24    /// If None, credentials apply to all origins.
25    pub origin: Option<String>,
26}
27
28impl HttpCredentials {
29    /// Create new HTTP credentials.
30    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
31        Self {
32            username: username.into(),
33            password: password.into(),
34            origin: None,
35        }
36    }
37
38    /// Create HTTP credentials restricted to a specific origin.
39    pub fn for_origin(
40        username: impl Into<String>,
41        password: impl Into<String>,
42        origin: impl Into<String>,
43    ) -> Self {
44        Self {
45            username: username.into(),
46            password: password.into(),
47            origin: Some(origin.into()),
48        }
49    }
50
51    /// Check if these credentials apply to the given challenge origin.
52    pub fn matches_origin(&self, challenge_origin: &str) -> bool {
53        match &self.origin {
54            Some(origin) => {
55                // Match if origin matches exactly or is a subdomain
56                challenge_origin == origin || challenge_origin.ends_with(&format!(".{origin}"))
57            }
58            None => true, // No origin restriction - apply to all
59        }
60    }
61}
62
63/// Handler for HTTP authentication challenges.
64#[derive(Debug)]
65pub struct AuthHandler {
66    /// CDP connection.
67    connection: Arc<CdpConnection>,
68    /// Session ID for CDP commands.
69    session_id: String,
70    /// Stored credentials.
71    credentials: RwLock<Option<HttpCredentials>>,
72    /// How many times to retry with credentials before canceling.
73    max_retries: u32,
74    /// Current retry count per origin.
75    retry_counts: RwLock<std::collections::HashMap<String, u32>>,
76}
77
78impl AuthHandler {
79    /// Create a new auth handler.
80    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
81        Self {
82            connection,
83            session_id,
84            credentials: RwLock::new(None),
85            max_retries: 3,
86            retry_counts: RwLock::new(std::collections::HashMap::new()),
87        }
88    }
89
90    /// Create an auth handler with pre-configured credentials.
91    pub fn with_credentials(
92        connection: Arc<CdpConnection>,
93        session_id: String,
94        credentials: HttpCredentials,
95    ) -> Self {
96        Self {
97            connection,
98            session_id,
99            credentials: RwLock::new(Some(credentials)),
100            max_retries: 3,
101            retry_counts: RwLock::new(std::collections::HashMap::new()),
102        }
103    }
104
105    /// Set HTTP credentials.
106    pub async fn set_credentials(&self, credentials: HttpCredentials) {
107        let mut creds = self.credentials.write().await;
108        *creds = Some(credentials);
109    }
110
111    /// Set HTTP credentials synchronously (for use during construction).
112    ///
113    /// This uses `blocking_write` which should only be called from non-async contexts.
114    pub fn set_credentials_sync(&self, credentials: HttpCredentials) {
115        // Use try_write to avoid blocking - this is called during construction
116        // before any async tasks are running, so it should always succeed.
117        if let Ok(mut creds) = self.credentials.try_write() {
118            *creds = Some(credentials);
119        }
120    }
121
122    /// Clear HTTP credentials.
123    pub async fn clear_credentials(&self) {
124        let mut creds = self.credentials.write().await;
125        *creds = None;
126    }
127
128    /// Handle an authentication challenge.
129    ///
130    /// Returns true if the challenge was handled, false if no credentials available.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the CDP command to continue with authentication fails,
135    /// such as when the connection is closed or the browser rejects the request.
136    pub async fn handle_auth_challenge(
137        &self,
138        event: &AuthRequiredEvent,
139    ) -> Result<bool, NetworkError> {
140        let creds = self.credentials.read().await;
141
142        if let Some(credentials) = &*creds {
143            // Check if credentials match the challenge origin
144            if !credentials.matches_origin(&event.auth_challenge.origin) {
145                tracing::debug!(
146                    origin = %event.auth_challenge.origin,
147                    "No matching credentials for origin"
148                );
149                return self.cancel_auth(&event.request_id).await.map(|()| false);
150            }
151
152            // Check retry count
153            {
154                let mut counts = self.retry_counts.write().await;
155                let count = counts
156                    .entry(event.auth_challenge.origin.clone())
157                    .or_insert(0);
158
159                if *count >= self.max_retries {
160                    tracing::warn!(
161                        origin = %event.auth_challenge.origin,
162                        retries = self.max_retries,
163                        "Max auth retries exceeded, canceling"
164                    );
165                    return self.cancel_auth(&event.request_id).await.map(|()| false);
166                }
167
168                *count += 1;
169            }
170
171            // Provide credentials based on the authentication scheme
172            self.provide_credentials(
173                &event.request_id,
174                &event.auth_challenge,
175                &credentials.username,
176                &credentials.password,
177            )
178            .await?;
179
180            Ok(true)
181        } else {
182            tracing::debug!(
183                origin = %event.auth_challenge.origin,
184                scheme = %event.auth_challenge.scheme,
185                "No credentials available, deferring to default"
186            );
187            // No credentials - let browser handle it (show dialog or fail)
188            self.default_auth(&event.request_id).await?;
189            Ok(false)
190        }
191    }
192
193    /// Provide credentials for an auth challenge.
194    async fn provide_credentials(
195        &self,
196        request_id: &str,
197        challenge: &AuthChallenge,
198        username: &str,
199        password: &str,
200    ) -> Result<(), NetworkError> {
201        tracing::debug!(
202            origin = %challenge.origin,
203            scheme = %challenge.scheme,
204            realm = %challenge.realm,
205            "Providing credentials for auth challenge"
206        );
207
208        self.connection
209            .send_command::<_, serde_json::Value>(
210                "Fetch.continueWithAuth",
211                Some(ContinueWithAuthParams {
212                    request_id: request_id.to_string(),
213                    auth_challenge_response: AuthChallengeResponse::provide_credentials(
214                        username, password,
215                    ),
216                }),
217                Some(&self.session_id),
218            )
219            .await?;
220
221        Ok(())
222    }
223
224    /// Cancel authentication.
225    async fn cancel_auth(&self, request_id: &str) -> Result<(), NetworkError> {
226        tracing::debug!("Canceling auth challenge");
227
228        self.connection
229            .send_command::<_, serde_json::Value>(
230                "Fetch.continueWithAuth",
231                Some(ContinueWithAuthParams {
232                    request_id: request_id.to_string(),
233                    auth_challenge_response: AuthChallengeResponse::cancel(),
234                }),
235                Some(&self.session_id),
236            )
237            .await?;
238
239        Ok(())
240    }
241
242    /// Use default browser behavior for auth.
243    async fn default_auth(&self, request_id: &str) -> Result<(), NetworkError> {
244        self.connection
245            .send_command::<_, serde_json::Value>(
246                "Fetch.continueWithAuth",
247                Some(ContinueWithAuthParams {
248                    request_id: request_id.to_string(),
249                    auth_challenge_response: AuthChallengeResponse::default_response(),
250                }),
251                Some(&self.session_id),
252            )
253            .await?;
254
255        Ok(())
256    }
257
258    /// Reset retry counts (call after successful auth).
259    pub async fn reset_retries(&self, origin: &str) {
260        let mut counts = self.retry_counts.write().await;
261        counts.remove(origin);
262    }
263
264    /// Reset all retry counts.
265    pub async fn reset_all_retries(&self) {
266        let mut counts = self.retry_counts.write().await;
267        counts.clear();
268    }
269}
270
271#[cfg(test)]
272mod tests;