tpex_api/server/
mod.rs

1use std::{fmt::Debug, time::Duration};
2
3use crate::{shared::*, state_type};
4pub mod tokens;
5pub mod state;
6
7use axum::{extract::{ws::rejection::WebSocketUpgradeRejection, FromRequestParts}, response::IntoResponse, serve::Listener, Router};
8use tokio::io::{AsyncBufRead, AsyncSeek, AsyncWrite};
9use tokio_util::sync::CancellationToken;
10use tower_http::trace::TraceLayer;
11use tpex::StateSync;
12#[derive(clap::Parser)]
13pub struct Args {
14    pub trades: std::path::PathBuf,
15    pub db: String,
16    pub endpoint: String,
17}
18
19#[derive(Debug)]
20enum Error {
21    TPEx(tpex::Error),
22    UncontrolledUser,
23    TokenTooLowLevel,
24    TokenInvalid,
25    NotNextId{next_id: u64}
26}
27impl From<tpex::Error> for Error {
28    fn from(value: tpex::Error) -> Self {
29        Self::TPEx(value)
30    }
31}
32impl axum::response::IntoResponse for Error {
33    fn into_response(self) -> axum::response::Response {
34        let (code,err) = match self {
35            Self::TPEx(err) => (409, ErrorInfo{error:err.to_string()}),
36            Self::UncontrolledUser => (403, ErrorInfo{error:"This action would act on behalf of a different user.".to_owned()}),
37            Self::TokenTooLowLevel => (403, ErrorInfo{error:"This action requires a higher permission level".to_owned()}),
38            Self::NotNextId{next_id} => (409, ErrorInfo{error:format!("The requested ID was not the next, which is {next_id}")}),
39            Self::TokenInvalid => (409, ErrorInfo{error:"The given token does not exist".to_owned()})
40        };
41
42        let body = serde_json::to_vec(&err).expect("Unable to serialise error");
43
44        let body = axum::body::Body::from(body);
45
46        axum::response::Response::builder()
47        .status(code)
48        .header("Content-Type", "application/json")
49        .body(body)
50        .expect("Unable to create error response")
51    }
52}
53
54async fn state_patch(
55    axum::extract::State(state): axum::extract::State<state_type!()>,
56    token: TokenInfo,
57    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StatePatchArgs>,
58    axum::extract::Json(action): axum::extract::Json<tpex::Action>
59) -> Result<axum::response::Json<u64>, Error> {
60    match token.level {
61        TokenLevel::ReadOnly => return Err(Error::TokenTooLowLevel),
62        TokenLevel::ProxyOne => {
63            let perms = state.tpex.read().await.state().perms(&action)?;
64            if perms.player != token.user {
65                return Err(Error::UncontrolledUser);
66            }
67        }
68        // Apply catches all banker perm mismatches, assuming that upstream has verified their action:
69        TokenLevel::ProxyAll => ()
70    }
71    let mut tpex_state = state.tpex.write().await;
72    // We have to do this *after* we lock, or else we could get out-of-order timing or with the time being out
73    //
74    // tbh we can get that anyway because we're using the system clock, but let's not make it worse, ok?
75    let now = chrono::Utc::now();
76    let id =
77        if let Some(expected_id) = args.and_then(|i| i.id) {
78            let next_id = tpex_state.state().get_next_id();
79            if next_id != expected_id {
80                return Err(Error::NotNextId{next_id});
81            }
82            let id = tpex_state.apply(action, now).await?;
83            assert_eq!(id, next_id, "Somehow got ID mismatch");
84            id
85        }
86        else {
87            tpex_state.apply(action, now).await?
88        };
89    // We patched, so update the id
90    //
91    // We use send_replace so that we don't have to worry about if anyone's listening
92    state.updated.send_replace(id);
93    // Finally, we can drop the lock
94    drop(tpex_state);
95    Ok(axum::Json(id))
96}
97
98struct OptionalWebSocket(pub Option<axum::extract::ws::WebSocketUpgrade>);
99impl<S : Send + Sync> FromRequestParts<S> for OptionalWebSocket {
100    #[doc = " If the extractor fails it\'ll use this \"rejection\" type. A rejection is"]
101    #[doc = " a kind of error that can be converted into a response."]
102    type Rejection = WebSocketUpgradeRejection;
103
104    async fn from_request_parts(parts: &mut axum::http::request::Parts,state: &S,) -> Result<Self,Self::Rejection> {
105        match axum::extract::ws::WebSocketUpgrade::from_request_parts(parts, state).await {
106            Ok(x) => Ok(Self(Some(x))),
107            Err(WebSocketUpgradeRejection::MethodNotGet(_)) |
108            Err(WebSocketUpgradeRejection::MethodNotConnect(_)) |
109            Err(WebSocketUpgradeRejection::InvalidConnectionHeader(_)) |
110            Err(WebSocketUpgradeRejection::InvalidUpgradeHeader(_)) => Ok(Self(None)),
111            Err(e) => Err(e)
112        }
113    }
114}
115
116async fn state_get(
117    axum::extract::State(state): axum::extract::State<state_type!()>,
118    // must extract token to auth
119    _token: TokenInfo,
120    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StateGetArgs>,
121    OptionalWebSocket(upgrade): OptionalWebSocket
122) -> axum::response::Response {
123    let mut from = args.unwrap_or_default().from.unwrap_or(1);
124    if let Some(upgrade) = upgrade {
125        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
126            let mut subscription = state.updated.subscribe();
127            loop {
128                let should_ping = tokio::select! {
129                    new_actions = subscription.wait_for(|i| *i >= from) => {
130                        new_actions.expect("Failed to poll updated_recv");
131                        false
132                    },
133                    _timeout = tokio::time::sleep(Duration::from_secs(10)) => true
134                };
135                if should_ping {
136                    if sock.send(axum::extract::ws::Message::Ping(Default::default())).await.is_err() {
137                        break;
138                    }
139                    else {
140                        continue;
141                    }
142                }
143                let tpex_state_handle = state.tpex.read().await;
144                // It's better to clone these out than hold state
145                let res =
146                    tpex_state_handle.cache().iter()
147                    .skip((from as usize).saturating_sub(1))
148                    .map(Into::into)
149                    .map(axum::extract::ws::Message::Text)
150                    .collect::<Vec<_>>();
151                // rechecking the id prevents a race condition
152                from = tpex_state_handle.state().get_next_id();
153                // We have extracted all we need
154                drop(tpex_state_handle);
155                // Send it off
156                for i in res {
157                    if sock.send(i).await.is_err() {
158                        break;
159                    }
160                }
161            }
162        })
163    }
164    else {
165        let data =
166            state.tpex.read().await.cache().iter()
167            .skip(from as usize)
168            .fold(String::new(), |a, b| a + b);
169        let body = axum::body::Body::from(data);
170        axum::response::Response::builder()
171        .header("Content-Type", "text/plain")
172        .body(body)
173        .expect("Unable to create state_get response")
174    }
175}
176
177async fn token_get(
178    axum::extract::State(_state): axum::extract::State<state_type!()>,
179    token: TokenInfo
180) -> axum::Json<TokenInfo> {
181    axum::Json(token)
182}
183
184async fn token_post(
185    axum::extract::State(state): axum::extract::State<state_type!()>,
186    token: TokenInfo,
187    axum::extract::Json(args): axum::extract::Json<TokenPostArgs>,
188) -> Result<axum::Json<Token>, Error> {
189    if args.level > token.level {
190        return Err(Error::TokenTooLowLevel)
191    }
192    if args.user != token.user && token.level < TokenLevel::ProxyAll {
193        return Err(Error::UncontrolledUser)
194    }
195
196    Ok(axum::Json(state.tokens.create_token(args.level, args.user).await.expect("Cannot access DB")))
197}
198
199async fn token_delete(
200    axum::extract::State(state): axum::extract::State<state_type!()>,
201    token: TokenInfo,
202    axum::extract::Json(args): axum::extract::Json<TokenDeleteArgs>
203) -> Result<axum::Json<()>, Error> {
204    let target = args.token.unwrap_or(token.token);
205    // We only need perms to delete other tokens
206    if target != token.token && token.level < TokenLevel::ProxyOne {
207        return Err(Error::TokenTooLowLevel);
208    }
209    state.tokens.delete_token(&token.token).await
210    .map_or(Err(Error::TokenInvalid), |_| Ok(axum::Json(())))
211}
212
213async fn fastsync_get(
214    axum::extract::State(state): axum::extract::State<state_type!()>,
215    _token: TokenInfo,
216    OptionalWebSocket(upgrade): OptionalWebSocket
217) -> axum::response::Response {
218    if let Some(upgrade) = upgrade {
219        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
220            let mut subscription = state.updated.subscribe();
221            subscription.mark_changed();
222            loop {
223                let should_ping = tokio::select! {
224                    new_actions = subscription.changed() => {
225                        new_actions.expect("Failed to poll updated_recv");
226                        false
227                    },
228                    _timeout = tokio::time::sleep(Duration::from_secs(10)) => true
229                };
230                if should_ping {
231                    if sock.send(axum::extract::ws::Message::Ping(Default::default())).await.is_err() {
232                        break;
233                    }
234                    else {
235                        continue;
236                    }
237                }
238                let res = StateSync::from(state.tpex.read().await.state());
239                if sock.send(axum::extract::ws::Message::Text(serde_json::to_string(&res).expect("Could not serialise state sync").into())).await.is_err() {
240                    break;
241                }
242            }
243        })
244    }
245    else {
246        let res = StateSync::from(state.tpex.read().await.state());
247        axum::Json(res).into_response()
248    }
249}
250
251async fn inspect_balance_get(
252    axum::extract::State(state): axum::extract::State<state_type!()>,
253    _token: TokenInfo,
254    axum::extract::Query(args): axum::extract::Query<InspectBalanceGetArgs>
255) -> axum::response::Response {
256    axum::Json(state.tpex.read().await.state().get_bal(&args.player)).into_response()
257}
258
259async fn inspect_assets_get(
260    axum::extract::State(state): axum::extract::State<state_type!()>,
261    _token: TokenInfo,
262    axum::extract::Query(args): axum::extract::Query<InspectAssetsGetArgs>
263) -> axum::response::Response {
264    axum::Json(state.tpex.read().await.state().get_assets(&args.player)).into_response()
265}
266
267async fn inspect_audit_get(
268    axum::extract::State(state): axum::extract::State<state_type!()>,
269    _token: TokenInfo
270) -> axum::response::Response {
271    axum::Json(state.tpex.read().await.state().itemised_audit()).into_response()
272}
273
274async fn price_history_get(
275    axum::extract::State(state): axum::extract::State<state_type!()>,
276    _token: TokenInfo,
277    axum::extract::Query(args): axum::extract::Query<PriceHistoryArgs>
278) -> axum::response::Response {
279    axum::Json(state.tpex.read().await.price_history().get(&args.asset).unwrap_or(&Vec::default())).into_response()
280}
281
282pub async fn run_server<L: Listener>(
283    cancel: CancellationToken,
284    trade_log: impl AsyncWrite + AsyncBufRead + AsyncSeek + Unpin + Send + Sync + 'static,
285    token_handler: tokens::TokenHandler,
286    listener: L) where L::Addr : Debug
287{
288    let (updated, _) = tokio::sync::watch::channel(1);
289    let state = state::StateStruct {
290        tpex: tokio::sync::RwLock::new(state::TPExState::replay(trade_log).await.expect("Could not replace trades")),
291        tokens: token_handler,
292        updated,
293    };
294
295    let cors = tower_http::cors::CorsLayer::new()
296        .allow_headers(tower_http::cors::Any)
297        .allow_origin(tower_http::cors::Any)
298        .allow_methods(tower_http::cors::Any);
299
300
301    let app = Router::new()
302        .route("/state", axum::routing::get(state_get))
303        .route("/state", axum::routing::connect(state_get))
304        .route("/state", axum::routing::patch(state_patch))
305
306        .route("/token", axum::routing::get(token_get))
307        .route("/token", axum::routing::post(token_post))
308        .route("/token", axum::routing::delete(token_delete))
309
310        .route("/fastsync", axum::routing::get(fastsync_get))
311
312        .route("/inspect/balance", axum::routing::get(inspect_balance_get))
313        .route("/inspect/assets", axum::routing::get(inspect_assets_get))
314        .route("/inspect/audit", axum::routing::get(inspect_audit_get))
315
316        .route("/price/history", axum::routing::get(price_history_get))
317
318        .with_state(std::sync::Arc::new(state))
319
320        .layer(TraceLayer::new_for_http())
321
322        .route_layer(cors);
323
324    axum::serve(listener, app).with_graceful_shutdown(async move { cancel.cancelled().await }).await.expect("Failed to serve");
325}
326
327pub async fn run_server_with_args(args: Args, cancel: CancellationToken) {
328    run_server(
329        cancel,
330        tokio::io::BufStream::with_capacity(16<<20, 16<<20,
331            tokio::fs::File::options()
332            .read(true)
333            .write(true)
334            .truncate(false)
335            .create(true)
336            .open(args.trades).await.expect("Unable to open trade list")),
337        tokens::TokenHandler::new(&args.db).await.expect("Could not connect to DB"),
338        tokio::net::TcpListener::bind(args.endpoint).await.expect("Could not bind to endpoint")
339    ).await
340}