1use std::net::SocketAddr;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, RwLock};
6
7use protosocket::TcpSocketListener;
8use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
9use protosocket_rpc::Message;
10use protosocket_rpc::server::{ConnectionService, RpcResponder, SocketRpcServer, SocketService};
11use tokio::sync::watch;
12use tracing::metadata::LevelFilter;
13use tracing_cache::{ChanceHandle, EnabledPredicate, LevelHandle, SpanCache, SpanRecord};
14
15use crate::protocol::{Request, RequestBody, Response, WireLevel, WireLevelFilter};
16use crate::wire::{TimeBase, span_to_wire};
17
18type ServerCodec = (MessagePackSerializer<Response>, MessagePackDecoder<Request>);
21
22const STREAM_SUBSCRIBER_CAPACITY: u64 = 65_536;
27
28#[derive(Debug, Default)]
31struct StreamState {
32 streaming: bool,
33 min_level: Option<WireLevel>,
34 sampling_rate: f64,
35}
36
37impl StreamState {
38 fn new() -> Self {
39 Self {
40 streaming: false,
41 min_level: None,
42 sampling_rate: 1.0,
43 }
44 }
45}
46
47#[derive(Clone)]
60pub(crate) struct CacheLevelBroadcast {
61 level_handle: LevelHandle,
62 level_tx: watch::Sender<WireLevelFilter>,
63 chance_handle: ChanceHandle,
64 chance_tx: watch::Sender<f64>,
65 active_streams: Arc<AtomicUsize>,
66}
67
68impl CacheLevelBroadcast {
69 pub fn new(level_handle: LevelHandle, chance_handle: ChanceHandle) -> Self {
70 let initial_level = WireLevelFilter::from_tracing(level_handle.get());
71 let initial_chance = chance_handle.get();
72 let (level_tx, _) = watch::channel(initial_level);
73 let (chance_tx, _) = watch::channel(initial_chance);
74 Self {
75 level_handle,
76 level_tx,
77 chance_handle,
78 chance_tx,
79 active_streams: Arc::new(AtomicUsize::new(0)),
80 }
81 }
82
83 fn set_level(&self, filter: WireLevelFilter) {
84 self.level_handle.set(filter.to_tracing());
85 let _ = self.level_tx.send(filter);
86 }
87
88 fn set_chance(&self, pct: f64) {
89 let pct = if pct.is_nan() {
93 0.0
94 } else {
95 pct.clamp(0.0, 100.0)
96 };
97 self.chance_handle.set(pct);
98 let _ = self.chance_tx.send(pct);
99 }
100
101 fn subscribe_level(&self) -> watch::Receiver<WireLevelFilter> {
102 self.level_tx.subscribe()
103 }
104
105 fn subscribe_chance(&self) -> watch::Receiver<f64> {
106 self.chance_tx.subscribe()
107 }
108
109 fn enter_stream(&self) -> StreamGuard {
118 self.active_streams.fetch_add(1, Ordering::SeqCst);
119 StreamGuard {
120 broadcast: self.clone(),
121 }
122 }
123}
124
125pub(crate) struct StreamGuard {
131 broadcast: CacheLevelBroadcast,
132}
133
134impl Drop for StreamGuard {
135 fn drop(&mut self) {
136 let prev = self.broadcast.active_streams.fetch_sub(1, Ordering::SeqCst);
137 if prev == 1 {
138 self.broadcast.level_handle.set(LevelFilter::OFF);
143 let _ = self.broadcast.level_tx.send(WireLevelFilter::Off);
144 self.broadcast.chance_handle.set(100.0);
145 let _ = self.broadcast.chance_tx.send(100.0);
146 }
147 }
148}
149
150pub(crate) struct ConnectionState<P: EnabledPredicate> {
156 cache: Arc<SpanCache<P>>,
157 base: TimeBase,
158 state: Arc<RwLock<StreamState>>,
159 level_bus: CacheLevelBroadcast,
160 stream_guard: Option<StreamGuard>,
167}
168
169impl<P: EnabledPredicate> ConnectionState<P> {
170 fn new(cache: Arc<SpanCache<P>>, base: TimeBase, level_bus: CacheLevelBroadcast) -> Self {
171 Self {
172 cache,
173 base,
174 state: Arc::new(RwLock::new(StreamState::new())),
175 level_bus,
176 stream_guard: None,
177 }
178 }
179}
180
181impl<P: EnabledPredicate> ConnectionService for ConnectionState<P> {
182 type Request = Request;
183 type Response = Response;
184
185 #[allow(clippy::expect_used, reason = "poisoned lock")]
186 fn new_rpc(&mut self, msg: Request, responder: RpcResponder<'_, Response>) {
187 let request_id = msg.message_id();
191 match msg.body {
192 RequestBody::StartStream => {
193 self.state
194 .write()
195 .expect("lock must not be poisoned")
196 .streaming = true;
197 if self.stream_guard.is_none() {
202 self.stream_guard = Some(self.level_bus.enter_stream());
203 }
204 let cache = Arc::clone(&self.cache);
205 let state = Arc::clone(&self.state);
206 let base = self.base;
207 let level_rx = self.level_bus.subscribe_level();
208 let chance_rx = self.level_bus.subscribe_chance();
209 tokio::spawn(responder.stream(span_stream(
210 cache, state, base, level_rx, chance_rx, request_id,
211 )));
212 }
213 RequestBody::StopStream => {
214 self.state
215 .write()
216 .expect("lock must not be poisoned")
217 .streaming = false;
218 responder.immediate(Response::ack().with_id(request_id));
219 }
220 RequestBody::SetLevel(level) => {
221 self.state
222 .write()
223 .expect("lock must not be poisoned")
224 .min_level = Some(level);
225 responder.immediate(Response::ack().with_id(request_id));
226 }
227 RequestBody::SetCacheLevel(filter) => {
228 self.level_bus.set_level(filter);
229 responder.immediate(Response::ack().with_id(request_id));
230 }
231 RequestBody::SetCacheChance(pct) => {
232 self.level_bus.set_chance(pct);
233 responder.immediate(Response::ack().with_id(request_id));
234 }
235 RequestBody::SetSamplingRate(rate) => {
236 if !(0.0..=1.0).contains(&rate) || rate.is_nan() {
237 responder.immediate(
238 Response::error(format!("sampling rate {rate} out of range [0.0, 1.0]"))
239 .with_id(request_id),
240 );
241 return;
242 }
243 self.state
244 .write()
245 .expect("lock must not be poisoned")
246 .sampling_rate = rate;
247 responder.immediate(Response::ack().with_id(request_id));
248 }
249 RequestBody::Noop => {}
250 }
251 }
252}
253
254fn span_stream<P: EnabledPredicate>(
264 cache: Arc<SpanCache<P>>,
265 state: Arc<RwLock<StreamState>>,
266 base: TimeBase,
267 mut level_rx: watch::Receiver<WireLevelFilter>,
268 mut chance_rx: watch::Receiver<f64>,
269 request_id: u64,
270) -> impl futures_core::Stream<Item = Response> {
271 async_stream::stream! {
272 let initial_level = *level_rx.borrow_and_update();
275 yield Response::cache_level(initial_level).with_id(request_id);
276 let initial_chance = *chance_rx.borrow_and_update();
277 yield Response::cache_chance(initial_chance).with_id(request_id);
278
279 let mut span_rx = cache.subscribe(STREAM_SUBSCRIBER_CAPACITY);
287
288 loop {
289 tokio::select! {
290 changed = level_rx.changed() => {
291 if changed.is_err() { break; }
292 let lvl = *level_rx.borrow_and_update();
293 yield Response::cache_level(lvl).with_id(request_id);
294 }
295 changed = chance_rx.changed() => {
296 if changed.is_err() { break; }
297 let pct = *chance_rx.borrow_and_update();
298 yield Response::cache_chance(pct).with_id(request_id);
299 }
300 batch = span_rx.next_batch() => {
301 let Some(batch) = batch else { break };
302 let (streaming, min_level, sampling_rate) = {
303 #[allow(clippy::expect_used, reason = "poisoned lock")]
304 let s = state.read().expect("lock must not be poisoned");
305 (s.streaming, s.min_level, s.sampling_rate)
306 };
307 if !streaming {
308 drop(batch);
312 continue;
313 }
314 for record in batch {
315 if let Some(min) = min_level
316 && !level_at_least(record.metadata.level(), min)
317 {
318 continue;
319 }
320 if !sampling_passes(&record, sampling_rate) {
321 continue;
322 }
323 yield Response::span(span_to_wire(&record, base)).with_id(request_id);
324 }
325 }
326 }
327 }
328 }
329}
330
331fn level_at_least(record_level: &tracing::Level, floor: WireLevel) -> bool {
336 record_level <= &floor.to_tracing()
337}
338
339fn sampling_passes(record: &SpanRecord, rate: f64) -> bool {
344 if rate >= 1.0 {
345 return true;
346 }
347 if rate <= 0.0 {
348 return false;
349 }
350 let bucket_id = record.parent_id.unwrap_or(record.id);
352 let mut x = bucket_id.wrapping_mul(0x9E37_79B9_7F4A_7C15);
354 x ^= x >> 33;
355 x = x.wrapping_mul(0xC2B2_AE3D_27D4_EB4F);
356 x ^= x >> 29;
357 let frac = (x as f64) / (u64::MAX as f64);
358 frac < rate
359}
360
361struct Service<P: EnabledPredicate> {
364 cache: Arc<SpanCache<P>>,
365 base: TimeBase,
366 level_bus: CacheLevelBroadcast,
367}
368
369impl<P: EnabledPredicate> SocketService for Service<P> {
370 type Codec = ServerCodec;
371 type ConnectionService = ConnectionState<P>;
372 type SocketListener = TcpSocketListener;
373
374 fn codec(&self) -> Self::Codec {
375 (
376 MessagePackSerializer::default(),
377 MessagePackDecoder::default(),
378 )
379 }
380
381 fn new_stream_service(
382 &self,
383 _stream: &<Self::SocketListener as protosocket::SocketListener>::Stream,
384 ) -> Self::ConnectionService {
385 ConnectionState::new(Arc::clone(&self.cache), self.base, self.level_bus.clone())
386 }
387}
388
389#[derive(Debug)]
391pub enum ServeError {
392 Io(std::io::Error),
393 Rpc(protosocket_rpc::Error),
394}
395impl std::fmt::Display for ServeError {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 match self {
398 ServeError::Io(e) => write!(f, "io: {e}"),
399 ServeError::Rpc(e) => write!(f, "rpc: {e}"),
400 }
401 }
402}
403impl std::error::Error for ServeError {}
404impl From<std::io::Error> for ServeError {
405 fn from(e: std::io::Error) -> Self {
406 ServeError::Io(e)
407 }
408}
409impl From<protosocket_rpc::Error> for ServeError {
410 fn from(e: protosocket_rpc::Error) -> Self {
411 ServeError::Rpc(e)
412 }
413}
414
415pub async fn serve<P: EnabledPredicate>(
424 cache: Arc<SpanCache<P>>,
425 level_handle: LevelHandle,
426 chance_handle: ChanceHandle,
427 addr: SocketAddr,
428) -> Result<(), ServeError> {
429 let listener = TcpSocketListener::listen(addr, 1024, None)?;
431
432 let service = Service {
433 cache,
434 base: TimeBase::now(),
435 level_bus: CacheLevelBroadcast::new(level_handle, chance_handle),
436 };
437 let server: SocketRpcServer<Service<P>, _> = SocketRpcServer::new(
438 listener,
439 service,
440 16 * 1024 * 1024,
441 64 * 1024,
442 4096,
443 )?;
444 server.await?;
445 Ok(())
446}
447
448#[cfg(test)]
451mod tests {
452 use super::*;
453 use std::net::TcpListener as StdTcpListener;
454 use std::time::Duration;
455
456 use futures::StreamExt;
457 use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
458 use protosocket_rpc::client::{self, Configuration, RpcClient, TcpStreamConnector};
459 use tracing_cache::{ChancePredicate, SpanCache};
460
461 use crate::protocol::{ResponseBody, WireLevel};
462
463 type ClientCodec = (MessagePackSerializer<Request>, MessagePackDecoder<Response>);
464
465 fn pick_addr() -> SocketAddr {
469 let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
470 let port = listener.local_addr().unwrap().port();
471 drop(listener);
472 format!("127.0.0.1:{port}").parse().unwrap()
473 }
474
475 fn prepare_cache() -> (
480 Arc<SpanCache<ChancePredicate<tracing_cache::LevelPredicate>>>,
481 LevelHandle,
482 ChanceHandle,
483 ) {
484 let level =
485 tracing_cache::LevelPredicate::with_filter(tracing::metadata::LevelFilter::TRACE);
486 let level_handle = level.handle();
487 let predicate = ChancePredicate::new(level, 100.0);
488 let chance_handle = predicate.handle();
489 let (cache, driver) = SpanCache::with_predicate(1024, predicate);
490 let cache = Arc::new(cache);
491 tokio::spawn(driver.run());
492 (cache, level_handle, chance_handle)
493 }
494
495 fn emit_under<P: EnabledPredicate>(cache: &Arc<SpanCache<P>>, f: impl FnOnce()) {
499 tracing::subscriber::with_default(Arc::clone(cache), f);
500 cache.flush_pending();
501 }
502
503 async fn wait_for_initial(
509 stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
510 ) {
511 let mut got_level = false;
512 let mut got_chance = false;
513 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
514 while !(got_level && got_chance) && tokio::time::Instant::now() < deadline {
515 match tokio::time::timeout(Duration::from_millis(200), stream.next()).await {
516 Ok(Some(Ok(resp))) => match resp.body {
517 ResponseBody::CacheLevel(_) => got_level = true,
518 ResponseBody::CacheChance(_) => got_chance = true,
519 _ => {}
520 },
521 _ => break,
522 }
523 }
524 assert!(
525 got_level && got_chance,
526 "stream did not yield initial CacheLevel/CacheChance",
527 );
528 }
529
530 async fn spawn_server<P: EnabledPredicate>(
533 cache: Arc<SpanCache<P>>,
534 level_handle: LevelHandle,
535 chance_handle: ChanceHandle,
536 ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
537 let addr = pick_addr();
538 let server_cache = Arc::clone(&cache);
539 let serve_level = level_handle.clone();
540 let serve_chance = chance_handle.clone();
541 let handle = tokio::spawn(async move {
542 let _ = serve(server_cache, serve_level, serve_chance, addr).await;
544 });
545 for _ in 0..50 {
547 if std::net::TcpStream::connect(addr).is_ok() {
548 return (addr, handle);
549 }
550 tokio::time::sleep(Duration::from_millis(10)).await;
551 }
552 panic!("server never came up on {addr}");
553 }
554
555 async fn connect_client(addr: SocketAddr) -> RpcClient<Request, Response> {
556 let cfg = Configuration::new(TcpStreamConnector);
557 let (rpc_client, conn) = client::connect::<ClientCodec, _>(addr, &cfg).await.unwrap();
558 tokio::spawn(conn);
560 rpc_client
561 }
562
563 async fn collect_spans(
565 stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
566 n: usize,
567 total_timeout: Duration,
568 ) -> Vec<crate::WireSpan> {
569 let mut out = Vec::with_capacity(n);
570 let deadline = tokio::time::Instant::now() + total_timeout;
571 while out.len() < n {
572 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
573 match tokio::time::timeout(remaining, stream.next()).await {
574 Ok(Some(Ok(resp))) => {
575 if let ResponseBody::Span(s) = resp.body {
576 out.push(s);
577 }
578 }
579 Ok(Some(Err(_))) | Ok(None) => break,
580 Err(_) => break,
581 }
582 }
583 out
584 }
585
586 #[tokio::test]
589 async fn start_stream_delivers_closed_spans() {
590 let (cache, level_handle, chance_handle) = prepare_cache();
591 let (addr, server) = spawn_server(
592 Arc::clone(&cache),
593 level_handle.clone(),
594 chance_handle.clone(),
595 )
596 .await;
597 let client = connect_client(addr).await;
598 let mut stream = client
599 .send_streaming(Request::new(RequestBody::StartStream))
600 .unwrap();
601 wait_for_initial(&mut stream).await;
602
603 emit_under(&cache, || {
604 for _ in 0..3 {
605 let span = tracing::span!(parent: None, tracing::Level::INFO, "test_a");
606 let _g = span.enter();
607 }
608 });
609
610 let received = collect_spans(&mut stream, 3, Duration::from_secs(2)).await;
611 assert_eq!(received.len(), 3);
612 assert!(received.iter().all(|s| s.name == "test_a"));
613 assert!(received.iter().all(|s| s.closed_at_ns.is_some()));
614
615 server.abort();
616 }
617
618 #[tokio::test]
619 async fn stop_stream_halts_delivery() {
620 let (cache, level_handle, chance_handle) = prepare_cache();
621 let (addr, server) = spawn_server(
622 Arc::clone(&cache),
623 level_handle.clone(),
624 chance_handle.clone(),
625 )
626 .await;
627 let client = connect_client(addr).await;
628 let mut stream = client
629 .send_streaming(Request::new(RequestBody::StartStream))
630 .unwrap();
631 wait_for_initial(&mut stream).await;
632
633 emit_under(&cache, || {
635 let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
636 });
637 let initial = collect_spans(&mut stream, 1, Duration::from_secs(2)).await;
638 assert_eq!(initial.len(), 1);
639
640 let ack = client
642 .send_unary(Request::new(RequestBody::StopStream))
643 .unwrap()
644 .await
645 .unwrap();
646 assert!(matches!(ack.body, ResponseBody::Ack));
647 tokio::time::sleep(Duration::from_millis(50)).await;
649
650 emit_under(&cache, || {
653 for _ in 0..5 {
654 let _g = tracing::span!(parent: None, tracing::Level::INFO, "test_b").entered();
655 }
656 });
657 let drained_after_stop = collect_spans(&mut stream, 5, Duration::from_millis(300)).await;
658 assert!(
659 drained_after_stop.len() < 5,
660 "stream did not stop: got {} more spans after StopStream",
661 drained_after_stop.len(),
662 );
663
664 server.abort();
665 }
666
667 #[tokio::test]
668 async fn set_level_filters_below_threshold() {
669 let (cache, level_handle, chance_handle) = prepare_cache();
670 let (addr, server) = spawn_server(
671 Arc::clone(&cache),
672 level_handle.clone(),
673 chance_handle.clone(),
674 )
675 .await;
676 let client = connect_client(addr).await;
677
678 let ack = client
679 .send_unary(Request::new(RequestBody::SetLevel(WireLevel::Info)))
680 .unwrap()
681 .await
682 .unwrap();
683 assert!(matches!(ack.body, ResponseBody::Ack));
684
685 let mut stream = client
686 .send_streaming(Request::new(RequestBody::StartStream))
687 .unwrap();
688 wait_for_initial(&mut stream).await;
689
690 emit_under(&cache, || {
693 drop(tracing::span!(parent: None, tracing::Level::INFO, "info_span"));
694 drop(tracing::span!(parent: None, tracing::Level::DEBUG, "debug_span"));
695 });
696
697 let received = collect_spans(&mut stream, 2, Duration::from_millis(500)).await;
698 let names: Vec<_> = received.iter().map(|s| s.name.as_str()).collect();
699 assert_eq!(names, vec!["info_span"], "got: {names:?}");
700
701 server.abort();
702 }
703
704 #[tokio::test]
705 async fn set_sampling_rate_zero_drops_all() {
706 let (cache, level_handle, chance_handle) = prepare_cache();
707 let (addr, server) = spawn_server(
708 Arc::clone(&cache),
709 level_handle.clone(),
710 chance_handle.clone(),
711 )
712 .await;
713 let client = connect_client(addr).await;
714
715 client
716 .send_unary(Request::new(RequestBody::SetSamplingRate(0.0)))
717 .unwrap()
718 .await
719 .unwrap();
720 let mut stream = client
721 .send_streaming(Request::new(RequestBody::StartStream))
722 .unwrap();
723 wait_for_initial(&mut stream).await;
724
725 emit_under(&cache, || {
726 for _ in 0..5 {
727 let _g = tracing::span!(parent: None, tracing::Level::INFO, "sampled").entered();
728 }
729 });
730
731 let received = collect_spans(&mut stream, 5, Duration::from_millis(400)).await;
732 assert!(
733 received.is_empty(),
734 "rate=0 should drop everything; got {received:?}",
735 );
736
737 server.abort();
738 }
739
740 #[tokio::test]
744 async fn set_cache_level_keeps_stream_open() {
745 let (cache, level_handle, chance_handle) = prepare_cache();
746 let (addr, server) = spawn_server(
747 Arc::clone(&cache),
748 level_handle.clone(),
749 chance_handle.clone(),
750 )
751 .await;
752 let client = connect_client(addr).await;
753 let mut start = Request::new(RequestBody::StartStream);
756 start.id = 100;
757 let mut stream = client.send_streaming(start).unwrap();
758
759 let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
761 .await
762 .unwrap()
763 .unwrap()
764 .unwrap();
765 assert!(
766 matches!(first.body, ResponseBody::CacheLevel(_)),
767 "first message should be CacheLevel, got {:?}",
768 first.body
769 );
770
771 let mut set = Request::new(RequestBody::SetCacheLevel(WireLevelFilter::Off));
774 set.id = 101;
775 let ack = client.send_unary(set).unwrap().await.unwrap();
776 assert!(matches!(ack.body, ResponseBody::Ack));
777
778 let mut next_level: Option<WireLevelFilter> = None;
781 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
782 while tokio::time::Instant::now() < deadline && next_level.is_none() {
783 let item = tokio::time::timeout(Duration::from_millis(200), stream.next()).await;
784 let Ok(Some(Ok(resp))) = item else { continue };
785 match resp.body {
786 ResponseBody::CacheLevel(l) => next_level = Some(l),
787 ResponseBody::CacheChance(_) => continue,
790 ResponseBody::Span(_) => continue,
791 other => panic!("unexpected stream item: {other:?}"),
792 }
793 }
794 assert_eq!(
795 next_level,
796 Some(WireLevelFilter::Off),
797 "stream did not yield the updated CacheLevel (probably ended)",
798 );
799
800 server.abort();
801 }
802
803 #[tokio::test]
808 async fn level_resets_to_off_when_last_console_disconnects() {
809 let (cache, level_handle, chance_handle) = prepare_cache();
810 level_handle.set(LevelFilter::INFO);
812
813 let (addr, server) = spawn_server(
814 Arc::clone(&cache),
815 level_handle.clone(),
816 chance_handle.clone(),
817 )
818 .await;
819
820 {
823 let client = connect_client(addr).await;
824 let mut start = Request::new(RequestBody::StartStream);
825 start.id = 200;
826 let _stream = client.send_streaming(start).unwrap();
827 tokio::time::sleep(Duration::from_millis(50)).await;
830 }
831 tokio::time::sleep(Duration::from_millis(500)).await;
834
835 assert_eq!(
837 level_handle.get(),
838 LevelFilter::OFF,
839 "level should have reset to OFF after last console disconnected",
840 );
841
842 server.abort();
843 }
844
845 use std::time::Instant;
848 use tracing::callsite::{Callsite, DefaultCallsite, Identifier};
849 use tracing::field::FieldSet;
850 use tracing::metadata::Kind;
851 use tracing_cache::{FieldList, SpanRecord};
852
853 static SAMPLING_CALLSITE: DefaultCallsite = {
854 static META: tracing::Metadata<'static> = tracing::Metadata::new(
855 "sampling_test",
856 "sampling::test",
857 tracing::Level::INFO,
858 None,
859 None,
860 None,
861 FieldSet::new(&[], Identifier(&SAMPLING_CALLSITE)),
862 Kind::SPAN,
863 );
864 DefaultCallsite::new(&META)
865 };
866
867 fn synth_span(id: u64, parent_id: Option<u64>) -> SpanRecord {
868 SpanRecord {
869 id,
870 parent_id,
871 metadata: SAMPLING_CALLSITE.metadata(),
872 fields: FieldList::default(),
873 events: Vec::new(),
874 opened_at: Instant::now(),
875 closed_at: Some(Instant::now()),
876 }
877 }
878
879 #[test]
880 fn sampling_passes_rate_one_short_circuits_true() {
881 for id in [0u64, 1, 17, u64::MAX, 0x9E37_79B9_7F4A_7C15] {
883 assert!(sampling_passes(&synth_span(id, None), 1.0));
884 }
885 }
886
887 #[test]
888 fn sampling_passes_rate_zero_short_circuits_false() {
889 for id in [0u64, 1, 17, u64::MAX] {
890 assert!(!sampling_passes(&synth_span(id, None), 0.0));
891 }
892 }
893
894 #[test]
895 fn sampling_passes_is_deterministic_per_root_id() {
896 for id in 1u64..=20 {
900 let r = synth_span(id, None);
901 let first = sampling_passes(&r, 0.5);
902 for _ in 0..3 {
903 assert_eq!(sampling_passes(&r, 0.5), first, "id={id}");
904 }
905 }
906 }
907
908 #[test]
909 fn sampling_passes_children_inherit_parents_root_id_bucket() {
910 let root = synth_span(7, None);
914 let want = sampling_passes(&root, 0.5);
915 for child_id in [100u64, 200, 300, u64::MAX] {
917 let child = synth_span(child_id, Some(7));
918 assert_eq!(sampling_passes(&child, 0.5), want);
919 }
920 }
921
922 #[test]
923 fn sampling_passes_partitions_population_near_target_rate() {
924 let rate = 0.3;
928 let n = 5_000u64;
929 let mut passed = 0usize;
930 for id in 1..=n {
931 if sampling_passes(&synth_span(id, None), rate) {
932 passed += 1;
933 }
934 }
935 let frac = passed as f64 / n as f64;
936 assert!(
937 (frac - rate).abs() < 0.03,
938 "frac={frac} rate={rate} — hash distribution drifted",
939 );
940 }
941
942 #[tokio::test]
945 async fn set_sampling_rate_rejects_out_of_range() {
946 let (cache, level_handle, chance_handle) = prepare_cache();
947 let (addr, server) = spawn_server(
948 Arc::clone(&cache),
949 level_handle.clone(),
950 chance_handle.clone(),
951 )
952 .await;
953 let client = connect_client(addr).await;
954
955 for bad in [1.5_f64, -0.1, f64::NAN] {
956 let resp = client
957 .send_unary(Request::new(RequestBody::SetSamplingRate(bad)))
958 .unwrap()
959 .await
960 .unwrap();
961 match resp.body {
962 ResponseBody::Error(msg) => {
963 assert!(
964 msg.contains("sampling rate"),
965 "unexpected error message for {bad}: {msg}",
966 );
967 }
968 other => panic!("expected Error for rate={bad}, got {other:?}"),
969 }
970 }
971 server.abort();
972 }
973}