wae_authentication/saml/
service.rs1use crate::saml::{
4 AuthnContextComparison, SamlAuthnRequest, SamlConfig, SamlError, SamlLogoutRequest, SamlLogoutResponse, SamlNameIdPolicy,
5 SamlRequestedAuthnContext, SamlResponse, SamlResult, SpMetadataBuilder,
6};
7use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
8use chrono::Utc;
9use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder};
10use std::io::{Read, Write};
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
15pub struct SamlService {
16 config: SamlConfig,
17}
18
19impl SamlService {
20 pub fn new(config: SamlConfig) -> Self {
22 Self { config }
23 }
24
25 pub fn create_authn_request_url(&self) -> SamlResult<String> {
27 let request = self.create_authn_request()?;
28 let xml = self.serialize_authn_request(&request)?;
29 let encoded = self.encode_redirect_request(&xml)?;
30
31 let mut url = self.config.idp.sso_url.clone();
32 url.push_str("?SAMLRequest=");
33 url.push_str(&encoded);
34
35 Ok(url)
36 }
37
38 pub fn create_authn_request(&self) -> SamlResult<SamlAuthnRequest> {
40 let id = format!("id{}", Uuid::new_v4().simple());
41 let name_id_policy = SamlNameIdPolicy::new().with_format(self.config.sp.name_id_format.clone()).with_allow_create(true);
42
43 let authn_context = SamlRequestedAuthnContext::new(vec![
44 "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport".to_string(),
45 ])
46 .with_comparison(AuthnContextComparison::Minimum);
47
48 Ok(SamlAuthnRequest::new(id, &self.config.sp.entity_id)
49 .with_destination(&self.config.idp.sso_url)
50 .with_protocol_binding(self.config.idp.sso_binding)
51 .with_acs_url(&self.config.sp.acs_url)
52 .with_name_id_policy(name_id_policy)
53 .with_authn_context(authn_context))
54 }
55
56 pub fn process_authn_response(&self, saml_response: &str) -> SamlResult<SamlResponse> {
58 let decoded = BASE64_STANDARD.decode(saml_response).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
59
60 let xml = String::from_utf8(decoded).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
61
62 let response: SamlResponse = quick_xml::de::from_str(&xml).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
63
64 self.validate_response(&response)?;
65
66 Ok(response)
67 }
68
69 fn validate_response(&self, response: &SamlResponse) -> SamlResult<()> {
71 if !response.is_success() {
72 return Err(SamlError::InvalidResponse(format!("SAML response status: {:?}", response.status.code())));
73 }
74
75 if self.config.validate_issuer {
76 if response.issuer != self.config.idp.entity_id {
77 return Err(SamlError::IssuerValidationFailed {
78 expected: self.config.idp.entity_id.clone(),
79 actual: response.issuer.clone(),
80 });
81 }
82 }
83
84 if let Some(ref assertion) = response.assertion {
85 self.validate_assertion(assertion)?;
86 }
87
88 Ok(())
89 }
90
91 fn validate_assertion(&self, assertion: &crate::saml::SamlAssertion) -> SamlResult<()> {
93 let now = Utc::now().timestamp();
94
95 if let Some(ref conditions) = assertion.conditions {
96 if let Some(not_before) = conditions.not_before {
97 if now < not_before.timestamp() - self.config.clock_skew_seconds {
98 return Err(SamlError::AssertionNotYetValid);
99 }
100 }
101
102 if let Some(not_on_or_after) = conditions.not_on_or_after {
103 if now >= not_on_or_after.timestamp() + self.config.clock_skew_seconds {
104 return Err(SamlError::AssertionExpired);
105 }
106 }
107
108 if self.config.validate_audience {
109 if let Some(ref restriction) = conditions.audience_restriction {
110 if !restriction.audience.contains(&self.config.sp.entity_id) {
111 return Err(SamlError::AudienceValidationFailed {
112 expected: self.config.sp.entity_id.clone(),
113 actual: restriction.audience.first().cloned().unwrap_or_default(),
114 });
115 }
116 }
117 }
118 }
119
120 Ok(())
121 }
122
123 pub fn create_logout_request_url(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<String> {
125 let slo_url =
126 self.config.idp.slo_url.as_ref().ok_or_else(|| SamlError::ConfigurationError("SLO URL not configured".into()))?;
127
128 let request = self.create_logout_request(name_id, session_index)?;
129 let xml = self.serialize_logout_request(&request)?;
130 let encoded = self.encode_redirect_request(&xml)?;
131
132 let mut url = slo_url.clone();
133 url.push_str("?SAMLRequest=");
134 url.push_str(&encoded);
135
136 Ok(url)
137 }
138
139 pub fn create_logout_request(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<SamlLogoutRequest> {
141 let id = format!("id{}", Uuid::new_v4().simple());
142 let slo_url =
143 self.config.idp.slo_url.as_ref().ok_or_else(|| SamlError::ConfigurationError("SLO URL not configured".into()))?;
144
145 let mut request = SamlLogoutRequest::new(id, &self.config.sp.entity_id)
146 .with_destination(slo_url)
147 .with_name_id(crate::saml::SamlNameId::new(name_id));
148
149 if let Some(idx) = session_index {
150 request = request.with_session_index(idx);
151 }
152
153 Ok(request)
154 }
155
156 pub fn process_logout_response(&self, saml_response: &str) -> SamlResult<SamlLogoutResponse> {
158 let decoded = BASE64_STANDARD.decode(saml_response).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
159
160 let xml = String::from_utf8(decoded).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
161
162 let response: SamlLogoutResponse =
163 quick_xml::de::from_str(&xml).map_err(|e| SamlError::XmlParsingError(e.to_string()))?;
164
165 Ok(response)
166 }
167
168 pub fn generate_sp_metadata(&self) -> String {
170 let builder = SpMetadataBuilder::new(&self.config.sp.entity_id, &self.config.sp.acs_url)
171 .with_want_assertions_signed(self.config.sp.want_assertions_signed)
172 .with_authn_requests_signed(self.config.sp.want_response_signed);
173
174 let metadata = if let Some(ref slo_url) = self.config.sp.slo_url { builder.with_slo_url(slo_url) } else { builder };
175
176 let entity = metadata.build();
177
178 quick_xml::se::to_string(&entity).unwrap_or_default()
179 }
180
181 fn serialize_authn_request(&self, request: &SamlAuthnRequest) -> SamlResult<String> {
183 quick_xml::se::to_string(request).map_err(|e| SamlError::XmlParsingError(e.to_string()))
184 }
185
186 fn serialize_logout_request(&self, request: &SamlLogoutRequest) -> SamlResult<String> {
188 quick_xml::se::to_string(request).map_err(|e| SamlError::XmlParsingError(e.to_string()))
189 }
190
191 fn encode_redirect_request(&self, xml: &str) -> SamlResult<String> {
193 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
194 encoder.write_all(xml.as_bytes()).map_err(|e| SamlError::CompressionError(e.to_string()))?;
195 let compressed = encoder.finish().map_err(|e| SamlError::CompressionError(e.to_string()))?;
196
197 Ok(BASE64_STANDARD.encode(&compressed))
198 }
199
200 pub fn decode_redirect_request(&self, encoded: &str) -> SamlResult<String> {
202 let decoded = BASE64_STANDARD.decode(encoded).map_err(|e| SamlError::Base64DecodeError(e.to_string()))?;
203
204 let mut decoder = DeflateDecoder::new(&decoded[..]);
205 let mut decompressed = String::new();
206 decoder.read_to_string(&mut decompressed).map_err(|e| SamlError::CompressionError(e.to_string()))?;
207
208 Ok(decompressed)
209 }
210
211 pub fn config(&self) -> &SamlConfig {
213 &self.config
214 }
215}
216
217pub fn create_saml_service(
219 sp_entity_id: impl Into<String>,
220 sp_acs_url: impl Into<String>,
221 idp_entity_id: impl Into<String>,
222 idp_sso_url: impl Into<String>,
223 idp_certificate: impl Into<String>,
224) -> SamlService {
225 use crate::saml::{IdentityProviderConfig, ServiceProviderConfig};
226
227 let sp = ServiceProviderConfig::new(sp_entity_id, sp_acs_url);
228 let idp = IdentityProviderConfig::new(idp_entity_id, idp_sso_url, idp_certificate);
229 let config = SamlConfig::new(sp, idp);
230
231 SamlService::new(config)
232}