tpex_api/server/
mod.rs

1use std::fmt::Debug;
2
3use crate::{shared::*, state_type};
4pub mod tokens;
5pub mod state;
6
7use axum::{extract::{ws::rejection::WebSocketUpgradeRejection, FromRequestParts}, response::IntoResponse, serve::Listener, Router};
8use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite};
9use tokio_util::sync::CancellationToken;
10use tower_http::trace::TraceLayer;
11use tpex::{ActionLevel, AssetId, AssetInfo, StateSync};
12#[derive(clap::Parser)]
13pub struct Args {
14    pub trades: std::path::PathBuf,
15    pub db: String,
16    pub endpoint: String,
17    pub assets: Option<std::path::PathBuf>,
18}
19
20#[derive(Debug)]
21enum Error {
22    TPEx(tpex::Error),
23    UncontrolledUser,
24    TokenTooLowLevel,
25    TokenInvalid,
26    NotNextId{next_id: u64}
27}
28impl From<tpex::Error> for Error {
29    fn from(value: tpex::Error) -> Self {
30        Self::TPEx(value)
31    }
32}
33impl axum::response::IntoResponse for Error {
34    fn into_response(self) -> axum::response::Response {
35        let (code,err) = match self {
36            Self::TPEx(err) => (409, ErrorInfo{error:err.to_string()}),
37            Self::UncontrolledUser => (403, ErrorInfo{error:"This action would act on behalf of a different user.".to_owned()}),
38            Self::TokenTooLowLevel => (403, ErrorInfo{error:"This action requires a higher permission level".to_owned()}),
39            Self::NotNextId{next_id} => (409, ErrorInfo{error:format!("The requested ID was not the next, which is {next_id}")}),
40            Self::TokenInvalid => (409, ErrorInfo{error:"The given token does not exist".to_owned()})
41        };
42
43        let body = serde_json::to_vec(&err).expect("Unable to serialise error");
44
45        let body = axum::body::Body::from(body);
46
47        axum::response::Response::builder()
48        .status(code)
49        .header("Content-Type", "application/json")
50        .body(body)
51        .expect("Unable to create error response")
52    }
53}
54
55async fn state_patch(
56    axum::extract::State(state): axum::extract::State<state_type!()>,
57    token: TokenInfo,
58    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StatePatchArgs>,
59    axum::extract::Json(action): axum::extract::Json<tpex::Action>
60) -> Result<axum::response::Json<u64>, Error> {
61    match token.level {
62        TokenLevel::ReadOnly => return Err(Error::TokenTooLowLevel),
63        TokenLevel::ProxyOne => {
64            let perms = state.tpex.read().await.state().perms(&action)?;
65            if perms.player != token.user {
66                return Err(Error::UncontrolledUser);
67            }
68            if perms.level > ActionLevel::Normal {
69                return Err(Error::TokenTooLowLevel);
70            }
71        }
72        // Apply catches all banker perm mismatches, assuming that upstream has verified their action:
73        TokenLevel::ProxyAll => ()
74    }
75    let mut tpex_state = state.tpex.write().await;
76    let id =
77        if let Some(expected_id) = args.and_then(|i| i.id) {
78            let next_id = tpex_state.state().get_next_id();
79            if next_id != expected_id {
80                return Err(Error::NotNextId{next_id});
81            }
82            let id = tpex_state.apply(action).await?;
83            assert_eq!(id, next_id, "Somehow got ID mismatch");
84            id
85        }
86        else {
87            tpex_state.apply(action).await?
88        };
89    // We patched, so update the id
90    //
91    // We use send_replace so that we don't have to worry about if anyone's listening
92    state.updated.send_replace(id);
93    Ok(axum::Json(id))
94}
95
96struct OptionalWebSocket(pub Option<axum::extract::ws::WebSocketUpgrade>);
97impl<S : Send + Sync> FromRequestParts<S> for OptionalWebSocket {
98    #[doc = " If the extractor fails it\'ll use this \"rejection\" type. A rejection is"]
99    #[doc = " a kind of error that can be converted into a response."]
100    type Rejection = WebSocketUpgradeRejection;
101
102    async fn from_request_parts(parts: &mut axum::http::request::Parts,state: &S,) -> Result<Self,Self::Rejection> {
103        match axum::extract::ws::WebSocketUpgrade::from_request_parts(parts, state).await {
104            Ok(x) => Ok(Self(Some(x))),
105            Err(WebSocketUpgradeRejection::MethodNotGet(_)) |
106            Err(WebSocketUpgradeRejection::MethodNotConnect(_)) |
107            Err(WebSocketUpgradeRejection::InvalidConnectionHeader(_)) |
108            Err(WebSocketUpgradeRejection::InvalidUpgradeHeader(_)) => Ok(Self(None)),
109            Err(e) => Err(e)
110        }
111    }
112}
113
114async fn state_get(
115    axum::extract::State(state): axum::extract::State<state_type!()>,
116    // must extract token to auth
117    _token: TokenInfo,
118    axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StateGetArgs>,
119    OptionalWebSocket(upgrade): OptionalWebSocket
120) -> axum::response::Response {
121    let mut from = args.unwrap_or_default().from.unwrap_or(0);
122    if let Some(upgrade) = upgrade {
123        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
124            let mut subscription = state.updated.subscribe();
125            loop {
126                subscription.wait_for(|i| *i >= from).await.expect("Failed to poll updated_recv");
127
128                let tpex_state_handle = state.tpex.read().await;
129                // It's better to clone these out than hold state
130                let res =
131                    tpex_state_handle.cache().iter()
132                    .skip(from as usize)
133                    .map(Into::into)
134                    .map(axum::extract::ws::Message::Text)
135                    .collect::<Vec<_>>();
136                // rechecking the id prevents a race condition
137                from = tpex_state_handle.state().get_next_id() - 1;
138                // We have extracted all we need
139                drop(tpex_state_handle);
140                // Send it off
141                for i in res {
142                    if sock.send(i).await.is_err() {
143                        break;
144                    }
145                }
146            }
147        })
148    }
149    else {
150        let data =
151            state.tpex.read().await.cache().iter()
152            .skip(from as usize)
153            .fold(String::new(), |a, b| a + b);
154        let body = axum::body::Body::from(data);
155        axum::response::Response::builder()
156        .header("Content-Type", "text/plain")
157        .body(body)
158        .expect("Unable to create state_get response")
159    }
160}
161
162async fn token_get(
163    axum::extract::State(_state): axum::extract::State<state_type!()>,
164    token: TokenInfo
165) -> axum::Json<TokenInfo> {
166    axum::Json(token)
167}
168
169async fn token_post(
170    axum::extract::State(state): axum::extract::State<state_type!()>,
171    token: TokenInfo,
172    axum::extract::Json(args): axum::extract::Json<TokenPostArgs>,
173) -> Result<axum::Json<Token>, Error> {
174    if args.level > token.level {
175        return Err(Error::TokenTooLowLevel)
176    }
177    if args.user != token.user && token.level < TokenLevel::ProxyAll {
178        return Err(Error::UncontrolledUser)
179    }
180
181    Ok(axum::Json(state.tokens.create_token(args.level, args.user).await.expect("Cannot access DB")))
182}
183
184async fn token_delete(
185    axum::extract::State(state): axum::extract::State<state_type!()>,
186    token: TokenInfo,
187    axum::extract::Json(args): axum::extract::Json<TokenDeleteArgs>
188) -> Result<axum::Json<()>, Error> {
189    let target = args.token.unwrap_or(token.token);
190    // We only need perms to delete other tokens
191    if target != token.token && token.level < TokenLevel::ProxyAll {
192        return Err(Error::TokenTooLowLevel);
193    }
194    state.tokens.delete_token(&token.token).await
195    .map_or(Err(Error::TokenInvalid), |_| Ok(axum::Json(())))
196}
197
198async fn fastsync_get(
199    axum::extract::State(state): axum::extract::State<state_type!()>,
200    _token: TokenInfo,
201    OptionalWebSocket(upgrade): OptionalWebSocket
202) -> axum::response::Response {
203    if let Some(upgrade) = upgrade {
204        upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
205            let mut subscription = state.updated.subscribe();
206            subscription.mark_changed();
207            loop {
208                subscription.changed().await.expect("Failed to poll updated_recv");
209                let res = StateSync::from(state.tpex.read().await.state());
210                if sock.send(axum::extract::ws::Message::Text(serde_json::to_string(&res).expect("Could not serialise state sync").into())).await.is_err() {
211                    break;
212                }
213            }
214        })
215    }
216    else {
217        let res = StateSync::from(state.tpex.read().await.state());
218        axum::Json(res).into_response()
219    }
220}
221
222async fn inspect_balance_get(
223    axum::extract::State(state): axum::extract::State<state_type!()>,
224    _token: TokenInfo,
225    axum::extract::Query(args): axum::extract::Query<InspectBalanceGetArgs>
226) -> axum::response::Response {
227    axum::Json(state.tpex.read().await.state().get_bal(&args.player)).into_response()
228}
229
230async fn inspect_assets_get(
231    axum::extract::State(state): axum::extract::State<state_type!()>,
232    _token: TokenInfo,
233    axum::extract::Query(args): axum::extract::Query<InspectAssetsGetArgs>
234) -> axum::response::Response {
235    axum::Json(state.tpex.read().await.state().get_assets(&args.player)).into_response()
236}
237
238async fn inspect_audit_get(
239    axum::extract::State(state): axum::extract::State<state_type!()>,
240    _token: TokenInfo
241) -> axum::response::Response {
242    axum::Json(state.tpex.read().await.state().itemised_audit()).into_response()
243}
244
245pub async fn run_server<L: Listener>(
246    cancel: CancellationToken,
247    mut trade_log: impl AsyncWrite + AsyncBufRead + AsyncSeek + Unpin + Send + Sync + 'static,
248    token_handler: tokens::TokenHandler,
249    listener: L,
250    assets: Option<std::collections::HashMap<AssetId, AssetInfo>>)
251    where L::Addr : Debug
252{
253    // Load cache
254    let mut cache = Vec::new();
255    {
256        let mut lines = trade_log.lines();
257        while let Some(mut line) = lines.next_line().await.expect("Could not read trade file") {
258            line.push('\n');
259            cache.push(line);
260        }
261        trade_log = lines.into_inner();
262        trade_log.rewind().await.expect("Could not rewind trade file");
263    }
264
265    let mut tpex_state = tpex::State::new();
266    if let Some(assets) = assets {
267        tpex_state.update_asset_info(assets)
268    }
269    tpex_state.replay(&mut trade_log, true).await.expect("Could not replay trades");
270
271    let (updated, _) = tokio::sync::watch::channel(tpex_state.get_next_id().checked_sub(1).expect("Poll counter underflow"));
272    let state = state::StateStruct {
273        tpex: tokio::sync::RwLock::new(state::TPExState::new(tpex_state, trade_log, cache)),
274        tokens: token_handler,
275        updated
276    };
277
278    let cors = tower_http::cors::CorsLayer::new()
279        .allow_headers(tower_http::cors::Any)
280        .allow_origin(tower_http::cors::Any)
281        .allow_methods(tower_http::cors::Any);
282
283
284    let app = Router::new()
285        .route("/state", axum::routing::get(state_get))
286        .route("/state", axum::routing::connect(state_get))
287        .route("/state", axum::routing::patch(state_patch))
288
289        .route("/token", axum::routing::get(token_get))
290        .route("/token", axum::routing::post(token_post))
291        .route("/token", axum::routing::delete(token_delete))
292
293        .route("/fastsync", axum::routing::get(fastsync_get))
294
295        .route("/inspect/balance", axum::routing::get(inspect_balance_get))
296        .route("/inspect/assets", axum::routing::get(inspect_assets_get))
297        .route("/inspect/audit", axum::routing::get(inspect_audit_get))
298
299        .with_state(std::sync::Arc::new(state))
300
301        .layer(TraceLayer::new_for_http())
302
303        .route_layer(cors);
304
305    axum::serve(listener, app).with_graceful_shutdown(async move { cancel.cancelled().await }).await.expect("Failed to serve");
306}
307
308pub async fn run_server_with_args(args: Args, cancel: CancellationToken) {
309    run_server(
310        cancel,
311        tokio::io::BufStream::with_capacity(16<<20, 16<<20,
312            tokio::fs::File::options()
313            .read(true)
314            .write(true)
315            .truncate(false)
316            .create(true)
317            .open(args.trades).await.expect("Unable to open trade list")),
318        tokens::TokenHandler::new(&args.db).await.expect("Could not connect to DB"),
319        tokio::net::TcpListener::bind(args.endpoint).await.expect("Could not bind to endpoint"),
320        match args.assets {
321            Some(asset_path) => {
322                let mut assets = String::new();
323                tokio::fs::File::open(asset_path).await.expect("Unable to open asset info")
324                .read_to_string(&mut assets).await.expect("Unable to read asset list");
325
326                Some(serde_json::from_str(&assets).expect("Unable to parse asset info"))
327            },
328            None=> None
329        }
330    ).await
331}