Skip to main content

tael_server/storage/
fanout.rs

1//! `FanoutStore` — the scatter-gather query layer for the sharded topology
2//! (`docs/tael-server-scaling-ha.md` §3, Phase 2). It implements [`Store`] over
3//! N shard `Store`s (typically [`RemoteStore`](super::RemoteStore)s, one per
4//! `tael-server` shard) so the REST/gRPC/CLI layers above the trait are
5//! unchanged — they see one logical store.
6//!
7//! ## Routing vs. fan-out
8//!
9//! The shard key is `trace_id` (the design's choice — it keeps `get_trace` and
10//! `correlate` local). Operations split two ways:
11//!
12//! - **Routed** to the single owning shard via `hash(trace_id) % N`: comment
13//!   reads/writes and the write path (`insert_*`). A batch is grouped by shard
14//!   and each group dispatched to its owner — the routing layer, in code.
15//! - **Fanned out** to all shards then merged: the core reads. `query_*`
16//!   concatenate and re-limit (each shard already returns newest-first, so a
17//!   sort+truncate is a k-way merge); `list_services`/`query_summary`/
18//!   `query_anomalies` re-aggregate, because counts sum but rates/averages do
19//!   not. `get_trace`/`correlate` fan out for correctness under routing-hash
20//!   skew or rebalancing windows (design Open Q #1), short-circuiting once the
21//!   owning shard answers.
22//!
23//! `query_sql` is deliberately **not** distributed (design §3 recommendation
24//! (c)): arbitrary SQL over per-shard DuckDB projections can't be merged
25//! soundly, so it returns an error pointing at a single-node endpoint.
26//!
27//! ## Partial availability
28//!
29//! Fan-out reads tolerate a down shard: results from healthy shards are
30//! returned and the failure logged, matching the HA goal that losing one shard
31//! degrades rather than fails queries. A read errors only when *every* shard
32//! fails. (Routed ops have a single owner, so they surface that owner's error.)
33
34use std::collections::HashMap;
35use std::hash::{Hash, Hasher};
36use std::sync::Arc;
37
38use anyhow::{Result, bail};
39use serde_json::Value;
40
41use super::Store;
42use super::models::{
43    Anomaly, AnomalyReport, CorrelateReport, ErrorOperation, LogQuery, LogRecord, LogSummary,
44    MetricPoint, MetricQuery, MetricSummary, ServiceInfo, ServiceSummary, Span, SummaryReport,
45    TraceComment, TraceQuery, TraceSummary,
46};
47
48/// A [`Store`] that scatters reads across N shard stores and gathers/merges the
49/// results. See the module docs for routing vs. fan-out semantics.
50pub struct FanoutStore {
51    shards: Vec<Arc<dyn Store>>,
52}
53
54/// Deterministic shard selection. `DefaultHasher::new()` is seeded with fixed
55/// keys (unlike `RandomState`), so the same key maps to the same shard across
56/// processes and restarts — a prerequisite for `get_trace` to find the shard
57/// the ingest router wrote to.
58fn shard_index(key: &str, n: usize) -> usize {
59    let mut h = std::collections::hash_map::DefaultHasher::new();
60    key.hash(&mut h);
61    (h.finish() % n as u64) as usize
62}
63
64impl FanoutStore {
65    /// Build a fan-out over the given shards. Requires at least one shard.
66    pub fn new(shards: Vec<Arc<dyn Store>>) -> Result<Self> {
67        if shards.is_empty() {
68            bail!("FanoutStore requires at least one shard");
69        }
70        Ok(Self { shards })
71    }
72
73    fn shard_for(&self, key: &str) -> &Arc<dyn Store> {
74        &self.shards[shard_index(key, self.shards.len())]
75    }
76
77    /// Run `f` against every shard, collecting successes. Logs and tolerates
78    /// per-shard failures; errors only if *all* shards fail.
79    fn fan_out<T>(&self, op: &str, f: impl Fn(&Arc<dyn Store>) -> Result<T>) -> Result<Vec<T>> {
80        let mut results = Vec::with_capacity(self.shards.len());
81        let mut last_err = None;
82        for (i, shard) in self.shards.iter().enumerate() {
83            match f(shard) {
84                Ok(v) => results.push(v),
85                Err(e) => {
86                    tracing::warn!(shard = i, op, error = %e, "shard failed; serving partial results");
87                    last_err = Some(e);
88                }
89            }
90        }
91        if results.is_empty()
92            && let Some(e) = last_err
93        {
94            return Err(e.context(format!("all {} shards failed for {op}", self.shards.len())));
95        }
96        Ok(results)
97    }
98}
99
100impl Store for FanoutStore {
101    // ── Spans / traces ──────────────────────────────────────────────
102    fn insert_spans(&self, spans: &[Span]) -> Result<()> {
103        // Route each span to its trace's owning shard. A trace's spans thus all
104        // land together, keeping get_trace/correlate single-shard.
105        route_and_insert(
106            &self.shards,
107            spans,
108            |s| &s.trace_id,
109            |store, batch| store.insert_spans(batch),
110        )
111    }
112
113    fn query_traces(&self, query: &TraceQuery) -> Result<Vec<Span>> {
114        let limit = query.limit.unwrap_or(100) as usize;
115        let mut all: Vec<Span> = self
116            .fan_out("query_traces", |s| s.query_traces(query))?
117            .into_iter()
118            .flatten()
119            .collect();
120        // Each shard returned newest-first; merge by re-sorting and re-limiting.
121        all.sort_by(|a, b| b.start_time.cmp(&a.start_time));
122        all.truncate(limit);
123        Ok(all)
124    }
125
126    fn get_trace(&self, trace_id: &str) -> Result<Vec<Span>> {
127        // A trace lives on its owning shard, but fan out (short-circuiting on a
128        // hit) so a routing-hash mismatch or a rebalancing window can't drop it.
129        let mut spans = self.shard_for(trace_id).get_trace(trace_id)?;
130        if spans.is_empty() {
131            for (i, shard) in self.shards.iter().enumerate() {
132                match shard.get_trace(trace_id) {
133                    Ok(s) if !s.is_empty() => {
134                        spans = s;
135                        break;
136                    }
137                    Ok(_) => {}
138                    Err(e) => tracing::warn!(shard = i, error = %e, "get_trace shard failed"),
139                }
140            }
141        }
142        // Dedup by span_id in case spans transiently overlap shards.
143        let mut seen = std::collections::HashSet::new();
144        spans.retain(|s| seen.insert(s.span_id.clone()));
145        spans.sort_by_key(|s| s.start_time);
146        Ok(spans)
147    }
148
149    fn list_services(&self) -> Result<Vec<ServiceInfo>> {
150        let per_shard = self.fan_out("list_services", |s| s.list_services())?;
151        Ok(merge_services(per_shard))
152    }
153
154    // ── Comments ── routed to the trace's owning shard ──────────────
155    fn add_comment(
156        &self,
157        trace_id: &str,
158        span_id: Option<&str>,
159        author: &str,
160        body: &str,
161    ) -> Result<TraceComment> {
162        self.shard_for(trace_id)
163            .add_comment(trace_id, span_id, author, body)
164    }
165
166    fn get_comments(&self, trace_id: &str) -> Result<Vec<TraceComment>> {
167        self.shard_for(trace_id).get_comments(trace_id)
168    }
169
170    // ── Logs ────────────────────────────────────────────────────────
171    fn insert_logs(&self, logs: &[LogRecord]) -> Result<()> {
172        // Logs carry trace_id and route the same way; orphan logs (no trace)
173        // shard by service so a service's logs stay co-located-ish.
174        route_and_insert(
175            &self.shards,
176            logs,
177            |l| l.trace_id.as_deref().unwrap_or(&l.service),
178            |store, batch| store.insert_logs(batch),
179        )
180    }
181
182    fn query_logs(&self, query: &LogQuery) -> Result<Vec<LogRecord>> {
183        let limit = query.limit.unwrap_or(100) as usize;
184        let mut all: Vec<LogRecord> = self
185            .fan_out("query_logs", |s| s.query_logs(query))?
186            .into_iter()
187            .flatten()
188            .collect();
189        all.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
190        all.truncate(limit);
191        Ok(all)
192    }
193
194    // ── Metrics ─────────────────────────────────────────────────────
195    fn insert_metrics(&self, metrics: &[MetricPoint]) -> Result<()> {
196        // Metrics carry no trace; shard by name (design §3) so a series stays on
197        // one shard and unique-name counts merge by simple sum.
198        route_and_insert(
199            &self.shards,
200            metrics,
201            |m| &m.name,
202            |store, batch| store.insert_metrics(batch),
203        )
204    }
205
206    fn query_metrics(&self, query: &MetricQuery) -> Result<Vec<MetricPoint>> {
207        let limit = query.limit.unwrap_or(500) as usize;
208        let mut all: Vec<MetricPoint> = self
209            .fan_out("query_metrics", |s| s.query_metrics(query))?
210            .into_iter()
211            .flatten()
212            .collect();
213        all.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
214        all.truncate(limit);
215        Ok(all)
216    }
217
218    // ── Cross-signal analytics ──────────────────────────────────────
219    fn query_summary(&self, last_seconds: i64, service: Option<&str>) -> Result<SummaryReport> {
220        let per_shard =
221            self.fan_out("query_summary", |s| s.query_summary(last_seconds, service))?;
222        Ok(merge_summaries(per_shard, last_seconds, service))
223    }
224
225    fn query_anomalies(
226        &self,
227        current_seconds: i64,
228        baseline_seconds: i64,
229        service: Option<&str>,
230    ) -> Result<AnomalyReport> {
231        let per_shard = self.fan_out("query_anomalies", |s| {
232            s.query_anomalies(current_seconds, baseline_seconds, service)
233        })?;
234        Ok(merge_anomalies(
235            per_shard,
236            current_seconds,
237            baseline_seconds,
238            service,
239        ))
240    }
241
242    fn query_correlate(&self, trace_id: &str) -> Result<Option<CorrelateReport>> {
243        // The trace is single-shard; try its owner, then fall back to a fan-out.
244        if let Some(r) = self.shard_for(trace_id).query_correlate(trace_id)? {
245            return Ok(Some(r));
246        }
247        for (i, shard) in self.shards.iter().enumerate() {
248            match shard.query_correlate(trace_id) {
249                Ok(Some(r)) => return Ok(Some(r)),
250                Ok(None) => {}
251                Err(e) => tracing::warn!(shard = i, error = %e, "correlate shard failed"),
252            }
253        }
254        Ok(None)
255    }
256
257    fn query_sql(&self, _sql: &str) -> Result<Vec<Value>> {
258        // Arbitrary SQL over per-shard DuckDB projections doesn't distribute
259        // soundly (cross-shard GROUP BY/aggregates). Design §3 recommendation
260        // (c): keep it a single-node power tool.
261        bail!(
262            "query_sql is not distributed across shards; run it directly against a single shard's /api/v1/sql"
263        );
264    }
265
266    // ── Lifecycle ───────────────────────────────────────────────────
267    fn health(&self) -> Result<()> {
268        // Ready if at least one shard answers — the node can still serve partial
269        // results, consistent with fan-out read tolerance. Not-ready only when
270        // every shard is unreachable.
271        let mut healthy = 0usize;
272        for (i, shard) in self.shards.iter().enumerate() {
273            match shard.health() {
274                Ok(()) => healthy += 1,
275                Err(e) => tracing::warn!(shard = i, error = %e, "shard unhealthy"),
276            }
277        }
278        if healthy == 0 {
279            bail!("no healthy shards ({} total)", self.shards.len());
280        }
281        Ok(())
282    }
283
284    fn flush(&self) -> Result<()> {
285        for shard in &self.shards {
286            shard.flush()?;
287        }
288        Ok(())
289    }
290}
291
292/// Group `items` by the shard their key hashes to, then hand each group to its
293/// owning shard in a single call.
294fn route_and_insert<T: Clone>(
295    shards: &[Arc<dyn Store>],
296    items: &[T],
297    key: impl Fn(&T) -> &str,
298    insert: impl Fn(&Arc<dyn Store>, &[T]) -> Result<()>,
299) -> Result<()> {
300    if items.is_empty() {
301        return Ok(());
302    }
303    let n = shards.len();
304    let mut buckets: HashMap<usize, Vec<T>> = HashMap::new();
305    for item in items {
306        let idx = shard_index(key(item), n);
307        buckets.entry(idx).or_default().push(item.clone());
308    }
309    for (idx, batch) in buckets {
310        insert(&shards[idx], &batch)?;
311    }
312    Ok(())
313}
314
315/// Sum disjoint per-shard service rollups. `span_count`/`trace_count` add
316/// directly (trace_ids are disjoint across shards); `avg_duration_ms` and
317/// `error_rate` are recomputed as span-count-weighted means.
318fn merge_services(per_shard: Vec<Vec<ServiceInfo>>) -> Vec<ServiceInfo> {
319    struct Acc {
320        span_count: i64,
321        trace_count: i64,
322        dur_weighted: f64,
323        err_weighted: f64,
324    }
325    let mut acc: HashMap<String, Acc> = HashMap::new();
326    for shard in per_shard {
327        for s in shard {
328            let e = acc.entry(s.name).or_insert(Acc {
329                span_count: 0,
330                trace_count: 0,
331                dur_weighted: 0.0,
332                err_weighted: 0.0,
333            });
334            e.span_count += s.span_count;
335            e.trace_count += s.trace_count;
336            e.dur_weighted += s.avg_duration_ms * s.span_count as f64;
337            e.err_weighted += s.error_rate * s.span_count as f64;
338        }
339    }
340    let mut out: Vec<ServiceInfo> = acc
341        .into_iter()
342        .map(|(name, a)| {
343            let w = a.span_count.max(1) as f64;
344            ServiceInfo {
345                name,
346                span_count: a.span_count,
347                trace_count: a.trace_count,
348                avg_duration_ms: a.dur_weighted / w,
349                error_rate: a.err_weighted / w,
350            }
351        })
352        .collect();
353    out.sort_by(|a, b| b.span_count.cmp(&a.span_count));
354    out
355}
356
357/// Span-count-weighted mean of a per-shard value.
358fn weighted_mean(values: impl Iterator<Item = (f64, i64)>) -> f64 {
359    let mut num = 0.0;
360    let mut den = 0i64;
361    for (v, w) in values {
362        num += v * w as f64;
363        den += w;
364    }
365    if den == 0 { 0.0 } else { num / den as f64 }
366}
367
368/// Merge per-shard summaries. Counts sum; rates and averages recompute from
369/// component sums. Percentiles are span-count-weighted approximations — exact
370/// cross-shard percentiles need a mergeable sketch (t-digest), tracked as
371/// future work; they are an estimate, not a true global quantile.
372fn merge_summaries(
373    per_shard: Vec<SummaryReport>,
374    window_seconds: i64,
375    service: Option<&str>,
376) -> SummaryReport {
377    let mut traces = TraceSummary::default();
378    let mut logs = LogSummary::default();
379    let mut metrics = MetricSummary::default();
380
381    // Re-aggregate top_services and top_error_operations across shards.
382    struct SvcAcc {
383        span_count: i64,
384        err_weighted: f64,
385        p95_weighted: f64,
386    }
387    let mut svc: HashMap<String, SvcAcc> = HashMap::new();
388    let mut errops: HashMap<(String, String), i64> = HashMap::new();
389
390    // Collect (value, weight) pairs for the weighted trace fields.
391    let mut avg_pairs = Vec::new();
392    let mut p50_pairs = Vec::new();
393    let mut p95_pairs = Vec::new();
394    let mut p99_pairs = Vec::new();
395
396    for r in &per_shard {
397        let t = &r.traces;
398        traces.span_count += t.span_count;
399        traces.trace_count += t.trace_count;
400        traces.error_count += t.error_count;
401        traces.max_ms = traces.max_ms.max(t.max_ms);
402        avg_pairs.push((t.avg_ms, t.span_count));
403        p50_pairs.push((t.p50_ms, t.span_count));
404        p95_pairs.push((t.p95_ms, t.span_count));
405        p99_pairs.push((t.p99_ms, t.span_count));
406
407        logs.total += r.logs.total;
408        logs.error += r.logs.error;
409        logs.warn += r.logs.warn;
410        logs.info += r.logs.info;
411        logs.debug += r.logs.debug;
412
413        metrics.point_count += r.metrics.point_count;
414        // Metrics shard by name → names disjoint → unique counts sum.
415        metrics.unique_names += r.metrics.unique_names;
416
417        for s in &r.top_services {
418            let e = svc.entry(s.service.clone()).or_insert(SvcAcc {
419                span_count: 0,
420                err_weighted: 0.0,
421                p95_weighted: 0.0,
422            });
423            e.span_count += s.span_count;
424            e.err_weighted += s.error_rate * s.span_count as f64;
425            e.p95_weighted += s.p95_ms * s.span_count as f64;
426        }
427        for o in &r.top_error_operations {
428            *errops
429                .entry((o.service.clone(), o.operation.clone()))
430                .or_insert(0) += o.error_count;
431        }
432    }
433
434    traces.avg_ms = weighted_mean(avg_pairs.into_iter());
435    traces.p50_ms = weighted_mean(p50_pairs.into_iter());
436    traces.p95_ms = weighted_mean(p95_pairs.into_iter());
437    traces.p99_ms = weighted_mean(p99_pairs.into_iter());
438    traces.error_rate = if traces.span_count > 0 {
439        traces.error_count as f64 / traces.span_count as f64
440    } else {
441        0.0
442    };
443
444    let mut top_services: Vec<ServiceSummary> = svc
445        .into_iter()
446        .map(|(service, a)| {
447            let w = a.span_count.max(1) as f64;
448            ServiceSummary {
449                service,
450                span_count: a.span_count,
451                error_rate: a.err_weighted / w,
452                p95_ms: a.p95_weighted / w,
453            }
454        })
455        .collect();
456    top_services.sort_by(|a, b| b.span_count.cmp(&a.span_count));
457    top_services.truncate(10);
458
459    let mut top_error_operations: Vec<ErrorOperation> = errops
460        .into_iter()
461        .map(|((service, operation), error_count)| ErrorOperation {
462            service,
463            operation,
464            error_count,
465        })
466        .collect();
467    top_error_operations.sort_by(|a, b| b.error_count.cmp(&a.error_count));
468    top_error_operations.truncate(10);
469
470    SummaryReport {
471        window_seconds,
472        service_filter: service.map(str::to_string),
473        traces,
474        top_services,
475        top_error_operations,
476        logs,
477        metrics,
478    }
479}
480
481/// Merge per-shard anomaly reports. Because trace_id sharding spreads a single
482/// service's traffic across all shards, each shard sees only a slice, so these
483/// are merged best-effort: anomalies for the same (service, kind) are collapsed
484/// to the one with the largest |delta| (most significant signal). This is an
485/// approximation — a precise version would recompute current/baseline from raw
486/// per-shard partials.
487fn merge_anomalies(
488    per_shard: Vec<AnomalyReport>,
489    current_seconds: i64,
490    baseline_seconds: i64,
491    service: Option<&str>,
492) -> AnomalyReport {
493    let mut best: HashMap<(String, String), Anomaly> = HashMap::new();
494    for r in per_shard {
495        for a in r.anomalies {
496            let key = (a.service.clone(), a.kind.clone());
497            match best.get(&key) {
498                Some(existing) if existing.delta.abs() >= a.delta.abs() => {}
499                _ => {
500                    best.insert(key, a);
501                }
502            }
503        }
504    }
505    let mut anomalies: Vec<Anomaly> = best.into_values().collect();
506    anomalies.sort_by(|a, b| {
507        b.delta
508            .abs()
509            .partial_cmp(&a.delta.abs())
510            .unwrap_or(std::cmp::Ordering::Equal)
511    });
512    AnomalyReport {
513        current_seconds,
514        baseline_seconds,
515        service_filter: service.map(str::to_string),
516        anomalies,
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use crate::storage::models::{SpanKind, SpanStatus};
524    use chrono::{TimeZone, Utc};
525    use std::collections::HashMap;
526    use std::sync::Mutex;
527
528    /// Minimal in-memory `Store` for exercising fan-out routing and merge logic
529    /// without a network. Implements the span surface the tests touch; the rest
530    /// return empty/default.
531    #[derive(Default)]
532    struct MockStore {
533        spans: Mutex<Vec<Span>>,
534    }
535
536    impl Store for MockStore {
537        fn insert_spans(&self, spans: &[Span]) -> Result<()> {
538            self.spans.lock().unwrap().extend_from_slice(spans);
539            Ok(())
540        }
541        fn query_traces(&self, query: &TraceQuery) -> Result<Vec<Span>> {
542            let mut v = self.spans.lock().unwrap().clone();
543            v.sort_by(|a, b| b.start_time.cmp(&a.start_time)); // newest-first contract
544            v.truncate(query.limit.unwrap_or(100) as usize);
545            Ok(v)
546        }
547        fn get_trace(&self, trace_id: &str) -> Result<Vec<Span>> {
548            Ok(self
549                .spans
550                .lock()
551                .unwrap()
552                .iter()
553                .filter(|s| s.trace_id == trace_id)
554                .cloned()
555                .collect())
556        }
557        fn list_services(&self) -> Result<Vec<ServiceInfo>> {
558            let spans = self.spans.lock().unwrap();
559            let mut by: HashMap<String, (i64, f64, i64)> = HashMap::new();
560            for s in spans.iter() {
561                let e = by.entry(s.service.clone()).or_insert((0, 0.0, 0));
562                e.0 += 1;
563                e.1 += s.duration_ms;
564                if s.status == SpanStatus::Error {
565                    e.2 += 1;
566                }
567            }
568            Ok(by
569                .into_iter()
570                .map(|(name, (n, dur, err))| ServiceInfo {
571                    name,
572                    span_count: n,
573                    trace_count: n,
574                    avg_duration_ms: dur / n as f64,
575                    error_rate: err as f64 / n as f64,
576                })
577                .collect())
578        }
579        fn add_comment(
580            &self,
581            _t: &str,
582            _s: Option<&str>,
583            _a: &str,
584            _b: &str,
585        ) -> Result<TraceComment> {
586            bail!("unused")
587        }
588        fn get_comments(&self, _t: &str) -> Result<Vec<TraceComment>> {
589            Ok(vec![])
590        }
591        fn insert_logs(&self, _l: &[LogRecord]) -> Result<()> {
592            Ok(())
593        }
594        fn query_logs(&self, _q: &LogQuery) -> Result<Vec<LogRecord>> {
595            Ok(vec![])
596        }
597        fn insert_metrics(&self, _m: &[MetricPoint]) -> Result<()> {
598            Ok(())
599        }
600        fn query_metrics(&self, _q: &MetricQuery) -> Result<Vec<MetricPoint>> {
601            Ok(vec![])
602        }
603        fn query_summary(&self, _l: i64, _s: Option<&str>) -> Result<SummaryReport> {
604            bail!("unused")
605        }
606        fn query_anomalies(&self, _c: i64, _b: i64, _s: Option<&str>) -> Result<AnomalyReport> {
607            bail!("unused")
608        }
609        fn query_correlate(&self, _t: &str) -> Result<Option<CorrelateReport>> {
610            Ok(None)
611        }
612        fn query_sql(&self, _s: &str) -> Result<Vec<Value>> {
613            bail!("unused")
614        }
615    }
616
617    fn span(trace: &str, sid: &str, svc: &str, secs: i64, status: SpanStatus) -> Span {
618        let t = Utc.timestamp_opt(secs, 0).unwrap();
619        Span {
620            trace_id: trace.into(),
621            span_id: sid.into(),
622            parent_span_id: None,
623            service: svc.into(),
624            operation: "op".into(),
625            start_time: t,
626            end_time: t,
627            duration_ms: 10.0,
628            status,
629            attributes: HashMap::new(),
630            events: vec![],
631            kind: SpanKind::Internal,
632            llm: None,
633        }
634    }
635
636    fn fanout(n: usize) -> (FanoutStore, Vec<Arc<MockStore>>) {
637        let mocks: Vec<Arc<MockStore>> = (0..n).map(|_| Arc::new(MockStore::default())).collect();
638        let shards: Vec<Arc<dyn Store>> =
639            mocks.iter().map(|m| m.clone() as Arc<dyn Store>).collect();
640        (FanoutStore::new(shards).unwrap(), mocks)
641    }
642
643    #[test]
644    fn new_rejects_empty_shards() {
645        assert!(FanoutStore::new(vec![]).is_err());
646    }
647
648    #[test]
649    fn shard_index_is_deterministic_and_bounded() {
650        for key in ["trace-abc", "trace-def", "svc-1", ""] {
651            let a = shard_index(key, 4);
652            let b = shard_index(key, 4);
653            assert_eq!(a, b, "hashing must be stable across calls");
654            assert!(a < 4);
655        }
656    }
657
658    #[test]
659    fn insert_spans_routes_whole_trace_to_one_shard() {
660        let (fo, mocks) = fanout(3);
661        // Two traces, each with two spans, in one mixed batch.
662        fo.insert_spans(&[
663            span("t1", "a", "api", 1, SpanStatus::Ok),
664            span("t2", "b", "api", 2, SpanStatus::Ok),
665            span("t1", "c", "db", 3, SpanStatus::Ok),
666            span("t2", "d", "db", 4, SpanStatus::Ok),
667        ])
668        .unwrap();
669
670        // Every span of a trace lands on exactly one shard (the trace's owner).
671        for tid in ["t1", "t2"] {
672            let owners: Vec<usize> = mocks
673                .iter()
674                .enumerate()
675                .filter(|(_, m)| m.spans.lock().unwrap().iter().any(|s| s.trace_id == tid))
676                .map(|(i, _)| i)
677                .collect();
678            assert_eq!(
679                owners.len(),
680                1,
681                "trace {tid} split across shards: {owners:?}"
682            );
683            let owned = mocks[owners[0]].spans.lock().unwrap();
684            assert_eq!(owned.iter().filter(|s| s.trace_id == tid).count(), 2);
685        }
686    }
687
688    #[test]
689    fn query_traces_merges_newest_first_and_re_limits() {
690        let (fo, _m) = fanout(3);
691        // Insert 6 spans across distinct traces at increasing timestamps.
692        for i in 0..6 {
693            fo.insert_spans(&[span(&format!("t{i}"), "s", "api", i, SpanStatus::Ok)])
694                .unwrap();
695        }
696        let q = TraceQuery {
697            limit: Some(3),
698            ..Default::default()
699        };
700        let got = fo.query_traces(&q).unwrap();
701        assert_eq!(got.len(), 3, "must re-limit after gathering");
702        // Newest-first global order: t5, t4, t3.
703        let starts: Vec<i64> = got.iter().map(|s| s.start_time.timestamp()).collect();
704        assert_eq!(starts, vec![5, 4, 3]);
705    }
706
707    #[test]
708    fn get_trace_finds_trace_on_its_owning_shard() {
709        let (fo, _m) = fanout(4);
710        fo.insert_spans(&[
711            span("trace-x", "1", "api", 1, SpanStatus::Ok),
712            span("trace-x", "2", "db", 2, SpanStatus::Ok),
713        ])
714        .unwrap();
715        let spans = fo.get_trace("trace-x").unwrap();
716        assert_eq!(spans.len(), 2);
717        assert!(spans.iter().all(|s| s.trace_id == "trace-x"));
718        assert!(fo.get_trace("nope").unwrap().is_empty());
719    }
720
721    #[test]
722    fn list_services_aggregates_across_shards() {
723        let (fo, _m) = fanout(3);
724        // 3 "api" spans (1 error) spread by trace_id across shards.
725        fo.insert_spans(&[
726            span("ta", "1", "api", 1, SpanStatus::Ok),
727            span("tb", "2", "api", 2, SpanStatus::Ok),
728            span("tc", "3", "api", 3, SpanStatus::Error),
729        ])
730        .unwrap();
731        let svcs = fo.list_services().unwrap();
732        let api = svcs.iter().find(|s| s.name == "api").unwrap();
733        assert_eq!(api.span_count, 3, "counts must sum across shards");
734        assert_eq!(api.trace_count, 3);
735        assert!(
736            (api.error_rate - 1.0 / 3.0).abs() < 1e-9,
737            "error_rate recomputed from sums"
738        );
739        assert!((api.avg_duration_ms - 10.0).abs() < 1e-9);
740    }
741
742    #[test]
743    fn query_sql_is_not_distributed() {
744        let (fo, _m) = fanout(2);
745        assert!(fo.query_sql("SELECT 1").is_err());
746    }
747
748    #[test]
749    fn merge_services_sums_counts_and_weights_rates() {
750        let shards = vec![
751            vec![ServiceInfo {
752                name: "api".into(),
753                span_count: 10,
754                trace_count: 4,
755                avg_duration_ms: 20.0,
756                error_rate: 0.1,
757            }],
758            vec![ServiceInfo {
759                name: "api".into(),
760                span_count: 30,
761                trace_count: 6,
762                avg_duration_ms: 40.0,
763                error_rate: 0.5,
764            }],
765        ];
766        let merged = merge_services(shards);
767        let api = &merged[0];
768        assert_eq!(api.span_count, 40);
769        assert_eq!(api.trace_count, 10);
770        // (20*10 + 40*30) / 40 = 35
771        assert!((api.avg_duration_ms - 35.0).abs() < 1e-9);
772        // (0.1*10 + 0.5*30) / 40 = 0.4
773        assert!((api.error_rate - 0.4).abs() < 1e-9);
774    }
775}