scion_sdk_reqwest_connect_rpc/
client.rs1use std::{borrow::Cow, sync::Arc, time::Duration};
17
18use anyhow::Context as _;
19use bytes::Bytes;
20use reqwest::header::{self, HeaderMap, HeaderValue};
21use thiserror::Error;
22use tracing::Instrument;
23
24use crate::{
25 error::CrpcError,
26 token_source::{TokenSource, TokenSourceError},
27};
28
29#[derive(Debug, Error)]
31pub enum CrpcClientError {
32 #[error("connection error {context}: {source:#?}")]
34 ConnectionError {
35 context: Cow<'static, str>,
37 source: Box<dyn std::error::Error + Send + Sync + 'static>,
39 },
40 #[error("server returned an error: {0:#?}")]
42 CrpcError(CrpcError),
43 #[error("failed to decode response body: {context}: {source:#?}")]
45 DecodeError {
46 context: Cow<'static, str>,
48 source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
50 body: Option<Bytes>,
52 },
53 #[error("failed to retrieve token: {0}")]
55 TokenSourceError(#[from] TokenSourceError),
56}
57
58const APPLICATION_PROTO: &str = "application/proto";
59
60pub struct CrpcClient {
62 http_client: reqwest::Client,
63 base_url: url::Url,
64 token_source: Option<Arc<dyn TokenSource>>,
65 user_agent: HeaderValue,
66}
67
68impl CrpcClient {
69 pub fn new(base_url: &url::Url) -> anyhow::Result<Self> {
71 let http_client = reqwest::ClientBuilder::new()
72 .timeout(Duration::from_secs(5))
73 .build()
74 .context("error creating HTTP client")?;
75
76 Self::new_with_client(base_url, http_client)
77 }
78
79 pub fn new_with_client(
81 base_url: &url::Url,
82 http_client: reqwest::Client,
83 ) -> anyhow::Result<Self> {
84 let user_agent =
85 HeaderValue::from_str(&format!("reqwest-crpc {}", env!("CARGO_PKG_VERSION")))
86 .context("error creating user agent header")?;
87
88 Ok(CrpcClient {
89 http_client,
90 base_url: base_url.clone(),
91 token_source: None,
92 user_agent,
93 })
94 }
95
96 pub fn use_token_source(&mut self, token_source: Arc<dyn TokenSource>) -> &mut Self {
98 self.token_source = Some(token_source);
99 self
100 }
101
102 pub fn use_user_agent(&mut self, user_agent: &str) -> anyhow::Result<&mut Self> {
104 self.user_agent = HeaderValue::from_str(user_agent)
105 .with_context(|| format!("error creating user agent header from {user_agent}"))?;
106 Ok(self)
107 }
108
109 pub async fn unary_request<Req, Res>(
111 &self,
112 path: &str,
113 req: Req,
114 ) -> Result<Res, CrpcClientError>
115 where
116 Req: prost::Message + Default,
117 Res: prost::Message + Default,
118 {
119 self.do_unary_request(path, req)
120 .instrument(tracing::info_span!("request", %path, id = rand::random::<u16>()))
121 .await
122 }
123
124 async fn do_unary_request<Req, Res>(&self, path: &str, req: Req) -> Result<Res, CrpcClientError>
126 where
127 Req: prost::Message + Default,
128 Res: prost::Message + Default,
129 {
130 let url = self.base_url.join(path).map_err(|e| {
131 CrpcClientError::ConnectionError {
132 context: "error joining base URL and path".into(),
133 source: e.into(),
134 }
135 })?;
136
137 let mut headers = HeaderMap::with_capacity(3);
138 headers.insert(
139 header::CONTENT_TYPE,
140 header::HeaderValue::from_static(APPLICATION_PROTO),
141 );
142 headers.insert(header::USER_AGENT, self.user_agent.clone());
143 if let Some(token_source) = &self.token_source {
144 let token = token_source.get_token().await?;
145 let token_header = header::HeaderValue::from_str(&token_source.format_header(token))
146 .map_err(|e| {
147 CrpcClientError::TokenSourceError(
148 format!("error formatting token as header value: {e:?}").into(),
149 )
150 })?;
151
152 headers.insert(header::AUTHORIZATION, token_header);
153 }
154
155 tracing::debug!("Sending request");
156
157 let body = req.encode_to_vec();
158 let response = self
159 .http_client
160 .post(url)
161 .body(reqwest::Body::from(body))
162 .headers(headers)
163 .send()
164 .await
165 .map_err(|e| {
166 CrpcClientError::ConnectionError {
167 context: "error sending request".into(),
168 source: e.into(),
169 }
170 })?;
171
172 tracing::debug!(status=%response.status(), body_len=%response.content_length().unwrap_or(0), "Got response");
173
174 let status = response.status();
175 if !status.is_success() {
176 let response_raw = response
177 .text()
178 .await
179 .unwrap_or_else(|_| "<failed to read body>".to_string());
180
181 match serde_json::from_str::<CrpcError>(&response_raw) {
183 Ok(crpc_err) => {
184 return Err(CrpcClientError::CrpcError(crpc_err));
185 }
186 Err(_) => {
187 return Err(CrpcClientError::CrpcError(CrpcError::new(
188 status.into(),
189 response_raw,
190 )));
191 }
192 }
193 }
194
195 let body = response.bytes().await.map_err(|e| {
196 CrpcClientError::DecodeError {
197 context: "error reading response body".into(),
198 source: Some(e.into()),
199 body: None,
200 }
201 })?;
202
203 Res::decode(&body[..]).map_err(|e| {
204 CrpcClientError::DecodeError {
205 context: "error decoding response body".into(),
206 source: Some(e.into()),
207 body: Some(body.clone()),
208 }
209 })
210 }
211}