tpex_api/server/
state.rs

1use std::{collections::HashMap, pin::pin};
2
3use tpex::Action;
4
5use super::{PriceSummary, tokens};
6
7use tokio::io::AsyncBufReadExt;
8
9
10struct CachedFileView<Stream: tokio::io::AsyncWrite> {
11    base: Stream,
12    cache: Vec<u8>
13}
14impl<Stream: tokio::io::AsyncWrite> CachedFileView<Stream> {
15    fn new(base: Stream) -> Self {
16        CachedFileView { base, cache: Vec::new() }
17    }
18    fn extract(self) -> Vec<u8> {
19        self.cache
20    }
21}
22impl<Stream: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for CachedFileView<Stream> {
23    fn poll_write(
24        mut self: std::pin::Pin<&mut Self>,
25        cx: &mut std::task::Context<'_>,
26        buf: &[u8],
27    ) -> std::task::Poll<Result<usize, std::io::Error>> {
28        let ret = pin!(&mut self.base).poll_write(cx, buf);
29        if let std::task::Poll::Ready(Ok(len)) = ret {
30            self.cache.extend_from_slice(&buf[..len]);
31        }
32        ret
33    }
34
35    fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
36        pin!(&mut self.base).poll_flush(cx)
37    }
38
39    fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
40        pin!(&mut self.base).poll_shutdown(cx)
41    }
42}
43
44pub(crate) struct TPExState<Stream: tokio::io::AsyncWrite> {
45    state: tpex::State,
46    file: Stream,
47    cache: Vec<String>,
48    price_history: HashMap<tpex::AssetId, Vec<PriceSummary>>
49}
50impl<Stream: tokio::io::AsyncSeek + tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin + tokio::io::AsyncBufRead> TPExState<Stream> {
51    pub async fn replay(file: Stream) -> Result<Self, tpex::Error> {
52        // This is the state we will call apply on repeatedly
53        //
54        // When we're done, we'll extract all the information and add in the file, which will now be positioned at the end
55        let mut tmp_state = TPExState { state: tpex::State::new(), file: tokio::io::sink(), cache: Default::default(), price_history: Default::default() };
56        let mut lines = file.lines();
57        while let Some(line) = lines.next_line().await.expect("Could not read next action") {
58            let wrapped_action: tpex::WrappedAction = serde_json::from_str(&line).expect("Could not parse state");
59            let id = tmp_state.apply(wrapped_action.action, wrapped_action.time).await?;
60            assert_eq!(id, wrapped_action.id, "Wrapped action had out-of-order id");
61        }
62        Ok(Self {
63            file: lines.into_inner(),
64            state: tmp_state.state,
65            cache: tmp_state.cache,
66            price_history: tmp_state.price_history
67        })
68    }
69}
70impl<Stream: tokio::io::AsyncWrite + Unpin> TPExState<Stream> {
71    #[allow(dead_code)]
72    pub fn new(file: Stream, cache: Vec<String>) -> Self {
73        TPExState { state: tpex::State::new(), file, cache, price_history: Default::default() }
74    }
75
76    pub async fn apply(&mut self, action: Action, time: chrono::DateTime<chrono::Utc>) -> Result<u64, tpex::Error> {
77        // Grab the information to price history before we consume the action and modify everything
78        let maybe_asset = match &action {
79            tpex::Action::BuyOrder { asset, .. } => Some(asset.clone()),
80            tpex::Action::SellOrder { asset, .. } => Some(asset.clone()),
81            tpex::Action::CancelOrder { target } => Some(self.state.get_order(*target).expect("Invalid order id").asset.clone()),
82            _ => None
83        };
84
85        let mut stream = CachedFileView::new(&mut self.file);
86        let ret = self.state.apply_with_time(action, time, &mut stream).await?;
87        // If the price has changed, log it
88        if let Some(asset) = maybe_asset {
89            let (new_buy, new_sell) = self.state.get_prices(&asset);
90            let new_elem = PriceSummary {
91                time,
92                best_buy: new_buy.keys().next_back().cloned(),
93                n_buy: new_buy.values().sum(),
94                best_sell: new_sell.keys().next().cloned(),
95                n_sell: new_sell.values().sum()
96            };
97            let target = self.price_history.entry(asset).or_default();
98            target.push(new_elem);
99        }
100        self.cache.push(String::from_utf8(stream.extract()).expect("Produced non-utf8 log line"));
101        Ok(ret)
102    }
103
104    pub fn cache(&self) -> &[String] {
105        &self.cache
106    }
107
108    pub fn state(&self) -> &tpex::State {
109        &self.state
110    }
111
112    pub fn price_history(&self) -> &HashMap<tpex::AssetId, Vec<PriceSummary>> {
113        &self.price_history
114    }
115    // async fn get_lines(&mut self) -> Vec<u8> {
116    //     // Keeping everything in the log file means we can't have different versions of the same data
117    //     self.file.rewind().await.expect("Could not rewind trade file.");
118    //     let mut buf = Vec::new();
119    //     // This will seek to the end again, so pos is the same before and after get_lines
120    //     self.file.read_to_end(&mut buf).await.expect("Could not re-read trade file.");
121    //     buf
122    // }
123}
124
125pub(crate) struct StateStruct<Stream: tokio::io::AsyncSeek + tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin> {
126    pub(crate) tpex: tokio::sync::RwLock<TPExState<Stream>>,
127    pub(crate) tokens: tokens::TokenHandler,
128    pub(crate) updated: tokio::sync::watch::Sender<u64>,
129}
130#[macro_export]
131macro_rules! state_type {
132    () => {
133        std::sync::Arc<$crate::server::state::StateStruct<impl AsyncBufRead + AsyncWrite + AsyncSeek + Unpin + Send + Sync + 'static>>
134    };
135}