1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
//!
//! A module for manipulating additional HTTP headers especially when
//! communicating with proxy servers.
////////////////////////////////////////////////////////////////////////////////
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
const ADDITIONAL_HEADERS_VAR: &'static str = "RUST_AI_ADDITIONAL_HEADERS";
/// A type alias for (String, String), which denotes HTTP header/value pair.
type RawHeader<'a> = (&'a str, &'a str);
type Header = (String, String);
/// This is the only interface to set/get runtime level headers when sending
/// requests to endpoints such as OpenAI.
///
/// # Usage
/// 1. Call `AdditionalHeaders::default()` to initialize an instance.
/// 2. Call `set_header()` to set required headers.
/// 3. Depends on the caller:
/// - If you are using Rust-AI as an library, call `to_var()` to set related
/// environment variable.
/// - If you are calling from Rust-AI itself, you should call `provide()` to
/// turn current instance into [`HeaderMap`].
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct AdditionalHeaders {
headers: Vec<Header>,
}
impl Default for AdditionalHeaders {
/// Create a new instance of [`AdditionalHeaders`] that doesn't contain any
/// headers.
fn default() -> Self {
Self { headers: vec![] }
}
}
impl AdditionalHeaders {
/// Try to initialize [`AdditionalHeaders`] instance from environment
/// variable. If no such environment variables available, then create an
/// empty instance.
pub fn from_var() -> Self {
if let Ok(c) = std::env::var(ADDITIONAL_HEADERS_VAR) {
if let Ok(deserialized) = serde_json::from_str::<Self>(&c) {
return deserialized;
}
}
Default::default()
}
}
impl AdditionalHeaders {
/// Set a new header pair. Header name and values should never include NUL
/// characters (`\0`).
pub fn set_header<'a>(&mut self, header: RawHeader<'a>) {
let header_name = header.0.to_string();
let header_value = header.1.to_string();
if header_name.contains("\0") || header_value.contains("\0") {
panic!("`\0` cannot present in any field of the header");
}
self.headers.push((header_name, header_value));
}
/// Turn the contained headers into [`HeaderMap`]. Will not consume current
/// instance.
pub fn provide(&self) -> HeaderMap {
let mut hm = HeaderMap::new();
self.headers.iter().for_each(|h| {
hm.append(
HeaderName::from_lowercase(h.0.to_lowercase().as_bytes()).unwrap(),
HeaderValue::from_str(&h.1).unwrap(),
);
});
hm
}
/// Try to serialize current instance into string form and set to specific
/// environment variable.
pub fn to_var(&self) {
let serialized = serde_json::to_string(self).unwrap();
std::env::set_var(ADDITIONAL_HEADERS_VAR, serialized);
}
}