1#[cfg(feature = "native")]
2use crate::app_schema::default_app_schema;
3#[cfg(feature = "native")]
4use crate::binary_snapshot::decode_binary_snapshot_rows;
5use crate::binary_snapshot::SnapshotChunkRows;
6#[cfg(feature = "native")]
7use crate::binary_sync_pack::{decode_binary_sync_pack, is_binary_sync_pack_content_type};
8use crate::error::{ErrorKind, Result, SyncularError};
9use crate::protocol::*;
10#[cfg(feature = "native")]
11use flate2::read::GzDecoder;
12#[cfg(feature = "native")]
13use reqwest::blocking::Body as BlockingBody;
14#[cfg(feature = "native")]
15use reqwest::blocking::Client as HttpClient;
16#[cfg(feature = "native")]
17use reqwest::Method;
18#[cfg(test)]
19use serde_json::json;
20use serde_json::Value;
21#[cfg(feature = "native")]
22use sha2::{Digest, Sha256};
23use std::collections::BTreeMap;
24#[cfg(feature = "native")]
25use std::fs;
26#[cfg(feature = "native")]
27use std::fs::File;
28#[cfg(feature = "native")]
29use std::io::{Read, Write};
30#[cfg(feature = "native")]
31use std::net::{TcpStream, ToSocketAddrs};
32#[cfg(feature = "native")]
33use std::path::{Path, PathBuf};
34#[cfg(feature = "native")]
35use std::sync::{Arc, Mutex};
36#[cfg(feature = "native")]
37use std::time::{Duration, SystemTime};
38#[cfg(feature = "native")]
39use tungstenite::client::IntoClientRequest;
40#[cfg(feature = "native")]
41use tungstenite::stream::MaybeTlsStream;
42#[cfg(feature = "native")]
43use tungstenite::{client_tls_with_config, Message, WebSocket};
44#[cfg(feature = "native")]
45use uuid::Uuid;
46
47#[cfg(all(feature = "web-transport", target_arch = "wasm32"))]
48pub mod web;
49
50pub type SyncAuthHeaders = BTreeMap<String, String>;
51
52#[cfg(feature = "native")]
53#[derive(Debug, Clone)]
54pub struct SyncRequestToSign {
55 pub method: String,
56 pub url: String,
57 pub body: Vec<u8>,
58}
59
60#[cfg(feature = "native")]
61pub type SyncAuthSigner =
62 Arc<dyn Fn(SyncRequestToSign) -> std::result::Result<SyncAuthHeaders, String> + Send + Sync>;
63
64pub trait SyncAuthHeaderStore {
65 fn set_auth_headers(&mut self, headers: SyncAuthHeaders);
66}
67
68#[cfg(feature = "native")]
69pub trait SyncAuthSignerStore {
70 fn set_auth_signer(&mut self, signer: Option<SyncAuthSigner>);
71}
72
73#[cfg(feature = "native")]
74#[derive(Debug, Clone)]
75pub struct SyncTransportConfig {
76 pub base_url: String,
77 pub client_id: String,
78 pub actor_id: String,
79 pub timeouts: SyncTransportTimeouts,
80}
81
82#[cfg(feature = "native")]
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub struct SyncTransportTimeouts {
85 pub http_connect: Duration,
86 pub http_request: Duration,
87 pub http_response_body: Duration,
88 pub websocket_open: Duration,
89 pub websocket_idle: Duration,
90 pub websocket_push_response: Duration,
91 pub websocket_shutdown: Duration,
92}
93
94#[cfg(feature = "native")]
95impl Default for SyncTransportTimeouts {
96 fn default() -> Self {
97 Self {
98 http_connect: Duration::from_secs(10),
99 http_request: Duration::from_secs(30),
100 http_response_body: Duration::from_secs(30),
101 websocket_open: Duration::from_secs(10),
102 websocket_idle: Duration::from_secs(1),
103 websocket_push_response: Duration::from_secs(10),
104 websocket_shutdown: Duration::from_secs(2),
105 }
106 }
107}
108
109#[cfg(feature = "native")]
110impl SyncTransportConfig {
111 pub fn new(
112 base_url: impl Into<String>,
113 client_id: impl Into<String>,
114 actor_id: impl Into<String>,
115 ) -> Self {
116 Self {
117 base_url: base_url.into(),
118 client_id: client_id.into(),
119 actor_id: actor_id.into(),
120 timeouts: SyncTransportTimeouts::default(),
121 }
122 }
123}
124
125#[cfg(feature = "native")]
126pub struct HttpSyncTransport {
127 http: HttpClient,
128 config: SyncTransportConfig,
129 auth_headers: SyncAuthHeaders,
130 auth_signer: Option<SyncAuthSigner>,
131 schema_version: i32,
132 sync_trace_context: Mutex<Option<SyncTraceContext>>,
133}
134
135#[cfg(feature = "native")]
136pub struct RealtimeSocket {
137 socket: WebSocket<MaybeTlsStream<TcpStream>>,
138 push_response_timeout: Duration,
139 shutdown_timeout: Duration,
140}
141
142#[cfg(feature = "native")]
143#[derive(Debug, Clone, PartialEq, Eq)]
144struct SyncTraceContext {
145 sync_attempt_id: String,
146 trace_id: String,
147 span_id: String,
148}
149
150pub use crate::protocol::{RealtimePresenceEntry, RealtimePresenceEvent};
151
152#[derive(Debug, Clone)]
153pub enum RealtimeEvent {
154 Sync,
155 Presence(RealtimePresenceEvent),
156 Other(String),
157}
158
159pub trait SyncTransport {
160 type Realtime: RealtimeTransport;
161
162 fn post_sync(&self, request: &CombinedRequest) -> Result<CombinedResponse>;
163 fn fetch_snapshot_chunk_rows(
164 &self,
165 chunk: &SnapshotChunkRef,
166 scopes: &ScopeValues,
167 ) -> Result<SnapshotChunkRows>;
168 fn fetch_snapshot_artifact_bytes(
169 &self,
170 _artifact: &ScopedSnapshotArtifactRef,
171 _scopes: &ScopeValues,
172 ) -> Result<Vec<u8>> {
173 Err(SyncularError::protocol_message(
174 "snapshot artifact transport is not implemented",
175 ))
176 }
177 fn connect_realtime(&self) -> Result<Self::Realtime>;
178}
179
180pub trait BlobTransport {
181 fn upload_blob(&self, blob: &BlobRef, bytes: &[u8]) -> Result<()>;
182 fn download_blob(&self, blob: &BlobRef) -> Result<Vec<u8>>;
183
184 #[cfg(feature = "native")]
185 fn upload_blob_file(&self, blob: &BlobRef, path: &Path) -> Result<()> {
186 let bytes = fs::read(path).map_err(|err| {
187 SyncularError::storage(err).context(format!("read blob file {path:?}"))
188 })?;
189 self.upload_blob(blob, &bytes)
190 }
191
192 #[cfg(feature = "native")]
193 fn download_blob_to_file(&self, blob: &BlobRef, path: &Path) -> Result<()> {
194 let bytes = self.download_blob(blob)?;
195 fs::write(path, bytes)
196 .map_err(|err| SyncularError::storage(err).context(format!("write blob file {path:?}")))
197 }
198}
199
200pub trait RealtimeTransport {
201 fn push_commit(&mut self, commit: PushCommitRequest) -> Result<PushCommitResponse>;
202 fn send_presence(
203 &mut self,
204 action: &str,
205 scope_key: &str,
206 metadata: Option<&Value>,
207 ) -> Result<()> {
208 let _ = (action, scope_key, metadata);
209 Err(SyncularError::message(
210 ErrorKind::Transport,
211 "realtime presence is not supported by this transport",
212 ))
213 }
214 fn read_event(&mut self) -> Result<Option<RealtimeEvent>>;
215 fn close(&mut self);
216}
217
218#[cfg(feature = "native")]
219impl HttpSyncTransport {
220 pub fn new(config: SyncTransportConfig) -> Self {
221 let http = HttpClient::builder()
222 .connect_timeout(config.timeouts.http_connect)
223 .timeout(config.timeouts.http_request)
224 .build()
225 .unwrap_or_else(|_| HttpClient::new());
226 Self {
227 http,
228 config,
229 auth_headers: SyncAuthHeaders::new(),
230 auth_signer: None,
231 schema_version: default_app_schema().current_schema_version(),
232 sync_trace_context: Mutex::new(None),
233 }
234 }
235
236 pub fn with_schema_version(mut self, schema_version: i32) -> Self {
237 self.schema_version = schema_version;
238 self
239 }
240
241 pub fn issue_auth_lease(
242 &self,
243 request: &AuthLeaseIssueRequest,
244 ) -> Result<AuthLeaseIssueResponse> {
245 let url = format!(
246 "{}/auth-leases/issue",
247 self.config.base_url.trim_end_matches('/')
248 );
249 let body = serde_json::to_vec(request)?;
250 let builder = self
251 .http
252 .post(&url)
253 .header("accept", "application/json")
254 .header("content-type", "application/json")
255 .header("x-syncular-schema-version", self.schema_version.to_string());
256 let response = self
257 .apply_auth(builder, "POST", &url, &body)?
258 .body(body)
259 .send()
260 .map_err(|err| SyncularError::transport(err).context(format!("POST {url}")))?;
261 let status = response.status();
262 if !status.is_success() {
263 let body = response.text().unwrap_or_default();
264 return Err(SyncularError::message(
265 ErrorKind::Transport,
266 format!("auth lease issue failed with HTTP {status}: {body}"),
267 ));
268 }
269
270 let response: AuthLeaseIssueResponse = response.json()?;
271 if !response.ok {
272 return Err(SyncularError::message(
273 ErrorKind::Transport,
274 "auth lease issue returned ok=false",
275 ));
276 }
277 Ok(response)
278 }
279}
280
281#[cfg(feature = "native")]
282impl SyncAuthHeaderStore for HttpSyncTransport {
283 fn set_auth_headers(&mut self, headers: SyncAuthHeaders) {
284 self.auth_headers = headers;
285 }
286}
287
288#[cfg(feature = "native")]
289impl SyncAuthSignerStore for HttpSyncTransport {
290 fn set_auth_signer(&mut self, signer: Option<SyncAuthSigner>) {
291 self.auth_signer = signer;
292 }
293}
294
295#[cfg(feature = "native")]
296impl SyncTransport for HttpSyncTransport {
297 type Realtime = RealtimeSocket;
298
299 fn post_sync(&self, request: &CombinedRequest) -> Result<CombinedResponse> {
300 let body = serde_json::to_vec(request)?;
301 let mut headers = self.signed_auth_headers("POST", &self.config.base_url, &body)?;
302 let trace_context = SyncTraceContext::from_headers_or_new(&headers);
303 trace_context.insert_missing_headers(&mut headers);
304 self.set_sync_trace_context(trace_context);
305 let builder = self
306 .http
307 .post(&self.config.base_url)
308 .header("content-type", "application/json")
309 .header("x-syncular-schema-version", self.schema_version.to_string())
310 .header("x-syncular-transport-path", "direct");
311 let response = self
312 .apply_headers(builder, &headers)
313 .body(body)
314 .send()
315 .map_err(|err| {
316 SyncularError::transport(err).context(format!("POST {}", self.config.base_url))
317 })?;
318
319 let status = response.status();
320 if !status.is_success() {
321 let body = response.text().unwrap_or_default();
322 return Err(SyncularError::message(
323 ErrorKind::Transport,
324 format!("sync failed with HTTP {status}: {body}"),
325 ));
326 }
327
328 let content_type = response
329 .headers()
330 .get(reqwest::header::CONTENT_TYPE)
331 .and_then(|value| value.to_str().ok())
332 .map(str::to_string);
333 if is_binary_sync_pack_content_type(content_type.as_deref()) {
334 let bytes = response.bytes()?.to_vec();
335 return decode_binary_sync_pack(&bytes);
336 }
337
338 Ok(response.json()?)
339 }
340
341 fn fetch_snapshot_chunk_rows(
342 &self,
343 chunk: &SnapshotChunkRef,
344 scopes: &ScopeValues,
345 ) -> Result<SnapshotChunkRows> {
346 validate_snapshot_chunk_ref_size(chunk)?;
347 let url = format!(
348 "{}/snapshot-chunks/{}",
349 self.config.base_url.trim_end_matches('/'),
350 chunk.id
351 );
352 let request = self
353 .http
354 .get(&url)
355 .header("x-syncular-snapshot-scopes", serde_json::to_string(scopes)?);
356 let mut headers = self.signed_auth_headers("GET", &url, &[])?;
357 let trace_context = self.sync_trace_context(&headers);
358 trace_context.insert_missing_headers(&mut headers);
359 let response = self
360 .apply_headers(request, &headers)
361 .send()
362 .map_err(|err| SyncularError::transport(err).context(format!("GET {url}")))?;
363 let status = response.status();
364 if !status.is_success() {
365 let body = response.text().unwrap_or_default();
366 return Err(SyncularError::message(
367 ErrorKind::Transport,
368 format!("snapshot chunk failed with HTTP {status}: {body}"),
369 ));
370 }
371 syncular_protocol::validate_snapshot_chunk_format(chunk)?;
372 let compressed = response.bytes()?.to_vec();
373 decode_compressed_snapshot_chunk_rows(chunk, &compressed)
374 }
375
376 fn fetch_snapshot_artifact_bytes(
377 &self,
378 artifact: &ScopedSnapshotArtifactRef,
379 scopes: &ScopeValues,
380 ) -> Result<Vec<u8>> {
381 validate_snapshot_artifact_ref_size(artifact)?;
382 let url = format!(
383 "{}/snapshot-artifacts/{}",
384 self.config.base_url.trim_end_matches('/'),
385 artifact.id
386 );
387 let request = self
388 .http
389 .get(&url)
390 .header("x-syncular-snapshot-scopes", serde_json::to_string(scopes)?);
391 let mut headers = self.signed_auth_headers("GET", &url, &[])?;
392 let trace_context = self.sync_trace_context(&headers);
393 trace_context.insert_missing_headers(&mut headers);
394 let response = self
395 .apply_headers(request, &headers)
396 .send()
397 .map_err(|err| SyncularError::transport(err).context(format!("GET {url}")))?;
398 let status = response.status();
399 if !status.is_success() {
400 let body = response.text().unwrap_or_default();
401 return Err(SyncularError::message(
402 ErrorKind::Transport,
403 format!("snapshot artifact failed with HTTP {status}: {body}"),
404 ));
405 }
406 let bytes = response.bytes()?.to_vec();
407 decode_snapshot_artifact_bytes(artifact, &bytes)
408 }
409
410 fn connect_realtime(&self) -> Result<RealtimeSocket> {
411 RealtimeSocket::connect(
412 &self.config,
413 &self.auth_headers,
414 self.auth_signer.clone(),
415 self.schema_version,
416 )
417 }
418}
419
420#[cfg(feature = "native")]
421impl BlobTransport for HttpSyncTransport {
422 fn upload_blob(&self, blob: &BlobRef, bytes: &[u8]) -> Result<()> {
423 validate_blob_bytes(blob, bytes)?;
424 self.upload_blob_body(blob, BlockingBody::from(bytes.to_vec()))
425 }
426
427 fn upload_blob_file(&self, blob: &BlobRef, path: &Path) -> Result<()> {
428 let file = File::open(path).map_err(|err| {
429 SyncularError::storage(err).context(format!("open blob file {path:?}"))
430 })?;
431 let (actual_hash, actual_size) = blob_hash_reader(file)?;
432 validate_blob_digest(blob, &actual_hash, actual_size)?;
433 let file = File::open(path).map_err(|err| {
434 SyncularError::storage(err).context(format!("reopen blob file {path:?}"))
435 })?;
436 let len = u64::try_from(blob.size)
437 .map_err(|_| SyncularError::protocol_message("blob size cannot be negative"))?;
438 self.upload_blob_body(blob, BlockingBody::sized(file, len))
439 }
440
441 fn download_blob(&self, blob: &BlobRef) -> Result<Vec<u8>> {
442 validate_blob_hash(&blob.hash)?;
443 validate_blob_ref_size(blob)?;
444 let response = self.open_blob_download(blob)?;
445 let bytes = response.bytes()?.to_vec();
446 validate_blob_bytes(blob, &bytes)?;
447 Ok(bytes)
448 }
449
450 fn download_blob_to_file(&self, blob: &BlobRef, path: &Path) -> Result<()> {
451 validate_blob_hash(&blob.hash)?;
452 validate_blob_ref_size(blob)?;
453 let mut response = self.open_blob_download(blob)?;
454 let temp_path = temp_download_path(path);
455 let mut file = File::create(&temp_path).map_err(|err| {
456 SyncularError::storage(err).context(format!("create blob temp file {temp_path:?}"))
457 })?;
458 let mut hasher = Sha256::new();
459 let mut size = 0i64;
460 let mut buffer = [0u8; 64 * 1024];
461 loop {
462 let read = response.read(&mut buffer).map_err(|err| {
463 SyncularError::transport(err).context("read blob download response")
464 })?;
465 if read == 0 {
466 break;
467 }
468 size = size
469 .checked_add(i64::try_from(read).map_err(|_| {
470 SyncularError::protocol_message("blob chunk is too large for size metadata")
471 })?)
472 .ok_or_else(|| SyncularError::protocol_message("blob is too large"))?;
473 hasher.update(&buffer[..read]);
474 file.write_all(&buffer[..read]).map_err(|err| {
475 SyncularError::storage(err).context(format!("write blob temp file {temp_path:?}"))
476 })?;
477 }
478 file.flush().map_err(|err| {
479 SyncularError::storage(err).context(format!("flush blob temp file {temp_path:?}"))
480 })?;
481 validate_blob_digest(
482 blob,
483 &format!("sha256:{}", hex::encode(hasher.finalize())),
484 size,
485 )?;
486 fs::rename(&temp_path, path).map_err(|err| {
487 SyncularError::storage(err)
488 .context(format!("move blob temp file {temp_path:?} to {path:?}"))
489 })?;
490 Ok(())
491 }
492}
493
494#[cfg(feature = "native")]
495impl HttpSyncTransport {
496 fn apply_auth(
497 &self,
498 builder: reqwest::blocking::RequestBuilder,
499 method: &str,
500 url: &str,
501 body: &[u8],
502 ) -> Result<reqwest::blocking::RequestBuilder> {
503 Ok(self.apply_headers(builder, &self.signed_auth_headers(method, url, body)?))
504 }
505
506 fn signed_auth_headers(&self, method: &str, url: &str, body: &[u8]) -> Result<SyncAuthHeaders> {
507 let mut headers = self.auth_headers.clone();
508 if let Some(signer) = &self.auth_signer {
509 let signed = signer(SyncRequestToSign {
510 method: method.to_string(),
511 url: url.to_string(),
512 body: body.to_vec(),
513 })
514 .map_err(|err| {
515 SyncularError::message(ErrorKind::Transport, format!("sign sync request: {err}"))
516 })?;
517 headers.extend(signed);
518 }
519 Ok(headers)
520 }
521
522 fn apply_headers(
523 &self,
524 builder: reqwest::blocking::RequestBuilder,
525 headers: &SyncAuthHeaders,
526 ) -> reqwest::blocking::RequestBuilder {
527 apply_auth_headers(builder, headers)
528 }
529
530 fn set_sync_trace_context(&self, trace_context: SyncTraceContext) {
531 if let Ok(mut current) = self.sync_trace_context.lock() {
532 *current = Some(trace_context);
533 }
534 }
535
536 fn sync_trace_context(&self, headers: &SyncAuthHeaders) -> SyncTraceContext {
537 self.sync_trace_context
538 .lock()
539 .ok()
540 .and_then(|current| current.clone())
541 .unwrap_or_else(|| SyncTraceContext::from_headers_or_new(headers))
542 }
543
544 fn upload_blob_body(&self, blob: &BlobRef, body: BlockingBody) -> Result<()> {
545 let url = format!(
546 "{}/blobs/upload",
547 self.config.base_url.trim_end_matches('/')
548 );
549 let request = BlobUploadInitRequest {
550 hash: blob.hash.clone(),
551 size: blob.size,
552 mime_type: blob.mime_type.clone(),
553 };
554 let request_body = serde_json::to_vec(&request)?;
555 let response = self
556 .apply_auth(
557 self.http
558 .post(&url)
559 .header("content-type", "application/json"),
560 "POST",
561 &url,
562 &request_body,
563 )?
564 .body(request_body)
565 .send()
566 .map_err(|err| SyncularError::transport(err).context(format!("POST {url}")))?;
567 let status = response.status();
568 if !status.is_success() {
569 let body = response.text().unwrap_or_default();
570 return Err(SyncularError::message(
571 ErrorKind::Transport,
572 format!("blob upload init failed with HTTP {status}: {body}"),
573 ));
574 }
575 let init: BlobUploadInitResponse = response.json()?;
576 if init.exists {
577 return Ok(());
578 }
579 let upload_url = init.upload_url.ok_or_else(|| {
580 SyncularError::protocol_message("blob upload init response missing uploadUrl")
581 })?;
582 let method = init.upload_method.as_deref().unwrap_or("PUT");
583 let method = Method::from_bytes(method.as_bytes())
584 .map_err(|err| SyncularError::protocol(err).context("blob upload method"))?;
585 let mut upload = self.http.request(method, &upload_url).body(body);
586 for (name, value) in init.upload_headers {
587 upload = upload.header(name, value);
588 }
589 let response = upload.send().map_err(|err| {
590 SyncularError::transport(err).context(format!("upload blob to {upload_url}"))
591 })?;
592 let status = response.status();
593 if !status.is_success() {
594 let body = response.text().unwrap_or_default();
595 return Err(SyncularError::message(
596 ErrorKind::Transport,
597 format!("blob upload failed with HTTP {status}: {body}"),
598 ));
599 }
600
601 let complete_url = format!(
602 "{}/blobs/{}/complete",
603 self.config.base_url.trim_end_matches('/'),
604 blob_hash_path(&blob.hash)?
605 );
606 let response = self
607 .apply_auth(self.http.post(&complete_url), "POST", &complete_url, &[])?
608 .send()
609 .map_err(|err| SyncularError::transport(err).context(format!("POST {complete_url}")))?;
610 let status = response.status();
611 if !status.is_success() {
612 let body = response.text().unwrap_or_default();
613 return Err(SyncularError::message(
614 ErrorKind::Transport,
615 format!("blob upload complete failed with HTTP {status}: {body}"),
616 ));
617 }
618 let complete: BlobUploadCompleteResponse = response.json()?;
619 if !complete.ok {
620 return Err(SyncularError::protocol_message(
621 complete
622 .error
623 .unwrap_or_else(|| "failed to complete blob upload".to_string()),
624 ));
625 }
626 Ok(())
627 }
628
629 fn open_blob_download(&self, blob: &BlobRef) -> Result<reqwest::blocking::Response> {
630 let url = format!(
631 "{}/blobs/{}/url",
632 self.config.base_url.trim_end_matches('/'),
633 blob_hash_path(&blob.hash)?
634 );
635 let response = self
636 .apply_auth(self.http.get(&url), "GET", &url, &[])?
637 .send()
638 .map_err(|err| SyncularError::transport(err).context(format!("GET {url}")))?;
639 let status = response.status();
640 if !status.is_success() {
641 let body = response.text().unwrap_or_default();
642 return Err(SyncularError::message(
643 ErrorKind::Transport,
644 format!("blob download url failed with HTTP {status}: {body}"),
645 ));
646 }
647 let download: BlobDownloadUrlResponse = response.json()?;
648 let response = self.http.get(&download.url).send().map_err(|err| {
649 SyncularError::transport(err).context(format!("GET {}", download.url))
650 })?;
651 let status = response.status();
652 if !status.is_success() {
653 let body = response.text().unwrap_or_default();
654 return Err(SyncularError::message(
655 ErrorKind::Transport,
656 format!("blob download failed with HTTP {status}: {body}"),
657 ));
658 }
659 Ok(response)
660 }
661}
662
663#[cfg(feature = "native")]
664fn temp_download_path(path: &Path) -> PathBuf {
665 let file_name = path
666 .file_name()
667 .and_then(|value| value.to_str())
668 .unwrap_or("blob");
669 let temp_name = format!(".{file_name}.syncular-download-{}", Uuid::new_v4());
670 path.with_file_name(temp_name)
671}
672
673#[cfg(feature = "native")]
674impl RealtimeSocket {
675 pub fn connect(
676 config: &SyncTransportConfig,
677 auth_headers: &SyncAuthHeaders,
678 auth_signer: Option<SyncAuthSigner>,
679 schema_version: i32,
680 ) -> Result<Self> {
681 let url = ws_url(&config.base_url, &config.client_id, schema_version)?;
682 let mut auth_headers = signed_realtime_auth_headers(auth_headers, auth_signer, &url)?;
683 let trace_context = SyncTraceContext::from_headers_or_new(&auth_headers);
684 trace_context.insert_missing_headers(&mut auth_headers);
685 let mut request = url
686 .into_client_request()
687 .map_err(|err| SyncularError::transport(err).context("build websocket request"))?;
688 for (name, value) in effective_auth_headers(&auth_headers) {
689 let name = reqwest::header::HeaderName::from_bytes(name.as_bytes())
690 .map_err(SyncularError::transport)?;
691 let value = reqwest::header::HeaderValue::from_str(&value)?;
692 request.headers_mut().insert(name, value);
693 }
694 request.headers_mut().insert(
695 "x-syncular-schema-version",
696 schema_version.to_string().parse()?,
697 );
698
699 let stream = connect_websocket_tcp(request.uri(), config.timeouts.websocket_open)?;
700 stream.set_nodelay(true).ok();
701 stream
702 .set_read_timeout(Some(config.timeouts.websocket_open))
703 .ok();
704 stream
705 .set_write_timeout(Some(config.timeouts.websocket_open))
706 .ok();
707
708 let (mut socket, _response) = client_tls_with_config(request, stream, None, None)
709 .map_err(|err| SyncularError::transport(err).context("connect websocket handshake"))?;
710 set_websocket_stream_timeouts(
711 socket.get_mut(),
712 Some(config.timeouts.websocket_idle),
713 Some(config.timeouts.websocket_shutdown),
714 );
715 Ok(Self {
716 socket,
717 push_response_timeout: config.timeouts.websocket_push_response,
718 shutdown_timeout: config.timeouts.websocket_shutdown,
719 })
720 }
721}
722
723#[cfg(feature = "native")]
724impl SyncTraceContext {
725 fn new() -> Self {
726 let trace_id = Uuid::new_v4().simple().to_string();
727 let span_seed = Uuid::new_v4().simple().to_string();
728 let span_id = span_seed[..16].to_string();
729 Self {
730 sync_attempt_id: trace_id.clone(),
731 trace_id,
732 span_id,
733 }
734 }
735
736 fn from_headers_or_new(headers: &SyncAuthHeaders) -> Self {
737 if let Some(traceparent) =
738 header_value(headers, "traceparent").and_then(parse_w3c_traceparent)
739 {
740 return Self {
741 sync_attempt_id: header_value(headers, "x-syncular-sync-attempt-id")
742 .map(str::to_string)
743 .unwrap_or_else(|| traceparent.0.clone()),
744 trace_id: traceparent.0,
745 span_id: traceparent.1,
746 };
747 }
748
749 Self::new()
750 }
751
752 fn traceparent(&self) -> String {
753 format!("00-{}-{}-01", self.trace_id, self.span_id)
754 }
755
756 fn sentry_trace(&self) -> String {
757 format!("{}-{}-1", self.trace_id, self.span_id)
758 }
759
760 fn insert_missing_headers(&self, headers: &mut SyncAuthHeaders) {
761 insert_header_if_missing(headers, "traceparent", self.traceparent());
762 insert_header_if_missing(headers, "sentry-trace", self.sentry_trace());
763 insert_header_if_missing(
764 headers,
765 "x-syncular-sync-attempt-id",
766 self.sync_attempt_id.clone(),
767 );
768 }
769}
770
771#[cfg(feature = "native")]
772impl RealtimeTransport for RealtimeSocket {
773 fn push_commit(&mut self, commit: PushCommitRequest) -> Result<PushCommitResponse> {
774 let request_id = Uuid::new_v4().to_string();
775 let client_commit_id = commit.client_commit_id.clone();
776 let message = RealtimePushRequest::from_commit(request_id.clone(), commit);
777
778 let message = serde_json::to_string(&message)?;
779 validate_websocket_text_frame_size(&message)?;
780 self.socket.send(Message::Text(message.into()))?;
781
782 let deadline = SystemTime::now()
783 .checked_add(self.push_response_timeout)
784 .unwrap_or_else(SystemTime::now);
785
786 while SystemTime::now() < deadline {
787 match self.socket.read() {
788 Ok(Message::Text(text)) => {
789 validate_websocket_text_frame_size(&text)?;
790 let value: Value = match serde_json::from_str(&text) {
791 Ok(value) => value,
792 Err(_) => continue,
793 };
794 if let Some(response) = syncular_protocol::realtime_push_response_from_value(
795 &value,
796 &request_id,
797 &client_commit_id,
798 )? {
799 return Ok(response);
800 }
801 }
802 Ok(Message::Ping(bytes)) => {
803 self.socket.send(Message::Pong(bytes))?;
804 }
805 Ok(Message::Close(_)) => {
806 return Err(SyncularError::message(
807 ErrorKind::Transport,
808 "websocket closed during push",
809 ));
810 }
811 Ok(_) => {}
812 Err(tungstenite::Error::Io(err))
813 if err.kind() == std::io::ErrorKind::WouldBlock
814 || err.kind() == std::io::ErrorKind::TimedOut => {}
815 Err(err) => {
816 return Err(
817 SyncularError::transport(err).context("read websocket push response")
818 );
819 }
820 }
821 }
822
823 Err(SyncularError::message(
824 ErrorKind::Transport,
825 "timed out waiting for websocket push-response",
826 ))
827 }
828
829 fn send_presence(
830 &mut self,
831 action: &str,
832 scope_key: &str,
833 metadata: Option<&Value>,
834 ) -> Result<()> {
835 let message = RealtimePresenceRequest::new(action, scope_key, metadata.cloned());
836 let message = serde_json::to_string(&message)?;
837 validate_websocket_text_frame_size(&message)?;
838 self.socket.send(Message::Text(message.into()))?;
839 Ok(())
840 }
841
842 fn read_event(&mut self) -> Result<Option<RealtimeEvent>> {
843 match self.socket.read() {
844 Ok(Message::Text(text)) => {
845 validate_websocket_text_frame_size(&text)?;
846 let value: Value = match serde_json::from_str(&text) {
847 Ok(value) => value,
848 Err(_) => return Ok(None),
849 };
850 let event = value.get("event").and_then(Value::as_str).unwrap_or("");
851 if event == REALTIME_SERVER_EVENT_SYNC {
852 Ok(Some(RealtimeEvent::Sync))
853 } else if event == REALTIME_SERVER_EVENT_PRESENCE {
854 Ok(
855 syncular_protocol::realtime_presence_event_from_value(&value)
856 .map(RealtimeEvent::Presence),
857 )
858 } else {
859 Ok(Some(RealtimeEvent::Other(event.to_string())))
860 }
861 }
862 Ok(Message::Ping(bytes)) => {
863 self.socket.send(Message::Pong(bytes))?;
864 Ok(None)
865 }
866 Ok(Message::Close(_)) => Err(SyncularError::message(
867 ErrorKind::Transport,
868 "websocket closed",
869 )),
870 Ok(_) => Ok(None),
871 Err(tungstenite::Error::Io(err))
872 if err.kind() == std::io::ErrorKind::WouldBlock
873 || err.kind() == std::io::ErrorKind::TimedOut =>
874 {
875 Ok(None)
876 }
877 Err(err) => Err(SyncularError::transport(err).context("read websocket message")),
878 }
879 }
880
881 fn close(&mut self) {
882 set_websocket_stream_timeouts(
883 self.socket.get_mut(),
884 Some(self.shutdown_timeout),
885 Some(self.shutdown_timeout),
886 );
887 self.socket.close(None).ok();
888 }
889}
890
891#[cfg(feature = "native")]
892fn apply_auth_headers(
893 mut request: reqwest::blocking::RequestBuilder,
894 auth_headers: &SyncAuthHeaders,
895) -> reqwest::blocking::RequestBuilder {
896 for (name, value) in effective_auth_headers(auth_headers) {
897 request = request.header(name.as_str(), value.as_str());
898 }
899 request
900}
901
902#[cfg(feature = "native")]
903fn effective_auth_headers(auth_headers: &SyncAuthHeaders) -> Vec<(String, String)> {
904 auth_headers
905 .iter()
906 .map(|(name, value)| (name.clone(), value.clone()))
907 .collect()
908}
909
910#[cfg(feature = "native")]
911fn header_value<'a>(headers: &'a SyncAuthHeaders, name: &str) -> Option<&'a str> {
912 headers
913 .iter()
914 .find(|(candidate, _)| candidate.eq_ignore_ascii_case(name))
915 .map(|(_, value)| value.as_str())
916}
917
918#[cfg(feature = "native")]
919fn insert_header_if_missing(headers: &mut SyncAuthHeaders, name: &str, value: String) {
920 if header_value(headers, name).is_none() {
921 headers.insert(name.to_string(), value);
922 }
923}
924
925#[cfg(feature = "native")]
926fn parse_w3c_traceparent(traceparent: &str) -> Option<(String, String)> {
927 let mut parts = traceparent.trim().split('-');
928 let version = parts.next()?;
929 let trace_id = parts.next()?;
930 let span_id = parts.next()?;
931 let flags = parts.next()?;
932 if parts.next().is_some()
933 || version != "00"
934 || !is_valid_trace_hex(trace_id, 32)
935 || !is_valid_trace_hex(span_id, 16)
936 || !is_valid_trace_hex(flags, 2)
937 || trace_id.chars().all(|ch| ch == '0')
938 || span_id.chars().all(|ch| ch == '0')
939 {
940 return None;
941 }
942 Some((trace_id.to_ascii_lowercase(), span_id.to_ascii_lowercase()))
943}
944
945#[cfg(feature = "native")]
946fn is_valid_trace_hex(value: &str, len: usize) -> bool {
947 value.len() == len && value.as_bytes().iter().all(u8::is_ascii_hexdigit)
948}
949
950#[cfg(feature = "native")]
951fn signed_realtime_auth_headers(
952 auth_headers: &SyncAuthHeaders,
953 auth_signer: Option<SyncAuthSigner>,
954 url: &str,
955) -> Result<SyncAuthHeaders> {
956 let mut headers = auth_headers.clone();
957 if let Some(signer) = auth_signer {
958 let signed = signer(SyncRequestToSign {
959 method: "GET".to_string(),
960 url: url.to_string(),
961 body: Vec::new(),
962 })
963 .map_err(|err| {
964 SyncularError::message(
965 ErrorKind::Transport,
966 format!("sign websocket request: {err}"),
967 )
968 })?;
969 headers.extend(signed);
970 }
971 Ok(headers)
972}
973
974#[cfg(feature = "native")]
975fn ws_url(base_url: &str, client_id: &str, schema_version: i32) -> Result<String> {
976 let mut url = reqwest::Url::parse(base_url).map_err(|err| {
977 SyncularError::config(format!("invalid base url for websocket: {base_url}")).context(err)
978 })?;
979 match url.scheme() {
980 "http" => url
981 .set_scheme("ws")
982 .map_err(|_| SyncularError::config("failed to set ws scheme"))?,
983 "https" => url
984 .set_scheme("wss")
985 .map_err(|_| SyncularError::config("failed to set wss scheme"))?,
986 "ws" | "wss" => {}
987 scheme => {
988 return Err(SyncularError::config(format!(
989 "unsupported websocket base url scheme: {scheme}"
990 )));
991 }
992 }
993 let path = url.path().trim_end_matches('/').to_string();
994 url.set_path(&format!("{path}/realtime"));
995 url.query_pairs_mut()
996 .append_pair("clientId", client_id)
997 .append_pair("schemaVersion", &schema_version.to_string())
998 .append_pair("transportPath", "direct");
999 Ok(url.to_string())
1000}
1001
1002#[cfg(feature = "native")]
1003fn connect_websocket_tcp(uri: &tungstenite::http::Uri, timeout: Duration) -> Result<TcpStream> {
1004 let host = uri.host().ok_or_else(|| {
1005 SyncularError::message(ErrorKind::Transport, "websocket url is missing a host")
1006 })?;
1007 let host = host
1008 .strip_prefix('[')
1009 .and_then(|value| value.strip_suffix(']'))
1010 .unwrap_or(host);
1011 let port = uri.port_u16().unwrap_or(match uri.scheme_str() {
1012 Some("ws") => 80,
1013 Some("wss") => 443,
1014 Some(scheme) => {
1015 return Err(SyncularError::message(
1016 ErrorKind::Transport,
1017 format!("unsupported websocket url scheme: {scheme}"),
1018 ));
1019 }
1020 None => {
1021 return Err(SyncularError::message(
1022 ErrorKind::Transport,
1023 "websocket url is missing a scheme",
1024 ));
1025 }
1026 });
1027
1028 let mut last_error = None;
1029 for address in (host, port)
1030 .to_socket_addrs()
1031 .map_err(|err| SyncularError::transport(err).context("resolve websocket host"))?
1032 {
1033 match TcpStream::connect_timeout(&address, timeout) {
1034 Ok(stream) => return Ok(stream),
1035 Err(err) => last_error = Some(err),
1036 }
1037 }
1038
1039 let message = last_error
1040 .map(|err| format!("connect websocket tcp: {err}"))
1041 .unwrap_or_else(|| "connect websocket tcp: host resolved to no addresses".to_string());
1042 Err(SyncularError::message(ErrorKind::Transport, message))
1043}
1044
1045#[cfg(feature = "native")]
1046fn set_websocket_stream_timeouts(
1047 stream: &mut MaybeTlsStream<TcpStream>,
1048 read_timeout: Option<Duration>,
1049 write_timeout: Option<Duration>,
1050) {
1051 match stream {
1052 MaybeTlsStream::Plain(stream) => {
1053 stream.set_read_timeout(read_timeout).ok();
1054 stream.set_write_timeout(write_timeout).ok();
1055 }
1056 MaybeTlsStream::Rustls(stream) => {
1057 stream.sock.set_read_timeout(read_timeout).ok();
1058 stream.sock.set_write_timeout(write_timeout).ok();
1059 }
1060 _ => {}
1061 }
1062}
1063
1064#[cfg(feature = "native")]
1065fn blob_hash_path(hash: &str) -> Result<String> {
1066 validate_blob_hash(hash)?;
1067 let hex = hash
1068 .strip_prefix("sha256:")
1069 .expect("validated hash should have sha256 prefix");
1070 Ok(format!("sha256%3A{hex}"))
1071}
1072
1073#[cfg(feature = "native")]
1074fn decode_compressed_snapshot_chunk_rows(
1075 chunk: &SnapshotChunkRef,
1076 compressed: &[u8],
1077) -> Result<SnapshotChunkRows> {
1078 syncular_protocol::validate_snapshot_chunk_format(chunk)?;
1079 validate_snapshot_chunk_compressed_bytes(chunk, compressed)?;
1080 let actual_hash = hex::encode(Sha256::digest(compressed));
1081 syncular_protocol::validate_snapshot_chunk_hash_hex(chunk, &actual_hash)?;
1082
1083 let mut decoder = GzDecoder::new(compressed);
1084 let mut decoded = Vec::new();
1085 decoder.read_to_end(&mut decoded)?;
1086 validate_snapshot_chunk_decompressed_bytes(&decoded)?;
1087
1088 decode_snapshot_chunk_rows(chunk, &decoded)
1089}
1090
1091#[cfg(feature = "native")]
1092fn validate_snapshot_artifact_bytes(
1093 artifact: &ScopedSnapshotArtifactRef,
1094 bytes: &[u8],
1095) -> Result<()> {
1096 syncular_protocol::validate_scoped_snapshot_artifact_ref(artifact)?;
1097 validate_snapshot_artifact_compressed_bytes(artifact, bytes)?;
1098 let actual_hash = hex::encode(Sha256::digest(bytes));
1099 if actual_hash != artifact.sha256 {
1100 return Err(SyncularError::protocol_message(format!(
1101 "snapshot artifact sha256 mismatch: expected {}, got {}",
1102 artifact.sha256, actual_hash
1103 )));
1104 }
1105 Ok(())
1106}
1107
1108#[cfg(feature = "native")]
1109fn decode_snapshot_artifact_bytes(
1110 artifact: &ScopedSnapshotArtifactRef,
1111 compressed: &[u8],
1112) -> Result<Vec<u8>> {
1113 validate_snapshot_artifact_bytes(artifact, compressed)?;
1114 if artifact.compression != SNAPSHOT_CHUNK_COMPRESSION_GZIP {
1115 return Err(SyncularError::protocol_message(format!(
1116 "unsupported snapshot artifact compression {}",
1117 artifact.compression
1118 )));
1119 }
1120 let mut decoder = GzDecoder::new(compressed);
1121 let mut decoded = Vec::new();
1122 decoder.read_to_end(&mut decoded)?;
1123 validate_snapshot_artifact_decompressed_bytes(&decoded)?;
1124 Ok(decoded)
1125}
1126
1127#[cfg(feature = "native")]
1128fn decode_snapshot_chunk_rows(chunk: &SnapshotChunkRef, bytes: &[u8]) -> Result<SnapshotChunkRows> {
1129 syncular_protocol::validate_snapshot_chunk_format(chunk)?;
1130
1131 match chunk.encoding.as_str() {
1132 SNAPSHOT_CHUNK_ENCODING_BINARY_TABLE_V1 => {
1133 decode_binary_snapshot_rows(bytes).map(SnapshotChunkRows::Binary)
1134 }
1135 encoding => Err(SyncularError::protocol_message(format!(
1136 "unsupported snapshot chunk encoding: {encoding}"
1137 ))),
1138 }
1139}
1140
1141#[cfg(all(test, feature = "native"))]
1142mod tests {
1143 use super::*;
1144 use std::net::TcpListener;
1145 use std::sync::mpsc;
1146 use std::sync::{Arc, Mutex};
1147 use std::thread;
1148 use std::time::Instant;
1149
1150 #[test]
1151 fn effective_auth_headers_are_empty_without_app_headers() {
1152 let headers = effective_auth_headers(&SyncAuthHeaders::new());
1153
1154 assert_eq!(headers, Vec::<(String, String)>::new());
1155 }
1156
1157 #[test]
1158 fn transport_config_has_production_timeout_defaults() {
1159 let config = SyncTransportConfig::new("https://api.example.test/sync", "client", "actor");
1160
1161 assert_eq!(config.timeouts.http_connect, Duration::from_secs(10));
1162 assert_eq!(config.timeouts.http_request, Duration::from_secs(30));
1163 assert_eq!(config.timeouts.http_response_body, Duration::from_secs(30));
1164 assert_eq!(config.timeouts.websocket_open, Duration::from_secs(10));
1165 assert_eq!(config.timeouts.websocket_idle, Duration::from_secs(1));
1166 assert_eq!(
1167 config.timeouts.websocket_push_response,
1168 Duration::from_secs(10)
1169 );
1170 assert_eq!(config.timeouts.websocket_shutdown, Duration::from_secs(2));
1171 }
1172
1173 #[test]
1174 fn effective_auth_headers_use_supplied_headers_without_dev_actor_headers() {
1175 let mut auth_headers = SyncAuthHeaders::new();
1176 auth_headers.insert("authorization".to_string(), "Bearer token-1".to_string());
1177
1178 let headers = effective_auth_headers(&auth_headers);
1179
1180 assert_eq!(
1181 headers,
1182 vec![("authorization".to_string(), "Bearer token-1".to_string())]
1183 );
1184 }
1185
1186 #[test]
1187 fn sync_trace_context_derives_attempt_from_existing_traceparent() {
1188 let trace_id = "4bf92f3577b34da6a3ce929d0e0e4736";
1189 let span_id = "00f067aa0ba902b7";
1190 let headers = SyncAuthHeaders::from([(
1191 "TraceParent".to_string(),
1192 format!("00-{trace_id}-{span_id}-01"),
1193 )]);
1194
1195 let context = SyncTraceContext::from_headers_or_new(&headers);
1196
1197 assert_eq!(context.sync_attempt_id, trace_id);
1198 assert_eq!(context.trace_id, trace_id);
1199 assert_eq!(context.span_id, span_id);
1200 }
1201
1202 #[test]
1203 fn http_sync_reuses_trace_context_for_snapshot_chunks() {
1204 let listener = TcpListener::bind("127.0.0.1:0").expect("bind sync trace server");
1205 let address = listener.local_addr().expect("sync trace server address");
1206 let (headers_tx, headers_rx) = mpsc::channel::<(BTreeMap<String, String>, String)>();
1207
1208 let compressed_chunk = gzip_bytes(b"not-binary-table");
1209 let chunk = SnapshotChunkRef {
1210 id: "trace-chunk".to_string(),
1211 byte_length: compressed_chunk.len() as i64,
1212 sha256: hex::encode(Sha256::digest(&compressed_chunk)),
1213 encoding: SNAPSHOT_CHUNK_ENCODING_BINARY_TABLE_V1.to_string(),
1214 compression: SNAPSHOT_CHUNK_COMPRESSION_GZIP.to_string(),
1215 };
1216 let server_chunk = compressed_chunk.clone();
1217
1218 let server = thread::spawn(move || {
1219 let (mut stream, _) = listener.accept().expect("accept sync request");
1220 let post = read_http_request_raw(&mut stream);
1221 let post_headers = http_headers(&post);
1222 let attempt_id = post_headers
1223 .get("x-syncular-sync-attempt-id")
1224 .expect("post sync attempt id")
1225 .to_string();
1226 headers_tx
1227 .send((post_headers, "post".to_string()))
1228 .expect("send post headers");
1229 write_http_json_response(
1230 &mut stream,
1231 json!({
1232 "ok": true,
1233 "push": null,
1234 "pull": null
1235 }),
1236 );
1237
1238 let (mut stream, _) = listener.accept().expect("accept snapshot chunk request");
1239 let get = read_http_request_raw(&mut stream);
1240 let get_headers = http_headers(&get);
1241 assert_eq!(
1242 get_headers.get("x-syncular-sync-attempt-id"),
1243 Some(&attempt_id)
1244 );
1245 headers_tx
1246 .send((get_headers, "get".to_string()))
1247 .expect("send get headers");
1248 write_http_bytes_response(&mut stream, "application/octet-stream", &server_chunk);
1249 });
1250
1251 let transport = HttpSyncTransport::new(SyncTransportConfig::new(
1252 format!("http://{address}/sync"),
1253 "native-trace-client",
1254 "native-trace-actor",
1255 ));
1256 let request = CombinedRequest {
1257 client_id: "native-trace-client".to_string(),
1258 push: None,
1259 pull: None,
1260 };
1261 transport.post_sync(&request).expect("post sync");
1262 let _ = transport.fetch_snapshot_chunk_rows(&chunk, &ScopeValues::new());
1263
1264 let (post_headers, post_kind) = headers_rx
1265 .recv_timeout(Duration::from_secs(2))
1266 .expect("post headers");
1267 let (get_headers, get_kind) = headers_rx
1268 .recv_timeout(Duration::from_secs(2))
1269 .expect("get headers");
1270 assert_eq!(post_kind, "post");
1271 assert_eq!(get_kind, "get");
1272 let attempt_id = post_headers
1273 .get("x-syncular-sync-attempt-id")
1274 .expect("post attempt id");
1275 assert_eq!(
1276 get_headers.get("x-syncular-sync-attempt-id"),
1277 Some(attempt_id)
1278 );
1279 assert!(post_headers
1280 .get("traceparent")
1281 .is_some_and(|value| value.contains(attempt_id)));
1282 assert_eq!(post_headers.get("sentry-trace").is_some(), true);
1283 assert_eq!(
1284 get_headers.get("traceparent"),
1285 post_headers.get("traceparent")
1286 );
1287 server.join().expect("sync trace server finished");
1288 }
1289
1290 #[test]
1291 fn realtime_auth_headers_are_signed_for_websocket_get_request() {
1292 let captured = Arc::new(Mutex::new(None::<SyncRequestToSign>));
1293 let captured_for_signer = Arc::clone(&captured);
1294 let signer: SyncAuthSigner = Arc::new(move |request| {
1295 *captured_for_signer.lock().expect("capture signer request") = Some(request);
1296 Ok(SyncAuthHeaders::from([(
1297 "x-signed-realtime".to_string(),
1298 "yes".to_string(),
1299 )]))
1300 });
1301
1302 let headers = signed_realtime_auth_headers(
1303 &SyncAuthHeaders::new(),
1304 Some(signer),
1305 "wss://api.notsuru.app/sync/realtime?clientId=flutter-shell",
1306 )
1307 .expect("signed realtime headers");
1308
1309 assert_eq!(headers["x-signed-realtime"], "yes");
1310 let request = captured
1311 .lock()
1312 .expect("captured request lock")
1313 .clone()
1314 .expect("request was signed");
1315 assert_eq!(request.method, "GET");
1316 assert_eq!(
1317 request.url,
1318 "wss://api.notsuru.app/sync/realtime?clientId=flutter-shell"
1319 );
1320 assert!(request.body.is_empty());
1321 }
1322
1323 #[test]
1324 fn realtime_socket_handshake_uses_auth_signer_and_reads_sync_wakeup() {
1325 let listener = TcpListener::bind("127.0.0.1:0").expect("bind websocket test server");
1326 let address = listener.local_addr().expect("websocket server address");
1327 let (headers_tx, headers_rx) = mpsc::channel::<(String, String)>();
1328
1329 let server = thread::spawn(move || {
1330 let (stream, _) = listener.accept().expect("accept websocket client");
1331 let mut socket = tungstenite::accept_hdr(
1332 stream,
1333 |request: &tungstenite::handshake::server::Request, response| {
1334 let signed = request
1335 .headers()
1336 .get("x-signed-realtime")
1337 .and_then(|value| value.to_str().ok())
1338 .unwrap_or("")
1339 .to_string();
1340 let schema = request
1341 .headers()
1342 .get("x-syncular-schema-version")
1343 .and_then(|value| value.to_str().ok())
1344 .unwrap_or("")
1345 .to_string();
1346 headers_tx
1347 .send((signed, schema))
1348 .expect("send captured websocket headers");
1349 Ok(response)
1350 },
1351 )
1352 .expect("complete websocket handshake");
1353 socket
1354 .send(Message::Text(
1355 json!({"event": "sync", "data": {"cursor": 42}})
1356 .to_string()
1357 .into(),
1358 ))
1359 .expect("send realtime sync event");
1360 socket.close(None).ok();
1361 });
1362
1363 let signer: SyncAuthSigner = Arc::new(|request| {
1364 assert_eq!(request.method, "GET");
1365 assert!(request.url.starts_with("ws://127.0.0.1:"));
1366 assert!(request.url.contains("/api/sync/realtime?"));
1367 assert!(request.body.is_empty());
1368 Ok(SyncAuthHeaders::from([(
1369 "x-signed-realtime".to_string(),
1370 "yes".to_string(),
1371 )]))
1372 });
1373 let config = SyncTransportConfig::new(
1374 format!("ws://{address}/api/sync"),
1375 "flutter-shell",
1376 "passkey:user-test",
1377 );
1378
1379 let mut socket = RealtimeSocket::connect(&config, &SyncAuthHeaders::new(), Some(signer), 7)
1380 .expect("connect realtime websocket");
1381
1382 assert!(matches!(socket.read_event(), Ok(Some(RealtimeEvent::Sync))));
1383 let (signed, schema) = headers_rx
1384 .recv_timeout(Duration::from_secs(2))
1385 .expect("captured websocket headers");
1386 assert_eq!(signed, "yes");
1387 assert_eq!(schema, "7");
1388 server.join().expect("websocket test server finished");
1389 }
1390
1391 #[test]
1392 fn realtime_socket_connect_uses_websocket_open_timeout() {
1393 let listener = TcpListener::bind("127.0.0.1:0").expect("bind websocket test server");
1394 let address = listener.local_addr().expect("websocket server address");
1395
1396 let server = thread::spawn(move || {
1397 if let Ok((stream, _)) = listener.accept() {
1398 thread::sleep(Duration::from_millis(350));
1399 drop(stream);
1400 }
1401 });
1402
1403 let mut config = SyncTransportConfig::new(
1404 format!("ws://{address}/api/sync"),
1405 "flutter-shell",
1406 "passkey:user-test",
1407 );
1408 config.timeouts.websocket_open = Duration::from_millis(75);
1409
1410 let started = Instant::now();
1411 let result = RealtimeSocket::connect(&config, &SyncAuthHeaders::new(), None, 7);
1412
1413 let elapsed = started.elapsed();
1414 let error = match result {
1415 Ok(_) => panic!("websocket connect should time out"),
1416 Err(error) => error,
1417 };
1418 assert_eq!(error.kind(), ErrorKind::Transport);
1419 assert!(
1420 elapsed < Duration::from_millis(250),
1421 "websocket open ignored configured timeout: {elapsed:?}"
1422 );
1423 server.join().expect("websocket test server finished");
1424 }
1425
1426 fn gzip_bytes(bytes: &[u8]) -> Vec<u8> {
1427 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::fast());
1428 encoder.write_all(bytes).expect("write gzip payload");
1429 encoder.finish().expect("finish gzip payload")
1430 }
1431
1432 fn read_http_request_raw(stream: &mut std::net::TcpStream) -> String {
1433 stream
1434 .set_read_timeout(Some(Duration::from_secs(2)))
1435 .expect("set request read timeout");
1436 let mut buffer = Vec::new();
1437 let mut chunk = [0u8; 4096];
1438 loop {
1439 let read = stream.read(&mut chunk).expect("read http request");
1440 if read == 0 {
1441 break;
1442 }
1443 buffer.extend_from_slice(&chunk[..read]);
1444 if http_request_complete(&buffer) {
1445 break;
1446 }
1447 }
1448 String::from_utf8_lossy(&buffer).into_owned()
1449 }
1450
1451 fn http_request_complete(buffer: &[u8]) -> bool {
1452 let request = String::from_utf8_lossy(buffer);
1453 let Some(header_end) = request.find("\r\n\r\n") else {
1454 return false;
1455 };
1456 let content_length = request
1457 .lines()
1458 .find_map(|line| line.split_once(':'))
1459 .filter(|(name, _)| name.eq_ignore_ascii_case("content-length"))
1460 .and_then(|(_, value)| value.trim().parse::<usize>().ok())
1461 .unwrap_or(0);
1462 buffer.len() >= header_end + 4 + content_length
1463 }
1464
1465 fn http_headers(request: &str) -> BTreeMap<String, String> {
1466 request
1467 .lines()
1468 .skip(1)
1469 .take_while(|line| !line.trim().is_empty())
1470 .filter_map(|line| line.split_once(':'))
1471 .map(|(name, value)| (name.trim().to_ascii_lowercase(), value.trim().to_string()))
1472 .collect()
1473 }
1474
1475 fn write_http_json_response(stream: &mut std::net::TcpStream, body: Value) {
1476 write_http_bytes_response(stream, "application/json", body.to_string().as_bytes());
1477 }
1478
1479 fn write_http_bytes_response(
1480 stream: &mut std::net::TcpStream,
1481 content_type: &str,
1482 body: &[u8],
1483 ) {
1484 let headers = format!(
1485 "HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n",
1486 body.len()
1487 );
1488 stream
1489 .write_all(headers.as_bytes())
1490 .expect("write http response headers");
1491 stream.write_all(body).expect("write http response body");
1492 }
1493}