Skip to main content

tael_server/storage/
fanout.rs

1//! `FanoutStore` — the scatter-gather query layer for the sharded topology
2//! (`docs/tael-server-scaling-ha.md` §3, Phase 2). It implements [`Store`] over
3//! N shard `Store`s (typically [`RemoteStore`](super::RemoteStore)s, one per
4//! `tael-server` shard) so the REST/gRPC/CLI layers above the trait are
5//! unchanged — they see one logical store.
6//!
7//! ## Routing vs. fan-out
8//!
9//! The shard key is `trace_id` (the design's choice — it keeps `get_trace` and
10//! `correlate` local). Operations split two ways:
11//!
12//! - **Routed** to the single owning shard via `hash(trace_id) % N`: comment
13//!   reads/writes and the write path (`insert_*`). A batch is grouped by shard
14//!   and each group dispatched to its owner — the routing layer, in code.
15//! - **Fanned out** to all shards then merged: the core reads. `query_*`
16//!   concatenate and re-limit (each shard already returns newest-first, so a
17//!   sort+truncate is a k-way merge); `list_services`/`query_summary`/
18//!   `query_anomalies` re-aggregate, because counts sum but rates/averages do
19//!   not. `get_trace`/`correlate` fan out for correctness under routing-hash
20//!   skew or rebalancing windows (design Open Q #1), short-circuiting once the
21//!   owning shard answers.
22//!
23//! `query_sql` is deliberately **not** distributed (design §3 recommendation
24//! (c)): arbitrary SQL over per-shard DuckDB projections can't be merged
25//! soundly, so it returns an error pointing at a single-node endpoint.
26//!
27//! ## Partial availability
28//!
29//! Fan-out reads tolerate a down shard: results from healthy shards are
30//! returned and the failure logged, matching the HA goal that losing one shard
31//! degrades rather than fails queries. A read errors only when *every* shard
32//! fails. (Routed ops have a single owner, so they surface that owner's error.)
33
34use std::collections::HashMap;
35use std::hash::{Hash, Hasher};
36use std::sync::Arc;
37
38use anyhow::{Result, bail};
39use serde_json::Value;
40
41use super::Store;
42use super::models::{
43    Anomaly, AnomalyReport, CorrelateReport, ErrorOperation, LogQuery, LogRecord, LogSummary,
44    MetricPoint, MetricQuery, MetricSummary, ServiceInfo, ServiceSummary, Span, SummaryReport,
45    TraceComment, TraceQuery, TraceSummary,
46};
47
48/// A [`Store`] that scatters reads across N shard stores and gathers/merges the
49/// results. See the module docs for routing vs. fan-out semantics.
50pub struct FanoutStore {
51    shards: Vec<Arc<dyn Store>>,
52}
53
54/// Deterministic shard selection. `DefaultHasher::new()` is seeded with fixed
55/// keys (unlike `RandomState`), so the same key maps to the same shard across
56/// processes and restarts — a prerequisite for `get_trace` to find the shard
57/// the ingest router wrote to.
58fn shard_index(key: &str, n: usize) -> usize {
59    let mut h = std::collections::hash_map::DefaultHasher::new();
60    key.hash(&mut h);
61    (h.finish() % n as u64) as usize
62}
63
64impl FanoutStore {
65    /// Build a fan-out over the given shards. Requires at least one shard.
66    pub fn new(shards: Vec<Arc<dyn Store>>) -> Result<Self> {
67        if shards.is_empty() {
68            bail!("FanoutStore requires at least one shard");
69        }
70        Ok(Self { shards })
71    }
72
73    fn shard_for(&self, key: &str) -> &Arc<dyn Store> {
74        &self.shards[shard_index(key, self.shards.len())]
75    }
76
77    /// Run `f` against every shard, collecting successes. Logs and tolerates
78    /// per-shard failures; errors only if *all* shards fail.
79    fn fan_out<T>(
80        &self,
81        op: &str,
82        f: impl Fn(&Arc<dyn Store>) -> Result<T>,
83    ) -> Result<Vec<T>> {
84        let mut results = Vec::with_capacity(self.shards.len());
85        let mut last_err = None;
86        for (i, shard) in self.shards.iter().enumerate() {
87            match f(shard) {
88                Ok(v) => results.push(v),
89                Err(e) => {
90                    tracing::warn!(shard = i, op, error = %e, "shard failed; serving partial results");
91                    last_err = Some(e);
92                }
93            }
94        }
95        if results.is_empty()
96            && let Some(e) = last_err
97        {
98            return Err(e.context(format!("all {} shards failed for {op}", self.shards.len())));
99        }
100        Ok(results)
101    }
102}
103
104impl Store for FanoutStore {
105    // ── Spans / traces ──────────────────────────────────────────────
106    fn insert_spans(&self, spans: &[Span]) -> Result<()> {
107        // Route each span to its trace's owning shard. A trace's spans thus all
108        // land together, keeping get_trace/correlate single-shard.
109        route_and_insert(&self.shards, spans, |s| &s.trace_id, |store, batch| {
110            store.insert_spans(batch)
111        })
112    }
113
114    fn query_traces(&self, query: &TraceQuery) -> Result<Vec<Span>> {
115        let limit = query.limit.unwrap_or(100) as usize;
116        let mut all: Vec<Span> = self
117            .fan_out("query_traces", |s| s.query_traces(query))?
118            .into_iter()
119            .flatten()
120            .collect();
121        // Each shard returned newest-first; merge by re-sorting and re-limiting.
122        all.sort_by(|a, b| b.start_time.cmp(&a.start_time));
123        all.truncate(limit);
124        Ok(all)
125    }
126
127    fn get_trace(&self, trace_id: &str) -> Result<Vec<Span>> {
128        // A trace lives on its owning shard, but fan out (short-circuiting on a
129        // hit) so a routing-hash mismatch or a rebalancing window can't drop it.
130        let mut spans = self.shard_for(trace_id).get_trace(trace_id)?;
131        if spans.is_empty() {
132            for (i, shard) in self.shards.iter().enumerate() {
133                match shard.get_trace(trace_id) {
134                    Ok(s) if !s.is_empty() => {
135                        spans = s;
136                        break;
137                    }
138                    Ok(_) => {}
139                    Err(e) => tracing::warn!(shard = i, error = %e, "get_trace shard failed"),
140                }
141            }
142        }
143        // Dedup by span_id in case spans transiently overlap shards.
144        let mut seen = std::collections::HashSet::new();
145        spans.retain(|s| seen.insert(s.span_id.clone()));
146        spans.sort_by_key(|s| s.start_time);
147        Ok(spans)
148    }
149
150    fn list_services(&self) -> Result<Vec<ServiceInfo>> {
151        let per_shard = self.fan_out("list_services", |s| s.list_services())?;
152        Ok(merge_services(per_shard))
153    }
154
155    // ── Comments ── routed to the trace's owning shard ──────────────
156    fn add_comment(
157        &self,
158        trace_id: &str,
159        span_id: Option<&str>,
160        author: &str,
161        body: &str,
162    ) -> Result<TraceComment> {
163        self.shard_for(trace_id)
164            .add_comment(trace_id, span_id, author, body)
165    }
166
167    fn get_comments(&self, trace_id: &str) -> Result<Vec<TraceComment>> {
168        self.shard_for(trace_id).get_comments(trace_id)
169    }
170
171    // ── Logs ────────────────────────────────────────────────────────
172    fn insert_logs(&self, logs: &[LogRecord]) -> Result<()> {
173        // Logs carry trace_id and route the same way; orphan logs (no trace)
174        // shard by service so a service's logs stay co-located-ish.
175        route_and_insert(
176            &self.shards,
177            logs,
178            |l| l.trace_id.as_deref().unwrap_or(&l.service),
179            |store, batch| store.insert_logs(batch),
180        )
181    }
182
183    fn query_logs(&self, query: &LogQuery) -> Result<Vec<LogRecord>> {
184        let limit = query.limit.unwrap_or(100) as usize;
185        let mut all: Vec<LogRecord> = self
186            .fan_out("query_logs", |s| s.query_logs(query))?
187            .into_iter()
188            .flatten()
189            .collect();
190        all.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
191        all.truncate(limit);
192        Ok(all)
193    }
194
195    // ── Metrics ─────────────────────────────────────────────────────
196    fn insert_metrics(&self, metrics: &[MetricPoint]) -> Result<()> {
197        // Metrics carry no trace; shard by name (design §3) so a series stays on
198        // one shard and unique-name counts merge by simple sum.
199        route_and_insert(&self.shards, metrics, |m| &m.name, |store, batch| {
200            store.insert_metrics(batch)
201        })
202    }
203
204    fn query_metrics(&self, query: &MetricQuery) -> Result<Vec<MetricPoint>> {
205        let limit = query.limit.unwrap_or(500) as usize;
206        let mut all: Vec<MetricPoint> = self
207            .fan_out("query_metrics", |s| s.query_metrics(query))?
208            .into_iter()
209            .flatten()
210            .collect();
211        all.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
212        all.truncate(limit);
213        Ok(all)
214    }
215
216    // ── Cross-signal analytics ──────────────────────────────────────
217    fn query_summary(&self, last_seconds: i64, service: Option<&str>) -> Result<SummaryReport> {
218        let per_shard = self.fan_out("query_summary", |s| s.query_summary(last_seconds, service))?;
219        Ok(merge_summaries(per_shard, last_seconds, service))
220    }
221
222    fn query_anomalies(
223        &self,
224        current_seconds: i64,
225        baseline_seconds: i64,
226        service: Option<&str>,
227    ) -> Result<AnomalyReport> {
228        let per_shard = self.fan_out("query_anomalies", |s| {
229            s.query_anomalies(current_seconds, baseline_seconds, service)
230        })?;
231        Ok(merge_anomalies(
232            per_shard,
233            current_seconds,
234            baseline_seconds,
235            service,
236        ))
237    }
238
239    fn query_correlate(&self, trace_id: &str) -> Result<Option<CorrelateReport>> {
240        // The trace is single-shard; try its owner, then fall back to a fan-out.
241        if let Some(r) = self.shard_for(trace_id).query_correlate(trace_id)? {
242            return Ok(Some(r));
243        }
244        for (i, shard) in self.shards.iter().enumerate() {
245            match shard.query_correlate(trace_id) {
246                Ok(Some(r)) => return Ok(Some(r)),
247                Ok(None) => {}
248                Err(e) => tracing::warn!(shard = i, error = %e, "correlate shard failed"),
249            }
250        }
251        Ok(None)
252    }
253
254    fn query_sql(&self, _sql: &str) -> Result<Vec<Value>> {
255        // Arbitrary SQL over per-shard DuckDB projections doesn't distribute
256        // soundly (cross-shard GROUP BY/aggregates). Design §3 recommendation
257        // (c): keep it a single-node power tool.
258        bail!(
259            "query_sql is not distributed across shards; run it directly against a single shard's /api/v1/sql"
260        );
261    }
262
263    // ── Lifecycle ───────────────────────────────────────────────────
264    fn health(&self) -> Result<()> {
265        // Ready if at least one shard answers — the node can still serve partial
266        // results, consistent with fan-out read tolerance. Not-ready only when
267        // every shard is unreachable.
268        let mut healthy = 0usize;
269        for (i, shard) in self.shards.iter().enumerate() {
270            match shard.health() {
271                Ok(()) => healthy += 1,
272                Err(e) => tracing::warn!(shard = i, error = %e, "shard unhealthy"),
273            }
274        }
275        if healthy == 0 {
276            bail!("no healthy shards ({} total)", self.shards.len());
277        }
278        Ok(())
279    }
280
281    fn flush(&self) -> Result<()> {
282        for shard in &self.shards {
283            shard.flush()?;
284        }
285        Ok(())
286    }
287}
288
289/// Group `items` by the shard their key hashes to, then hand each group to its
290/// owning shard in a single call.
291fn route_and_insert<T: Clone>(
292    shards: &[Arc<dyn Store>],
293    items: &[T],
294    key: impl Fn(&T) -> &str,
295    insert: impl Fn(&Arc<dyn Store>, &[T]) -> Result<()>,
296) -> Result<()> {
297    if items.is_empty() {
298        return Ok(());
299    }
300    let n = shards.len();
301    let mut buckets: HashMap<usize, Vec<T>> = HashMap::new();
302    for item in items {
303        let idx = shard_index(key(item), n);
304        buckets.entry(idx).or_default().push(item.clone());
305    }
306    for (idx, batch) in buckets {
307        insert(&shards[idx], &batch)?;
308    }
309    Ok(())
310}
311
312/// Sum disjoint per-shard service rollups. `span_count`/`trace_count` add
313/// directly (trace_ids are disjoint across shards); `avg_duration_ms` and
314/// `error_rate` are recomputed as span-count-weighted means.
315fn merge_services(per_shard: Vec<Vec<ServiceInfo>>) -> Vec<ServiceInfo> {
316    struct Acc {
317        span_count: i64,
318        trace_count: i64,
319        dur_weighted: f64,
320        err_weighted: f64,
321    }
322    let mut acc: HashMap<String, Acc> = HashMap::new();
323    for shard in per_shard {
324        for s in shard {
325            let e = acc.entry(s.name).or_insert(Acc {
326                span_count: 0,
327                trace_count: 0,
328                dur_weighted: 0.0,
329                err_weighted: 0.0,
330            });
331            e.span_count += s.span_count;
332            e.trace_count += s.trace_count;
333            e.dur_weighted += s.avg_duration_ms * s.span_count as f64;
334            e.err_weighted += s.error_rate * s.span_count as f64;
335        }
336    }
337    let mut out: Vec<ServiceInfo> = acc
338        .into_iter()
339        .map(|(name, a)| {
340            let w = a.span_count.max(1) as f64;
341            ServiceInfo {
342                name,
343                span_count: a.span_count,
344                trace_count: a.trace_count,
345                avg_duration_ms: a.dur_weighted / w,
346                error_rate: a.err_weighted / w,
347            }
348        })
349        .collect();
350    out.sort_by(|a, b| b.span_count.cmp(&a.span_count));
351    out
352}
353
354/// Span-count-weighted mean of a per-shard value.
355fn weighted_mean(values: impl Iterator<Item = (f64, i64)>) -> f64 {
356    let mut num = 0.0;
357    let mut den = 0i64;
358    for (v, w) in values {
359        num += v * w as f64;
360        den += w;
361    }
362    if den == 0 { 0.0 } else { num / den as f64 }
363}
364
365/// Merge per-shard summaries. Counts sum; rates and averages recompute from
366/// component sums. Percentiles are span-count-weighted approximations — exact
367/// cross-shard percentiles need a mergeable sketch (t-digest), tracked as
368/// future work; they are an estimate, not a true global quantile.
369fn merge_summaries(
370    per_shard: Vec<SummaryReport>,
371    window_seconds: i64,
372    service: Option<&str>,
373) -> SummaryReport {
374    let mut traces = TraceSummary::default();
375    let mut logs = LogSummary::default();
376    let mut metrics = MetricSummary::default();
377
378    // Re-aggregate top_services and top_error_operations across shards.
379    struct SvcAcc {
380        span_count: i64,
381        err_weighted: f64,
382        p95_weighted: f64,
383    }
384    let mut svc: HashMap<String, SvcAcc> = HashMap::new();
385    let mut errops: HashMap<(String, String), i64> = HashMap::new();
386
387    // Collect (value, weight) pairs for the weighted trace fields.
388    let mut avg_pairs = Vec::new();
389    let mut p50_pairs = Vec::new();
390    let mut p95_pairs = Vec::new();
391    let mut p99_pairs = Vec::new();
392
393    for r in &per_shard {
394        let t = &r.traces;
395        traces.span_count += t.span_count;
396        traces.trace_count += t.trace_count;
397        traces.error_count += t.error_count;
398        traces.max_ms = traces.max_ms.max(t.max_ms);
399        avg_pairs.push((t.avg_ms, t.span_count));
400        p50_pairs.push((t.p50_ms, t.span_count));
401        p95_pairs.push((t.p95_ms, t.span_count));
402        p99_pairs.push((t.p99_ms, t.span_count));
403
404        logs.total += r.logs.total;
405        logs.error += r.logs.error;
406        logs.warn += r.logs.warn;
407        logs.info += r.logs.info;
408        logs.debug += r.logs.debug;
409
410        metrics.point_count += r.metrics.point_count;
411        // Metrics shard by name → names disjoint → unique counts sum.
412        metrics.unique_names += r.metrics.unique_names;
413
414        for s in &r.top_services {
415            let e = svc.entry(s.service.clone()).or_insert(SvcAcc {
416                span_count: 0,
417                err_weighted: 0.0,
418                p95_weighted: 0.0,
419            });
420            e.span_count += s.span_count;
421            e.err_weighted += s.error_rate * s.span_count as f64;
422            e.p95_weighted += s.p95_ms * s.span_count as f64;
423        }
424        for o in &r.top_error_operations {
425            *errops
426                .entry((o.service.clone(), o.operation.clone()))
427                .or_insert(0) += o.error_count;
428        }
429    }
430
431    traces.avg_ms = weighted_mean(avg_pairs.into_iter());
432    traces.p50_ms = weighted_mean(p50_pairs.into_iter());
433    traces.p95_ms = weighted_mean(p95_pairs.into_iter());
434    traces.p99_ms = weighted_mean(p99_pairs.into_iter());
435    traces.error_rate = if traces.span_count > 0 {
436        traces.error_count as f64 / traces.span_count as f64
437    } else {
438        0.0
439    };
440
441    let mut top_services: Vec<ServiceSummary> = svc
442        .into_iter()
443        .map(|(service, a)| {
444            let w = a.span_count.max(1) as f64;
445            ServiceSummary {
446                service,
447                span_count: a.span_count,
448                error_rate: a.err_weighted / w,
449                p95_ms: a.p95_weighted / w,
450            }
451        })
452        .collect();
453    top_services.sort_by(|a, b| b.span_count.cmp(&a.span_count));
454    top_services.truncate(10);
455
456    let mut top_error_operations: Vec<ErrorOperation> = errops
457        .into_iter()
458        .map(|((service, operation), error_count)| ErrorOperation {
459            service,
460            operation,
461            error_count,
462        })
463        .collect();
464    top_error_operations.sort_by(|a, b| b.error_count.cmp(&a.error_count));
465    top_error_operations.truncate(10);
466
467    SummaryReport {
468        window_seconds,
469        service_filter: service.map(str::to_string),
470        traces,
471        top_services,
472        top_error_operations,
473        logs,
474        metrics,
475    }
476}
477
478/// Merge per-shard anomaly reports. Because trace_id sharding spreads a single
479/// service's traffic across all shards, each shard sees only a slice, so these
480/// are merged best-effort: anomalies for the same (service, kind) are collapsed
481/// to the one with the largest |delta| (most significant signal). This is an
482/// approximation — a precise version would recompute current/baseline from raw
483/// per-shard partials.
484fn merge_anomalies(
485    per_shard: Vec<AnomalyReport>,
486    current_seconds: i64,
487    baseline_seconds: i64,
488    service: Option<&str>,
489) -> AnomalyReport {
490    let mut best: HashMap<(String, String), Anomaly> = HashMap::new();
491    for r in per_shard {
492        for a in r.anomalies {
493            let key = (a.service.clone(), a.kind.clone());
494            match best.get(&key) {
495                Some(existing) if existing.delta.abs() >= a.delta.abs() => {}
496                _ => {
497                    best.insert(key, a);
498                }
499            }
500        }
501    }
502    let mut anomalies: Vec<Anomaly> = best.into_values().collect();
503    anomalies.sort_by(|a, b| {
504        b.delta
505            .abs()
506            .partial_cmp(&a.delta.abs())
507            .unwrap_or(std::cmp::Ordering::Equal)
508    });
509    AnomalyReport {
510        current_seconds,
511        baseline_seconds,
512        service_filter: service.map(str::to_string),
513        anomalies,
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::storage::models::{SpanKind, SpanStatus};
521    use chrono::{TimeZone, Utc};
522    use std::collections::HashMap;
523    use std::sync::Mutex;
524
525    /// Minimal in-memory `Store` for exercising fan-out routing and merge logic
526    /// without a network. Implements the span surface the tests touch; the rest
527    /// return empty/default.
528    #[derive(Default)]
529    struct MockStore {
530        spans: Mutex<Vec<Span>>,
531    }
532
533    impl Store for MockStore {
534        fn insert_spans(&self, spans: &[Span]) -> Result<()> {
535            self.spans.lock().unwrap().extend_from_slice(spans);
536            Ok(())
537        }
538        fn query_traces(&self, query: &TraceQuery) -> Result<Vec<Span>> {
539            let mut v = self.spans.lock().unwrap().clone();
540            v.sort_by(|a, b| b.start_time.cmp(&a.start_time)); // newest-first contract
541            v.truncate(query.limit.unwrap_or(100) as usize);
542            Ok(v)
543        }
544        fn get_trace(&self, trace_id: &str) -> Result<Vec<Span>> {
545            Ok(self
546                .spans
547                .lock()
548                .unwrap()
549                .iter()
550                .filter(|s| s.trace_id == trace_id)
551                .cloned()
552                .collect())
553        }
554        fn list_services(&self) -> Result<Vec<ServiceInfo>> {
555            let spans = self.spans.lock().unwrap();
556            let mut by: HashMap<String, (i64, f64, i64)> = HashMap::new();
557            for s in spans.iter() {
558                let e = by.entry(s.service.clone()).or_insert((0, 0.0, 0));
559                e.0 += 1;
560                e.1 += s.duration_ms;
561                if s.status == SpanStatus::Error {
562                    e.2 += 1;
563                }
564            }
565            Ok(by
566                .into_iter()
567                .map(|(name, (n, dur, err))| ServiceInfo {
568                    name,
569                    span_count: n,
570                    trace_count: n,
571                    avg_duration_ms: dur / n as f64,
572                    error_rate: err as f64 / n as f64,
573                })
574                .collect())
575        }
576        fn add_comment(
577            &self,
578            _t: &str,
579            _s: Option<&str>,
580            _a: &str,
581            _b: &str,
582        ) -> Result<TraceComment> {
583            bail!("unused")
584        }
585        fn get_comments(&self, _t: &str) -> Result<Vec<TraceComment>> {
586            Ok(vec![])
587        }
588        fn insert_logs(&self, _l: &[LogRecord]) -> Result<()> {
589            Ok(())
590        }
591        fn query_logs(&self, _q: &LogQuery) -> Result<Vec<LogRecord>> {
592            Ok(vec![])
593        }
594        fn insert_metrics(&self, _m: &[MetricPoint]) -> Result<()> {
595            Ok(())
596        }
597        fn query_metrics(&self, _q: &MetricQuery) -> Result<Vec<MetricPoint>> {
598            Ok(vec![])
599        }
600        fn query_summary(&self, _l: i64, _s: Option<&str>) -> Result<SummaryReport> {
601            bail!("unused")
602        }
603        fn query_anomalies(&self, _c: i64, _b: i64, _s: Option<&str>) -> Result<AnomalyReport> {
604            bail!("unused")
605        }
606        fn query_correlate(&self, _t: &str) -> Result<Option<CorrelateReport>> {
607            Ok(None)
608        }
609        fn query_sql(&self, _s: &str) -> Result<Vec<Value>> {
610            bail!("unused")
611        }
612    }
613
614    fn span(trace: &str, sid: &str, svc: &str, secs: i64, status: SpanStatus) -> Span {
615        let t = Utc.timestamp_opt(secs, 0).unwrap();
616        Span {
617            trace_id: trace.into(),
618            span_id: sid.into(),
619            parent_span_id: None,
620            service: svc.into(),
621            operation: "op".into(),
622            start_time: t,
623            end_time: t,
624            duration_ms: 10.0,
625            status,
626            attributes: HashMap::new(),
627            events: vec![],
628            kind: SpanKind::Internal,
629            llm: None,
630        }
631    }
632
633    fn fanout(n: usize) -> (FanoutStore, Vec<Arc<MockStore>>) {
634        let mocks: Vec<Arc<MockStore>> = (0..n).map(|_| Arc::new(MockStore::default())).collect();
635        let shards: Vec<Arc<dyn Store>> = mocks.iter().map(|m| m.clone() as Arc<dyn Store>).collect();
636        (FanoutStore::new(shards).unwrap(), mocks)
637    }
638
639    #[test]
640    fn new_rejects_empty_shards() {
641        assert!(FanoutStore::new(vec![]).is_err());
642    }
643
644    #[test]
645    fn shard_index_is_deterministic_and_bounded() {
646        for key in ["trace-abc", "trace-def", "svc-1", ""] {
647            let a = shard_index(key, 4);
648            let b = shard_index(key, 4);
649            assert_eq!(a, b, "hashing must be stable across calls");
650            assert!(a < 4);
651        }
652    }
653
654    #[test]
655    fn insert_spans_routes_whole_trace_to_one_shard() {
656        let (fo, mocks) = fanout(3);
657        // Two traces, each with two spans, in one mixed batch.
658        fo.insert_spans(&[
659            span("t1", "a", "api", 1, SpanStatus::Ok),
660            span("t2", "b", "api", 2, SpanStatus::Ok),
661            span("t1", "c", "db", 3, SpanStatus::Ok),
662            span("t2", "d", "db", 4, SpanStatus::Ok),
663        ])
664        .unwrap();
665
666        // Every span of a trace lands on exactly one shard (the trace's owner).
667        for tid in ["t1", "t2"] {
668            let owners: Vec<usize> = mocks
669                .iter()
670                .enumerate()
671                .filter(|(_, m)| m.spans.lock().unwrap().iter().any(|s| s.trace_id == tid))
672                .map(|(i, _)| i)
673                .collect();
674            assert_eq!(owners.len(), 1, "trace {tid} split across shards: {owners:?}");
675            let owned = mocks[owners[0]].spans.lock().unwrap();
676            assert_eq!(owned.iter().filter(|s| s.trace_id == tid).count(), 2);
677        }
678    }
679
680    #[test]
681    fn query_traces_merges_newest_first_and_re_limits() {
682        let (fo, _m) = fanout(3);
683        // Insert 6 spans across distinct traces at increasing timestamps.
684        for i in 0..6 {
685            fo.insert_spans(&[span(&format!("t{i}"), "s", "api", i, SpanStatus::Ok)])
686                .unwrap();
687        }
688        let q = TraceQuery {
689            limit: Some(3),
690            ..Default::default()
691        };
692        let got = fo.query_traces(&q).unwrap();
693        assert_eq!(got.len(), 3, "must re-limit after gathering");
694        // Newest-first global order: t5, t4, t3.
695        let starts: Vec<i64> = got.iter().map(|s| s.start_time.timestamp()).collect();
696        assert_eq!(starts, vec![5, 4, 3]);
697    }
698
699    #[test]
700    fn get_trace_finds_trace_on_its_owning_shard() {
701        let (fo, _m) = fanout(4);
702        fo.insert_spans(&[
703            span("trace-x", "1", "api", 1, SpanStatus::Ok),
704            span("trace-x", "2", "db", 2, SpanStatus::Ok),
705        ])
706        .unwrap();
707        let spans = fo.get_trace("trace-x").unwrap();
708        assert_eq!(spans.len(), 2);
709        assert!(spans.iter().all(|s| s.trace_id == "trace-x"));
710        assert!(fo.get_trace("nope").unwrap().is_empty());
711    }
712
713    #[test]
714    fn list_services_aggregates_across_shards() {
715        let (fo, _m) = fanout(3);
716        // 3 "api" spans (1 error) spread by trace_id across shards.
717        fo.insert_spans(&[
718            span("ta", "1", "api", 1, SpanStatus::Ok),
719            span("tb", "2", "api", 2, SpanStatus::Ok),
720            span("tc", "3", "api", 3, SpanStatus::Error),
721        ])
722        .unwrap();
723        let svcs = fo.list_services().unwrap();
724        let api = svcs.iter().find(|s| s.name == "api").unwrap();
725        assert_eq!(api.span_count, 3, "counts must sum across shards");
726        assert_eq!(api.trace_count, 3);
727        assert!((api.error_rate - 1.0 / 3.0).abs() < 1e-9, "error_rate recomputed from sums");
728        assert!((api.avg_duration_ms - 10.0).abs() < 1e-9);
729    }
730
731    #[test]
732    fn query_sql_is_not_distributed() {
733        let (fo, _m) = fanout(2);
734        assert!(fo.query_sql("SELECT 1").is_err());
735    }
736
737    #[test]
738    fn merge_services_sums_counts_and_weights_rates() {
739        let shards = vec![
740            vec![ServiceInfo {
741                name: "api".into(),
742                span_count: 10,
743                trace_count: 4,
744                avg_duration_ms: 20.0,
745                error_rate: 0.1,
746            }],
747            vec![ServiceInfo {
748                name: "api".into(),
749                span_count: 30,
750                trace_count: 6,
751                avg_duration_ms: 40.0,
752                error_rate: 0.5,
753            }],
754        ];
755        let merged = merge_services(shards);
756        let api = &merged[0];
757        assert_eq!(api.span_count, 40);
758        assert_eq!(api.trace_count, 10);
759        // (20*10 + 40*30) / 40 = 35
760        assert!((api.avg_duration_ms - 35.0).abs() < 1e-9);
761        // (0.1*10 + 0.5*30) / 40 = 0.4
762        assert!((api.error_rate - 0.4).abs() < 1e-9);
763    }
764}