xjp_oidc/
sse.rs

1//! Server-Sent Events (SSE) support for login flows
2//!
3//! This module provides support for real-time login status updates via SSE,
4//! commonly used for QR code login flows where the status needs to be
5//! monitored in real-time.
6
7#[cfg(not(target_arch = "wasm32"))]
8use crate::{
9    errors::{Error, Result},
10    http::HttpClient,
11};
12
13#[cfg(not(target_arch = "wasm32"))]
14use serde::{Deserialize, Serialize};
15
16/// Login status enum matching the backend implementation
17#[cfg(not(target_arch = "wasm32"))]
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19#[serde(rename_all = "UPPERCASE")]
20pub enum LoginStatus {
21    /// Waiting for user action (e.g., QR code scan)
22    Pending,
23    /// User has scanned but not yet authorized
24    Scanned,
25    /// User has authorized the login
26    Authorized,
27    /// Login completed successfully
28    Success,
29    /// Login failed
30    Failed,
31    /// Login session expired
32    Expired,
33}
34
35/// Login state information
36#[cfg(not(target_arch = "wasm32"))]
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct LoginState {
39    /// Current status of the login session
40    pub status: LoginStatus,
41    /// OAuth authorization code (present when status is Success)
42    pub code: Option<String>,
43    /// Error message (present when status is Failed)
44    pub error: Option<String>,
45    /// Creation timestamp (Unix timestamp)
46    pub created_at: i64,
47}
48
49/// SSE event types
50#[cfg(not(target_arch = "wasm32"))]
51#[derive(Debug, Clone)]
52pub enum LoginEvent {
53    /// Status update event
54    StatusUpdate(LoginState),
55    /// Heartbeat event (no data)
56    Heartbeat,
57    /// Stream closed event
58    Close,
59    /// Error event
60    Error(String),
61}
62
63/// Start a login session and get a login ID for monitoring
64///
65/// This creates a new login session on the server and returns a login ID
66/// that can be used to monitor the login status via SSE.
67///
68/// # Example
69/// ```no_run
70/// # #[cfg(not(target_arch = "wasm32"))]
71/// # use xjp_oidc::{sse::start_login_session, http::ReqwestHttpClient};
72/// # #[cfg(not(target_arch = "wasm32"))]
73/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
74/// let http = ReqwestHttpClient::default();
75///
76/// let (login_id, qr_url) = start_login_session(
77///     "https://auth.example.com",
78///     "my-client-id",
79///     "https://app.example.com/callback",
80///     &http
81/// ).await?;
82///
83/// println!("Login ID: {}", login_id);
84/// println!("QR Code URL: {}", qr_url);
85/// # Ok(())
86/// # }
87/// ```
88#[cfg(not(target_arch = "wasm32"))]
89pub async fn start_login_session(
90    issuer: &str,
91    client_id: &str,
92    redirect_uri: &str,
93    http: &dyn HttpClient,
94) -> Result<(String, String)> {
95    // Validate parameters
96    if issuer.is_empty() {
97        return Err(Error::InvalidParam("issuer cannot be empty"));
98    }
99    if client_id.is_empty() {
100        return Err(Error::InvalidParam("client_id cannot be empty"));
101    }
102    if redirect_uri.is_empty() {
103        return Err(Error::InvalidParam("redirect_uri cannot be empty"));
104    }
105
106    // Build the login session endpoint URL
107    let session_endpoint = format!("{}/auth/wechat/qr", issuer.trim_end_matches('/'));
108
109    // Prepare request body
110    let body = serde_json::json!({
111        "client_id": client_id,
112        "redirect_uri": redirect_uri,
113    });
114
115    // Make the request
116    let response = http
117        .post_json_value(&session_endpoint, &body, None)
118        .await
119        .map_err(|e| Error::Network(format!("Failed to start login session: {}", e)))?;
120
121    // Extract login_id and qr_url from response
122    let login_id = response["login_id"]
123        .as_str()
124        .ok_or_else(|| Error::InvalidState("Missing login_id in response".to_string()))?
125        .to_string();
126
127    let qr_url = response["wechat_qr_url"]
128        .as_str()
129        .ok_or_else(|| Error::InvalidState("Missing wechat_qr_url in response".to_string()))?
130        .to_string();
131
132    Ok((login_id, qr_url))
133}
134
135/// Configuration for SSE login monitoring
136#[cfg(not(target_arch = "wasm32"))]
137#[derive(Debug, Clone)]
138pub struct LoginMonitorConfig {
139    /// The issuer URL
140    pub issuer: String,
141    /// The login ID to monitor
142    pub login_id: String,
143    /// Optional timeout in seconds (default: 300)
144    pub timeout_secs: Option<u64>,
145    /// Optional reconnect attempts (default: 3)
146    pub max_reconnects: Option<u32>,
147}
148
149/// Subscribe to login status updates via SSE
150///
151/// This function returns a stream of login events that can be consumed
152/// to track the login progress in real-time.
153///
154/// # Example
155/// ```no_run
156/// # #[cfg(not(target_arch = "wasm32"))]
157/// # use xjp_oidc::sse::{subscribe_login_events, LoginMonitorConfig, LoginEvent, LoginStatus};
158/// # #[cfg(not(target_arch = "wasm32"))]
159/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
160/// use futures_util::StreamExt;
161///
162/// let config = LoginMonitorConfig {
163///     issuer: "https://auth.example.com".to_string(),
164///     login_id: "login-123".to_string(),
165///     timeout_secs: Some(300),
166///     max_reconnects: Some(3),
167/// };
168///
169/// let mut event_stream = subscribe_login_events(config).await?;
170///
171/// while let Some(event) = event_stream.next().await {
172///     match event {
173///         Ok(LoginEvent::StatusUpdate(state)) => {
174///             println!("Status: {:?}", state.status);
175///             if state.status == LoginStatus::Success {
176///                 println!("Login successful! Code: {:?}", state.code);
177///                 break;
178///             }
179///         }
180///         Ok(LoginEvent::Heartbeat) => {
181///             println!("Heartbeat received");
182///         }
183///         Ok(LoginEvent::Close) => {
184///             println!("Stream closed");
185///             break;
186///         }
187///         Err(e) => {
188///             eprintln!("Error: {}", e);
189///             break;
190///         }
191///     }
192/// }
193/// # Ok(())
194/// # }
195/// ```
196#[cfg(all(not(target_arch = "wasm32"), feature = "sse"))]
197pub async fn subscribe_login_events(
198    config: LoginMonitorConfig,
199) -> Result<impl futures_util::Stream<Item = Result<LoginEvent>>> {
200    use eventsource_client::{Client, ClientBuilder, ReconnectOptions, SSE};
201    use futures_util::StreamExt;
202
203    // Build SSE endpoint URL
204    let sse_url = format!(
205        "{}/auth/login-stream?login_id={}",
206        config.issuer.trim_end_matches('/'),
207        urlencoding::encode(&config.login_id)
208    );
209
210    // Create reconnect options with available methods
211    let reconnect = ReconnectOptions::reconnect(true)
212        .retry_initial(false)
213        .delay(std::time::Duration::from_secs(1))
214        .delay_max(std::time::Duration::from_secs(5))
215        .build();
216
217    // Create SSE client
218    let client = ClientBuilder::for_url(&sse_url)
219        .map_err(|e| Error::Network(format!("Failed to create SSE client: {}", e)))?
220        .reconnect(reconnect)
221        .build();
222
223    // Convert the event stream to our LoginEvent type
224    let event_stream = client.stream().map(|result| {
225        match result {
226            Ok(SSE::Event(e)) => {
227                match e.event_type.as_str() {
228                    "pending" | "scanned" | "authorized" | "success" | "failed" | "expired" => {
229                        // Parse login state from event data - handle both direct and nested formats
230                        let parse_result: Result<LoginState> = {
231                            // First try to parse as direct LoginState
232                            if let Ok(state) = serde_json::from_str::<LoginState>(&e.data) {
233                                Ok(state)
234                            } else {
235                                // Try to parse as nested format and extract state field
236                                if let Ok(response) = serde_json::from_str::<serde_json::Value>(&e.data) {
237                                    if let Some(nested_state) = response.get("state") {
238                                        serde_json::from_value(nested_state.clone())
239                                            .map_err(|e| Error::Verification(format!("Failed to parse login state from nested SSE format: {}", e)))
240                                    } else {
241                                        Err(Error::Verification("SSE data is not valid LoginState and has no 'state' field".to_string()))
242                                    }
243                                } else {
244                                    Err(Error::Verification("Failed to parse SSE data as JSON".to_string()))
245                                }
246                            }
247                        };
248                        
249                        match parse_result {
250                            Ok(state) => Ok(LoginEvent::StatusUpdate(state)),
251                            Err(e) => Err(e),
252                        }
253                    }
254                    "close" => Ok(LoginEvent::Close),
255                    "heartbeat" | "" => Ok(LoginEvent::Heartbeat),
256                    _ => Ok(LoginEvent::Heartbeat), // Treat unknown events as heartbeat
257                }
258            }
259            Ok(SSE::Comment(_)) => Ok(LoginEvent::Heartbeat),
260            Err(e) => Err(Error::Network(format!("SSE error: {}", e))),
261        }
262    });
263
264    // Apply timeout if specified
265    if let Some(timeout_secs) = config.timeout_secs {
266        let timeout_stream = tokio_stream::StreamExt::timeout(
267            event_stream,
268            std::time::Duration::from_secs(timeout_secs),
269        )
270        .map(move |result| {
271            result
272                .map_err(|_| Error::Network("SSE stream timeout".to_string()))
273                .and_then(|inner| inner)
274        });
275
276        Ok(Box::pin(timeout_stream) as std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<LoginEvent>> + Send>>)
277    } else {
278        Ok(Box::pin(event_stream) as std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<LoginEvent>> + Send>>)
279    }
280}
281
282/// Check login status once (non-streaming)
283///
284/// This is useful for polling the login status without using SSE,
285/// or as a fallback when SSE is not available.
286///
287/// # Example
288/// ```no_run
289/// # #[cfg(not(target_arch = "wasm32"))]
290/// # use xjp_oidc::{sse::check_login_status, http::ReqwestHttpClient};
291/// # #[cfg(not(target_arch = "wasm32"))]
292/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
293/// let http = ReqwestHttpClient::default();
294///
295/// let state = check_login_status(
296///     "https://auth.example.com",
297///     "login-123",
298///     &http
299/// ).await?;
300///
301/// println!("Current status: {:?}", state.status);
302/// # Ok(())
303/// # }
304/// ```
305#[cfg(not(target_arch = "wasm32"))]
306pub async fn check_login_status(
307    issuer: &str,
308    login_id: &str,
309    http: &dyn HttpClient,
310) -> Result<LoginState> {
311    // Validate parameters
312    if issuer.is_empty() {
313        return Err(Error::InvalidParam("issuer cannot be empty"));
314    }
315    if login_id.is_empty() {
316        return Err(Error::InvalidParam("login_id cannot be empty"));
317    }
318
319    // Build status check endpoint URL
320    let status_endpoint = format!(
321        "{}/auth/login-status/{}",
322        issuer.trim_end_matches('/'),
323        urlencoding::encode(login_id)
324    );
325
326    // Make the request
327    let response = http
328        .get_value(&status_endpoint)
329        .await
330        .map_err(|e| Error::Network(format!("Failed to check login status: {}", e)))?;
331
332    // Parse response - handle both direct and nested formats
333    let state: LoginState = if let Some(nested_state) = response.get("state") {
334        // New format: { "loginId": "...", "state": { ... } }
335        serde_json::from_value(nested_state.clone())
336            .map_err(|e| Error::Verification(format!("Failed to parse login state from nested format: {}", e)))?
337    } else {
338        // Legacy format: direct LoginState object
339        serde_json::from_value(response)
340            .map_err(|e| Error::Verification(format!("Failed to parse login state from direct format: {}", e)))?
341    };
342
343    Ok(state)
344}
345
346// WASM stub implementations
347#[cfg(target_arch = "wasm32")]
348use crate::errors::{Error, Result};
349
350#[cfg(target_arch = "wasm32")]
351pub async fn start_login_session(
352    _issuer: &str,
353    _client_id: &str,
354    _redirect_uri: &str,
355    _http: &dyn crate::http::HttpClient,
356) -> Result<(String, String)> {
357    Err(Error::ServerOnly("SSE login sessions"))
358}
359
360#[cfg(target_arch = "wasm32")]
361pub async fn check_login_status(
362    _issuer: &str,
363    _login_id: &str,
364    _http: &dyn crate::http::HttpClient,
365) -> Result<()> {
366    Err(Error::ServerOnly("SSE login status"))
367}
368
369#[cfg(test)]
370#[cfg(not(target_arch = "wasm32"))]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_login_state_parsing() {
376        let json = serde_json::json!({
377            "status": "SUCCESS",
378            "code": "auth_code_123",
379            "error": null,
380            "created_at": 1234567890
381        });
382
383        let state: LoginState = serde_json::from_value(json).unwrap();
384        assert_eq!(state.status, LoginStatus::Success);
385        assert_eq!(state.code, Some("auth_code_123".to_string()));
386        assert!(state.error.is_none());
387        assert_eq!(state.created_at, 1234567890);
388    }
389
390    #[test]
391    fn test_login_status_serialization() {
392        let status = LoginStatus::Pending;
393        let json = serde_json::to_string(&status).unwrap();
394        assert_eq!(json, r#""PENDING""#);
395
396        let status = LoginStatus::Success;
397        let json = serde_json::to_string(&status).unwrap();
398        assert_eq!(json, r#""SUCCESS""#);
399    }
400
401    #[test]
402    fn test_nested_login_state_parsing() {
403        // Test new nested format from server
404        let nested_json = serde_json::json!({
405            "loginId": "login_123",
406            "state": {
407                "status": "PENDING",
408                "code": null,
409                "error": null,
410                "created_at": 1234567890
411            }
412        });
413
414        // Simulate the parsing logic from check_login_status
415        let state: LoginState = if let Some(nested_state) = nested_json.get("state") {
416            serde_json::from_value(nested_state.clone()).unwrap()
417        } else {
418            serde_json::from_value(nested_json).unwrap()
419        };
420
421        assert_eq!(state.status, LoginStatus::Pending);
422        assert!(state.code.is_none());
423        assert!(state.error.is_none());
424        assert_eq!(state.created_at, 1234567890);
425    }
426
427    #[test]
428    fn test_sse_event_data_parsing() {
429        // Test SSE event data parsing with direct format
430        let direct_data = r#"{"status":"SUCCESS","code":"auth_123","error":null,"created_at":1234567890}"#;
431        let state: LoginState = serde_json::from_str(direct_data).unwrap();
432        assert_eq!(state.status, LoginStatus::Success);
433        assert_eq!(state.code, Some("auth_123".to_string()));
434
435        // Test SSE event data parsing with nested format
436        let nested_data = r#"{"loginId":"login_123","state":{"status":"FAILED","code":null,"error":"Auth failed","created_at":1234567890}}"#;
437        
438        // Simulate the parsing logic from SSE event handler
439        let parse_result: Result<LoginState> = {
440            if let Ok(state) = serde_json::from_str::<LoginState>(nested_data) {
441                Ok(state)
442            } else {
443                if let Ok(response) = serde_json::from_str::<serde_json::Value>(nested_data) {
444                    if let Some(nested_state) = response.get("state") {
445                        serde_json::from_value(nested_state.clone())
446                            .map_err(|e| Error::Verification(format!("Failed to parse: {}", e)))
447                    } else {
448                        Err(Error::Verification("No state field".to_string()))
449                    }
450                } else {
451                    Err(Error::Verification("Invalid JSON".to_string()))
452                }
453            }
454        };
455        
456        let state = parse_result.unwrap();
457        assert_eq!(state.status, LoginStatus::Failed);
458        assert!(state.code.is_none());
459        assert_eq!(state.error, Some("Auth failed".to_string()));
460    }
461}