1use std::collections::HashMap;
4use std::net::SocketAddr;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::{Arc, RwLock};
7use std::time::Duration;
8
9use protosocket::TcpSocketListener;
10use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
11use protosocket_rpc::Message;
12use protosocket_rpc::server::{ConnectionService, RpcResponder, SocketRpcServer, SocketService};
13use tokio::sync::watch;
14use tracing::metadata::LevelFilter;
15use tracing_cache::{ChanceHandle, EnabledPredicate, LevelHandle, SpanCache, SpanRecord};
16
17use crate::protocol::{Request, RequestBody, Response, WireLevel, WireLevelFilter};
18use crate::wire::{TimeBase, span_to_wire};
19
20type ServerCodec = (MessagePackSerializer<Response>, MessagePackDecoder<Request>);
23
24const STREAM_POLL_INTERVAL_INITIAL: Duration = Duration::from_millis(50);
32const STREAM_BATCH: usize = 4096;
38const STREAM_TARGET_BATCH: usize = 32;
44const STREAM_MIN_INTERVAL: Duration = Duration::ZERO;
49const STREAM_MAX_INTERVAL: Duration = Duration::from_millis(50);
50const STREAM_ADJUST_RATIO: f64 = 0.2; #[derive(Debug, Default)]
55struct StreamState {
56 streaming: bool,
57 min_level: Option<WireLevel>,
58 sampling_rate: f64,
59 root_filter: Option<String>,
61}
62
63impl StreamState {
64 fn new() -> Self {
65 Self {
66 streaming: false,
67 min_level: None,
68 sampling_rate: 1.0,
69 root_filter: None,
70 }
71 }
72}
73
74#[derive(Clone)]
87pub(crate) struct CacheLevelBroadcast {
88 level_handle: LevelHandle,
89 level_tx: watch::Sender<WireLevelFilter>,
90 chance_handle: ChanceHandle,
91 chance_tx: watch::Sender<f64>,
92 clear_cache: Arc<dyn Fn() + Send + Sync>,
97 active_streams: Arc<AtomicUsize>,
98}
99
100impl CacheLevelBroadcast {
101 pub fn new(
102 level_handle: LevelHandle,
103 chance_handle: ChanceHandle,
104 clear_cache: Arc<dyn Fn() + Send + Sync>,
105 ) -> Self {
106 let initial_level = WireLevelFilter::from_tracing(level_handle.get());
107 let initial_chance = chance_handle.get();
108 let (level_tx, _) = watch::channel(initial_level);
109 let (chance_tx, _) = watch::channel(initial_chance);
110 Self {
111 level_handle,
112 level_tx,
113 chance_handle,
114 chance_tx,
115 clear_cache,
116 active_streams: Arc::new(AtomicUsize::new(0)),
117 }
118 }
119
120 fn set_level(&self, filter: WireLevelFilter) {
121 if filter == WireLevelFilter::Off {
122 (self.clear_cache)();
123 }
124 self.level_handle.set(filter.to_tracing());
125 let _ = self.level_tx.send(filter);
126 }
127
128 fn set_chance(&self, pct: f64) {
129 let pct = if pct.is_nan() {
133 0.0
134 } else {
135 pct.clamp(0.0, 100.0)
136 };
137 self.chance_handle.set(pct);
138 let _ = self.chance_tx.send(pct);
139 }
140
141 fn subscribe_level(&self) -> watch::Receiver<WireLevelFilter> {
142 self.level_tx.subscribe()
143 }
144
145 fn subscribe_chance(&self) -> watch::Receiver<f64> {
146 self.chance_tx.subscribe()
147 }
148
149 fn enter_stream(&self) -> StreamGuard {
158 self.active_streams.fetch_add(1, Ordering::SeqCst);
159 StreamGuard {
160 broadcast: self.clone(),
161 }
162 }
163}
164
165pub(crate) struct StreamGuard {
171 broadcast: CacheLevelBroadcast,
172}
173
174impl Drop for StreamGuard {
175 fn drop(&mut self) {
176 let prev = self.broadcast.active_streams.fetch_sub(1, Ordering::SeqCst);
177 if prev == 1 {
178 (self.broadcast.clear_cache)();
183 self.broadcast.level_handle.set(LevelFilter::OFF);
184 let _ = self.broadcast.level_tx.send(WireLevelFilter::Off);
185 self.broadcast.chance_handle.set(100.0);
186 let _ = self.broadcast.chance_tx.send(100.0);
187 }
188 }
189}
190
191pub(crate) struct ConnectionState<P: EnabledPredicate> {
196 cache: Arc<SpanCache<P>>,
197 base: TimeBase,
198 state: Arc<RwLock<StreamState>>,
199 level_bus: CacheLevelBroadcast,
200 stream_guard: Option<StreamGuard>,
207 root_decisions: Arc<RwLock<HashMap<u64, bool>>>,
211}
212
213impl<P: EnabledPredicate> ConnectionState<P> {
214 fn new(cache: Arc<SpanCache<P>>, base: TimeBase, level_bus: CacheLevelBroadcast) -> Self {
215 Self {
216 cache,
217 base,
218 state: Arc::new(RwLock::new(StreamState::new())),
219 level_bus,
220 stream_guard: None,
221 root_decisions: Arc::new(RwLock::new(HashMap::new())),
222 }
223 }
224}
225
226impl<P: EnabledPredicate> ConnectionService for ConnectionState<P> {
227 type Request = Request;
228 type Response = Response;
229
230 #[allow(clippy::expect_used, reason = "poisoned lock")]
231 fn new_rpc(&mut self, msg: Request, responder: RpcResponder<'_, Response>) {
232 let request_id = msg.message_id();
236 match msg.body {
237 RequestBody::StartStream => {
238 self.state
239 .write()
240 .expect("lock must not be poisoned")
241 .streaming = true;
242 if self.stream_guard.is_none() {
247 self.stream_guard = Some(self.level_bus.enter_stream());
248 }
249 let cache = Arc::clone(&self.cache);
250 let state = Arc::clone(&self.state);
251 let roots = Arc::clone(&self.root_decisions);
252 let base = self.base;
253 let level_rx = self.level_bus.subscribe_level();
254 let chance_rx = self.level_bus.subscribe_chance();
255 tokio::spawn(responder.stream(span_stream(
256 cache, state, roots, base, level_rx, chance_rx, request_id,
257 )));
258 }
259 RequestBody::StopStream => {
260 self.state
261 .write()
262 .expect("lock must not be poisoned")
263 .streaming = false;
264 responder.immediate(Response::ack().with_id(request_id));
265 }
266 RequestBody::SetLevel(level) => {
267 self.state
268 .write()
269 .expect("lock must not be poisoned")
270 .min_level = Some(level);
271 responder.immediate(Response::ack().with_id(request_id));
272 }
273 RequestBody::SetCacheLevel(filter) => {
274 self.level_bus.set_level(filter);
275 responder.immediate(Response::ack().with_id(request_id));
276 }
277 RequestBody::SetCacheChance(pct) => {
278 self.level_bus.set_chance(pct);
279 responder.immediate(Response::ack().with_id(request_id));
280 }
281 RequestBody::SetSamplingRate(rate) => {
282 if !(0.0..=1.0).contains(&rate) || rate.is_nan() {
283 responder.immediate(
284 Response::error(format!("sampling rate {rate} out of range [0.0, 1.0]"))
285 .with_id(request_id),
286 );
287 return;
288 }
289 self.state
290 .write()
291 .expect("lock must not be poisoned")
292 .sampling_rate = rate;
293 self.root_decisions
294 .write()
295 .expect("lock must not be poisoned")
296 .clear();
297 responder.immediate(Response::ack().with_id(request_id));
298 }
299 RequestBody::SetFilter(f) => {
300 self.state
301 .write()
302 .expect("lock must not be poisoned")
303 .root_filter = f;
304 self.root_decisions
305 .write()
306 .expect("lock must not be poisoned")
307 .clear();
308 responder.immediate(Response::ack().with_id(request_id));
309 }
310 RequestBody::Noop => {}
311 }
312 }
313}
314
315fn span_stream<P: EnabledPredicate>(
325 cache: Arc<SpanCache<P>>,
326 state: Arc<RwLock<StreamState>>,
327 roots: Arc<RwLock<HashMap<u64, bool>>>,
328 base: TimeBase,
329 mut level_rx: watch::Receiver<WireLevelFilter>,
330 mut chance_rx: watch::Receiver<f64>,
331 request_id: u64,
332) -> impl futures_core::Stream<Item = Response> {
333 async_stream::stream! {
334 let initial_level = *level_rx.borrow_and_update();
337 yield Response::cache_level(initial_level).with_id(request_id);
338 let initial_chance = *chance_rx.borrow_and_update();
339 yield Response::cache_chance(initial_chance).with_id(request_id);
340
341 let mut cursor: u64 = 0;
342 let mut interval = STREAM_POLL_INTERVAL_INITIAL;
343 loop {
344 if interval.is_zero() {
345 if level_rx.has_changed().unwrap_or(false) {
350 let lvl = *level_rx.borrow_and_update();
351 yield Response::cache_level(lvl).with_id(request_id);
352 continue;
353 }
354 if chance_rx.has_changed().unwrap_or(false) {
355 let pct = *chance_rx.borrow_and_update();
356 yield Response::cache_chance(pct).with_id(request_id);
357 continue;
358 }
359 tokio::task::yield_now().await;
360 } else {
361 tokio::select! {
362 changed = level_rx.changed() => {
363 if changed.is_err() { break; }
364 let lvl = *level_rx.borrow_and_update();
365 yield Response::cache_level(lvl).with_id(request_id);
366 continue;
367 }
368 changed = chance_rx.changed() => {
369 if changed.is_err() { break; }
370 let pct = *chance_rx.borrow_and_update();
371 yield Response::cache_chance(pct).with_id(request_id);
372 continue;
373 }
374 _ = tokio::time::sleep(interval) => {}
375 }
376 }
377 let (streaming, min_level, sampling_rate, root_filter) = {
380 #[allow(clippy::expect_used, reason = "poisoned lock")]
381 let s = state.read().expect("lock must not be poisoned");
382 (s.streaming, s.min_level, s.sampling_rate, s.root_filter.clone())
383 };
384 if !streaming {
385 continue;
388 }
389 let batch = cache.page(cursor, STREAM_BATCH);
390 let count = batch.len();
391 for record in batch {
392 cursor = record.id;
393 if let Some(min) = min_level
394 && !level_at_least(record.metadata.level(), min)
395 {
396 continue;
397 }
398 if !sampling_passes(&record, sampling_rate) {
399 continue;
400 }
401 if !filter_passes(&record, &root_filter, &roots) {
402 continue;
403 }
404 yield Response::span(span_to_wire(&record, base)).with_id(request_id);
405 }
406 interval = adjust_interval(interval, count);
407 }
408 }
409}
410
411fn adjust_interval(current: Duration, count: usize) -> Duration {
417 if current.is_zero() {
424 return if count >= STREAM_TARGET_BATCH {
425 Duration::ZERO
426 } else {
427 Duration::from_micros(1)
428 };
429 }
430 let ratio = STREAM_TARGET_BATCH as f64 / count.max(1) as f64;
431 let raw = current.mul_f64(ratio);
432 let max_up = current.mul_f64(1.0 + STREAM_ADJUST_RATIO);
433 let min_down = current.mul_f64(1.0 - STREAM_ADJUST_RATIO);
434 let clamped = raw.clamp(min_down, max_up);
435 let bounded = clamped.clamp(STREAM_MIN_INTERVAL, STREAM_MAX_INTERVAL);
436 if bounded < Duration::from_micros(1) {
442 Duration::ZERO
443 } else {
444 bounded
445 }
446}
447
448fn level_at_least(record_level: &tracing::Level, floor: WireLevel) -> bool {
453 record_level <= &floor.to_tracing()
454}
455
456fn sampling_passes(record: &SpanRecord, rate: f64) -> bool {
461 if rate >= 1.0 {
462 return true;
463 }
464 if rate <= 0.0 {
465 return false;
466 }
467 let bucket_id = record.parent_id.unwrap_or(record.id);
469 let mut x = bucket_id.wrapping_mul(0x9E37_79B9_7F4A_7C15);
471 x ^= x >> 33;
472 x = x.wrapping_mul(0xC2B2_AE3D_27D4_EB4F);
473 x ^= x >> 29;
474 let frac = (x as f64) / (u64::MAX as f64);
475 frac < rate
476}
477
478fn filter_passes(
482 record: &SpanRecord,
483 filter: &Option<String>,
484 roots: &Arc<RwLock<HashMap<u64, bool>>>,
485) -> bool {
486 let needle = match filter {
487 None => return true,
488 Some(s) if s.is_empty() => return true,
489 Some(s) => s.as_str(),
490 };
491 #[allow(clippy::expect_used, reason = "poisoned lock")]
497 let Some(parent_id) = record.parent_id else {
498 let decision = root_matches(record, needle);
499 roots
500 .write()
501 .expect("lock must not be poisoned")
502 .insert(record.id, decision);
503 return decision;
504 };
505 let parent_decision = {
507 #[allow(clippy::expect_used, reason = "poisoned lock")]
508 let memo = roots.read().expect("lock must not be poisoned");
509 memo.get(&parent_id).copied()
510 };
511 let decision = parent_decision.unwrap_or(false);
512 #[allow(clippy::expect_used, reason = "poisoned lock")]
514 roots
515 .write()
516 .expect("lock must not be poisoned")
517 .insert(record.id, decision);
518 decision
519}
520
521fn root_matches(record: &SpanRecord, needle: &str) -> bool {
522 if record.metadata.name().contains(needle) {
523 return true;
524 }
525 record.fields.iter().any(|(_, v)| v.contains(needle))
526}
527
528struct Service<P: EnabledPredicate> {
531 cache: Arc<SpanCache<P>>,
532 base: TimeBase,
533 level_bus: CacheLevelBroadcast,
534}
535
536impl<P: EnabledPredicate> SocketService for Service<P> {
537 type Codec = ServerCodec;
538 type ConnectionService = ConnectionState<P>;
539 type SocketListener = TcpSocketListener;
540
541 fn codec(&self) -> Self::Codec {
542 (
543 MessagePackSerializer::default(),
544 MessagePackDecoder::default(),
545 )
546 }
547
548 fn new_stream_service(
549 &self,
550 _stream: &<Self::SocketListener as protosocket::SocketListener>::Stream,
551 ) -> Self::ConnectionService {
552 ConnectionState::new(Arc::clone(&self.cache), self.base, self.level_bus.clone())
553 }
554}
555
556#[derive(Debug)]
558pub enum ServeError {
559 Io(std::io::Error),
560 Rpc(protosocket_rpc::Error),
561}
562impl std::fmt::Display for ServeError {
563 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564 match self {
565 ServeError::Io(e) => write!(f, "io: {e}"),
566 ServeError::Rpc(e) => write!(f, "rpc: {e}"),
567 }
568 }
569}
570impl std::error::Error for ServeError {}
571impl From<std::io::Error> for ServeError {
572 fn from(e: std::io::Error) -> Self {
573 ServeError::Io(e)
574 }
575}
576impl From<protosocket_rpc::Error> for ServeError {
577 fn from(e: protosocket_rpc::Error) -> Self {
578 ServeError::Rpc(e)
579 }
580}
581
582pub async fn serve<P: EnabledPredicate>(
591 cache: Arc<SpanCache<P>>,
592 level_handle: LevelHandle,
593 chance_handle: ChanceHandle,
594 addr: SocketAddr,
595) -> Result<(), ServeError> {
596 let listener = TcpSocketListener::listen(addr, 1024, None)?;
598
599 let clear_cache: Arc<dyn Fn() + Send + Sync> = {
600 let cache = Arc::clone(&cache);
601 Arc::new(move || cache.clear())
602 };
603 let service = Service {
604 cache,
605 base: TimeBase::now(),
606 level_bus: CacheLevelBroadcast::new(level_handle, chance_handle, clear_cache),
607 };
608 let server: SocketRpcServer<Service<P>, _> = SocketRpcServer::new(
609 listener,
610 service,
611 16 * 1024 * 1024,
612 64 * 1024,
613 4096,
614 )?;
615 server.await?;
616 Ok(())
617}
618
619#[cfg(test)]
622mod tests {
623 use super::*;
624 use std::net::TcpListener as StdTcpListener;
625 use std::time::Duration;
626
627 use futures::StreamExt;
628 use protosocket_messagepack::{MessagePackDecoder, MessagePackSerializer};
629 use protosocket_rpc::client::{self, Configuration, RpcClient, TcpStreamConnector};
630 use tracing_cache::{ChancePredicate, SpanCache};
631
632 use crate::protocol::{ResponseBody, WireLevel};
633
634 type ClientCodec = (MessagePackSerializer<Request>, MessagePackDecoder<Response>);
635
636 #[test]
639 fn adjust_interval_holds_steady_at_target() {
640 let i = Duration::from_millis(50);
641 assert_eq!(adjust_interval(i, STREAM_TARGET_BATCH), i);
643 }
644
645 #[test]
646 fn adjust_interval_speeds_up_when_over_target_capped_at_20pct() {
647 let i = Duration::from_millis(25);
650 let next = adjust_interval(i, STREAM_TARGET_BATCH * 100);
652 assert_eq!(next, Duration::from_millis(20));
653 }
654
655 #[test]
656 fn adjust_interval_slows_down_when_under_target_capped_at_20pct() {
657 let i = Duration::from_millis(25);
658 let next_zero = adjust_interval(i, 0);
660 assert_eq!(next_zero, Duration::from_millis(30));
661 let next_one = adjust_interval(i, 1);
662 assert_eq!(next_one, Duration::from_millis(30));
663 }
664
665 #[test]
666 fn adjust_interval_takes_ratio_when_inside_20pct_band() {
667 let i = Duration::from_millis(25);
669 let count = STREAM_TARGET_BATCH + STREAM_TARGET_BATCH / 10;
670 let next = adjust_interval(i, count);
671 let expected = i.mul_f64(STREAM_TARGET_BATCH as f64 / count as f64);
674 assert_eq!(next, expected);
675 }
676
677 #[test]
678 fn adjust_interval_clamps_to_min_floor() {
679 let i = STREAM_MIN_INTERVAL;
681 let next = adjust_interval(i, STREAM_TARGET_BATCH * 100);
682 assert_eq!(next, STREAM_MIN_INTERVAL);
683 }
684
685 #[test]
686 fn adjust_interval_escapes_zero_when_batches_undersized() {
687 let next = adjust_interval(Duration::ZERO, 1);
690 assert_eq!(next, Duration::from_micros(1));
691 }
692
693 #[test]
694 fn adjust_interval_clamps_to_max_ceiling() {
695 let i = STREAM_MAX_INTERVAL;
697 let next = adjust_interval(i, 0);
698 assert_eq!(next, STREAM_MAX_INTERVAL);
699 }
700
701 #[test]
702 fn adjust_interval_reaches_min_in_bounded_steps_under_overload() {
703 let mut i = Duration::from_millis(50);
708 let mut steps = 0;
709 while i > STREAM_MIN_INTERVAL && steps < 1000 {
710 i = adjust_interval(i, STREAM_TARGET_BATCH * 1000);
711 steps += 1;
712 }
713 assert_eq!(i, STREAM_MIN_INTERVAL);
714 assert!(steps < 100, "took {steps} steps to reach floor");
715 }
716
717 fn pick_addr() -> SocketAddr {
721 let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
722 let port = listener.local_addr().unwrap().port();
723 drop(listener);
724 format!("127.0.0.1:{port}").parse().unwrap()
725 }
726
727 fn cache_with_spans<F>(
730 f: F,
731 ) -> (
732 Arc<SpanCache<ChancePredicate<tracing_cache::LevelPredicate>>>,
733 LevelHandle,
734 ChanceHandle,
735 )
736 where
737 F: FnOnce(),
738 {
739 let level =
740 tracing_cache::LevelPredicate::with_filter(tracing::metadata::LevelFilter::TRACE);
741 let level_handle = level.handle();
742 let predicate = ChancePredicate::new(level, 100.0);
743 let chance_handle = predicate.handle();
744 let (cache, driver) = SpanCache::with_predicate(1024, predicate);
745 let cache = Arc::new(cache);
746 tracing::subscriber::with_default(Arc::clone(&cache), f);
747 cache.flush_pending();
748 driver.drain_sync();
749 (cache, level_handle, chance_handle)
750 }
751
752 async fn spawn_server<P: EnabledPredicate>(
755 cache: Arc<SpanCache<P>>,
756 level_handle: LevelHandle,
757 chance_handle: ChanceHandle,
758 ) -> (SocketAddr, tokio::task::JoinHandle<()>) {
759 let addr = pick_addr();
760 let server_cache = Arc::clone(&cache);
761 let serve_level = level_handle.clone();
762 let serve_chance = chance_handle.clone();
763 let handle = tokio::spawn(async move {
764 let _ = serve(server_cache, serve_level, serve_chance, addr).await;
766 });
767 for _ in 0..50 {
769 if std::net::TcpStream::connect(addr).is_ok() {
770 return (addr, handle);
771 }
772 tokio::time::sleep(Duration::from_millis(10)).await;
773 }
774 panic!("server never came up on {addr}");
775 }
776
777 async fn connect_client(addr: SocketAddr) -> RpcClient<Request, Response> {
778 let cfg = Configuration::new(TcpStreamConnector);
779 let (rpc_client, conn) = client::connect::<ClientCodec, _>(addr, &cfg).await.unwrap();
780 tokio::spawn(conn);
782 rpc_client
783 }
784
785 async fn collect_spans(
787 stream: &mut (impl futures::Stream<Item = Result<Response, protosocket_rpc::Error>> + Unpin),
788 n: usize,
789 total_timeout: Duration,
790 ) -> Vec<crate::WireSpan> {
791 let mut out = Vec::with_capacity(n);
792 let deadline = tokio::time::Instant::now() + total_timeout;
793 while out.len() < n {
794 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
795 match tokio::time::timeout(remaining, stream.next()).await {
796 Ok(Some(Ok(resp))) => {
797 if let ResponseBody::Span(s) = resp.body {
798 out.push(s);
799 }
800 }
801 Ok(Some(Err(_))) | Ok(None) => break,
802 Err(_) => break,
803 }
804 }
805 out
806 }
807
808 #[tokio::test]
811 async fn start_stream_delivers_closed_spans() {
812 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
813 for _ in 0..3 {
814 let span = tracing::span!(parent: None, tracing::Level::INFO, "test_a");
815 let _g = span.enter();
816 }
817 });
818
819 let (addr, server) = spawn_server(
820 Arc::clone(&cache),
821 level_handle.clone(),
822 chance_handle.clone(),
823 )
824 .await;
825 let client = connect_client(addr).await;
826 let mut stream = client
827 .send_streaming(Request::new(RequestBody::StartStream))
828 .unwrap();
829
830 let received = collect_spans(&mut stream, 3, Duration::from_secs(2)).await;
831 assert_eq!(received.len(), 3);
832 assert!(received.iter().all(|s| s.name == "test_a"));
833 assert!(received.iter().all(|s| s.closed_at_ns.is_some()));
835
836 server.abort();
837 }
838
839 #[tokio::test]
840 async fn stop_stream_halts_delivery() {
841 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
842 for _ in 0..5 {
843 let span = tracing::span!(parent: None, tracing::Level::INFO, "test_b");
844 let _g = span.enter();
845 }
846 });
847
848 let (addr, server) = spawn_server(
849 Arc::clone(&cache),
850 level_handle.clone(),
851 chance_handle.clone(),
852 )
853 .await;
854 let client = connect_client(addr).await;
855 let mut stream = client
856 .send_streaming(Request::new(RequestBody::StartStream))
857 .unwrap();
858
859 let initial = collect_spans(&mut stream, 1, Duration::from_secs(2)).await;
861 assert_eq!(initial.len(), 1);
862
863 let ack = client
865 .send_unary(Request::new(RequestBody::StopStream))
866 .unwrap()
867 .await
868 .unwrap();
869 assert!(matches!(ack.body, ResponseBody::Ack));
870
871 let drained_after_stop = collect_spans(&mut stream, 100, Duration::from_millis(300)).await;
874 assert!(
877 drained_after_stop.len() < 5,
878 "stream did not stop: got {} more spans after StopStream",
879 drained_after_stop.len()
880 );
881
882 server.abort();
883 }
884
885 #[tokio::test]
886 async fn set_level_filters_below_threshold() {
887 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
888 let span_info = tracing::span!(parent: None, tracing::Level::INFO, "info_span");
891 drop(span_info);
892 let span_debug = tracing::span!(parent: None, tracing::Level::DEBUG, "debug_span");
893 drop(span_debug);
894 });
895
896 let (addr, server) = spawn_server(
897 Arc::clone(&cache),
898 level_handle.clone(),
899 chance_handle.clone(),
900 )
901 .await;
902 let client = connect_client(addr).await;
903
904 let ack = client
905 .send_unary(Request::new(RequestBody::SetLevel(WireLevel::Info)))
906 .unwrap()
907 .await
908 .unwrap();
909 assert!(matches!(ack.body, ResponseBody::Ack));
910
911 let mut stream = client
912 .send_streaming(Request::new(RequestBody::StartStream))
913 .unwrap();
914 let received = collect_spans(&mut stream, 5, Duration::from_millis(500)).await;
915
916 let names: Vec<_> = received.iter().map(|s| s.name.as_str()).collect();
917 assert_eq!(names, vec!["info_span"], "got: {names:?}");
918
919 server.abort();
920 }
921
922 #[tokio::test]
923 async fn set_sampling_rate_zero_drops_all() {
924 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
925 for _ in 0..5 {
926 let span = tracing::span!(parent: None, tracing::Level::INFO, "sampled");
927 let _g = span.enter();
928 }
929 });
930
931 let (addr, server) = spawn_server(
932 Arc::clone(&cache),
933 level_handle.clone(),
934 chance_handle.clone(),
935 )
936 .await;
937 let client = connect_client(addr).await;
938
939 client
940 .send_unary(Request::new(RequestBody::SetSamplingRate(0.0)))
941 .unwrap()
942 .await
943 .unwrap();
944 let mut stream = client
945 .send_streaming(Request::new(RequestBody::StartStream))
946 .unwrap();
947
948 let received = collect_spans(&mut stream, 5, Duration::from_millis(400)).await;
949 assert!(
950 received.is_empty(),
951 "rate=0 should drop everything; got {received:?}"
952 );
953
954 server.abort();
955 }
956
957 #[tokio::test]
958 async fn set_filter_matches_root_and_inherits_to_children() {
959 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
960 {
962 let root = tracing::span!(parent: None, tracing::Level::INFO, "alpha");
963 let _g = root.enter();
964 let _child = tracing::span!(tracing::Level::INFO, "alpha_child");
965 }
966 {
967 let root = tracing::span!(parent: None, tracing::Level::INFO, "beta");
968 let _g = root.enter();
969 let _child = tracing::span!(tracing::Level::INFO, "beta_child");
970 }
971 });
972
973 let (addr, server) = spawn_server(
974 Arc::clone(&cache),
975 level_handle.clone(),
976 chance_handle.clone(),
977 )
978 .await;
979 let client = connect_client(addr).await;
980
981 client
982 .send_unary(Request::new(RequestBody::SetFilter(Some(
983 "alpha".to_string(),
984 ))))
985 .unwrap()
986 .await
987 .unwrap();
988 let mut stream = client
989 .send_streaming(Request::new(RequestBody::StartStream))
990 .unwrap();
991
992 let received = collect_spans(&mut stream, 4, Duration::from_millis(500)).await;
993 let mut names: Vec<_> = received.iter().map(|s| s.name.clone()).collect();
994 names.sort();
995 assert_eq!(names, vec!["alpha".to_string(), "alpha_child".to_string()]);
996
997 server.abort();
998 }
999
1000 #[tokio::test]
1004 async fn set_cache_level_keeps_stream_open() {
1005 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
1006 let s = tracing::span!(parent: None, tracing::Level::INFO, "pre_level");
1008 drop(s);
1009 });
1010 let (addr, server) = spawn_server(
1011 Arc::clone(&cache),
1012 level_handle.clone(),
1013 chance_handle.clone(),
1014 )
1015 .await;
1016 let client = connect_client(addr).await;
1017 let mut start = Request::new(RequestBody::StartStream);
1020 start.id = 100;
1021 let mut stream = client.send_streaming(start).unwrap();
1022
1023 let first = tokio::time::timeout(Duration::from_secs(1), stream.next())
1025 .await
1026 .unwrap()
1027 .unwrap()
1028 .unwrap();
1029 assert!(
1030 matches!(first.body, ResponseBody::CacheLevel(_)),
1031 "first message should be CacheLevel, got {:?}",
1032 first.body
1033 );
1034
1035 let mut set = Request::new(RequestBody::SetCacheLevel(WireLevelFilter::Off));
1038 set.id = 101;
1039 let ack = client.send_unary(set).unwrap().await.unwrap();
1040 assert!(matches!(ack.body, ResponseBody::Ack));
1041
1042 let mut next_level: Option<WireLevelFilter> = None;
1045 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
1046 while tokio::time::Instant::now() < deadline && next_level.is_none() {
1047 let item = tokio::time::timeout(Duration::from_millis(200), stream.next()).await;
1048 let Ok(Some(Ok(resp))) = item else { continue };
1049 match resp.body {
1050 ResponseBody::CacheLevel(l) => next_level = Some(l),
1051 ResponseBody::CacheChance(_) => continue,
1054 ResponseBody::Span(_) => continue,
1055 other => panic!("unexpected stream item: {other:?}"),
1056 }
1057 }
1058 assert_eq!(
1059 next_level,
1060 Some(WireLevelFilter::Off),
1061 "stream did not yield the updated CacheLevel (probably ended)",
1062 );
1063
1064 server.abort();
1065 }
1066
1067 #[tokio::test]
1072 async fn level_resets_to_off_when_last_console_disconnects() {
1073 let (cache, level_handle, chance_handle) = cache_with_spans(|| {
1074 let s = tracing::span!(parent: None, tracing::Level::INFO, "anchor");
1075 drop(s);
1076 });
1077 level_handle.set(LevelFilter::INFO);
1079
1080 let (addr, server) = spawn_server(
1081 Arc::clone(&cache),
1082 level_handle.clone(),
1083 chance_handle.clone(),
1084 )
1085 .await;
1086
1087 {
1090 let client = connect_client(addr).await;
1091 let mut start = Request::new(RequestBody::StartStream);
1092 start.id = 200;
1093 let _stream = client.send_streaming(start).unwrap();
1094 tokio::time::sleep(Duration::from_millis(50)).await;
1097 }
1098 tokio::time::sleep(Duration::from_millis(500)).await;
1101
1102 assert_eq!(
1104 level_handle.get(),
1105 LevelFilter::OFF,
1106 "level should have reset to OFF after last console disconnected",
1107 );
1108
1109 server.abort();
1110 }
1111
1112 use std::time::Instant;
1115 use tracing::callsite::{Callsite, DefaultCallsite, Identifier};
1116 use tracing::field::FieldSet;
1117 use tracing::metadata::Kind;
1118 use tracing_cache::{FieldList, SpanRecord};
1119
1120 static SAMPLING_CALLSITE: DefaultCallsite = {
1121 static META: tracing::Metadata<'static> = tracing::Metadata::new(
1122 "sampling_test",
1123 "sampling::test",
1124 tracing::Level::INFO,
1125 None,
1126 None,
1127 None,
1128 FieldSet::new(&[], Identifier(&SAMPLING_CALLSITE)),
1129 Kind::SPAN,
1130 );
1131 DefaultCallsite::new(&META)
1132 };
1133
1134 fn synth_span(id: u64, parent_id: Option<u64>) -> SpanRecord {
1135 SpanRecord {
1136 id,
1137 parent_id,
1138 metadata: SAMPLING_CALLSITE.metadata(),
1139 fields: FieldList::default(),
1140 events: Vec::new(),
1141 opened_at: Instant::now(),
1142 closed_at: Some(Instant::now()),
1143 }
1144 }
1145
1146 #[test]
1147 fn sampling_passes_rate_one_short_circuits_true() {
1148 for id in [0u64, 1, 17, u64::MAX, 0x9E37_79B9_7F4A_7C15] {
1150 assert!(sampling_passes(&synth_span(id, None), 1.0));
1151 }
1152 }
1153
1154 #[test]
1155 fn sampling_passes_rate_zero_short_circuits_false() {
1156 for id in [0u64, 1, 17, u64::MAX] {
1157 assert!(!sampling_passes(&synth_span(id, None), 0.0));
1158 }
1159 }
1160
1161 #[test]
1162 fn sampling_passes_is_deterministic_per_root_id() {
1163 for id in 1u64..=20 {
1167 let r = synth_span(id, None);
1168 let first = sampling_passes(&r, 0.5);
1169 for _ in 0..3 {
1170 assert_eq!(sampling_passes(&r, 0.5), first, "id={id}");
1171 }
1172 }
1173 }
1174
1175 #[test]
1176 fn sampling_passes_children_inherit_parents_root_id_bucket() {
1177 let root = synth_span(7, None);
1181 let want = sampling_passes(&root, 0.5);
1182 for child_id in [100u64, 200, 300, u64::MAX] {
1184 let child = synth_span(child_id, Some(7));
1185 assert_eq!(sampling_passes(&child, 0.5), want);
1186 }
1187 }
1188
1189 #[test]
1190 fn sampling_passes_partitions_population_near_target_rate() {
1191 let rate = 0.3;
1195 let n = 5_000u64;
1196 let mut passed = 0usize;
1197 for id in 1..=n {
1198 if sampling_passes(&synth_span(id, None), rate) {
1199 passed += 1;
1200 }
1201 }
1202 let frac = passed as f64 / n as f64;
1203 assert!(
1204 (frac - rate).abs() < 0.03,
1205 "frac={frac} rate={rate} — hash distribution drifted",
1206 );
1207 }
1208
1209 #[test]
1210 fn filter_passes_descendant_inherits_root_decision_via_memo() {
1211 let filter = Some("alpha".to_string());
1212 let roots: Arc<RwLock<HashMap<u64, bool>>> = Arc::new(RwLock::new(HashMap::new()));
1213 roots.write().unwrap().insert(42, false);
1215
1216 let child = synth_span(43, Some(42));
1219 assert!(!filter_passes(&child, &filter, &roots));
1220 assert_eq!(roots.read().unwrap().get(&43).copied(), Some(false));
1222 }
1223
1224 #[test]
1225 fn filter_passes_root_caches_match_in_memo() {
1226 let filter = Some("sampling".to_string());
1228 let roots: Arc<RwLock<HashMap<u64, bool>>> = Arc::new(RwLock::new(HashMap::new()));
1229 let root = synth_span(10, None);
1230 assert!(filter_passes(&root, &filter, &roots));
1231 assert_eq!(roots.read().unwrap().get(&10).copied(), Some(true));
1232 }
1233
1234 #[test]
1235 fn filter_passes_empty_or_none_filter_accepts_everything() {
1236 let roots: Arc<RwLock<HashMap<u64, bool>>> = Arc::new(RwLock::new(HashMap::new()));
1237 let s = synth_span(1, None);
1238 assert!(filter_passes(&s, &None, &roots));
1239 assert!(filter_passes(&s, &Some(String::new()), &roots));
1240 assert!(roots.read().unwrap().is_empty());
1242 }
1243
1244 #[tokio::test]
1247 async fn set_sampling_rate_rejects_out_of_range() {
1248 let (cache, level_handle, chance_handle) = cache_with_spans(|| {});
1249 let (addr, server) = spawn_server(
1250 Arc::clone(&cache),
1251 level_handle.clone(),
1252 chance_handle.clone(),
1253 )
1254 .await;
1255 let client = connect_client(addr).await;
1256
1257 for bad in [1.5_f64, -0.1, f64::NAN] {
1258 let resp = client
1259 .send_unary(Request::new(RequestBody::SetSamplingRate(bad)))
1260 .unwrap()
1261 .await
1262 .unwrap();
1263 match resp.body {
1264 ResponseBody::Error(msg) => {
1265 assert!(
1266 msg.contains("sampling rate"),
1267 "unexpected error message for {bad}: {msg}",
1268 );
1269 }
1270 other => panic!("expected Error for rate={bad}, got {other:?}"),
1271 }
1272 }
1273 server.abort();
1274 }
1275}