1use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::time::Duration;
26
27use async_trait::async_trait;
28use chrono::{DateTime, Utc};
29use rustrade_core::{ExchangeClient, MetricsSink, Order, OrderKind, Symbol};
30use rustrade_supervisor::{RestartPolicy, TradingService};
31use tokio::sync::RwLock;
32use tokio_util::sync::CancellationToken;
33
34#[derive(Debug, Clone, PartialEq)]
36pub struct TrackedOrder {
37 pub order_id: String,
39 pub symbol: Symbol,
41 pub placed_at: DateTime<Utc>,
43}
44
45#[derive(Clone, Default)]
50pub struct OrderTracker {
51 inner: Arc<RwLock<HashMap<String, TrackedOrder>>>,
52}
53
54impl OrderTracker {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub(crate) async fn record(&self, order_id: String, order: &Order) {
65 if matches!(order.kind, OrderKind::Market) {
66 return;
67 }
68 self.inner.write().await.insert(
69 order_id.clone(),
70 TrackedOrder {
71 order_id,
72 symbol: order.symbol.clone(),
73 placed_at: Utc::now(),
74 },
75 );
76 }
77
78 pub(crate) async fn forget(&self, order_id: &str) {
80 self.inner.write().await.remove(order_id);
81 }
82
83 pub async fn len(&self) -> usize {
85 self.inner.read().await.len()
86 }
87
88 pub async fn is_empty(&self) -> bool {
90 self.inner.read().await.is_empty()
91 }
92
93 pub async fn snapshot(&self) -> Vec<TrackedOrder> {
95 self.inner.read().await.values().cloned().collect()
96 }
97}
98
99pub struct OrderReaperService {
106 exchange: Arc<dyn ExchangeClient>,
107 tracker: OrderTracker,
108 symbols: Vec<Symbol>,
109 ttl: Duration,
110 poll_cadence: Duration,
111 metrics: Arc<dyn MetricsSink>,
112 cancelled: AtomicU64,
113 reconciled: AtomicU64,
114 sweeps: AtomicU64,
115}
116
117impl OrderReaperService {
118 pub(crate) fn new(
119 exchange: Arc<dyn ExchangeClient>,
120 tracker: OrderTracker,
121 symbols: Vec<Symbol>,
122 ttl: Duration,
123 poll_cadence: Duration,
124 metrics: Arc<dyn MetricsSink>,
125 ) -> Self {
126 Self {
127 exchange,
128 tracker,
129 symbols,
130 ttl,
131 poll_cadence,
132 metrics,
133 cancelled: AtomicU64::new(0),
134 reconciled: AtomicU64::new(0),
135 sweeps: AtomicU64::new(0),
136 }
137 }
138
139 pub fn cancelled(&self) -> u64 {
141 self.cancelled.load(Ordering::Relaxed)
142 }
143 pub fn reconciled(&self) -> u64 {
146 self.reconciled.load(Ordering::Relaxed)
147 }
148 pub fn sweeps(&self) -> u64 {
150 self.sweeps.load(Ordering::Relaxed)
151 }
152
153 pub(crate) async fn sweep_once(&self) {
156 self.sweeps.fetch_add(1, Ordering::Relaxed);
157 let now = Utc::now();
158
159 for symbol in &self.symbols {
160 let open = match self.exchange.get_open_orders(symbol).await {
161 Ok(o) => o,
162 Err(e) => {
163 tracing::warn!(symbol = %symbol, error = %e, "get_open_orders failed; skipping sweep for symbol");
164 continue;
165 }
166 };
167 let live: HashMap<&str, &rustrade_core::OpenOrder> =
169 open.iter().map(|o| (o.order_id.as_str(), o)).collect();
170
171 let tracked: Vec<TrackedOrder> = self
174 .tracker
175 .snapshot()
176 .await
177 .into_iter()
178 .filter(|t| &t.symbol == symbol)
179 .collect();
180
181 for t in tracked {
182 match live.get(t.order_id.as_str()) {
183 None => {
184 self.tracker.forget(&t.order_id).await;
187 self.reconciled.fetch_add(1, Ordering::Relaxed);
188 tracing::debug!(symbol = %symbol, order_id = %t.order_id, "reconciled away (no longer open)");
189 }
190 Some(oo) => {
191 let age_from = oo.created_at.unwrap_or(t.placed_at);
195 let age = now.signed_duration_since(age_from);
196 if age.num_milliseconds().max(0) as u128 >= self.ttl.as_millis() {
197 match self.exchange.cancel_order(symbol, &t.order_id).await {
198 Ok(_) => {
199 self.tracker.forget(&t.order_id).await;
200 self.cancelled.fetch_add(1, Ordering::Relaxed);
201 self.metrics.counter(
202 "rustrade_orders_cancelled_ttl_total",
203 &[("symbol", symbol.as_str())],
204 1,
205 );
206 tracing::info!(symbol = %symbol, order_id = %t.order_id, ttl_secs = self.ttl.as_secs(), "cancelled stale resting order (TTL)");
207 }
208 Err(e) => {
209 tracing::warn!(symbol = %symbol, order_id = %t.order_id, error = %e, "TTL cancel failed; will retry next sweep")
210 }
211 }
212 }
213 }
214 }
215 }
216 }
217 }
218}
219
220#[async_trait]
221impl TradingService for OrderReaperService {
222 fn name(&self) -> &str {
223 "order-reaper"
224 }
225
226 fn restart_policy(&self) -> RestartPolicy {
227 RestartPolicy::OnFailure
228 }
229
230 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
231 tracing::info!(
232 ttl_secs = self.ttl.as_secs(),
233 cadence_secs = self.poll_cadence.as_secs(),
234 symbols = self.symbols.len(),
235 "order-reaper starting"
236 );
237 loop {
238 tokio::select! {
239 _ = cancel.cancelled() => {
240 tracing::info!(
241 sweeps = self.sweeps(),
242 cancelled = self.cancelled(),
243 reconciled = self.reconciled(),
244 "order-reaper shutting down"
245 );
246 return Ok(());
247 }
248 _ = tokio::time::sleep(self.poll_cadence) => {
249 self.sweep_once().await;
250 }
251 }
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use rustrade_core::{Capability, NoopSink, Position, Price, Result, Side, Volume};
260
261 fn limit(symbol: &str) -> Order {
262 Order::limit(symbol, Side::Buy, Volume(1.0), Price(100.0))
263 }
264
265 #[tokio::test]
266 async fn tracker_ignores_market_orders() {
267 let t = OrderTracker::new();
268 t.record(
269 "m1".into(),
270 &Order::market("BTCUSDT", Side::Buy, Volume(1.0)),
271 )
272 .await;
273 assert!(t.is_empty().await, "market orders must not be tracked");
274
275 t.record("l1".into(), &limit("BTCUSDT")).await;
276 assert_eq!(t.len().await, 1);
277 }
278
279 #[tokio::test]
280 async fn tracker_forget_removes() {
281 let t = OrderTracker::new();
282 t.record("l1".into(), &limit("BTCUSDT")).await;
283 t.forget("l1").await;
284 assert!(t.is_empty().await);
285 }
286
287 struct MockEx {
289 open: std::sync::Mutex<Vec<rustrade_core::OpenOrder>>,
290 cancels: std::sync::Mutex<Vec<String>>,
291 }
292 impl MockEx {
293 fn new(open: Vec<rustrade_core::OpenOrder>) -> Arc<Self> {
294 Arc::new(Self {
295 open: std::sync::Mutex::new(open),
296 cancels: std::sync::Mutex::new(Vec::new()),
297 })
298 }
299 }
300 #[async_trait]
301 impl ExchangeClient for MockEx {
302 fn name(&self) -> &str {
303 "mock"
304 }
305 async fn place_order(&self, _o: &Order) -> Result<String> {
306 Ok("x".into())
307 }
308 async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
309 Ok(0)
310 }
311 async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
312 Ok("c".into())
313 }
314 async fn get_position(&self, _s: &Symbol) -> Result<Position> {
315 Ok(Position::FLAT)
316 }
317 async fn get_balance(&self, _c: &str) -> Result<f64> {
318 Ok(0.0)
319 }
320 fn supports(&self, c: Capability) -> bool {
321 matches!(c, Capability::OrderTracking)
322 }
323 async fn get_open_orders(&self, _s: &Symbol) -> Result<Vec<rustrade_core::OpenOrder>> {
324 Ok(self.open.lock().unwrap().clone())
325 }
326 async fn cancel_order(&self, _s: &Symbol, order_id: &str) -> Result<bool> {
327 self.cancels.lock().unwrap().push(order_id.to_string());
328 Ok(true)
329 }
330 }
331
332 fn open_order(id: &str, created_at: Option<DateTime<Utc>>) -> rustrade_core::OpenOrder {
333 rustrade_core::OpenOrder {
334 order_id: id.into(),
335 client_id: None,
336 symbol: Symbol::from("BTCUSDT"),
337 side: Side::Buy,
338 kind: OrderKind::Limit,
339 limit_price: Some(Price(100.0)),
340 size: Volume(1.0),
341 filled: Volume(0.0),
342 status: rustrade_core::OrderStatus::Open,
343 created_at,
344 }
345 }
346
347 fn reaper(ex: Arc<MockEx>, tracker: OrderTracker, ttl: Duration) -> OrderReaperService {
348 OrderReaperService::new(
349 ex,
350 tracker,
351 vec![Symbol::from("BTCUSDT")],
352 ttl,
353 Duration::from_secs(60),
354 Arc::new(NoopSink),
355 )
356 }
357
358 #[tokio::test]
359 async fn sweep_reconciles_away_vanished_order() {
360 let tracker = OrderTracker::new();
362 tracker.record("gone".into(), &limit("BTCUSDT")).await;
363 let ex = MockEx::new(vec![]);
364 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
365
366 svc.sweep_once().await;
367 assert!(
368 tracker.is_empty().await,
369 "vanished order should be reconciled away"
370 );
371 assert_eq!(svc.reconciled(), 1);
372 assert_eq!(svc.cancelled(), 0);
373 assert!(ex.cancels.lock().unwrap().is_empty());
374 }
375
376 #[tokio::test]
377 async fn sweep_keeps_fresh_resting_order() {
378 let tracker = OrderTracker::new();
380 tracker.record("fresh".into(), &limit("BTCUSDT")).await;
381 let ex = MockEx::new(vec![open_order("fresh", Some(Utc::now()))]);
382 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
383
384 svc.sweep_once().await;
385 assert_eq!(tracker.len().await, 1, "fresh order should remain tracked");
386 assert_eq!(svc.cancelled(), 0);
387 assert!(ex.cancels.lock().unwrap().is_empty());
388 }
389
390 #[tokio::test]
391 async fn sweep_cancels_order_past_ttl() {
392 let tracker = OrderTracker::new();
394 tracker.record("stale".into(), &limit("BTCUSDT")).await;
395 let created = Utc::now() - chrono::Duration::hours(1);
396 let ex = MockEx::new(vec![open_order("stale", Some(created))]);
397 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(1));
398
399 svc.sweep_once().await;
400 assert_eq!(svc.cancelled(), 1, "stale order should be cancelled");
401 assert!(
402 tracker.is_empty().await,
403 "cancelled order should be forgotten"
404 );
405 assert_eq!(
406 ex.cancels.lock().unwrap().as_slice(),
407 &["stale".to_string()]
408 );
409 }
410}