Skip to main content

serde_tc/
http.rs

1use super::*;
2use axum::{
3    extract::Path,
4    http::{HeaderValue, StatusCode},
5    routing::{get, post},
6    Extension, Json, Router,
7};
8use reqwest::{Client, Method};
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11use std::{collections::HashMap, sync::Arc};
12use thiserror::Error;
13use tower_http::cors::CorsLayer;
14
15#[derive(Error, Debug)]
16enum HttpError {
17    #[error("invalid request: {0}")]
18    InvalidRequest(String),
19    #[error("method not found: {0}")]
20    MethodNotFound(String),
21}
22
23pub trait HttpInterface:
24    DispatchStringDictAsync<Error = serde_json::Error, Poly = serde_json::Value>
25    + DispatchStringTupleAsync<Error = serde_json::Error>
26    + Send
27    + Sync
28    + 'static
29{
30}
31
32impl<T> HttpInterface for Arc<T> where T: HttpInterface + ?Sized {}
33pub fn create_http_object<T: ?Sized + HttpInterface>(x: Arc<T>) -> Arc<dyn HttpInterface> {
34    Arc::new(x) as Arc<dyn HttpInterface>
35}
36
37#[derive(Clone)]
38struct State {
39    pub registered_objects: HashMap<String, Arc<dyn HttpInterface>>,
40}
41
42// basic handler that responds with a static string
43async fn root() -> &'static str {
44    "This is a serde-tc JSON RPC server. Please access to /<object-name> with POST, to use the API."
45}
46
47async fn dispatch(
48    Path(path): Path<String>,
49    Json(args): Json<RawArg>,
50    Extension(state): Extension<Arc<State>>,
51) -> (StatusCode, Json<Value>) {
52    if let Some(object) = state.registered_objects.get(&path) {
53        match dispatch_raw(object.as_ref(), &args.method, args.params.clone()).await {
54            Ok(value) => (StatusCode::OK, Json(value)),
55            Err(err) => (
56                StatusCode::INTERNAL_SERVER_ERROR,
57                Json(json!({
58                    "error": "invalid http request",
59                    "error_message": err.to_string(),
60                    "request": args,
61                })),
62            ),
63        }
64    } else {
65        (
66            StatusCode::NOT_FOUND,
67            Json(json!({
68                "error": "object not found",
69                "obejct": &path.as_str()[1..],
70            })),
71        )
72    }
73}
74
75pub async fn run_server(port: u16, objects: HashMap<String, Arc<dyn HttpInterface>>) {
76    let app = Router::new().route("/", get(root));
77    let app = app.route("/:key", post(dispatch));
78    let app = app
79        .layer(Extension(Arc::new(State {
80            registered_objects: objects,
81        })))
82        .layer(
83            CorsLayer::new()
84                .allow_origin("*".parse::<HeaderValue>().unwrap())
85                .allow_headers([axum::http::header::CONTENT_TYPE])
86                .allow_methods([Method::POST]),
87        );
88    let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
89    axum::Server::bind(&addr)
90        .serve(app.into_make_service())
91        .await
92        .unwrap();
93}
94
95#[derive(Serialize, Deserialize, Debug)]
96#[serde(deny_unknown_fields)]
97struct RawArg {
98    method: String,
99    params: serde_json::Value,
100}
101
102async fn dispatch_raw<T>(
103    api: &T,
104    method: &str,
105    arguments: serde_json::Value,
106) -> std::result::Result<serde_json::Value, HttpError>
107where
108    T: HttpInterface + ?Sized,
109{
110    let result = if arguments.is_array() {
111        DispatchStringTupleAsync::dispatch(api, method, &arguments.to_string()).await
112    } else if arguments.is_object() {
113        DispatchStringDictAsync::dispatch(api, method, &arguments.to_string()).await
114    } else {
115        return Err(HttpError::InvalidRequest(format!(
116            "invalid argument type: {}",
117            arguments
118        )));
119    };
120
121    match result {
122        Ok(x) => Ok(serde_json::from_str(&x).unwrap()),
123        Err(Error::MethodNotFound(x)) => Err(HttpError::MethodNotFound(x)),
124        Err(x) => Err(HttpError::InvalidRequest(x.to_string())),
125    }
126}
127
128/// A RPC client. Use `123.1.2.3:123/object_name` for `addr`.
129pub struct HttpClient {
130    client: Client,
131    addr: String,
132}
133
134impl HttpClient {
135    pub fn new(addr: String, client: Client) -> Self {
136        HttpClient { client, addr }
137    }
138}
139
140#[async_trait]
141impl StubCall for HttpClient {
142    type Error = anyhow::Error;
143
144    async fn call(&self, method: &'static str, params: String) -> Result<String, Self::Error> {
145        let body = format!(
146            r#"{{"method": "{}",
147        "params": {}}}"#,
148            method, params
149        );
150        let response = self
151            .client
152            .request(Method::POST, &format!("http://{}", self.addr))
153            .header("content-type", "application/json")
154            .body(body)
155            .send()
156            .await?;
157
158        if response.status().as_u16() != 200 {
159            Err(anyhow::Error::msg(format!(
160                r#"HTTP request failed: "{}""#,
161                response.text().await?
162            )))
163        } else {
164            Ok(response.text().await?)
165        }
166    }
167}