Skip to main content

scion_sdk_reqwest_connect_rpc/
client.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Connect-RPC client library using reqwest.
15
16use std::{borrow::Cow, sync::Arc, time::Duration};
17
18use anyhow::Context as _;
19use bytes::Bytes;
20use reqwest::header::{self, HeaderMap, HeaderValue};
21use thiserror::Error;
22use tracing::Instrument;
23
24use crate::{
25    error::CrpcError,
26    token_source::{TokenSource, TokenSourceError},
27};
28
29/// Connect RPC client error.
30#[derive(Debug, Error)]
31pub enum CrpcClientError {
32    /// Error that occurs when there is a connection issue.
33    #[error("connection error {context}: {source:#?}")]
34    ConnectionError {
35        /// Additional context about the connection error.
36        context: Cow<'static, str>,
37        /// The underlying source error.
38        source: Box<dyn std::error::Error + Send + Sync + 'static>,
39    },
40    /// Error returned by the server.
41    #[error("server returned an error: {0:#?}")]
42    CrpcError(CrpcError),
43    /// Error decoding the response body.
44    #[error("failed to decode response body: {context}: {source:#?}")]
45    DecodeError {
46        /// Additional context about the decoding error.
47        context: Cow<'static, str>,
48        /// The underlying source error.
49        source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
50        /// The response body, if available.
51        body: Option<Bytes>,
52    },
53    /// Error retrieving a token from the token source.
54    #[error("failed to retrieve token: {0}")]
55    TokenSourceError(#[from] TokenSourceError),
56}
57
58const APPLICATION_PROTO: &str = "application/proto";
59
60/// A Connect-RPC client.
61pub struct CrpcClient {
62    http_client: reqwest::Client,
63    base_url: url::Url,
64    token_source: Option<Arc<dyn TokenSource>>,
65    user_agent: HeaderValue,
66}
67
68impl CrpcClient {
69    /// Creates a new [`CrpcClient`] for the given base URL.
70    pub fn new(base_url: &url::Url) -> anyhow::Result<Self> {
71        let http_client = reqwest::ClientBuilder::new()
72            .timeout(Duration::from_secs(5))
73            .build()
74            .context("error creating HTTP client")?;
75
76        Self::new_with_client(base_url, http_client)
77    }
78
79    /// Creates a new [`CrpcClient`] for the given base URL and explicit [`reqwest::Client`].
80    pub fn new_with_client(
81        base_url: &url::Url,
82        http_client: reqwest::Client,
83    ) -> anyhow::Result<Self> {
84        let user_agent =
85            HeaderValue::from_str(&format!("reqwest-crpc {}", env!("CARGO_PKG_VERSION")))
86                .context("error creating user agent header")?;
87
88        Ok(CrpcClient {
89            http_client,
90            base_url: base_url.clone(),
91            token_source: None,
92            user_agent,
93        })
94    }
95
96    /// Uses given token source for authentication of all following requests.
97    pub fn use_token_source(&mut self, token_source: Arc<dyn TokenSource>) -> &mut Self {
98        self.token_source = Some(token_source);
99        self
100    }
101
102    /// Sets the user agent header for all following requests.
103    pub fn use_user_agent(&mut self, user_agent: &str) -> anyhow::Result<&mut Self> {
104        self.user_agent = HeaderValue::from_str(user_agent)
105            .with_context(|| format!("error creating user agent header from {user_agent}"))?;
106        Ok(self)
107    }
108
109    /// Unary RPC request.
110    pub async fn unary_request<Req, Res>(
111        &self,
112        path: &str,
113        req: Req,
114    ) -> Result<Res, CrpcClientError>
115    where
116        Req: prost::Message + Default,
117        Res: prost::Message + Default,
118    {
119        self.do_unary_request(path, req)
120            .instrument(tracing::info_span!("request", %path, id = rand::random::<u16>()))
121            .await
122    }
123
124    /// Sends a unary request to the endhost API.
125    async fn do_unary_request<Req, Res>(&self, path: &str, req: Req) -> Result<Res, CrpcClientError>
126    where
127        Req: prost::Message + Default,
128        Res: prost::Message + Default,
129    {
130        let url = self.base_url.join(path).map_err(|e| {
131            CrpcClientError::ConnectionError {
132                context: "error joining base URL and path".into(),
133                source: e.into(),
134            }
135        })?;
136
137        let mut headers = HeaderMap::with_capacity(3);
138        headers.insert(
139            header::CONTENT_TYPE,
140            header::HeaderValue::from_static(APPLICATION_PROTO),
141        );
142        headers.insert(header::USER_AGENT, self.user_agent.clone());
143        if let Some(token_source) = &self.token_source {
144            let token = token_source.get_token().await?;
145            let token_header = header::HeaderValue::from_str(&token_source.format_header(token))
146                .map_err(|e| {
147                    CrpcClientError::TokenSourceError(
148                        format!("error formatting token as header value: {e:?}").into(),
149                    )
150                })?;
151
152            headers.insert(header::AUTHORIZATION, token_header);
153        }
154
155        tracing::debug!("Sending request");
156
157        let body = req.encode_to_vec();
158        let response = self
159            .http_client
160            .post(url)
161            .body(reqwest::Body::from(body))
162            .headers(headers)
163            .send()
164            .await
165            .map_err(|e| {
166                CrpcClientError::ConnectionError {
167                    context: "error sending request".into(),
168                    source: e.into(),
169                }
170            })?;
171
172        tracing::debug!(status=%response.status(), body_len=%response.content_length().unwrap_or(0), "Got response");
173
174        let status = response.status();
175        if !status.is_success() {
176            let response_raw = response
177                .text()
178                .await
179                .unwrap_or_else(|_| "<failed to read body>".to_string());
180
181            // Try to parse the body as a CrpcError, otherwise create a generic one.
182            match serde_json::from_str::<CrpcError>(&response_raw) {
183                Ok(crpc_err) => {
184                    return Err(CrpcClientError::CrpcError(crpc_err));
185                }
186                Err(_) => {
187                    return Err(CrpcClientError::CrpcError(CrpcError::new(
188                        status.into(),
189                        response_raw,
190                    )));
191                }
192            }
193        }
194
195        let body = response.bytes().await.map_err(|e| {
196            CrpcClientError::DecodeError {
197                context: "error reading response body".into(),
198                source: Some(e.into()),
199                body: None,
200            }
201        })?;
202
203        Res::decode(&body[..]).map_err(|e| {
204            CrpcClientError::DecodeError {
205                context: "error decoding response body".into(),
206                source: Some(e.into()),
207                body: Some(body.clone()),
208            }
209        })
210    }
211}