1use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::time::Duration;
26
27use async_trait::async_trait;
28use chrono::{DateTime, Utc};
29use rustrade_core::{ExchangeClient, MetricsSink, Order, OrderKind, Symbol};
30use rustrade_supervisor::{RestartPolicy, TradingService};
31use tokio::sync::RwLock;
32use tokio_util::sync::CancellationToken;
33
34#[derive(Debug, Clone, PartialEq)]
36pub struct TrackedOrder {
37 pub order_id: String,
39 pub symbol: Symbol,
41 pub placed_at: DateTime<Utc>,
43}
44
45#[derive(Clone, Default)]
50pub struct OrderTracker {
51 inner: Arc<RwLock<HashMap<String, TrackedOrder>>>,
52}
53
54impl OrderTracker {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub(crate) async fn record(&self, order_id: String, order: &Order) {
65 if matches!(order.kind, OrderKind::Market) {
66 return;
67 }
68 self.inner.write().await.insert(
69 order_id.clone(),
70 TrackedOrder {
71 order_id,
72 symbol: order.symbol.clone(),
73 placed_at: Utc::now(),
74 },
75 );
76 }
77
78 pub(crate) async fn forget(&self, order_id: &str) {
80 self.inner.write().await.remove(order_id);
81 }
82
83 pub async fn len(&self) -> usize {
85 self.inner.read().await.len()
86 }
87
88 pub async fn is_empty(&self) -> bool {
90 self.inner.read().await.is_empty()
91 }
92
93 pub async fn snapshot(&self) -> Vec<TrackedOrder> {
95 self.inner.read().await.values().cloned().collect()
96 }
97}
98
99#[derive(Clone, Default)]
110pub struct OcoRegistry {
111 inner: Arc<RwLock<HashMap<String, OcoEntry>>>,
112}
113
114#[derive(Clone)]
115struct OcoEntry {
116 sibling: String,
117 symbol: Symbol,
118}
119
120impl OcoRegistry {
121 pub fn new() -> Self {
123 Self::default()
124 }
125
126 pub(crate) async fn register(&self, symbol: Symbol, a: String, b: String) {
129 let mut map = self.inner.write().await;
130 map.insert(
131 a.clone(),
132 OcoEntry {
133 sibling: b.clone(),
134 symbol: symbol.clone(),
135 },
136 );
137 map.insert(b, OcoEntry { sibling: a, symbol });
138 }
139
140 pub(crate) async fn take_sibling(&self, order_id: &str) -> Option<(Symbol, String)> {
144 let mut map = self.inner.write().await;
145 let entry = map.remove(order_id)?;
146 map.remove(&entry.sibling);
148 Some((entry.symbol, entry.sibling))
149 }
150
151 pub async fn len(&self) -> usize {
153 self.inner.read().await.len()
154 }
155
156 pub async fn is_empty(&self) -> bool {
158 self.inner.read().await.is_empty()
159 }
160}
161
162pub struct OrderReaperService {
169 exchange: Arc<dyn ExchangeClient>,
170 tracker: OrderTracker,
171 symbols: Vec<Symbol>,
172 ttl: Duration,
173 poll_cadence: Duration,
174 metrics: Arc<dyn MetricsSink>,
175 cancelled: AtomicU64,
176 reconciled: AtomicU64,
177 sweeps: AtomicU64,
178}
179
180impl OrderReaperService {
181 pub(crate) fn new(
182 exchange: Arc<dyn ExchangeClient>,
183 tracker: OrderTracker,
184 symbols: Vec<Symbol>,
185 ttl: Duration,
186 poll_cadence: Duration,
187 metrics: Arc<dyn MetricsSink>,
188 ) -> Self {
189 Self {
190 exchange,
191 tracker,
192 symbols,
193 ttl,
194 poll_cadence,
195 metrics,
196 cancelled: AtomicU64::new(0),
197 reconciled: AtomicU64::new(0),
198 sweeps: AtomicU64::new(0),
199 }
200 }
201
202 pub fn cancelled(&self) -> u64 {
204 self.cancelled.load(Ordering::Relaxed)
205 }
206 pub fn reconciled(&self) -> u64 {
209 self.reconciled.load(Ordering::Relaxed)
210 }
211 pub fn sweeps(&self) -> u64 {
213 self.sweeps.load(Ordering::Relaxed)
214 }
215
216 pub(crate) async fn sweep_once(&self) {
219 self.sweeps.fetch_add(1, Ordering::Relaxed);
220 let now = Utc::now();
221
222 for symbol in &self.symbols {
223 let open = match self.exchange.get_open_orders(symbol).await {
224 Ok(o) => o,
225 Err(e) => {
226 tracing::warn!(symbol = %symbol, error = %e, "get_open_orders failed; skipping sweep for symbol");
227 continue;
228 }
229 };
230 let live: HashMap<&str, &rustrade_core::OpenOrder> =
232 open.iter().map(|o| (o.order_id.as_str(), o)).collect();
233
234 let tracked: Vec<TrackedOrder> = self
237 .tracker
238 .snapshot()
239 .await
240 .into_iter()
241 .filter(|t| &t.symbol == symbol)
242 .collect();
243
244 for t in tracked {
245 match live.get(t.order_id.as_str()) {
246 None => {
247 self.tracker.forget(&t.order_id).await;
250 self.reconciled.fetch_add(1, Ordering::Relaxed);
251 tracing::debug!(symbol = %symbol, order_id = %t.order_id, "reconciled away (no longer open)");
252 }
253 Some(oo) => {
254 let age_from = oo.created_at.unwrap_or(t.placed_at);
258 let age = now.signed_duration_since(age_from);
259 if age.num_milliseconds().max(0) as u128 >= self.ttl.as_millis() {
260 match self.exchange.cancel_order(symbol, &t.order_id).await {
261 Ok(_) => {
262 self.tracker.forget(&t.order_id).await;
263 self.cancelled.fetch_add(1, Ordering::Relaxed);
264 self.metrics.counter(
265 "rustrade_orders_cancelled_ttl_total",
266 &[("symbol", symbol.as_str())],
267 1,
268 );
269 tracing::info!(symbol = %symbol, order_id = %t.order_id, ttl_secs = self.ttl.as_secs(), "cancelled stale resting order (TTL)");
270 }
271 Err(e) => {
272 tracing::warn!(symbol = %symbol, order_id = %t.order_id, error = %e, "TTL cancel failed; will retry next sweep")
273 }
274 }
275 }
276 }
277 }
278 }
279 }
280 }
281}
282
283#[async_trait]
284impl TradingService for OrderReaperService {
285 fn name(&self) -> &str {
286 "order-reaper"
287 }
288
289 fn restart_policy(&self) -> RestartPolicy {
290 RestartPolicy::OnFailure
291 }
292
293 async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
294 tracing::info!(
295 ttl_secs = self.ttl.as_secs(),
296 cadence_secs = self.poll_cadence.as_secs(),
297 symbols = self.symbols.len(),
298 "order-reaper starting"
299 );
300 loop {
301 tokio::select! {
302 _ = cancel.cancelled() => {
303 tracing::info!(
304 sweeps = self.sweeps(),
305 cancelled = self.cancelled(),
306 reconciled = self.reconciled(),
307 "order-reaper shutting down"
308 );
309 return Ok(());
310 }
311 _ = tokio::time::sleep(self.poll_cadence) => {
312 self.sweep_once().await;
313 }
314 }
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use rustrade_core::{Capability, NoopSink, Position, Price, Result, Side, Volume};
323
324 fn limit(symbol: &str) -> Order {
325 Order::limit(symbol, Side::Buy, Volume(1.0), Price(100.0))
326 }
327
328 #[tokio::test]
329 async fn tracker_ignores_market_orders() {
330 let t = OrderTracker::new();
331 t.record(
332 "m1".into(),
333 &Order::market("BTCUSDT", Side::Buy, Volume(1.0)),
334 )
335 .await;
336 assert!(t.is_empty().await, "market orders must not be tracked");
337
338 t.record("l1".into(), &limit("BTCUSDT")).await;
339 assert_eq!(t.len().await, 1);
340 }
341
342 #[tokio::test]
343 async fn tracker_forget_removes() {
344 let t = OrderTracker::new();
345 t.record("l1".into(), &limit("BTCUSDT")).await;
346 t.forget("l1").await;
347 assert!(t.is_empty().await);
348 }
349
350 struct MockEx {
352 open: std::sync::Mutex<Vec<rustrade_core::OpenOrder>>,
353 cancels: std::sync::Mutex<Vec<String>>,
354 }
355 impl MockEx {
356 fn new(open: Vec<rustrade_core::OpenOrder>) -> Arc<Self> {
357 Arc::new(Self {
358 open: std::sync::Mutex::new(open),
359 cancels: std::sync::Mutex::new(Vec::new()),
360 })
361 }
362 }
363 #[async_trait]
364 impl ExchangeClient for MockEx {
365 fn name(&self) -> &str {
366 "mock"
367 }
368 async fn place_order(&self, _o: &Order) -> Result<String> {
369 Ok("x".into())
370 }
371 async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
372 Ok(0)
373 }
374 async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
375 Ok("c".into())
376 }
377 async fn get_position(&self, _s: &Symbol) -> Result<Position> {
378 Ok(Position::FLAT)
379 }
380 async fn get_balance(&self, _c: &str) -> Result<f64> {
381 Ok(0.0)
382 }
383 fn supports(&self, c: Capability) -> bool {
384 matches!(c, Capability::OrderTracking)
385 }
386 async fn get_open_orders(&self, _s: &Symbol) -> Result<Vec<rustrade_core::OpenOrder>> {
387 Ok(self.open.lock().unwrap().clone())
388 }
389 async fn cancel_order(&self, _s: &Symbol, order_id: &str) -> Result<bool> {
390 self.cancels.lock().unwrap().push(order_id.to_string());
391 Ok(true)
392 }
393 }
394
395 fn open_order(id: &str, created_at: Option<DateTime<Utc>>) -> rustrade_core::OpenOrder {
396 rustrade_core::OpenOrder {
397 order_id: id.into(),
398 client_id: None,
399 symbol: Symbol::from("BTCUSDT"),
400 side: Side::Buy,
401 kind: OrderKind::Limit,
402 limit_price: Some(Price(100.0)),
403 size: Volume(1.0),
404 filled: Volume(0.0),
405 status: rustrade_core::OrderStatus::Open,
406 created_at,
407 }
408 }
409
410 fn reaper(ex: Arc<MockEx>, tracker: OrderTracker, ttl: Duration) -> OrderReaperService {
411 OrderReaperService::new(
412 ex,
413 tracker,
414 vec![Symbol::from("BTCUSDT")],
415 ttl,
416 Duration::from_secs(60),
417 Arc::new(NoopSink),
418 )
419 }
420
421 #[tokio::test]
422 async fn sweep_reconciles_away_vanished_order() {
423 let tracker = OrderTracker::new();
425 tracker.record("gone".into(), &limit("BTCUSDT")).await;
426 let ex = MockEx::new(vec![]);
427 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
428
429 svc.sweep_once().await;
430 assert!(
431 tracker.is_empty().await,
432 "vanished order should be reconciled away"
433 );
434 assert_eq!(svc.reconciled(), 1);
435 assert_eq!(svc.cancelled(), 0);
436 assert!(ex.cancels.lock().unwrap().is_empty());
437 }
438
439 #[tokio::test]
440 async fn sweep_keeps_fresh_resting_order() {
441 let tracker = OrderTracker::new();
443 tracker.record("fresh".into(), &limit("BTCUSDT")).await;
444 let ex = MockEx::new(vec![open_order("fresh", Some(Utc::now()))]);
445 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(3600));
446
447 svc.sweep_once().await;
448 assert_eq!(tracker.len().await, 1, "fresh order should remain tracked");
449 assert_eq!(svc.cancelled(), 0);
450 assert!(ex.cancels.lock().unwrap().is_empty());
451 }
452
453 #[tokio::test]
454 async fn sweep_cancels_order_past_ttl() {
455 let tracker = OrderTracker::new();
457 tracker.record("stale".into(), &limit("BTCUSDT")).await;
458 let created = Utc::now() - chrono::Duration::hours(1);
459 let ex = MockEx::new(vec![open_order("stale", Some(created))]);
460 let svc = reaper(ex.clone(), tracker.clone(), Duration::from_secs(1));
461
462 svc.sweep_once().await;
463 assert_eq!(svc.cancelled(), 1, "stale order should be cancelled");
464 assert!(
465 tracker.is_empty().await,
466 "cancelled order should be forgotten"
467 );
468 assert_eq!(
469 ex.cancels.lock().unwrap().as_slice(),
470 &["stale".to_string()]
471 );
472 }
473
474 #[tokio::test]
475 async fn oco_register_and_take_sibling_is_symmetric() {
476 let oco = OcoRegistry::new();
477 let sym = Symbol::from("BTCUSDT");
478 oco.register(sym.clone(), "sl".into(), "tp".into()).await;
479 assert_eq!(oco.len().await, 2);
480
481 let sib = oco.take_sibling("sl").await;
483 assert_eq!(sib, Some((sym, "tp".to_string())));
484 assert!(oco.is_empty().await, "both legs cleared after one fills");
485
486 assert!(oco.take_sibling("tp").await.is_none());
488 }
489
490 #[tokio::test]
491 async fn oco_take_sibling_from_either_leg() {
492 let oco = OcoRegistry::new();
493 let sym = Symbol::from("ETHUSDT");
494 oco.register(sym.clone(), "a".into(), "b".into()).await;
495 assert_eq!(oco.take_sibling("b").await, Some((sym, "a".to_string())));
497 assert!(oco.is_empty().await);
498 }
499
500 #[tokio::test]
501 async fn oco_unknown_id_is_none() {
502 let oco = OcoRegistry::new();
503 assert!(oco.take_sibling("nope").await.is_none());
504 }
505}