tpex_api/server/
mod.rs

1use std::fmt::Debug;
2
3use crate::{shared::*, state_type};
4pub mod tokens;
5pub mod state;
6
7use axum::{extract::{ws::rejection::WebSocketUpgradeRejection, FromRequestParts}, response::IntoResponse, serve::Listener, Router};
8use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite};
9use tokio_util::sync::CancellationToken;
10use tower_http::trace::TraceLayer;
11use tpex::{ActionLevel, StateSync};
12#[derive(clap::Parser)]
13pub struct Args {
14    pub trades: std::path::PathBuf,
15    pub db: String,
16    pub endpoint: String,
17}
18
19#[derive(Debug)]
20enum Error {
21    TPEx(tpex::Error),
22    UncontrolledUser,
23    TokenTooLowLevel,
24    TokenInvalid,
25    NotNextId{next_id: u64}
26}
27impl From<tpex::Error> for Error {
28    fn from(value: tpex::Error) -> Self {
29        Self::TPEx(value)
30    }
31}
32impl axum::response::IntoResponse for Error {
33    fn into_response(self) -> axum::response::Response {
34        let (code,err) = match self {
35            Self::TPEx(err) => (409, ErrorInfo{error:err.to_string()}),
36            Self::UncontrolledUser => (403, ErrorInfo{error:"This action would act on behalf of a different user.".to_owned()}),
37            Self::TokenTooLowLevel => (403, ErrorInfo{error:"This action requires a higher permission level".to_owned()}),
38            Self::NotNextId{next_id} => (409, ErrorInfo{error:format!("The requested ID was not the next, which is {next_id}")}),
39            Self::TokenInvalid => (409, ErrorInfo{error:"The given token does not exist".to_owned()})
40        };
41
42        let body = serde_json::to_vec(&err).expect("Unable to serialise error");
43
44        let body = axum::body::Body::from(body);
45
46        axum::response::Response::builder()
47        .status(code)
48        .header("Content-Type", "application/json")
49        .body(body)
50        .expect("Unable to create error response")
51    }
52}
53
54async fn state_patch(
55    axum::extract::State(state): axum::extract::State<state_type!()>,
56    token: TokenInfo,
57    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StatePatchArgs>,
58    axum::extract::Json(action): axum::extract::Json<tpex::Action>
59) -> Result<axum::response::Json<u64>, Error> {
60    match token.level {
61        TokenLevel::ReadOnly => return Err(Error::TokenTooLowLevel),
62        TokenLevel::ProxyOne => {
63            let perms = state.tpex.read().await.state().perms(&action)?;
64            if perms.player != token.user {
65                return Err(Error::UncontrolledUser);
66            }
67            if perms.level > ActionLevel::Normal {
68                return Err(Error::TokenTooLowLevel);
69            }
70        }
71        // Apply catches all banker perm mismatches, assuming that upstream has verified their action:
72        TokenLevel::ProxyAll => ()
73    }
74    let mut tpex_state = state.tpex.write().await;
75    let id =
76        if let Some(expected_id) = args.and_then(|i| i.id) {
77            let next_id = tpex_state.state().get_next_id();
78            if next_id != expected_id {
79                return Err(Error::NotNextId{next_id});
80            }
81            let id = tpex_state.apply(action).await?;
82            assert_eq!(id, next_id, "Somehow got ID mismatch");
83            id
84        }
85        else {
86            tpex_state.apply(action).await?
87        };
88    // We patched, so update the id
89    //
90    // We use send_replace so that we don't have to worry about if anyone's listening
91    state.updated.send_replace(id);
92    Ok(axum::Json(id))
93}
94
95struct OptionalWebSocket(pub Option<axum::extract::ws::WebSocketUpgrade>);
96impl<S : Send + Sync> FromRequestParts<S> for OptionalWebSocket {
97    #[doc = " If the extractor fails it\'ll use this \"rejection\" type. A rejection is"]
98    #[doc = " a kind of error that can be converted into a response."]
99    type Rejection = WebSocketUpgradeRejection;
100
101    async fn from_request_parts(parts: &mut axum::http::request::Parts,state: &S,) -> Result<Self,Self::Rejection> {
102        match axum::extract::ws::WebSocketUpgrade::from_request_parts(parts, state).await {
103            Ok(x) => Ok(Self(Some(x))),
104            Err(WebSocketUpgradeRejection::MethodNotGet(_)) |
105            Err(WebSocketUpgradeRejection::MethodNotConnect(_)) |
106            Err(WebSocketUpgradeRejection::InvalidConnectionHeader(_)) |
107            Err(WebSocketUpgradeRejection::InvalidUpgradeHeader(_)) => Ok(Self(None)),
108            Err(e) => Err(e)
109        }
110    }
111}
112
113async fn state_get(
114    axum::extract::State(state): axum::extract::State<state_type!()>,
115    // must extract token to auth
116    _token: TokenInfo,
117    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StateGetArgs>,
118    OptionalWebSocket(upgrade): OptionalWebSocket
119) -> axum::response::Response {
120    let mut from = args.unwrap_or_default().from.unwrap_or(0);
121    if let Some(upgrade) = upgrade {
122        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
123            let mut subscription = state.updated.subscribe();
124            loop {
125                subscription.wait_for(|i| *i >= from).await.expect("Failed to poll updated_recv");
126
127                let tpex_state_handle = state.tpex.read().await;
128                // It's better to clone these out than hold state
129                let res =
130                    tpex_state_handle.cache().iter()
131                    .skip(from as usize)
132                    .map(Into::into)
133                    .map(axum::extract::ws::Message::Text)
134                    .collect::<Vec<_>>();
135                // rechecking the id prevents a race condition
136                from = tpex_state_handle.state().get_next_id() - 1;
137                // We have extracted all we need
138                drop(tpex_state_handle);
139                // Send it off
140                for i in res {
141                    if sock.send(i).await.is_err() {
142                        break;
143                    }
144                }
145            }
146        })
147    }
148    else {
149        let data =
150            state.tpex.read().await.cache().iter()
151            .skip(from as usize)
152            .fold(String::new(), |a, b| a + b);
153        let body = axum::body::Body::from(data);
154        axum::response::Response::builder()
155        .header("Content-Type", "text/plain")
156        .body(body)
157        .expect("Unable to create state_get response")
158    }
159}
160
161async fn token_get(
162    axum::extract::State(_state): axum::extract::State<state_type!()>,
163    token: TokenInfo
164) -> axum::Json<TokenInfo> {
165    axum::Json(token)
166}
167
168async fn token_post(
169    axum::extract::State(state): axum::extract::State<state_type!()>,
170    token: TokenInfo,
171    axum::extract::Json(args): axum::extract::Json<TokenPostArgs>,
172) -> Result<axum::Json<Token>, Error> {
173    if args.level > token.level {
174        return Err(Error::TokenTooLowLevel)
175    }
176    if args.user != token.user && token.level < TokenLevel::ProxyAll {
177        return Err(Error::UncontrolledUser)
178    }
179
180    Ok(axum::Json(state.tokens.create_token(args.level, args.user).await.expect("Cannot access DB")))
181}
182
183async fn token_delete(
184    axum::extract::State(state): axum::extract::State<state_type!()>,
185    token: TokenInfo,
186    axum::extract::Json(args): axum::extract::Json<TokenDeleteArgs>
187) -> Result<axum::Json<()>, Error> {
188    let target = args.token.unwrap_or(token.token);
189    // We only need perms to delete other tokens
190    if target != token.token && token.level < TokenLevel::ProxyAll {
191        return Err(Error::TokenTooLowLevel);
192    }
193    state.tokens.delete_token(&token.token).await
194    .map_or(Err(Error::TokenInvalid), |_| Ok(axum::Json(())))
195}
196
197async fn fastsync_get(
198    axum::extract::State(state): axum::extract::State<state_type!()>,
199    _token: TokenInfo,
200    OptionalWebSocket(upgrade): OptionalWebSocket
201) -> axum::response::Response {
202    if let Some(upgrade) = upgrade {
203        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
204            let mut subscription = state.updated.subscribe();
205            subscription.mark_changed();
206            loop {
207                subscription.changed().await.expect("Failed to poll updated_recv");
208                let res = StateSync::from(state.tpex.read().await.state());
209                if sock.send(axum::extract::ws::Message::Text(serde_json::to_string(&res).expect("Could not serialise state sync").into())).await.is_err() {
210                    break;
211                }
212            }
213        })
214    }
215    else {
216        let res = StateSync::from(state.tpex.read().await.state());
217        axum::Json(res).into_response()
218    }
219}
220
221async fn inspect_balance_get(
222    axum::extract::State(state): axum::extract::State<state_type!()>,
223    _token: TokenInfo,
224    axum::extract::Query(args): axum::extract::Query<InspectBalanceGetArgs>
225) -> axum::response::Response {
226    axum::Json(state.tpex.read().await.state().get_bal(&args.player)).into_response()
227}
228
229async fn inspect_assets_get(
230    axum::extract::State(state): axum::extract::State<state_type!()>,
231    _token: TokenInfo,
232    axum::extract::Query(args): axum::extract::Query<InspectAssetsGetArgs>
233) -> axum::response::Response {
234    axum::Json(state.tpex.read().await.state().get_assets(&args.player)).into_response()
235}
236
237async fn inspect_audit_get(
238    axum::extract::State(state): axum::extract::State<state_type!()>,
239    _token: TokenInfo
240) -> axum::response::Response {
241    axum::Json(state.tpex.read().await.state().itemised_audit()).into_response()
242}
243
244pub async fn run_server<L: Listener>(
245    cancel: CancellationToken,
246    mut trade_log: impl AsyncWrite + AsyncBufRead + AsyncSeek + Unpin + Send + Sync + 'static,
247    token_handler: tokens::TokenHandler,
248    listener: L) where L::Addr : Debug
249{
250    // Load cache
251    let mut cache = Vec::new();
252    {
253        let mut lines = trade_log.lines();
254        while let Some(mut line) = lines.next_line().await.expect("Could not read trade file") {
255            line.push('\n');
256            cache.push(line);
257        }
258        trade_log = lines.into_inner();
259        trade_log.rewind().await.expect("Could not rewind trade file");
260    }
261
262    let mut tpex_state = tpex::State::new();
263    tpex_state.replay(&mut trade_log, true).await.expect("Could not replay trades");
264
265    let (updated, _) = tokio::sync::watch::channel(tpex_state.get_next_id().checked_sub(1).expect("Poll counter underflow"));
266    let state = state::StateStruct {
267        tpex: tokio::sync::RwLock::new(state::TPExState::new(tpex_state, trade_log, cache)),
268        tokens: token_handler,
269        updated
270    };
271
272    let cors = tower_http::cors::CorsLayer::new()
273        .allow_headers(tower_http::cors::Any)
274        .allow_origin(tower_http::cors::Any)
275        .allow_methods(tower_http::cors::Any);
276
277
278    let app = Router::new()
279        .route("/state", axum::routing::get(state_get))
280        .route("/state", axum::routing::connect(state_get))
281        .route("/state", axum::routing::patch(state_patch))
282
283        .route("/token", axum::routing::get(token_get))
284        .route("/token", axum::routing::post(token_post))
285        .route("/token", axum::routing::delete(token_delete))
286
287        .route("/fastsync", axum::routing::get(fastsync_get))
288
289        .route("/inspect/balance", axum::routing::get(inspect_balance_get))
290        .route("/inspect/assets", axum::routing::get(inspect_assets_get))
291        .route("/inspect/audit", axum::routing::get(inspect_audit_get))
292
293        .with_state(std::sync::Arc::new(state))
294
295        .layer(TraceLayer::new_for_http())
296
297        .route_layer(cors);
298
299    axum::serve(listener, app).with_graceful_shutdown(async move { cancel.cancelled().await }).await.expect("Failed to serve");
300}
301
302pub async fn run_server_with_args(args: Args, cancel: CancellationToken) {
303    run_server(
304        cancel,
305        tokio::io::BufStream::with_capacity(16<<20, 16<<20,
306            tokio::fs::File::options()
307            .read(true)
308            .write(true)
309            .truncate(false)
310            .create(true)
311            .open(args.trades).await.expect("Unable to open trade list")),
312        tokens::TokenHandler::new(&args.db).await.expect("Could not connect to DB"),
313        tokio::net::TcpListener::bind(args.endpoint).await.expect("Could not bind to endpoint")
314    ).await
315}