ureq_proto/client/
sendreq.rs

1use std::io::Write;
2
3use base64::prelude::BASE64_STANDARD;
4use base64::Engine;
5use http::uri::Scheme;
6use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Uri, Version};
7
8use crate::client::amended::AmendedRequest;
9use crate::ext::{AuthorityExt, MethodExt, SchemeExt};
10use crate::util::Writer;
11use crate::Error;
12
13use super::state::SendRequest;
14use super::{BodyState, Call, RequestPhase, SendRequestResult};
15
16impl Call<SendRequest> {
17    /// Write the request to the buffer.
18    ///
19    /// Writes incrementally, it can be called repeatedly in situations where the output
20    /// buffer is small.
21    ///
22    /// This includes the first row, i.e. `GET / HTTP/1.1` and all headers.
23    /// The output buffer needs to be large enough for the longest row.
24    ///
25    /// Example:
26    ///
27    /// ```text
28    /// POST /bar HTTP/1.1\r\n
29    /// Host: my.server.test\r\n
30    /// User-Agent: myspecialthing\r\n
31    /// \r\n
32    /// <body data>
33    /// ```
34    ///
35    /// The buffer would need to be at least 28 bytes big, since the `User-Agent` row is
36    /// 28 bytes long.
37    ///
38    /// If the output is too small for the longest line, the result is an `OutputOverflow` error.
39    ///
40    /// The `Ok(usize)` is the number of bytes of the `output` buffer that was used.
41    pub fn write(&mut self, output: &mut [u8]) -> Result<usize, Error> {
42        self.maybe_analyze_request()?;
43
44        let mut w = Writer::new(output);
45        try_write_prelude(&self.inner.request, &mut self.inner.state, &mut w)?;
46
47        let output_used = w.len();
48
49        Ok(output_used)
50    }
51
52    /// The configured method.
53    pub fn method(&self) -> &Method {
54        self.inner.request.method()
55    }
56
57    /// The uri being requested.
58    pub fn uri(&self) -> &Uri {
59        self.inner.request.uri()
60    }
61
62    /// Version of the request.
63    ///
64    /// This can only be 1.0 or 1.1.
65    pub fn version(&self) -> Version {
66        self.inner.request.version()
67    }
68
69    /// The configured headers.
70    pub fn headers_map(&mut self) -> Result<HeaderMap, Error> {
71        self.maybe_analyze_request()?;
72        let mut map = HeaderMap::new();
73        for (k, v) in self.inner.request.headers() {
74            map.insert(k, v.clone());
75        }
76        Ok(map)
77    }
78
79    /// Check whether the entire request has been sent.
80    ///
81    /// This is useful when the output buffer is small and we need to repeatedly
82    /// call `write()` to send the entire request.
83    pub fn can_proceed(&self) -> bool {
84        !self.inner.state.phase.is_prelude()
85    }
86
87    /// Attempt to proceed from this state to the next.
88    ///
89    /// Returns `None` if the entire request has not been sent. It is guaranteed that if
90    /// `can_proceed()` returns `true`, this will return `Some`.
91    pub fn proceed(mut self) -> Result<Option<SendRequestResult>, Error> {
92        if !self.can_proceed() {
93            return Ok(None);
94        }
95
96        if self.inner.state.writer.has_body() {
97            if self.inner.await_100_continue {
98                Ok(Some(SendRequestResult::Await100(Call::wrap(self.inner))))
99            } else {
100                // TODO(martin): is this needed?
101                self.maybe_analyze_request()?;
102                let call = Call::wrap(self.inner);
103                Ok(Some(SendRequestResult::SendBody(call)))
104            }
105        } else {
106            let call = Call::wrap(self.inner);
107            Ok(Some(SendRequestResult::RecvResponse(call)))
108        }
109    }
110
111    pub(crate) fn maybe_analyze_request(&mut self) -> Result<(), Error> {
112        if self.inner.analyzed {
113            return Ok(());
114        }
115
116        let info = self.inner.request.analyze(
117            self.inner.state.writer,
118            self.inner.state.allow_non_standard_methods,
119        )?;
120
121        let method = self.inner.request.method();
122        let send_body = (method.allow_request_body() || self.inner.force_send_body)
123            && info.body_mode.has_body();
124
125        if !send_body && info.body_mode.has_body() {
126            return Err(Error::BodyNotAllowed);
127        }
128
129        if !info.req_host_header && method != Method::CONNECT {
130            if let Some(host) = self.inner.request.uri().host() {
131                // User did not set a host header, and there is one in uri, we set that.
132                // We need an owned value to set the host header.
133
134                // This might append the port if it differs from the scheme default.
135                let value = maybe_with_port(host, self.inner.request.uri())?;
136
137                self.inner.request.set_header(header::HOST, value)?;
138            }
139        }
140
141        if let Some(auth) = self.inner.request.uri().authority() {
142            if self.inner.request.method() != Method::CONNECT {
143                if auth.userinfo().is_some() && !info.req_auth_header {
144                    let user = auth.username().unwrap_or_default();
145                    let pass = auth.password().unwrap_or_default();
146                    let creds = BASE64_STANDARD.encode(format!("{}:{}", user, pass));
147                    let auth = format!("Basic {}", creds);
148                    self.inner.request.set_header(header::AUTHORIZATION, auth)?;
149                }
150            } else if !info.req_host_header {
151                self.inner
152                    .request
153                    .set_header(header::HOST, auth.clone().as_str())?;
154            }
155        }
156
157        if !info.req_body_header && info.body_mode.has_body() {
158            // User did not set a body header, we set one.
159            let header = info.body_mode.body_header();
160            self.inner.request.set_header(header.0, header.1)?;
161        }
162
163        self.inner.state.writer = info.body_mode;
164
165        self.inner.analyzed = true;
166        Ok(())
167    }
168}
169
170fn maybe_with_port(host: &str, uri: &Uri) -> Result<HeaderValue, Error> {
171    fn from_str(src: &str) -> Result<HeaderValue, Error> {
172        HeaderValue::from_str(src).map_err(|e| Error::BadHeader(e.to_string()))
173    }
174
175    if let Some(port) = uri.port() {
176        let scheme = uri.scheme().unwrap_or(&Scheme::HTTP);
177        if let Some(scheme_default) = scheme.default_port() {
178            if port != scheme_default {
179                // This allocates, so we only do it if we absolutely have to.
180                let host_port = format!("{}:{}", host, port);
181                return from_str(&host_port);
182            }
183        }
184    }
185
186    // Fall back on no port (without allocating).
187    from_str(host)
188}
189
190fn try_write_prelude(
191    request: &AmendedRequest,
192    state: &mut BodyState,
193    w: &mut Writer,
194) -> Result<(), Error> {
195    let at_start = w.len();
196
197    loop {
198        if try_write_prelude_part(request, state, w) {
199            continue;
200        }
201
202        let written = w.len() - at_start;
203
204        if written > 0 || state.phase.is_body() {
205            return Ok(());
206        } else {
207            return Err(Error::OutputOverflow);
208        }
209    }
210}
211
212fn try_write_prelude_part(request: &AmendedRequest, state: &mut BodyState, w: &mut Writer) -> bool {
213    match &mut state.phase {
214        RequestPhase::Line => {
215            let success = do_write_send_line(request.prelude(), w);
216            if success {
217                state.phase = RequestPhase::Headers(0);
218            }
219            success
220        }
221
222        RequestPhase::Headers(index) => {
223            let header_count = request.headers_len();
224            let all = request.headers();
225            let skipped = all.skip(*index);
226
227            if header_count > 0 {
228                do_write_headers(skipped, index, header_count - 1, w);
229            }
230
231            if *index == header_count {
232                state.phase = RequestPhase::Body;
233            }
234            false
235        }
236
237        // We're past the header.
238        _ => false,
239    }
240}
241
242fn do_write_send_line(line: (&Method, &str, Version), w: &mut Writer) -> bool {
243    // Ensure origin-form request-target starts with "/" when only a query is present
244    // per RFC 9112 ยง3.2.1 (@https://datatracker.ietf.org/doc/html/rfc9112#section-3.2.1).
245    let need_initial_slash = line.1.starts_with('?');
246    let slash = if need_initial_slash { "/" } else { "" };
247
248    w.try_write(|w| write!(w, "{} {}{} {:?}\r\n", line.0, slash, line.1, line.2))
249}
250
251fn do_write_headers<'a, I>(headers: I, index: &mut usize, last_index: usize, w: &mut Writer)
252where
253    I: Iterator<Item = (&'a HeaderName, &'a HeaderValue)>,
254{
255    for h in headers {
256        let success = w.try_write(|w| {
257            write!(w, "{}: ", h.0)?;
258            w.write_all(h.1.as_bytes())?;
259            write!(w, "\r\n")?;
260            if *index == last_index {
261                write!(w, "\r\n")?;
262            }
263            Ok(())
264        });
265
266        if success {
267            *index += 1;
268        } else {
269            break;
270        }
271    }
272}