Skip to main content

rustfs_kafka/network/
connection.rs

1#[cfg(feature = "security")]
2use std::collections::HashMap;
3use std::fmt;
4use std::io::{Read, Write};
5use std::net::{Shutdown, TcpStream, ToSocketAddrs};
6use std::time::Duration;
7use tracing::debug;
8
9#[cfg(feature = "security")]
10use base64::Engine as _;
11#[cfg(feature = "security")]
12use base64::engine::general_purpose::STANDARD as BASE64;
13#[cfg(feature = "security")]
14use bytes::{Bytes, BytesMut};
15#[cfg(feature = "security")]
16use hmac::{Hmac, Mac};
17#[cfg(feature = "security")]
18use kafka_protocol::messages::{
19    ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse,
20    SaslHandshakeRequest, SaslHandshakeResponse,
21};
22#[cfg(feature = "security")]
23use kafka_protocol::protocol::{Decodable, Encodable, HeaderVersion, StrBytes};
24#[cfg(feature = "security")]
25use pbkdf2::pbkdf2_hmac;
26#[cfg(feature = "security")]
27use rand::distr::{Alphanumeric, SampleString};
28#[cfg(feature = "security")]
29use sha2::{Digest, Sha256, Sha512};
30
31use crate::error::Result;
32#[cfg(feature = "security")]
33use crate::error::{Error, KafkaCode, ProtocolError};
34#[cfg(feature = "security")]
35use crate::tls::{RustlsConnector, TlsConfig, TlsStream};
36
37// --------------------------------------------------------------------
38
39/// Security relevant configuration options for `KafkaClient`.
40#[cfg(feature = "security")]
41#[derive(Clone)]
42pub struct SecurityConfig {
43    pub(crate) tls_config: TlsConfig,
44    pub(crate) sasl_config: Option<SaslConfig>,
45}
46
47/// SASL configuration options for `KafkaClient`.
48#[cfg(feature = "security")]
49#[derive(Clone, Debug)]
50pub struct SaslConfig {
51    pub(crate) mechanism: String,
52    pub(crate) username: String,
53    pub(crate) password: String,
54}
55
56#[cfg(feature = "security")]
57impl SaslConfig {
58    /// Creates a SASL configuration with explicit mechanism and credentials.
59    #[must_use]
60    pub fn new(mechanism: String, username: String, password: String) -> Self {
61        Self {
62            mechanism,
63            username,
64            password,
65        }
66    }
67
68    /// Creates a SASL/PLAIN configuration with username and password.
69    #[must_use]
70    pub fn plain(username: String, password: String) -> Self {
71        Self::new("PLAIN".to_owned(), username, password)
72    }
73
74    /// Returns SASL mechanism.
75    #[must_use]
76    pub fn mechanism(&self) -> &str {
77        &self.mechanism
78    }
79
80    /// Returns SASL username.
81    #[must_use]
82    pub fn username(&self) -> &str {
83        &self.username
84    }
85
86    /// Returns SASL password.
87    #[must_use]
88    pub fn password(&self) -> &str {
89        &self.password
90    }
91}
92
93#[cfg(feature = "security")]
94impl SecurityConfig {
95    /// Create a new `SecurityConfig` with default TLS settings.
96    #[must_use]
97    pub fn new() -> Self {
98        SecurityConfig {
99            tls_config: TlsConfig::new(),
100            sasl_config: None,
101        }
102    }
103
104    /// Create a `SecurityConfig` from a `TlsConfig`
105    #[must_use]
106    pub fn from_tls_config(tls_config: TlsConfig) -> SecurityConfig {
107        SecurityConfig {
108            tls_config,
109            sasl_config: None,
110        }
111    }
112
113    /// Initiates a client-side TLS session with/without performing hostname verification.
114    #[must_use]
115    pub fn with_hostname_verification(mut self, verify_hostname: bool) -> SecurityConfig {
116        self.tls_config.verify_hostname = verify_hostname;
117        self
118    }
119
120    /// Set a custom CA certificate file path
121    #[must_use]
122    pub fn with_ca_cert(mut self, path: String) -> SecurityConfig {
123        self.tls_config.ca_cert_path = Some(path);
124        self
125    }
126
127    /// Set client certificate and key file paths
128    #[must_use]
129    pub fn with_client_cert(mut self, cert_path: String, key_path: String) -> SecurityConfig {
130        self.tls_config.client_cert_path = Some(cert_path);
131        self.tls_config.client_key_path = Some(key_path);
132        self
133    }
134
135    /// Sets SASL configuration.
136    #[must_use]
137    pub fn with_sasl(mut self, sasl_config: SaslConfig) -> SecurityConfig {
138        self.sasl_config = Some(sasl_config);
139        self
140    }
141
142    /// Sets SASL/PLAIN username and password.
143    #[must_use]
144    pub fn with_sasl_plain(mut self, username: String, password: String) -> SecurityConfig {
145        self.sasl_config = Some(SaslConfig::plain(username, password));
146        self
147    }
148
149    /// Returns the underlying TLS configuration.
150    #[must_use]
151    pub fn tls_config(&self) -> &TlsConfig {
152        &self.tls_config
153    }
154
155    /// Returns optional SASL configuration.
156    #[must_use]
157    pub fn sasl_config(&self) -> Option<&SaslConfig> {
158        self.sasl_config.as_ref()
159    }
160}
161
162#[cfg(feature = "security")]
163impl Default for SecurityConfig {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169#[cfg(feature = "security")]
170impl fmt::Debug for SecurityConfig {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        write!(
173            f,
174            "SecurityConfig {{ verify_hostname: {} }}",
175            self.tls_config.verify_hostname
176        )
177    }
178}
179
180// --------------------------------------------------------------------
181
182#[cfg(not(feature = "security"))]
183pub(crate) type KafkaStream = TcpStream;
184
185#[cfg(feature = "security")]
186pub(crate) enum KafkaStream {
187    Plain(TcpStream),
188    Tls(Box<dyn TlsStream>),
189}
190
191pub(crate) trait StreamOps {
192    fn is_secured(&self) -> bool;
193    fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
194    fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
195    fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()>;
196}
197
198#[cfg(not(feature = "security"))]
199impl StreamOps for KafkaStream {
200    fn is_secured(&self) -> bool {
201        false
202    }
203
204    fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
205        TcpStream::set_read_timeout(self, dur)
206    }
207
208    fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
209        TcpStream::set_write_timeout(self, dur)
210    }
211
212    fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
213        TcpStream::shutdown(self, how)
214    }
215}
216
217#[cfg(feature = "security")]
218impl StreamOps for KafkaStream {
219    fn is_secured(&self) -> bool {
220        match self {
221            KafkaStream::Plain(_) => false,
222            KafkaStream::Tls(_) => true,
223        }
224    }
225
226    fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
227        match self {
228            KafkaStream::Plain(s) => s.set_read_timeout(dur),
229            KafkaStream::Tls(s) => s.set_read_timeout(dur),
230        }
231    }
232
233    fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
234        match self {
235            KafkaStream::Plain(s) => s.set_write_timeout(dur),
236            KafkaStream::Tls(s) => s.set_write_timeout(dur),
237        }
238    }
239
240    fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
241        match self {
242            KafkaStream::Plain(s) => s.shutdown(how),
243            KafkaStream::Tls(s) => s.shutdown(),
244        }
245    }
246}
247
248#[cfg(feature = "security")]
249impl Read for KafkaStream {
250    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
251        match self {
252            KafkaStream::Plain(s) => s.read(buf),
253            KafkaStream::Tls(s) => s.read(buf),
254        }
255    }
256}
257
258#[cfg(feature = "security")]
259impl Write for KafkaStream {
260    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
261        match self {
262            KafkaStream::Plain(s) => s.write(buf),
263            KafkaStream::Tls(s) => s.write(buf),
264        }
265    }
266
267    fn flush(&mut self) -> std::io::Result<()> {
268        match self {
269            KafkaStream::Plain(s) => s.flush(),
270            KafkaStream::Tls(s) => s.flush(),
271        }
272    }
273}
274
275// --------------------------------------------------------------------
276
277/// A TCP stream to a remote Kafka broker.
278pub struct KafkaConnection {
279    id: u32,
280    host: String,
281    stream: KafkaStream,
282    state: ConnectionState,
283}
284
285/// Connection health state for detecting broken connections.
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub(crate) enum ConnectionState {
288    Connected,
289    Terminated,
290}
291
292impl fmt::Debug for KafkaConnection {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        write!(
295            f,
296            "KafkaConnection {{ id: {}, secured: {}, state: {:?}, host: \"{}\" }}",
297            self.id,
298            self.stream.is_secured(),
299            self.state,
300            self.host
301        )
302    }
303}
304
305/// Configure a TCP socket with keepalive and nodelay for Kafka compatibility.
306fn configure_tcp_socket(socket: &socket2::Socket) -> std::io::Result<()> {
307    use socket2::TcpKeepalive;
308
309    let keepalive = TcpKeepalive::new()
310        .with_time(Duration::from_secs(10))
311        .with_interval(Duration::from_secs(20));
312    socket.set_tcp_keepalive(&keepalive)?;
313    socket.set_tcp_nodelay(true)?;
314    Ok(())
315}
316
317#[cfg(feature = "security")]
318const API_VERSION_SASL_HANDSHAKE: i16 = 1;
319#[cfg(feature = "security")]
320const API_VERSION_SASL_AUTHENTICATE: i16 = 1;
321#[cfg(feature = "security")]
322const DEFAULT_CLIENT_ID: &str = "rustfs-kafka";
323
324#[cfg(feature = "security")]
325#[derive(Clone, Copy)]
326enum ScramAlgorithm {
327    Sha256,
328    Sha512,
329}
330
331#[cfg(feature = "security")]
332fn perform_sasl_authentication(stream: &mut KafkaStream, sasl: &SaslConfig) -> Result<()> {
333    let mechanism = sasl.mechanism().to_owned();
334    let correlation_id = 1;
335
336    let handshake_header = RequestHeader::default()
337        .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
338        .with_request_api_key(ApiKey::SaslHandshake as i16)
339        .with_request_api_version(API_VERSION_SASL_HANDSHAKE)
340        .with_correlation_id(correlation_id);
341    let handshake_request =
342        SaslHandshakeRequest::default().with_mechanism(StrBytes::from_string(mechanism.clone()));
343
344    send_kp_request_on_stream(
345        stream,
346        &handshake_header,
347        &handshake_request,
348        API_VERSION_SASL_HANDSHAKE,
349    )?;
350    let handshake_response: SaslHandshakeResponse =
351        get_kp_response_from_stream(stream, API_VERSION_SASL_HANDSHAKE)?;
352
353    if handshake_response.error_code != 0 {
354        return Err(map_kafka_code_or_unknown(handshake_response.error_code));
355    }
356
357    if !handshake_response.mechanisms.is_empty()
358        && !handshake_response
359            .mechanisms
360            .iter()
361            .any(|m| m.as_str().eq_ignore_ascii_case(&mechanism))
362    {
363        return Err(Error::Kafka(KafkaCode::UnsupportedSaslMechanism));
364    }
365
366    if mechanism.eq_ignore_ascii_case("PLAIN") {
367        return perform_sasl_plain_authenticate(stream, sasl, correlation_id + 1);
368    }
369    if mechanism.eq_ignore_ascii_case("SCRAM-SHA-256") {
370        return perform_sasl_scram_authenticate(
371            stream,
372            sasl,
373            ScramAlgorithm::Sha256,
374            correlation_id + 1,
375        );
376    }
377    if mechanism.eq_ignore_ascii_case("SCRAM-SHA-512") {
378        return perform_sasl_scram_authenticate(
379            stream,
380            sasl,
381            ScramAlgorithm::Sha512,
382            correlation_id + 1,
383        );
384    }
385
386    Err(Error::Config(format!(
387        "unsupported SASL mechanism for sync path: {}",
388        sasl.mechanism()
389    )))
390}
391
392#[cfg(feature = "security")]
393fn perform_sasl_plain_authenticate(
394    stream: &mut KafkaStream,
395    sasl: &SaslConfig,
396    correlation_id: i32,
397) -> Result<()> {
398    let auth_header = RequestHeader::default()
399        .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
400        .with_request_api_key(ApiKey::SaslAuthenticate as i16)
401        .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
402        .with_correlation_id(correlation_id);
403    let auth_request =
404        SaslAuthenticateRequest::default().with_auth_bytes(build_sasl_plain_auth_bytes(sasl));
405
406    send_kp_request_on_stream(
407        stream,
408        &auth_header,
409        &auth_request,
410        API_VERSION_SASL_AUTHENTICATE,
411    )?;
412    let auth_response: SaslAuthenticateResponse =
413        get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
414
415    if auth_response.error_code != 0 {
416        return Err(map_kafka_code_or_unknown(auth_response.error_code));
417    }
418
419    Ok(())
420}
421
422#[cfg(feature = "security")]
423#[allow(clippy::too_many_lines)]
424fn perform_sasl_scram_authenticate(
425    stream: &mut KafkaStream,
426    sasl: &SaslConfig,
427    algorithm: ScramAlgorithm,
428    correlation_id: i32,
429) -> Result<()> {
430    let client_nonce = generate_scram_nonce();
431    let user = scram_escape_username(sasl.username());
432    let client_first_bare = format!("n={user},r={client_nonce}");
433    let client_first = format!("n,,{client_first_bare}");
434
435    let auth_header_1 = RequestHeader::default()
436        .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
437        .with_request_api_key(ApiKey::SaslAuthenticate as i16)
438        .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
439        .with_correlation_id(correlation_id);
440    let auth_request_1 =
441        SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(client_first));
442    send_kp_request_on_stream(
443        stream,
444        &auth_header_1,
445        &auth_request_1,
446        API_VERSION_SASL_AUTHENTICATE,
447    )?;
448    let auth_response_1: SaslAuthenticateResponse =
449        get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
450    if auth_response_1.error_code != 0 {
451        return Err(map_kafka_code_or_unknown(auth_response_1.error_code));
452    }
453
454    let server_first =
455        std::str::from_utf8(&auth_response_1.auth_bytes).map_err(|_| Error::codec())?;
456    let server_first_attrs = parse_scram_attributes(server_first)?;
457    if let Some(err_msg) = server_first_attrs.get("e") {
458        return Err(Error::Config(format!("SCRAM server error: {err_msg}")));
459    }
460
461    let server_nonce = server_first_attrs
462        .get("r")
463        .ok_or_else(|| Error::Config("SCRAM challenge missing nonce".to_owned()))?;
464    if !server_nonce.starts_with(&client_nonce) {
465        return Err(Error::Config(
466            "SCRAM server nonce does not include client nonce prefix".to_owned(),
467        ));
468    }
469    let salt_b64 = server_first_attrs
470        .get("s")
471        .ok_or_else(|| Error::Config("SCRAM challenge missing salt".to_owned()))?;
472    let salt = BASE64
473        .decode(salt_b64)
474        .map_err(|e| Error::Config(format!("invalid SCRAM salt encoding: {e}")))?;
475    let iterations = server_first_attrs
476        .get("i")
477        .ok_or_else(|| Error::Config("SCRAM challenge missing iterations".to_owned()))?
478        .parse::<u32>()
479        .map_err(|e| Error::Config(format!("invalid SCRAM iterations: {e}")))?;
480    if iterations == 0 {
481        return Err(Error::Config(
482            "invalid SCRAM iterations: must be > 0".to_owned(),
483        ));
484    }
485
486    let client_final_without_proof = format!("c=biws,r={server_nonce}");
487    let auth_message = format!("{client_first_bare},{server_first},{client_final_without_proof}");
488    let (client_proof, expected_server_signature) = compute_scram_proof_and_server_signature(
489        algorithm,
490        sasl.password(),
491        &salt,
492        iterations,
493        &auth_message,
494    )?;
495
496    let client_final = format!(
497        "{client_final_without_proof},p={}",
498        BASE64.encode(client_proof)
499    );
500    let auth_header_2 = RequestHeader::default()
501        .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
502        .with_request_api_key(ApiKey::SaslAuthenticate as i16)
503        .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
504        .with_correlation_id(correlation_id + 1);
505    let auth_request_2 =
506        SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(client_final));
507    send_kp_request_on_stream(
508        stream,
509        &auth_header_2,
510        &auth_request_2,
511        API_VERSION_SASL_AUTHENTICATE,
512    )?;
513    let auth_response_2: SaslAuthenticateResponse =
514        get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
515    if auth_response_2.error_code != 0 {
516        return Err(map_kafka_code_or_unknown(auth_response_2.error_code));
517    }
518
519    let server_final =
520        std::str::from_utf8(&auth_response_2.auth_bytes).map_err(|_| Error::codec())?;
521    let server_final_attrs = parse_scram_attributes(server_final)?;
522    if let Some(err_msg) = server_final_attrs.get("e") {
523        return Err(Error::Config(format!(
524            "SCRAM authentication failed: {err_msg}"
525        )));
526    }
527    let server_signature_b64 = server_final_attrs
528        .get("v")
529        .ok_or_else(|| Error::Config("SCRAM final message missing server signature".to_owned()))?;
530    let server_signature = BASE64
531        .decode(server_signature_b64)
532        .map_err(|e| Error::Config(format!("invalid SCRAM server signature encoding: {e}")))?;
533    if server_signature != expected_server_signature {
534        return Err(Error::Config(
535            "SCRAM server signature verification failed".to_owned(),
536        ));
537    }
538
539    Ok(())
540}
541
542#[cfg(feature = "security")]
543fn build_sasl_plain_auth_bytes(sasl: &SaslConfig) -> Bytes {
544    let mut payload = Vec::with_capacity(sasl.username().len() + sasl.password().len() + 2);
545    payload.push(0);
546    payload.extend_from_slice(sasl.username().as_bytes());
547    payload.push(0);
548    payload.extend_from_slice(sasl.password().as_bytes());
549    Bytes::from(payload)
550}
551
552#[cfg(feature = "security")]
553fn compute_scram_proof_and_server_signature(
554    algorithm: ScramAlgorithm,
555    password: &str,
556    salt: &[u8],
557    iterations: u32,
558    auth_message: &str,
559) -> Result<(Vec<u8>, Vec<u8>)> {
560    match algorithm {
561        ScramAlgorithm::Sha256 => compute_scram_sha256(password, salt, iterations, auth_message),
562        ScramAlgorithm::Sha512 => compute_scram_sha512(password, salt, iterations, auth_message),
563    }
564}
565
566#[cfg(feature = "security")]
567fn compute_scram_sha256(
568    password: &str,
569    salt: &[u8],
570    iterations: u32,
571    auth_message: &str,
572) -> Result<(Vec<u8>, Vec<u8>)> {
573    type HmacSha256 = Hmac<Sha256>;
574
575    let mut salted_password = [0u8; 32];
576    pbkdf2_hmac::<Sha256>(password.as_bytes(), salt, iterations, &mut salted_password);
577    let client_key = hmac_bytes::<HmacSha256>(&salted_password, b"Client Key")?;
578    let stored_key = Sha256::digest(&client_key).to_vec();
579    let client_signature = hmac_bytes::<HmacSha256>(&stored_key, auth_message.as_bytes())?;
580    let client_proof = xor_bytes(&client_key, &client_signature)?;
581    let server_key = hmac_bytes::<HmacSha256>(&salted_password, b"Server Key")?;
582    let server_signature = hmac_bytes::<HmacSha256>(&server_key, auth_message.as_bytes())?;
583    Ok((client_proof, server_signature))
584}
585
586#[cfg(feature = "security")]
587fn compute_scram_sha512(
588    password: &str,
589    salt: &[u8],
590    iterations: u32,
591    auth_message: &str,
592) -> Result<(Vec<u8>, Vec<u8>)> {
593    type HmacSha512 = Hmac<Sha512>;
594
595    let mut salted_password = [0u8; 64];
596    pbkdf2_hmac::<Sha512>(password.as_bytes(), salt, iterations, &mut salted_password);
597    let client_key = hmac_bytes::<HmacSha512>(&salted_password, b"Client Key")?;
598    let stored_key = Sha512::digest(&client_key).to_vec();
599    let client_signature = hmac_bytes::<HmacSha512>(&stored_key, auth_message.as_bytes())?;
600    let client_proof = xor_bytes(&client_key, &client_signature)?;
601    let server_key = hmac_bytes::<HmacSha512>(&salted_password, b"Server Key")?;
602    let server_signature = hmac_bytes::<HmacSha512>(&server_key, auth_message.as_bytes())?;
603    Ok((client_proof, server_signature))
604}
605
606#[cfg(feature = "security")]
607fn hmac_bytes<M>(key: &[u8], data: &[u8]) -> Result<Vec<u8>>
608where
609    M: Mac + hmac::digest::KeyInit,
610{
611    let mut mac = <M as hmac::digest::KeyInit>::new_from_slice(key)
612        .map_err(|e| Error::Config(format!("hmac init failed: {e}")))?;
613    mac.update(data);
614    Ok(mac.finalize().into_bytes().to_vec())
615}
616
617#[cfg(feature = "security")]
618fn xor_bytes(left: &[u8], right: &[u8]) -> Result<Vec<u8>> {
619    if left.len() != right.len() {
620        return Err(Error::Config(
621            "SCRAM proof construction failed: buffer length mismatch".to_owned(),
622        ));
623    }
624    Ok(left.iter().zip(right.iter()).map(|(a, b)| a ^ b).collect())
625}
626
627#[cfg(feature = "security")]
628fn parse_scram_attributes(input: &str) -> Result<HashMap<String, String>> {
629    let mut out = HashMap::new();
630    for part in input.split(',') {
631        if part.is_empty() {
632            continue;
633        }
634        let Some((k, v)) = part.split_once('=') else {
635            return Err(Error::Config(format!(
636                "invalid SCRAM attribute segment: {part}"
637            )));
638        };
639        out.insert(k.to_owned(), v.to_owned());
640    }
641    Ok(out)
642}
643
644#[cfg(feature = "security")]
645fn generate_scram_nonce() -> String {
646    Alphanumeric.sample_string(&mut rand::rng(), 24)
647}
648
649#[cfg(feature = "security")]
650fn scram_escape_username(username: &str) -> String {
651    username.replace('=', "=3D").replace(',', "=2C")
652}
653
654#[cfg(feature = "security")]
655fn send_kp_request_on_stream<T>(
656    stream: &mut KafkaStream,
657    header: &RequestHeader,
658    body: &T,
659    api_version: i16,
660) -> Result<()>
661where
662    T: Encodable + HeaderVersion,
663{
664    let header_version = T::header_version(api_version);
665
666    let mut header_buf = BytesMut::new();
667    header
668        .encode(&mut header_buf, header_version)
669        .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
670
671    let mut body_buf = BytesMut::new();
672    body.encode(&mut body_buf, api_version)
673        .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
674
675    let total_len = i32::try_from(header_buf.len() + body_buf.len())
676        .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
677    let mut out = BytesMut::with_capacity(
678        4 + usize::try_from(total_len).map_err(|_| Error::Protocol(ProtocolError::Codec))?,
679    );
680    out.extend_from_slice(&total_len.to_be_bytes());
681    out.extend_from_slice(&header_buf);
682    out.extend_from_slice(&body_buf);
683
684    stream.write_all(&out).map_err(Error::from)?;
685    stream.flush().map_err(Error::from)
686}
687
688#[cfg(feature = "security")]
689fn get_kp_response_from_stream<R>(stream: &mut KafkaStream, api_version: i16) -> Result<R>
690where
691    R: Decodable + HeaderVersion,
692{
693    let mut size_buf = [0u8; 4];
694    stream.read_exact(&mut size_buf).map_err(Error::from)?;
695    let size = i32::from_be_bytes(size_buf);
696    if size < 0 {
697        return Err(Error::Protocol(ProtocolError::Codec));
698    }
699
700    let mut payload = vec![0u8; usize::try_from(size).map_err(|_| Error::codec())?];
701    stream.read_exact(&mut payload).map_err(Error::from)?;
702    let mut bytes = Bytes::from(payload);
703
704    let response_header_version = R::header_version(api_version);
705    let _resp_header = ResponseHeader::decode(&mut bytes, response_header_version)
706        .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
707
708    R::decode(&mut bytes, api_version).map_err(|_| Error::Protocol(ProtocolError::Codec))
709}
710
711#[cfg(feature = "security")]
712fn map_kafka_code_or_unknown(code: i16) -> Error {
713    Error::from_protocol(code).unwrap_or(Error::Kafka(KafkaCode::Unknown))
714}
715
716impl KafkaConnection {
717    pub fn send(&mut self, msg: &[u8]) -> Result<usize> {
718        self.stream.write(msg).map_err(|e| {
719            self.state = ConnectionState::Terminated;
720            From::from(e)
721        })
722    }
723
724    pub(crate) fn is_terminated(&self) -> bool {
725        self.state == ConnectionState::Terminated
726    }
727
728    pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
729        self.stream.read_exact(buf).map_err(|e| {
730            self.state = ConnectionState::Terminated;
731            From::from(e)
732        })
733    }
734
735    pub fn read_exact_alloc(&mut self, size: u64) -> Result<bytes::Bytes> {
736        let len = usize::try_from(size).expect("response size exceeds usize");
737        let mut buf = bytes::BytesMut::with_capacity(len);
738        buf.resize(len, 0);
739        self.read_exact(&mut buf)?;
740        Ok(buf.freeze())
741    }
742
743    pub(crate) fn shutdown(&mut self) -> Result<()> {
744        self.state = ConnectionState::Terminated;
745        let r = StreamOps::shutdown(&mut self.stream, Shutdown::Both);
746        debug!("Shut down: {:?} => {:?}", self, r);
747        r.map_err(From::from)
748    }
749
750    fn from_stream(
751        mut stream: KafkaStream,
752        id: u32,
753        host: &str,
754        rw_timeout: Option<Duration>,
755    ) -> Result<KafkaConnection> {
756        StreamOps::set_read_timeout(&mut stream, rw_timeout)?;
757        StreamOps::set_write_timeout(&mut stream, rw_timeout)?;
758        Ok(KafkaConnection {
759            id,
760            host: host.to_owned(),
761            stream,
762            state: ConnectionState::Connected,
763        })
764    }
765
766    fn new_tcp_stream(host: &str) -> std::io::Result<TcpStream> {
767        let mut last_err: Option<std::io::Error> = None;
768        for addr in host.to_socket_addrs()? {
769            let domain = match addr {
770                std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
771                std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
772            };
773            let socket =
774                socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
775
776            match socket.connect(&socket2::SockAddr::from(addr)) {
777                Ok(()) => {
778                    configure_tcp_socket(&socket)?;
779                    return Ok(socket.into());
780                }
781                Err(e) => last_err = Some(e),
782            }
783        }
784
785        Err(last_err.unwrap_or_else(|| {
786            std::io::Error::new(
787                std::io::ErrorKind::AddrNotAvailable,
788                format!("unable to resolve broker address: {host}"),
789            )
790        }))
791    }
792
793    #[cfg(not(feature = "security"))]
794    pub(crate) fn new(
795        id: u32,
796        host: &str,
797        rw_timeout: Option<Duration>,
798    ) -> Result<KafkaConnection> {
799        KafkaConnection::from_stream(Self::new_tcp_stream(host)?, id, host, rw_timeout)
800    }
801
802    #[cfg(feature = "security")]
803    pub(crate) fn new(
804        id: u32,
805        host: &str,
806        rw_timeout: Option<Duration>,
807        security: Option<&SecurityConfig>,
808    ) -> Result<KafkaConnection> {
809        let tcp_stream = Self::new_tcp_stream(host)?;
810
811        let mut stream = match security.map(SecurityConfig::tls_config) {
812            Some(config) => {
813                let domain = match host.rfind(':') {
814                    None => host,
815                    Some(i) => &host[..i],
816                };
817                let connector = RustlsConnector::new(config)?;
818                let tls_stream = connector.connect(domain, tcp_stream)?;
819                KafkaStream::Tls(tls_stream)
820            }
821            None => KafkaStream::Plain(tcp_stream),
822        };
823
824        if let Some(sasl) = security.and_then(SecurityConfig::sasl_config) {
825            perform_sasl_authentication(&mut stream, sasl)?;
826        }
827
828        KafkaConnection::from_stream(stream, id, host, rw_timeout)
829    }
830}