1use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25use parking_lot::RwLock;
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29use url::Url;
30
31use solid_pod_rs::security::ssrf::is_safe_url;
32
33pub const CLIENT_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
36
37#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct RegistrationRequest {
41 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub client_id: Option<String>,
45 #[serde(default)]
46 pub redirect_uris: Vec<String>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub client_name: Option<String>,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub client_uri: Option<String>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub logo_uri: Option<String>,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub policy_uri: Option<String>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub tos_uri: Option<String>,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub scope: Option<String>,
59 #[serde(default)]
60 pub grant_types: Vec<String>,
61 #[serde(default)]
62 pub response_types: Vec<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub token_endpoint_auth_method: Option<String>,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub application_type: Option<String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ClientDocument {
72 pub client_id: String,
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub client_secret: Option<String>,
77 pub client_id_issued_at: u64,
79 pub redirect_uris: Vec<String>,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub client_name: Option<String>,
84 pub grant_types: Vec<String>,
86 pub response_types: Vec<String>,
88 pub token_endpoint_auth_method: String,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub application_type: Option<String>,
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub scope: Option<String>,
96 #[serde(skip_serializing_if = "Option::is_none")]
100 pub client_id_document_url: Option<String>,
101}
102
103impl ClientDocument {
104 fn now_secs() -> u64 {
105 use std::time::{SystemTime, UNIX_EPOCH};
106 SystemTime::now()
107 .duration_since(UNIX_EPOCH)
108 .map(|d| d.as_secs())
109 .unwrap_or(0)
110 }
111}
112
113#[derive(Debug, Error)]
115pub enum RegError {
116 #[error("invalid registration: {0}")]
118 InvalidRequest(String),
119
120 #[error("SSRF-blocked: {0}")]
122 Ssrf(String),
123
124 #[error("fetch failed: {0}")]
126 Fetch(String),
127
128 #[error("invalid client document: {0}")]
130 InvalidDocument(String),
131}
132
133#[derive(Clone)]
140pub struct ClientStore {
141 inner: Arc<RwLock<ClientStoreInner>>,
142 http: Option<HttpClient>,
143 allow_unsafe_urls: bool,
148}
149
150impl Default for ClientStore {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[derive(Default)]
157struct ClientStoreInner {
158 registered: HashMap<String, ClientDocument>,
160 cache: HashMap<String, (ClientDocument, Instant)>,
162}
163
164impl ClientStore {
165 pub fn new() -> Self {
169 Self {
170 inner: Arc::new(RwLock::new(ClientStoreInner::default())),
171 http: HttpClient::builder()
172 .timeout(Duration::from_secs(10))
173 .redirect(reqwest::redirect::Policy::limited(3))
174 .build()
175 .ok(),
176 allow_unsafe_urls: false,
177 }
178 }
179
180 pub fn with_http(mut self, client: HttpClient) -> Self {
183 self.http = Some(client);
184 self
185 }
186
187 #[doc(hidden)]
192 pub fn allow_unsafe_urls_for_testing(mut self) -> Self {
193 self.allow_unsafe_urls = true;
194 self
195 }
196
197 pub fn insert(&self, client: ClientDocument) {
199 let mut inner = self.inner.write();
200 inner.registered.insert(client.client_id.clone(), client);
201 }
202
203 pub async fn find(&self, client_id: &str) -> Result<Option<ClientDocument>, RegError> {
207 if let Some(doc) = self.inner.read().registered.get(client_id).cloned() {
209 return Ok(Some(doc));
210 }
211
212 {
214 let inner = self.inner.read();
215 if let Some((doc, ts)) = inner.cache.get(client_id) {
216 if ts.elapsed() < CLIENT_CACHE_TTL {
217 return Ok(Some(doc.clone()));
218 }
219 }
220 }
221
222 if client_id.starts_with("http://") || client_id.starts_with("https://") {
224 let doc = self.fetch_client_document(client_id).await?;
225 let mut inner = self.inner.write();
226 inner
227 .cache
228 .insert(client_id.to_string(), (doc.clone(), Instant::now()));
229 return Ok(Some(doc));
230 }
231
232 Ok(None)
233 }
234
235 async fn fetch_client_document(&self, url: &str) -> Result<ClientDocument, RegError> {
236 if !self.allow_unsafe_urls {
241 is_safe_url(url).map_err(|e| RegError::Ssrf(e.to_string()))?;
242 }
243
244 let parsed = Url::parse(url)
247 .map_err(|e| RegError::InvalidDocument(format!("URL parse: {e}")))?;
248 if !matches!(parsed.scheme(), "http" | "https") {
254 return Err(RegError::InvalidDocument(format!(
255 "unsupported scheme: {}",
256 parsed.scheme()
257 )));
258 }
259
260 let http = self
261 .http
262 .as_ref()
263 .ok_or_else(|| RegError::Fetch("no HTTP client configured".into()))?;
264 let resp = http
265 .get(url)
266 .header("Accept", "application/ld+json, application/json")
267 .send()
268 .await
269 .map_err(|e| RegError::Fetch(e.to_string()))?;
270
271 if !resp.status().is_success() {
272 return Err(RegError::Fetch(format!(
273 "HTTP {} from {url}",
274 resp.status()
275 )));
276 }
277
278 let body: serde_json::Value = resp
279 .json()
280 .await
281 .map_err(|e| RegError::InvalidDocument(format!("JSON parse: {e}")))?;
282
283 if let Some(declared) = body.get("client_id").and_then(|v| v.as_str()) {
286 if declared != url {
287 return Err(RegError::InvalidDocument(format!(
288 "client_id mismatch: document says {declared}, URL is {url}"
289 )));
290 }
291 }
292
293 let redirect_uris: Vec<String> = body
296 .get("redirect_uris")
297 .and_then(|v| v.as_array())
298 .map(|arr| {
299 arr.iter()
300 .filter_map(|v| v.as_str().map(str::to_string))
301 .collect()
302 })
303 .unwrap_or_default();
304
305 if redirect_uris.is_empty() {
306 return Err(RegError::InvalidDocument(
307 "Client Identifier Document is missing redirect_uris".into(),
308 ));
309 }
310
311 let client_name = body
312 .get("client_name")
313 .and_then(|v| v.as_str())
314 .or_else(|| body.get("name").and_then(|v| v.as_str()))
315 .map(str::to_string);
316
317 let scope = body
318 .get("scope")
319 .and_then(|v| v.as_str())
320 .map(str::to_string)
321 .or_else(|| Some("openid webid".into()));
322
323 Ok(ClientDocument {
324 client_id: url.to_string(),
325 client_secret: None,
326 client_id_issued_at: ClientDocument::now_secs(),
327 redirect_uris,
328 client_name,
329 grant_types: vec!["authorization_code".into(), "refresh_token".into()],
332 response_types: vec!["code".into()],
333 token_endpoint_auth_method: "none".into(),
334 application_type: Some("web".into()),
335 scope,
336 client_id_document_url: Some(url.to_string()),
337 })
338 }
339}
340
341pub async fn register_client(
347 store: &ClientStore,
348 req: RegistrationRequest,
349) -> Result<ClientDocument, RegError> {
350 if let Some(id) = req.client_id.as_deref() {
354 if id.starts_with("http://") || id.starts_with("https://") {
355 if let Some(doc) = store.find(id).await? {
357 return Ok(doc);
358 }
359 return Err(RegError::InvalidDocument(
360 "Client Identifier Document fetch returned no document".into(),
361 ));
362 }
363 }
364
365 if req.redirect_uris.is_empty() {
366 return Err(RegError::InvalidRequest(
367 "redirect_uris is required for authorization-code flow".into(),
368 ));
369 }
370
371 let id_ts = u128::from(ClientDocument::now_secs()).max(1);
373 let ts36 = to_base36(id_ts);
374 let rand_tail: String = rand_base36(8);
375 let client_id = format!("client_{ts36}_{rand_tail}");
376
377 let auth_method = req
378 .token_endpoint_auth_method
379 .clone()
380 .unwrap_or_else(|| "none".into());
381 let client_secret = if auth_method == "none" {
382 None
383 } else {
384 Some(format!("secret-{}", uuid::Uuid::new_v4()))
385 };
386
387 let grant_types = if req.grant_types.is_empty() {
388 vec!["authorization_code".into(), "refresh_token".into()]
389 } else {
390 req.grant_types.clone()
391 };
392 let response_types = if req.response_types.is_empty() {
393 vec!["code".into()]
394 } else {
395 req.response_types.clone()
396 };
397
398 let doc = ClientDocument {
399 client_id,
400 client_secret,
401 client_id_issued_at: ClientDocument::now_secs(),
402 redirect_uris: req.redirect_uris,
403 client_name: req.client_name,
404 grant_types,
405 response_types,
406 token_endpoint_auth_method: auth_method,
407 application_type: req.application_type.or_else(|| Some("web".into())),
408 scope: req.scope.or_else(|| Some("openid webid".into())),
409 client_id_document_url: None,
410 };
411 store.insert(doc.clone());
412 Ok(doc)
413}
414
415fn to_base36(mut n: u128) -> String {
416 if n == 0 {
417 return "0".into();
418 }
419 const ALPHA: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
420 let mut out = Vec::new();
421 while n > 0 {
422 out.push(ALPHA[(n % 36) as usize]);
423 n /= 36;
424 }
425 out.reverse();
426 String::from_utf8(out).unwrap_or_default()
428}
429
430fn rand_base36(len: usize) -> String {
431 use rand::Rng;
432 const ALPHA: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz";
433 let mut rng = rand::thread_rng();
434 (0..len)
435 .map(|_| ALPHA[rng.gen_range(0..36)] as char)
436 .collect()
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use wiremock::matchers::{method, path};
443 use wiremock::{Mock, MockServer, ResponseTemplate};
444
445 #[tokio::test]
446 async fn opaque_registration_assigns_prefixed_client_id() {
447 let store = ClientStore::new();
448 let req = RegistrationRequest {
449 redirect_uris: vec!["https://app.example/cb".into()],
450 client_name: Some("App".into()),
451 ..Default::default()
452 };
453 let doc = register_client(&store, req).await.unwrap();
454 assert!(doc.client_id.starts_with("client_"));
455 assert!(doc.client_secret.is_none());
457 let again = store.find(&doc.client_id).await.unwrap().unwrap();
459 assert_eq!(again.client_id, doc.client_id);
460 }
461
462 #[tokio::test]
463 async fn registration_without_redirect_uris_is_rejected() {
464 let store = ClientStore::new();
465 let err = register_client(
466 &store,
467 RegistrationRequest {
468 ..Default::default()
469 },
470 )
471 .await
472 .unwrap_err();
473 assert!(matches!(err, RegError::InvalidRequest(_)));
474 }
475
476 #[tokio::test]
477 async fn client_identifier_document_is_fetched_and_cached() {
478 let server = MockServer::start().await;
479 let cid_url = format!("{}/client#id", server.uri());
480
481 let body = serde_json::json!({
482 "@context": "https://www.w3.org/ns/solid/oidc-context.jsonld",
483 "client_id": cid_url,
484 "client_name": "Federated App",
485 "redirect_uris": ["https://app.example/cb"],
486 "grant_types": ["authorization_code", "refresh_token"],
487 "scope": "openid webid profile"
488 });
489
490 Mock::given(method("GET"))
491 .and(path("/client"))
492 .respond_with(ResponseTemplate::new(200).set_body_json(body.clone()))
493 .expect(1) .mount(&server)
495 .await;
496
497 let store = ClientStore::new().allow_unsafe_urls_for_testing();
498 let doc = store.find(&cid_url).await.unwrap().unwrap();
499 assert_eq!(doc.client_id, cid_url);
500 assert_eq!(doc.redirect_uris, vec!["https://app.example/cb".to_string()]);
501 assert_eq!(doc.client_name.as_deref(), Some("Federated App"));
502 assert_eq!(doc.client_id_document_url.as_deref(), Some(cid_url.as_str()));
503
504 let _ = store.find(&cid_url).await.unwrap().unwrap();
506 }
507
508 #[tokio::test]
509 async fn client_identifier_document_rejects_id_mismatch() {
510 let server = MockServer::start().await;
511 let cid_url = format!("{}/client", server.uri());
512
513 Mock::given(method("GET"))
514 .and(path("/client"))
515 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
516 "client_id": "https://malicious.example/evil",
517 "redirect_uris": ["https://malicious.example/cb"]
518 })))
519 .mount(&server)
520 .await;
521
522 let store = ClientStore::new().allow_unsafe_urls_for_testing();
523 let err = store.find(&cid_url).await.unwrap_err();
524 assert!(matches!(err, RegError::InvalidDocument(_)));
525 }
526
527 #[tokio::test]
528 async fn client_identifier_document_rejects_private_ip() {
529 let store = ClientStore::new();
532 let err = store.find("http://127.0.0.1/client").await.unwrap_err();
533 assert!(matches!(err, RegError::Ssrf(_)));
534 }
535
536 #[tokio::test]
537 async fn client_identifier_document_requires_redirect_uris() {
538 let server = MockServer::start().await;
539 let cid_url = format!("{}/client", server.uri());
540
541 Mock::given(method("GET"))
542 .and(path("/client"))
543 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
544 "client_id": cid_url,
545 "client_name": "Incomplete"
546 })))
547 .mount(&server)
548 .await;
549
550 let store = ClientStore::new().allow_unsafe_urls_for_testing();
551 let err = store.find(&cid_url).await.unwrap_err();
552 assert!(matches!(err, RegError::InvalidDocument(_)));
553 }
554
555 #[test]
556 fn base36_encode_sanity() {
557 assert_eq!(to_base36(0), "0");
558 assert_eq!(to_base36(35), "z");
559 assert_eq!(to_base36(36), "10");
560 }
561}