1use std::{fmt::Debug, time::Duration};
2
3use crate::{shared::*, state_type};
4pub mod tokens;
5pub mod state;
6
7use axum::{extract::{ws::rejection::WebSocketUpgradeRejection, FromRequestParts}, response::IntoResponse, serve::Listener, Router};
8use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite};
9use tokio_util::sync::CancellationToken;
10use tower_http::trace::TraceLayer;
11use tpex::StateSync;
12#[derive(clap::Parser)]
13pub struct Args {
14 pub trades: std::path::PathBuf,
15 pub db: String,
16 pub endpoint: String,
17}
18
19#[derive(Debug)]
20enum Error {
21 TPEx(tpex::Error),
22 UncontrolledUser,
23 TokenTooLowLevel,
24 TokenInvalid,
25 NotNextId{next_id: u64}
26}
27impl From<tpex::Error> for Error {
28 fn from(value: tpex::Error) -> Self {
29 Self::TPEx(value)
30 }
31}
32impl axum::response::IntoResponse for Error {
33 fn into_response(self) -> axum::response::Response {
34 let (code,err) = match self {
35 Self::TPEx(err) => (409, ErrorInfo{error:err.to_string()}),
36 Self::UncontrolledUser => (403, ErrorInfo{error:"This action would act on behalf of a different user.".to_owned()}),
37 Self::TokenTooLowLevel => (403, ErrorInfo{error:"This action requires a higher permission level".to_owned()}),
38 Self::NotNextId{next_id} => (409, ErrorInfo{error:format!("The requested ID was not the next, which is {next_id}")}),
39 Self::TokenInvalid => (409, ErrorInfo{error:"The given token does not exist".to_owned()})
40 };
41
42 let body = serde_json::to_vec(&err).expect("Unable to serialise error");
43
44 let body = axum::body::Body::from(body);
45
46 axum::response::Response::builder()
47 .status(code)
48 .header("Content-Type", "application/json")
49 .body(body)
50 .expect("Unable to create error response")
51 }
52}
53
54async fn state_patch(
55 axum::extract::State(state): axum::extract::State<state_type!()>,
56 token: TokenInfo,
57 axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StatePatchArgs>,
58 axum::extract::Json(action): axum::extract::Json<tpex::Action>
59) -> Result<axum::response::Json<u64>, Error> {
60 match token.level {
61 TokenLevel::ReadOnly => return Err(Error::TokenTooLowLevel),
62 TokenLevel::ProxyOne => {
63 let perms = state.tpex.read().await.state().perms(&action)?;
64 if perms.player != token.user {
65 return Err(Error::UncontrolledUser);
66 }
67 }
68 TokenLevel::ProxyAll => ()
70 }
71 let mut tpex_state = state.tpex.write().await;
72 let id =
73 if let Some(expected_id) = args.and_then(|i| i.id) {
74 let next_id = tpex_state.state().get_next_id();
75 if next_id != expected_id {
76 return Err(Error::NotNextId{next_id});
77 }
78 let id = tpex_state.apply(action).await?;
79 assert_eq!(id, next_id, "Somehow got ID mismatch");
80 id
81 }
82 else {
83 tpex_state.apply(action).await?
84 };
85 state.updated.send_replace(id);
89 Ok(axum::Json(id))
90}
91
92struct OptionalWebSocket(pub Option<axum::extract::ws::WebSocketUpgrade>);
93impl<S : Send + Sync> FromRequestParts<S> for OptionalWebSocket {
94 #[doc = " If the extractor fails it\'ll use this \"rejection\" type. A rejection is"]
95 #[doc = " a kind of error that can be converted into a response."]
96 type Rejection = WebSocketUpgradeRejection;
97
98 async fn from_request_parts(parts: &mut axum::http::request::Parts,state: &S,) -> Result<Self,Self::Rejection> {
99 match axum::extract::ws::WebSocketUpgrade::from_request_parts(parts, state).await {
100 Ok(x) => Ok(Self(Some(x))),
101 Err(WebSocketUpgradeRejection::MethodNotGet(_)) |
102 Err(WebSocketUpgradeRejection::MethodNotConnect(_)) |
103 Err(WebSocketUpgradeRejection::InvalidConnectionHeader(_)) |
104 Err(WebSocketUpgradeRejection::InvalidUpgradeHeader(_)) => Ok(Self(None)),
105 Err(e) => Err(e)
106 }
107 }
108}
109
110async fn state_get(
111 axum::extract::State(state): axum::extract::State<state_type!()>,
112 _token: TokenInfo,
114 axum_extra::extract::OptionalQuery(args): axum_extra::extract::OptionalQuery<StateGetArgs>,
115 OptionalWebSocket(upgrade): OptionalWebSocket
116) -> axum::response::Response {
117 let mut from = args.unwrap_or_default().from.unwrap_or(1);
118 if let Some(upgrade) = upgrade {
119 upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
120 let mut subscription = state.updated.subscribe();
121 loop {
122 let should_ping = tokio::select! {
123 new_actions = subscription.wait_for(|i| *i >= from) => {
124 new_actions.expect("Failed to poll updated_recv");
125 false
126 },
127 _timeout = tokio::time::sleep(Duration::from_secs(10)) => true
128 };
129 if should_ping {
130 if sock.send(axum::extract::ws::Message::Ping(Default::default())).await.is_err() {
131 break;
132 }
133 else {
134 continue;
135 }
136 }
137 let tpex_state_handle = state.tpex.read().await;
138 let res =
140 tpex_state_handle.cache().iter()
141 .skip((from as usize).saturating_sub(1))
142 .map(Into::into)
143 .map(axum::extract::ws::Message::Text)
144 .collect::<Vec<_>>();
145 from = tpex_state_handle.state().get_next_id();
147 drop(tpex_state_handle);
149 for i in res {
151 if sock.send(i).await.is_err() {
152 break;
153 }
154 }
155 }
156 })
157 }
158 else {
159 let data =
160 state.tpex.read().await.cache().iter()
161 .skip(from as usize)
162 .fold(String::new(), |a, b| a + b);
163 let body = axum::body::Body::from(data);
164 axum::response::Response::builder()
165 .header("Content-Type", "text/plain")
166 .body(body)
167 .expect("Unable to create state_get response")
168 }
169}
170
171async fn token_get(
172 axum::extract::State(_state): axum::extract::State<state_type!()>,
173 token: TokenInfo
174) -> axum::Json<TokenInfo> {
175 axum::Json(token)
176}
177
178async fn token_post(
179 axum::extract::State(state): axum::extract::State<state_type!()>,
180 token: TokenInfo,
181 axum::extract::Json(args): axum::extract::Json<TokenPostArgs>,
182) -> Result<axum::Json<Token>, Error> {
183 if args.level > token.level {
184 return Err(Error::TokenTooLowLevel)
185 }
186 if args.user != token.user && token.level < TokenLevel::ProxyAll {
187 return Err(Error::UncontrolledUser)
188 }
189
190 Ok(axum::Json(state.tokens.create_token(args.level, args.user).await.expect("Cannot access DB")))
191}
192
193async fn token_delete(
194 axum::extract::State(state): axum::extract::State<state_type!()>,
195 token: TokenInfo,
196 axum::extract::Json(args): axum::extract::Json<TokenDeleteArgs>
197) -> Result<axum::Json<()>, Error> {
198 let target = args.token.unwrap_or(token.token);
199 if target != token.token && token.level < TokenLevel::ProxyOne {
201 return Err(Error::TokenTooLowLevel);
202 }
203 state.tokens.delete_token(&token.token).await
204 .map_or(Err(Error::TokenInvalid), |_| Ok(axum::Json(())))
205}
206
207async fn fastsync_get(
208 axum::extract::State(state): axum::extract::State<state_type!()>,
209 _token: TokenInfo,
210 OptionalWebSocket(upgrade): OptionalWebSocket
211) -> axum::response::Response {
212 if let Some(upgrade) = upgrade {
213 upgrade.on_upgrade(move |mut sock: axum::extract::ws::WebSocket| async move {
214 let mut subscription = state.updated.subscribe();
215 subscription.mark_changed();
216 loop {
217 let should_ping = tokio::select! {
218 new_actions = subscription.changed() => {
219 new_actions.expect("Failed to poll updated_recv");
220 false
221 },
222 _timeout = tokio::time::sleep(Duration::from_secs(10)) => true
223 };
224 if should_ping {
225 if sock.send(axum::extract::ws::Message::Ping(Default::default())).await.is_err() {
226 break;
227 }
228 else {
229 continue;
230 }
231 }
232 let res = StateSync::from(state.tpex.read().await.state());
233 if sock.send(axum::extract::ws::Message::Text(serde_json::to_string(&res).expect("Could not serialise state sync").into())).await.is_err() {
234 break;
235 }
236 }
237 })
238 }
239 else {
240 let res = StateSync::from(state.tpex.read().await.state());
241 axum::Json(res).into_response()
242 }
243}
244
245async fn inspect_balance_get(
246 axum::extract::State(state): axum::extract::State<state_type!()>,
247 _token: TokenInfo,
248 axum::extract::Query(args): axum::extract::Query<InspectBalanceGetArgs>
249) -> axum::response::Response {
250 axum::Json(state.tpex.read().await.state().get_bal(&args.player)).into_response()
251}
252
253async fn inspect_assets_get(
254 axum::extract::State(state): axum::extract::State<state_type!()>,
255 _token: TokenInfo,
256 axum::extract::Query(args): axum::extract::Query<InspectAssetsGetArgs>
257) -> axum::response::Response {
258 axum::Json(state.tpex.read().await.state().get_assets(&args.player)).into_response()
259}
260
261async fn inspect_audit_get(
262 axum::extract::State(state): axum::extract::State<state_type!()>,
263 _token: TokenInfo
264) -> axum::response::Response {
265 axum::Json(state.tpex.read().await.state().itemised_audit()).into_response()
266}
267
268pub async fn run_server<L: Listener>(
269 cancel: CancellationToken,
270 mut trade_log: impl AsyncWrite + AsyncBufRead + AsyncSeek + Unpin + Send + Sync + 'static,
271 token_handler: tokens::TokenHandler,
272 listener: L) where L::Addr : Debug
273{
274 let mut cache = Vec::new();
276 {
277 let mut lines = trade_log.lines();
278 while let Some(mut line) = lines.next_line().await.expect("Could not read trade file") {
279 line.push('\n');
280 cache.push(line);
281 }
282 trade_log = lines.into_inner();
283 trade_log.rewind().await.expect("Could not rewind trade file");
284 }
285
286 let mut tpex_state = tpex::State::new();
287 tpex_state.replay(&mut trade_log, true).await.expect("Could not replay trades");
288
289 let (updated, _) = tokio::sync::watch::channel(tpex_state.get_next_id().checked_sub(1).expect("Poll counter underflow"));
290 let state = state::StateStruct {
291 tpex: tokio::sync::RwLock::new(state::TPExState::new(tpex_state, trade_log, cache)),
292 tokens: token_handler,
293 updated
294 };
295
296 let cors = tower_http::cors::CorsLayer::new()
297 .allow_headers(tower_http::cors::Any)
298 .allow_origin(tower_http::cors::Any)
299 .allow_methods(tower_http::cors::Any);
300
301
302 let app = Router::new()
303 .route("/state", axum::routing::get(state_get))
304 .route("/state", axum::routing::connect(state_get))
305 .route("/state", axum::routing::patch(state_patch))
306
307 .route("/token", axum::routing::get(token_get))
308 .route("/token", axum::routing::post(token_post))
309 .route("/token", axum::routing::delete(token_delete))
310
311 .route("/fastsync", axum::routing::get(fastsync_get))
312
313 .route("/inspect/balance", axum::routing::get(inspect_balance_get))
314 .route("/inspect/assets", axum::routing::get(inspect_assets_get))
315 .route("/inspect/audit", axum::routing::get(inspect_audit_get))
316
317 .with_state(std::sync::Arc::new(state))
318
319 .layer(TraceLayer::new_for_http())
320
321 .route_layer(cors);
322
323 axum::serve(listener, app).with_graceful_shutdown(async move { cancel.cancelled().await }).await.expect("Failed to serve");
324}
325
326pub async fn run_server_with_args(args: Args, cancel: CancellationToken) {
327 run_server(
328 cancel,
329 tokio::io::BufStream::with_capacity(16<<20, 16<<20,
330 tokio::fs::File::options()
331 .read(true)
332 .write(true)
333 .truncate(false)
334 .create(true)
335 .open(args.trades).await.expect("Unable to open trade list")),
336 tokens::TokenHandler::new(&args.db).await.expect("Could not connect to DB"),
337 tokio::net::TcpListener::bind(args.endpoint).await.expect("Could not bind to endpoint")
338 ).await
339}