1use super::*;
2use axum::{
3 extract::Path,
4 http::{HeaderValue, StatusCode},
5 routing::{get, post},
6 Extension, Json, Router,
7};
8use reqwest::{Client, Method};
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11use std::{collections::HashMap, sync::Arc};
12use thiserror::Error;
13use tower_http::cors::CorsLayer;
14
15#[derive(Error, Debug)]
16enum HttpError {
17 #[error("invalid request: {0}")]
18 InvalidRequest(String),
19 #[error("method not found: {0}")]
20 MethodNotFound(String),
21}
22
23pub trait HttpInterface:
24 DispatchStringDictAsync<Error = serde_json::Error, Poly = serde_json::Value>
25 + DispatchStringTupleAsync<Error = serde_json::Error>
26 + Send
27 + Sync
28 + 'static
29{
30}
31
32impl<T> HttpInterface for Arc<T> where T: HttpInterface + ?Sized {}
33pub fn create_http_object<T: ?Sized + HttpInterface>(x: Arc<T>) -> Arc<dyn HttpInterface> {
34 Arc::new(x) as Arc<dyn HttpInterface>
35}
36
37#[derive(Clone)]
38struct State {
39 pub registered_objects: HashMap<String, Arc<dyn HttpInterface>>,
40}
41
42async fn root() -> &'static str {
44 "This is a serde-tc JSON RPC server. Please access to /<object-name> with POST, to use the API."
45}
46
47async fn dispatch(
48 Path(path): Path<String>,
49 Json(args): Json<RawArg>,
50 Extension(state): Extension<Arc<State>>,
51) -> (StatusCode, Json<Value>) {
52 if let Some(object) = state.registered_objects.get(&path) {
53 match dispatch_raw(object.as_ref(), &args.method, args.params.clone()).await {
54 Ok(value) => (StatusCode::OK, Json(value)),
55 Err(err) => (
56 StatusCode::INTERNAL_SERVER_ERROR,
57 Json(json!({
58 "error": "invalid http request",
59 "error_message": err.to_string(),
60 "request": args,
61 })),
62 ),
63 }
64 } else {
65 (
66 StatusCode::NOT_FOUND,
67 Json(json!({
68 "error": "object not found",
69 "obejct": &path.as_str()[1..],
70 })),
71 )
72 }
73}
74
75pub async fn run_server(port: u16, objects: HashMap<String, Arc<dyn HttpInterface>>) {
76 let app = Router::new().route("/", get(root));
77 let app = app.route("/:key", post(dispatch));
78 let app = app
79 .layer(Extension(Arc::new(State {
80 registered_objects: objects,
81 })))
82 .layer(
83 CorsLayer::new()
84 .allow_origin("*".parse::<HeaderValue>().unwrap())
85 .allow_headers([axum::http::header::CONTENT_TYPE])
86 .allow_methods([Method::POST]),
87 );
88 let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
89 axum::Server::bind(&addr)
90 .serve(app.into_make_service())
91 .await
92 .unwrap();
93}
94
95#[derive(Serialize, Deserialize, Debug)]
96#[serde(deny_unknown_fields)]
97struct RawArg {
98 method: String,
99 params: serde_json::Value,
100}
101
102async fn dispatch_raw<T>(
103 api: &T,
104 method: &str,
105 arguments: serde_json::Value,
106) -> std::result::Result<serde_json::Value, HttpError>
107where
108 T: HttpInterface + ?Sized,
109{
110 let result = if arguments.is_array() {
111 DispatchStringTupleAsync::dispatch(api, method, &arguments.to_string()).await
112 } else if arguments.is_object() {
113 DispatchStringDictAsync::dispatch(api, method, &arguments.to_string()).await
114 } else {
115 return Err(HttpError::InvalidRequest(format!(
116 "invalid argument type: {}",
117 arguments
118 )));
119 };
120
121 match result {
122 Ok(x) => Ok(serde_json::from_str(&x).unwrap()),
123 Err(Error::MethodNotFound(x)) => Err(HttpError::MethodNotFound(x)),
124 Err(x) => Err(HttpError::InvalidRequest(x.to_string())),
125 }
126}
127
128pub struct HttpClient {
130 client: Client,
131 addr: String,
132}
133
134impl HttpClient {
135 pub fn new(addr: String, client: Client) -> Self {
136 HttpClient { client, addr }
137 }
138}
139
140#[async_trait]
141impl StubCall for HttpClient {
142 type Error = anyhow::Error;
143
144 async fn call(&self, method: &'static str, params: String) -> Result<String, Self::Error> {
145 let body = format!(
146 r#"{{"method": "{}",
147 "params": {}}}"#,
148 method, params
149 );
150 let response = self
151 .client
152 .request(Method::POST, &format!("http://{}", self.addr))
153 .header("content-type", "application/json")
154 .body(body)
155 .send()
156 .await?;
157
158 if response.status().as_u16() != 200 {
159 Err(anyhow::Error::msg(format!(
160 r#"HTTP request failed: "{}""#,
161 response.text().await?
162 )))
163 } else {
164 Ok(response.text().await?)
165 }
166 }
167}