Skip to main content

rustrade/
services.rs

1//! Optional framework-side services wired in via builder methods on
2//! [`Bot`](crate::Bot):
3//!
4//! - [`MarketFeedService`] — `Bot::with_market_source(...)`. Drives a
5//!   [`MarketSource`] under supervisor control; the source publishes
6//!   events to the in-process `MarketDataBus` (the bus reference is the
7//!   source implementor's responsibility — typically obtained via
8//!   `bot.market_data_bus().clone()` before construction).
9//! - [`FillRoutingService`] — `Bot::with_fill_source(...)`. Polls a
10//!   [`FillSource`], calls [`Brain::on_fill`] on each brain, refreshes
11//!   the per-symbol position cache from the exchange, and auto-feeds
12//!   realised PnL into the risk gates using weighted-average entry
13//!   accounting.
14//! - [`CandlePollerService`] — `Bot::with_candle_poller(...)`. Periodic
15//!   poll of a [`CandleSource`]; publishes the newest closed candle for
16//!   each `(symbol, interval)` pair to the market-data bus.
17
18use std::sync::Arc;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use rustrade_core::{
24    Brain, CandleSource, Exchange, ExchangeClient, Fill, FillSource, MarketDataBus,
25    MarketDataEvent, MarketSource, MetricsSink, Side, Symbol,
26};
27use rustrade_supervisor::{RestartPolicy, TradingService};
28use tokio_util::sync::CancellationToken;
29
30use crate::risk_state::{PositionCache, RiskPersister, RiskStateMap};
31
32// ───────────────────────────────────────────────────────────────────────
33// MarketFeedService
34// ───────────────────────────────────────────────────────────────────────
35
36/// Drives a [`MarketSource`] under supervisor control.
37///
38/// The wrapper does not interact with the bus directly — the source's
39/// `run` method is expected to publish events to whatever bus it was
40/// constructed with. This service just makes the source restartable and
41/// drop-safe under the supervisor's cancellation contract.
42pub struct MarketFeedService {
43    name: String,
44    source: Arc<dyn MarketSource>,
45}
46
47impl MarketFeedService {
48    /// Wrap a [`MarketSource`] into a [`TradingService`].
49    pub fn new(source: Arc<dyn MarketSource>) -> Self {
50        let name = format!("market-feed[{}]", source.name());
51        Self { name, source }
52    }
53}
54
55#[async_trait]
56impl TradingService for MarketFeedService {
57    fn name(&self) -> &str {
58        &self.name
59    }
60
61    fn restart_policy(&self) -> RestartPolicy {
62        RestartPolicy::OnFailure
63    }
64
65    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
66        tracing::info!(service = %self.name, "market feed starting");
67        tokio::select! {
68            _ = cancel.cancelled() => {
69                tracing::info!(service = %self.name, "market feed cancelled");
70                Ok(())
71            }
72            r = self.source.run() => {
73                match &r {
74                    Ok(()) => tracing::info!(service = %self.name, "market feed exited cleanly"),
75                    Err(e) => tracing::warn!(service = %self.name, error = %e, "market feed exited with error"),
76                }
77                r.map_err(|e| anyhow::anyhow!("market source error: {e}"))
78            }
79        }
80    }
81}
82
83// ───────────────────────────────────────────────────────────────────────
84// FillRoutingService
85// ───────────────────────────────────────────────────────────────────────
86
87/// Routes fills from a [`FillSource`] to every brain, refreshes the
88/// position cache, and auto-feeds realised PnL into the risk state.
89///
90/// # PnL accounting
91///
92/// The service uses a **weighted-average entry** model (the same model
93/// the backtest engine uses). It reads the cached `Position` *before*
94/// refreshing it from the exchange, so the `entry_price` available is
95/// the pre-fill average. From that:
96///
97/// - A fill in the same direction as the open position **adds** to it.
98///   No realised PnL emitted; the post-refresh average from
99///   `exchange.get_position` becomes the new entry.
100/// - A fill in the opposite direction **reduces** the position. Gross
101///   PnL = `(fill_price - entry) * closed_qty * direction`. The
102///   service calls `BotHandle::record_trade_outcome` on the closed
103///   portion to feed `SessionPnl` + `CircuitBreaker`.
104/// - A fill that **flips** the position emits realised PnL for the
105///   closed portion only; the opening leg is left for the next
106///   reducing fill.
107///
108/// Fees come from `Fill.fee`. Hosts that need a different accounting
109/// model (FIFO, LIFO, tax-lot) should compute PnL themselves and call
110/// `BotHandle::record_trade_outcome` directly — but cannot also wire a
111/// `FillRoutingService`, since the two would double-count.
112pub struct FillRoutingService {
113    source: Arc<dyn FillSource>,
114    brains: Arc<Vec<Arc<dyn Brain>>>,
115    exchange: Arc<dyn ExchangeClient>,
116    positions: PositionCache,
117    risk: RiskStateMap,
118    metrics: Arc<dyn MetricsSink>,
119    persister: Option<RiskPersister>,
120    fills_routed: AtomicU64,
121    refresh_errors: AtomicU64,
122    trades_recorded: AtomicU64,
123}
124
125impl FillRoutingService {
126    pub(crate) fn new(
127        source: Arc<dyn FillSource>,
128        brains: Arc<Vec<Arc<dyn Brain>>>,
129        exchange: Arc<dyn ExchangeClient>,
130        positions: PositionCache,
131        risk: RiskStateMap,
132        metrics: Arc<dyn MetricsSink>,
133        persister: Option<RiskPersister>,
134    ) -> Self {
135        Self {
136            source,
137            brains,
138            exchange,
139            positions,
140            risk,
141            metrics,
142            persister,
143            fills_routed: AtomicU64::new(0),
144            refresh_errors: AtomicU64::new(0),
145            trades_recorded: AtomicU64::new(0),
146        }
147    }
148
149    /// Total fills delivered to brains since service start.
150    pub fn fills_routed(&self) -> u64 {
151        self.fills_routed.load(Ordering::Relaxed)
152    }
153
154    /// Total `exchange.get_position` failures during cache refresh.
155    pub fn refresh_errors(&self) -> u64 {
156        self.refresh_errors.load(Ordering::Relaxed)
157    }
158
159    /// Total realised-PnL closures fed into the risk state.
160    pub fn trades_recorded(&self) -> u64 {
161        self.trades_recorded.load(Ordering::Relaxed)
162    }
163
164    /// Compute realised PnL from a reducing fill and feed the risk state.
165    /// Returns the gross PnL portion attributable to this fill.
166    async fn maybe_record_pnl(&self, fill: &Fill, prior_qty: f64, prior_entry: Option<f64>) {
167        // Only reducing or flipping fills produce realised PnL.
168        let signed_fill_qty = match fill.side {
169            Side::Buy => fill.size.value(),
170            Side::Sell => -fill.size.value(),
171        };
172        if prior_qty == 0.0 || prior_qty.signum() == signed_fill_qty.signum() {
173            return;
174        }
175        let Some(entry) = prior_entry else {
176            // Reducing fill but no entry price recorded — can't compute
177            // PnL. Log and skip.
178            tracing::debug!(
179                symbol = %fill.symbol,
180                "reducing fill but cached position has no entry price; skipping auto-PnL"
181            );
182            return;
183        };
184        let closed_qty = prior_qty.abs().min(fill.size.value());
185        if closed_qty <= 0.0 {
186            return;
187        }
188        let direction = prior_qty.signum();
189        let gross = (fill.price.value() - entry) * direction * closed_qty;
190        // Apportion fee by closing fraction so a flip fill charges
191        // fees pro-rata to the closing portion.
192        let fee_share = if fill.size.value() > 0.0 {
193            fill.fee * (closed_qty / fill.size.value())
194        } else {
195            0.0
196        };
197
198        // Update the per-symbol risk state directly.
199        let recorded = {
200            let mut map = self.risk.write().await;
201            if let Some(risk) = map.get_mut(&fill.symbol) {
202                risk.session_pnl.record_close(gross, fee_share);
203                let net = gross - fee_share;
204                if net > 0.0 {
205                    risk.circuit_breaker.record_win();
206                } else if net < 0.0 {
207                    risk.circuit_breaker.record_loss();
208                }
209                self.trades_recorded.fetch_add(1, Ordering::Relaxed);
210                self.metrics.histogram(
211                    "rustrade_realised_pnl_quote",
212                    &[("symbol", fill.symbol.as_str())],
213                    net,
214                );
215                true
216            } else {
217                tracing::debug!(
218                    symbol = %fill.symbol,
219                    "auto-PnL: symbol not in risk-state map (was it configured?)"
220                );
221                false
222            }
223        };
224
225        // Persist the updated risk state (lock released) if a store is wired.
226        if recorded && let Some(persister) = &self.persister {
227            persister.persist_symbol(&self.risk, &fill.symbol).await;
228        }
229    }
230}
231
232#[async_trait]
233impl TradingService for FillRoutingService {
234    fn name(&self) -> &str {
235        "fill-routing"
236    }
237
238    fn restart_policy(&self) -> RestartPolicy {
239        RestartPolicy::OnFailure
240    }
241
242    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
243        tracing::info!("fill-routing service starting");
244        loop {
245            tokio::select! {
246                _ = cancel.cancelled() => {
247                    tracing::info!(
248                        routed = self.fills_routed(),
249                        refresh_errors = self.refresh_errors(),
250                        trades_recorded = self.trades_recorded(),
251                        "fill-routing service shutting down"
252                    );
253                    return Ok(());
254                }
255                next = self.source.next_fill() => {
256                    let Some(fill) = next else {
257                        tracing::info!("fill source closed; exiting");
258                        return Ok(());
259                    };
260
261                    let symbol = fill.symbol.clone();
262
263                    // Snapshot the pre-fill position so we can compute
264                    // realised PnL before the exchange refreshes the
265                    // entry price.
266                    let (prior_qty, prior_entry) = {
267                        let map = self.positions.read().await;
268                        let p = map.get(&symbol).copied().unwrap_or(rustrade_core::Position::FLAT);
269                        (p.qty, p.entry_price)
270                    };
271
272                    // Route to every brain. Errors are logged but don't
273                    // stop the service — the brain's on_fill is
274                    // informational by contract.
275                    for brain in self.brains.iter() {
276                        if let Err(e) = brain.on_fill(&fill).await {
277                            tracing::warn!(
278                                brain = brain.name(),
279                                error = %e,
280                                "brain on_fill returned error"
281                            );
282                        }
283                    }
284
285                    self.maybe_record_pnl(&fill, prior_qty, prior_entry).await;
286
287                    // Refresh position cache from the exchange.
288                    match self.exchange.get_position(&symbol).await {
289                        Ok(p) => {
290                            self.positions.write().await.insert(symbol.clone(), p);
291                            tracing::debug!(symbol = %symbol, qty = p.qty, "refreshed position");
292                        }
293                        Err(e) => {
294                            self.refresh_errors.fetch_add(1, Ordering::Relaxed);
295                            self.metrics.inc("rustrade_position_refresh_errors_total");
296                            tracing::warn!(
297                                symbol = %symbol,
298                                error = %e,
299                                "failed to refresh position after fill"
300                            );
301                        }
302                    }
303
304                    self.fills_routed.fetch_add(1, Ordering::Relaxed);
305                    self.metrics.counter(
306                        "rustrade_fills_routed_total",
307                        &[("symbol", symbol.as_str())],
308                        1,
309                    );
310                }
311            }
312        }
313    }
314}
315
316// ───────────────────────────────────────────────────────────────────────
317// CandlePollerService
318// ───────────────────────────────────────────────────────────────────────
319
320/// Periodic poll of a [`CandleSource`] for a single `(symbol, interval)`
321/// pair. Publishes each newly-closed candle to the
322/// [`MarketDataBus`].
323///
324/// Per-symbol cadences are achieved by spawning multiple services —
325/// `Bot::with_candle_poller(...)` accepts repeated calls and spawns one
326/// service per registered tuple.
327///
328/// # Deduplication
329///
330/// The service tracks the highest `Candle::time` it has already
331/// published; only candles with a strictly greater timestamp are
332/// re-published. This is robust against exchanges that return overlapping
333/// windows on consecutive polls.
334pub struct CandlePollerService {
335    name: String,
336    source: Arc<dyn CandleSource>,
337    symbol: Symbol,
338    interval: Duration,
339    poll_cadence: Duration,
340    limit: usize,
341    bus: MarketDataBus,
342    metrics: Arc<dyn MetricsSink>,
343    last_time: std::sync::Mutex<i64>,
344    polled: AtomicU64,
345    poll_errors: AtomicU64,
346    published: AtomicU64,
347}
348
349impl CandlePollerService {
350    pub(crate) fn new(
351        source: Arc<dyn CandleSource>,
352        symbol: Symbol,
353        interval: Duration,
354        poll_cadence: Duration,
355        limit: usize,
356        bus: MarketDataBus,
357        metrics: Arc<dyn MetricsSink>,
358    ) -> Self {
359        let name = format!("candle-poller[{}@{}s]", symbol.as_str(), interval.as_secs());
360        Self {
361            name,
362            source,
363            symbol,
364            interval,
365            poll_cadence,
366            limit,
367            bus,
368            metrics,
369            last_time: std::sync::Mutex::new(i64::MIN),
370            polled: AtomicU64::new(0),
371            poll_errors: AtomicU64::new(0),
372            published: AtomicU64::new(0),
373        }
374    }
375
376    /// Total successful polls.
377    pub fn polled(&self) -> u64 {
378        self.polled.load(Ordering::Relaxed)
379    }
380    /// Total failed polls.
381    pub fn poll_errors(&self) -> u64 {
382        self.poll_errors.load(Ordering::Relaxed)
383    }
384    /// Total candles published (deduplicated).
385    pub fn published(&self) -> u64 {
386        self.published.load(Ordering::Relaxed)
387    }
388}
389
390#[async_trait]
391impl TradingService for CandlePollerService {
392    fn name(&self) -> &str {
393        &self.name
394    }
395
396    fn restart_policy(&self) -> RestartPolicy {
397        RestartPolicy::OnFailure
398    }
399
400    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
401        tracing::info!(service = %self.name, "candle poller starting");
402        let exchange = Exchange::from(self.source.name());
403
404        loop {
405            tokio::select! {
406                _ = cancel.cancelled() => {
407                    tracing::info!(
408                        service = %self.name,
409                        polled = self.polled(),
410                        published = self.published(),
411                        errors = self.poll_errors(),
412                        "candle poller shutting down"
413                    );
414                    return Ok(());
415                }
416                _ = tokio::time::sleep(self.poll_cadence) => {
417                    match self.source.poll(&self.symbol, self.interval, self.limit).await {
418                        Ok(candles) => {
419                            self.polled.fetch_add(1, Ordering::Relaxed);
420                            let mut last = self.last_time.lock().expect("last_time poisoned");
421                            let mut new_high = *last;
422                            for candle in candles {
423                                if candle.time <= *last {
424                                    continue;
425                                }
426                                new_high = new_high.max(candle.time);
427                                self.bus.publish(MarketDataEvent::Candle {
428                                    exchange: exchange.clone(),
429                                    symbol: self.symbol.clone(),
430                                    candle,
431                                });
432                                self.published.fetch_add(1, Ordering::Relaxed);
433                                self.metrics.counter(
434                                    "rustrade_candles_published_total",
435                                    &[("symbol", self.symbol.as_str())],
436                                    1,
437                                );
438                            }
439                            *last = new_high;
440                        }
441                        Err(e) => {
442                            self.poll_errors.fetch_add(1, Ordering::Relaxed);
443                            self.metrics.inc("rustrade_candle_poll_errors_total");
444                            tracing::warn!(
445                                service = %self.name,
446                                error = %e,
447                                "candle poll failed"
448                            );
449                        }
450                    }
451                }
452            }
453        }
454    }
455}