wit_ai_rs/
client.rs

1//! Contains a client struct for interacting with the wit.ai API
2
3use crate::errors::{Error, ErrorResponse};
4use reqwest::{header::ACCEPT, Method, StatusCode};
5use serde::{de::DeserializeOwned, Serialize};
6
7const DEFAULT_API_HOST: &str = "https://api.wit.ai";
8
9/// The main struct for interacting with the Wit API
10#[derive(Debug, Clone)]
11pub struct WitClient {
12    pub(crate) api_host: String,
13    version: String,
14    pub(crate) auth_token: String,
15    // reqwest stores the client in an `Arc` internally, so it can be safely cloned
16    pub(crate) reqwest_client: reqwest::Client,
17}
18
19impl WitClient {
20    /// Create a new WitClient with the given `auth_token` and `version` and the default
21    /// API host. `version` is a date string of the form yyyymmdd (ex. 20231231)
22    ///
23    /// Example:
24    /// ```rust
25    /// # use wit_ai_rs::client::WitClient;
26    /// let wit_client = WitClient::new("TOKEN".to_string(), "20240215".to_string());
27    /// ```
28    pub fn new(auth_token: String, version: String) -> Self {
29        let api_host = String::from(DEFAULT_API_HOST);
30
31        let reqwest_client = reqwest::Client::new();
32
33        Self {
34            api_host,
35            version,
36            auth_token,
37            reqwest_client,
38        }
39    }
40
41    /// Changes the API host--only recommended for use while testing
42    ///
43    /// Example:
44    /// ```rust
45    /// # use wit_ai_rs::client::WitClient;
46    /// let wit_client = WitClient::new("TOKEN".to_string(), "20240215".to_string())
47    ///     .set_api_host("https://host.com".to_string());
48    /// ```
49    pub fn set_api_host(self, api_host: String) -> Self {
50        Self {
51            api_host,
52            auth_token: self.auth_token,
53            version: self.version,
54            reqwest_client: self.reqwest_client.clone(),
55        }
56    }
57
58    pub(crate) async fn make_request<T: DeserializeOwned>(
59        &self,
60        method: Method,
61        endpoint: &str,
62        url_params: Vec<(String, String)>,
63        body: Option<impl Serialize>,
64    ) -> Result<T, Error> {
65        let url = format!("{}{endpoint}?v={}", self.api_host, self.version);
66
67        let mut request = match method {
68            Method::GET => self.reqwest_client.get(url),
69            Method::POST => self.reqwest_client.post(url),
70            Method::DELETE => self.reqwest_client.delete(url),
71            Method::PUT => self.reqwest_client.put(url),
72            _ => panic!("invalid method passed to internal `make_request` method"),
73        };
74
75        request = request.query(&url_params);
76
77        request = match body {
78            // .json() internally sets the content type header to application/json
79            Some(body) => request.json(&body),
80            None => request,
81        };
82
83        let response = request
84            .bearer_auth(&self.auth_token)
85            .header(ACCEPT, format!("application/vnd.wit.{}+json", self.version))
86            .send()
87            .await?;
88
89        let data = match response.status() {
90            StatusCode::OK => Ok(response.json::<T>().await?),
91            _ => Err(response.json::<ErrorResponse>().await?),
92        }?;
93
94        Ok(data)
95    }
96
97    /// Getter for `WitClient` version
98    pub fn get_version(&self) -> &str {
99        &self.version
100    }
101}