Skip to main content

wae_authentication/saml/
service.rs

1//! SAML 服务实现
2
3use crate::saml::{
4    AuthnContextComparison, SamlAuthnRequest, SamlConfig, SamlError, SamlLogoutRequest, SamlLogoutResponse, SamlNameIdPolicy,
5    SamlRequestedAuthnContext, SamlResponse, SamlResult, SpMetadataBuilder,
6};
7use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
8use chrono::Utc;
9use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder};
10use std::io::{Read, Write};
11use uuid::Uuid;
12
13/// SAML 服务
14#[derive(Debug, Clone)]
15pub struct SamlService {
16    config: SamlConfig,
17}
18
19impl SamlService {
20    /// 创建新的 SAML 服务
21    pub fn new(config: SamlConfig) -> Self {
22        Self { config }
23    }
24
25    /// 创建认证请求 URL (HTTP Redirect 绑定)
26    pub fn create_authn_request_url(&self) -> SamlResult<String> {
27        let request = self.create_authn_request()?;
28        let xml = self.serialize_authn_request(&request)?;
29        let encoded = self.encode_redirect_request(&xml)?;
30
31        let mut url = self.config.idp.sso_url.clone();
32        url.push_str("?SAMLRequest=");
33        url.push_str(&encoded);
34
35        Ok(url)
36    }
37
38    /// 创建认证请求
39    pub fn create_authn_request(&self) -> SamlResult<SamlAuthnRequest> {
40        let id = format!("id{}", Uuid::new_v4().simple());
41        let name_id_policy = SamlNameIdPolicy::new().with_format(self.config.sp.name_id_format.clone()).with_allow_create(true);
42
43        let authn_context = SamlRequestedAuthnContext::new(vec![
44            "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport".to_string(),
45        ])
46        .with_comparison(AuthnContextComparison::Minimum);
47
48        Ok(SamlAuthnRequest::new(id, &self.config.sp.entity_id)
49            .with_destination(&self.config.idp.sso_url)
50            .with_protocol_binding(self.config.idp.sso_binding)
51            .with_acs_url(&self.config.sp.acs_url)
52            .with_name_id_policy(name_id_policy)
53            .with_authn_context(authn_context))
54    }
55
56    /// 处理认证响应
57    pub fn process_authn_response(&self, saml_response: &str) -> SamlResult<SamlResponse> {
58        let decoded = BASE64_STANDARD.decode(saml_response).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
59
60        let xml = String::from_utf8(decoded).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
61
62        let response: SamlResponse = quick_xml::de::from_str(&xml).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
63
64        self.validate_response(&response)?;
65
66        Ok(response)
67    }
68
69    /// 验证响应
70    fn validate_response(&self, response: &SamlResponse) -> SamlResult<()> {
71        if !response.is_success() {
72            return Err(SamlError::InvalidResponse(format!("SAML response status: {:?}", response.status.code())));
73        }
74
75        if self.config.validate_issuer {
76            if response.issuer != self.config.idp.entity_id {
77                return Err(SamlError::IssuerValidationFailed {
78                    expected: self.config.idp.entity_id.clone(),
79                    actual: response.issuer.clone(),
80                });
81            }
82        }
83
84        if let Some(ref assertion) = response.assertion {
85            self.validate_assertion(assertion)?;
86        }
87
88        Ok(())
89    }
90
91    /// 验证断言
92    fn validate_assertion(&self, assertion: &crate::saml::SamlAssertion) -> SamlResult<()> {
93        let now = Utc::now().timestamp();
94
95        if let Some(ref conditions) = assertion.conditions {
96            if let Some(not_before) = conditions.not_before {
97                if now < not_before.timestamp() - self.config.clock_skew_seconds {
98                    return Err(SamlError::AssertionNotYetValid);
99                }
100            }
101
102            if let Some(not_on_or_after) = conditions.not_on_or_after {
103                if now >= not_on_or_after.timestamp() + self.config.clock_skew_seconds {
104                    return Err(SamlError::AssertionExpired);
105                }
106            }
107
108            if self.config.validate_audience {
109                if let Some(ref restriction) = conditions.audience_restriction {
110                    if !restriction.audience.contains(&self.config.sp.entity_id) {
111                        return Err(SamlError::AudienceValidationFailed {
112                            expected: self.config.sp.entity_id.clone(),
113                            actual: restriction.audience.first().cloned().unwrap_or_default(),
114                        });
115                    }
116                }
117            }
118        }
119
120        Ok(())
121    }
122
123    /// 创建登出请求 URL
124    pub fn create_logout_request_url(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<String> {
125        let slo_url =
126            self.config.idp.slo_url.as_ref().ok_or_else(|| SamlError::ConfigurationError("SLO URL not configured".into()))?;
127
128        let request = self.create_logout_request(name_id, session_index)?;
129        let xml = self.serialize_logout_request(&request)?;
130        let encoded = self.encode_redirect_request(&xml)?;
131
132        let mut url = slo_url.clone();
133        url.push_str("?SAMLRequest=");
134        url.push_str(&encoded);
135
136        Ok(url)
137    }
138
139    /// 创建登出请求
140    pub fn create_logout_request(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<SamlLogoutRequest> {
141        let id = format!("id{}", Uuid::new_v4().simple());
142        let slo_url =
143            self.config.idp.slo_url.as_ref().ok_or_else(|| SamlError::ConfigurationError("SLO URL not configured".into()))?;
144
145        let mut request = SamlLogoutRequest::new(id, &self.config.sp.entity_id)
146            .with_destination(slo_url)
147            .with_name_id(crate::saml::SamlNameId::new(name_id));
148
149        if let Some(idx) = session_index {
150            request = request.with_session_index(idx);
151        }
152
153        Ok(request)
154    }
155
156    /// 处理登出响应
157    pub fn process_logout_response(&self, saml_response: &str) -> SamlResult<SamlLogoutResponse> {
158        let decoded = BASE64_STANDARD.decode(saml_response).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
159
160        let xml = String::from_utf8(decoded).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
161
162        let response: SamlLogoutResponse =
163            quick_xml::de::from_str(&xml).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
164
165        Ok(response)
166    }
167
168    /// 生成 SP 元数据
169    pub fn generate_sp_metadata(&self) -> String {
170        let builder = SpMetadataBuilder::new(&self.config.sp.entity_id, &self.config.sp.acs_url)
171            .with_want_assertions_signed(self.config.sp.want_assertions_signed)
172            .with_authn_requests_signed(self.config.sp.want_response_signed);
173
174        let metadata = if let Some(ref slo_url) = self.config.sp.slo_url { builder.with_slo_url(slo_url) } else { builder };
175
176        let entity = metadata.build();
177
178        quick_xml::se::to_string(&entity).unwrap_or_default()
179    }
180
181    /// 序列化认证请求
182    fn serialize_authn_request(&self, request: &SamlAuthnRequest) -> SamlResult<String> {
183        quick_xml::se::to_string(request).map_err(|e| SamlError::XmlParsingError(e.to_string()))
184    }
185
186    /// 序列化登出请求
187    fn serialize_logout_request(&self, request: &SamlLogoutRequest) -> SamlResult<String> {
188        quick_xml::se::to_string(request).map_err(|e| SamlError::XmlParsingError(e.to_string()))
189    }
190
191    /// 编码重定向请求 (Deflate + Base64)
192    fn encode_redirect_request(&self, xml: &str) -> SamlResult<String> {
193        let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
194        encoder.write_all(xml.as_bytes()).map_err(|e| SamlError::CompressionError(e.to_string()))?;
195        let compressed = encoder.finish().map_err(|e| SamlError::CompressionError(e.to_string()))?;
196
197        Ok(BASE64_STANDARD.encode(&compressed))
198    }
199
200    /// 解码重定向请求
201    pub fn decode_redirect_request(&self, encoded: &str) -> SamlResult<String> {
202        let decoded = BASE64_STANDARD.decode(encoded).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
203
204        let mut decoder = DeflateDecoder::new(&decoded[..]);
205        let mut decompressed = String::new();
206        decoder.read_to_string(&mut decompressed).map_err(|e| SamlError::CompressionError(e.to_string()))?;
207
208        Ok(decompressed)
209    }
210
211    /// 获取配置
212    pub fn config(&self) -> &SamlConfig {
213        &self.config
214    }
215}
216
217/// 便捷函数:创建 SAML 服务
218pub fn create_saml_service(
219    sp_entity_id: impl Into<String>,
220    sp_acs_url: impl Into<String>,
221    idp_entity_id: impl Into<String>,
222    idp_sso_url: impl Into<String>,
223    idp_certificate: impl Into<String>,
224) -> SamlService {
225    use crate::saml::{IdentityProviderConfig, ServiceProviderConfig};
226
227    let sp = ServiceProviderConfig::new(sp_entity_id, sp_acs_url);
228    let idp = IdentityProviderConfig::new(idp_entity_id, idp_sso_url, idp_certificate);
229    let config = SamlConfig::new(sp, idp);
230
231    SamlService::new(config)
232}