1#[cfg(feature = "security")]
2use std::collections::HashMap;
3use std::fmt;
4use std::io::{Read, Write};
5use std::net::{Shutdown, TcpStream, ToSocketAddrs};
6use std::time::Duration;
7use tracing::debug;
8
9#[cfg(feature = "security")]
10use base64::Engine as _;
11#[cfg(feature = "security")]
12use base64::engine::general_purpose::STANDARD as BASE64;
13#[cfg(feature = "security")]
14use bytes::{Bytes, BytesMut};
15#[cfg(feature = "security")]
16use hmac::{Hmac, Mac};
17#[cfg(feature = "security")]
18use kafka_protocol::messages::{
19 ApiKey, RequestHeader, ResponseHeader, SaslAuthenticateRequest, SaslAuthenticateResponse,
20 SaslHandshakeRequest, SaslHandshakeResponse,
21};
22#[cfg(feature = "security")]
23use kafka_protocol::protocol::{Decodable, Encodable, HeaderVersion, StrBytes};
24#[cfg(feature = "security")]
25use pbkdf2::pbkdf2_hmac;
26#[cfg(feature = "security")]
27use rand::distr::{Alphanumeric, SampleString};
28#[cfg(feature = "security")]
29use sha2::{Digest, Sha256, Sha512};
30
31use crate::error::Result;
32#[cfg(feature = "security")]
33use crate::error::{Error, KafkaCode, ProtocolError};
34#[cfg(feature = "security")]
35use crate::tls::{RustlsConnector, TlsConfig, TlsStream};
36
37#[cfg(feature = "security")]
41#[derive(Clone)]
42pub struct SecurityConfig {
43 pub(crate) tls_config: TlsConfig,
44 pub(crate) sasl_config: Option<SaslConfig>,
45}
46
47#[cfg(feature = "security")]
49#[derive(Clone, Debug)]
50pub struct SaslConfig {
51 pub(crate) mechanism: String,
52 pub(crate) username: String,
53 pub(crate) password: String,
54}
55
56#[cfg(feature = "security")]
57impl SaslConfig {
58 #[must_use]
60 pub fn new(mechanism: String, username: String, password: String) -> Self {
61 Self {
62 mechanism,
63 username,
64 password,
65 }
66 }
67
68 #[must_use]
70 pub fn plain(username: String, password: String) -> Self {
71 Self::new("PLAIN".to_owned(), username, password)
72 }
73
74 #[must_use]
76 pub fn mechanism(&self) -> &str {
77 &self.mechanism
78 }
79
80 #[must_use]
82 pub fn username(&self) -> &str {
83 &self.username
84 }
85
86 #[must_use]
88 pub fn password(&self) -> &str {
89 &self.password
90 }
91}
92
93#[cfg(feature = "security")]
94impl SecurityConfig {
95 #[must_use]
97 pub fn new() -> Self {
98 SecurityConfig {
99 tls_config: TlsConfig::new(),
100 sasl_config: None,
101 }
102 }
103
104 #[must_use]
106 pub fn from_tls_config(tls_config: TlsConfig) -> SecurityConfig {
107 SecurityConfig {
108 tls_config,
109 sasl_config: None,
110 }
111 }
112
113 #[must_use]
115 pub fn with_hostname_verification(mut self, verify_hostname: bool) -> SecurityConfig {
116 self.tls_config.verify_hostname = verify_hostname;
117 self
118 }
119
120 #[must_use]
122 pub fn with_ca_cert(mut self, path: String) -> SecurityConfig {
123 self.tls_config.ca_cert_path = Some(path);
124 self
125 }
126
127 #[must_use]
129 pub fn with_client_cert(mut self, cert_path: String, key_path: String) -> SecurityConfig {
130 self.tls_config.client_cert_path = Some(cert_path);
131 self.tls_config.client_key_path = Some(key_path);
132 self
133 }
134
135 #[must_use]
137 pub fn with_sasl(mut self, sasl_config: SaslConfig) -> SecurityConfig {
138 self.sasl_config = Some(sasl_config);
139 self
140 }
141
142 #[must_use]
144 pub fn with_sasl_plain(mut self, username: String, password: String) -> SecurityConfig {
145 self.sasl_config = Some(SaslConfig::plain(username, password));
146 self
147 }
148
149 #[must_use]
151 pub fn tls_config(&self) -> &TlsConfig {
152 &self.tls_config
153 }
154
155 #[must_use]
157 pub fn sasl_config(&self) -> Option<&SaslConfig> {
158 self.sasl_config.as_ref()
159 }
160}
161
162#[cfg(feature = "security")]
163impl Default for SecurityConfig {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[cfg(feature = "security")]
170impl fmt::Debug for SecurityConfig {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 write!(
173 f,
174 "SecurityConfig {{ verify_hostname: {} }}",
175 self.tls_config.verify_hostname
176 )
177 }
178}
179
180#[cfg(not(feature = "security"))]
183pub(crate) type KafkaStream = TcpStream;
184
185#[cfg(feature = "security")]
186pub(crate) enum KafkaStream {
187 Plain(TcpStream),
188 Tls(Box<dyn TlsStream>),
189}
190
191pub(crate) trait StreamOps {
192 fn is_secured(&self) -> bool;
193 fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
194 fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
195 fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()>;
196}
197
198#[cfg(not(feature = "security"))]
199impl StreamOps for KafkaStream {
200 fn is_secured(&self) -> bool {
201 false
202 }
203
204 fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
205 TcpStream::set_read_timeout(self, dur)
206 }
207
208 fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
209 TcpStream::set_write_timeout(self, dur)
210 }
211
212 fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
213 TcpStream::shutdown(self, how)
214 }
215}
216
217#[cfg(feature = "security")]
218impl StreamOps for KafkaStream {
219 fn is_secured(&self) -> bool {
220 match self {
221 KafkaStream::Plain(_) => false,
222 KafkaStream::Tls(_) => true,
223 }
224 }
225
226 fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
227 match self {
228 KafkaStream::Plain(s) => s.set_read_timeout(dur),
229 KafkaStream::Tls(s) => s.set_read_timeout(dur),
230 }
231 }
232
233 fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
234 match self {
235 KafkaStream::Plain(s) => s.set_write_timeout(dur),
236 KafkaStream::Tls(s) => s.set_write_timeout(dur),
237 }
238 }
239
240 fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
241 match self {
242 KafkaStream::Plain(s) => s.shutdown(how),
243 KafkaStream::Tls(s) => s.shutdown(),
244 }
245 }
246}
247
248#[cfg(feature = "security")]
249impl Read for KafkaStream {
250 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
251 match self {
252 KafkaStream::Plain(s) => s.read(buf),
253 KafkaStream::Tls(s) => s.read(buf),
254 }
255 }
256}
257
258#[cfg(feature = "security")]
259impl Write for KafkaStream {
260 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
261 match self {
262 KafkaStream::Plain(s) => s.write(buf),
263 KafkaStream::Tls(s) => s.write(buf),
264 }
265 }
266
267 fn flush(&mut self) -> std::io::Result<()> {
268 match self {
269 KafkaStream::Plain(s) => s.flush(),
270 KafkaStream::Tls(s) => s.flush(),
271 }
272 }
273}
274
275pub struct KafkaConnection {
279 id: u32,
280 host: String,
281 stream: KafkaStream,
282 state: ConnectionState,
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub(crate) enum ConnectionState {
288 Connected,
289 Terminated,
290}
291
292impl fmt::Debug for KafkaConnection {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(
295 f,
296 "KafkaConnection {{ id: {}, secured: {}, state: {:?}, host: \"{}\" }}",
297 self.id,
298 self.stream.is_secured(),
299 self.state,
300 self.host
301 )
302 }
303}
304
305fn configure_tcp_socket(socket: &socket2::Socket) -> std::io::Result<()> {
307 use socket2::TcpKeepalive;
308
309 let keepalive = TcpKeepalive::new()
310 .with_time(Duration::from_secs(10))
311 .with_interval(Duration::from_secs(20));
312 socket.set_tcp_keepalive(&keepalive)?;
313 socket.set_tcp_nodelay(true)?;
314 Ok(())
315}
316
317#[cfg(feature = "security")]
318const API_VERSION_SASL_HANDSHAKE: i16 = 1;
319#[cfg(feature = "security")]
320const API_VERSION_SASL_AUTHENTICATE: i16 = 1;
321#[cfg(feature = "security")]
322const DEFAULT_CLIENT_ID: &str = "rustfs-kafka";
323
324#[cfg(feature = "security")]
325#[derive(Clone, Copy)]
326enum ScramAlgorithm {
327 Sha256,
328 Sha512,
329}
330
331#[cfg(feature = "security")]
332fn perform_sasl_authentication(stream: &mut KafkaStream, sasl: &SaslConfig) -> Result<()> {
333 let mechanism = sasl.mechanism().to_owned();
334 let correlation_id = 1;
335
336 let handshake_header = RequestHeader::default()
337 .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
338 .with_request_api_key(ApiKey::SaslHandshake as i16)
339 .with_request_api_version(API_VERSION_SASL_HANDSHAKE)
340 .with_correlation_id(correlation_id);
341 let handshake_request =
342 SaslHandshakeRequest::default().with_mechanism(StrBytes::from_string(mechanism.clone()));
343
344 send_kp_request_on_stream(
345 stream,
346 &handshake_header,
347 &handshake_request,
348 API_VERSION_SASL_HANDSHAKE,
349 )?;
350 let handshake_response: SaslHandshakeResponse =
351 get_kp_response_from_stream(stream, API_VERSION_SASL_HANDSHAKE)?;
352
353 if handshake_response.error_code != 0 {
354 return Err(map_kafka_code_or_unknown(handshake_response.error_code));
355 }
356
357 if !handshake_response.mechanisms.is_empty()
358 && !handshake_response
359 .mechanisms
360 .iter()
361 .any(|m| m.as_str().eq_ignore_ascii_case(&mechanism))
362 {
363 return Err(Error::Kafka(KafkaCode::UnsupportedSaslMechanism));
364 }
365
366 if mechanism.eq_ignore_ascii_case("PLAIN") {
367 return perform_sasl_plain_authenticate(stream, sasl, correlation_id + 1);
368 }
369 if mechanism.eq_ignore_ascii_case("SCRAM-SHA-256") {
370 return perform_sasl_scram_authenticate(
371 stream,
372 sasl,
373 ScramAlgorithm::Sha256,
374 correlation_id + 1,
375 );
376 }
377 if mechanism.eq_ignore_ascii_case("SCRAM-SHA-512") {
378 return perform_sasl_scram_authenticate(
379 stream,
380 sasl,
381 ScramAlgorithm::Sha512,
382 correlation_id + 1,
383 );
384 }
385
386 Err(Error::Config(format!(
387 "unsupported SASL mechanism for sync path: {}",
388 sasl.mechanism()
389 )))
390}
391
392#[cfg(feature = "security")]
393fn perform_sasl_plain_authenticate(
394 stream: &mut KafkaStream,
395 sasl: &SaslConfig,
396 correlation_id: i32,
397) -> Result<()> {
398 let auth_header = RequestHeader::default()
399 .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
400 .with_request_api_key(ApiKey::SaslAuthenticate as i16)
401 .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
402 .with_correlation_id(correlation_id);
403 let auth_request =
404 SaslAuthenticateRequest::default().with_auth_bytes(build_sasl_plain_auth_bytes(sasl));
405
406 send_kp_request_on_stream(
407 stream,
408 &auth_header,
409 &auth_request,
410 API_VERSION_SASL_AUTHENTICATE,
411 )?;
412 let auth_response: SaslAuthenticateResponse =
413 get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
414
415 if auth_response.error_code != 0 {
416 return Err(map_kafka_code_or_unknown(auth_response.error_code));
417 }
418
419 Ok(())
420}
421
422#[cfg(feature = "security")]
423#[allow(clippy::too_many_lines)]
424fn perform_sasl_scram_authenticate(
425 stream: &mut KafkaStream,
426 sasl: &SaslConfig,
427 algorithm: ScramAlgorithm,
428 correlation_id: i32,
429) -> Result<()> {
430 let client_nonce = generate_scram_nonce();
431 let user = scram_escape_username(sasl.username());
432 let client_first_bare = format!("n={user},r={client_nonce}");
433 let client_first = format!("n,,{client_first_bare}");
434
435 let auth_header_1 = RequestHeader::default()
436 .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
437 .with_request_api_key(ApiKey::SaslAuthenticate as i16)
438 .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
439 .with_correlation_id(correlation_id);
440 let auth_request_1 =
441 SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(client_first));
442 send_kp_request_on_stream(
443 stream,
444 &auth_header_1,
445 &auth_request_1,
446 API_VERSION_SASL_AUTHENTICATE,
447 )?;
448 let auth_response_1: SaslAuthenticateResponse =
449 get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
450 if auth_response_1.error_code != 0 {
451 return Err(map_kafka_code_or_unknown(auth_response_1.error_code));
452 }
453
454 let server_first =
455 std::str::from_utf8(&auth_response_1.auth_bytes).map_err(|_| Error::codec())?;
456 let server_first_attrs = parse_scram_attributes(server_first)?;
457 if let Some(err_msg) = server_first_attrs.get("e") {
458 return Err(Error::Config(format!("SCRAM server error: {err_msg}")));
459 }
460
461 let server_nonce = server_first_attrs
462 .get("r")
463 .ok_or_else(|| Error::Config("SCRAM challenge missing nonce".to_owned()))?;
464 if !server_nonce.starts_with(&client_nonce) {
465 return Err(Error::Config(
466 "SCRAM server nonce does not include client nonce prefix".to_owned(),
467 ));
468 }
469 let salt_b64 = server_first_attrs
470 .get("s")
471 .ok_or_else(|| Error::Config("SCRAM challenge missing salt".to_owned()))?;
472 let salt = BASE64
473 .decode(salt_b64)
474 .map_err(|e| Error::Config(format!("invalid SCRAM salt encoding: {e}")))?;
475 let iterations = server_first_attrs
476 .get("i")
477 .ok_or_else(|| Error::Config("SCRAM challenge missing iterations".to_owned()))?
478 .parse::<u32>()
479 .map_err(|e| Error::Config(format!("invalid SCRAM iterations: {e}")))?;
480 if iterations == 0 {
481 return Err(Error::Config(
482 "invalid SCRAM iterations: must be > 0".to_owned(),
483 ));
484 }
485
486 let client_final_without_proof = format!("c=biws,r={server_nonce}");
487 let auth_message = format!("{client_first_bare},{server_first},{client_final_without_proof}");
488 let (client_proof, expected_server_signature) = compute_scram_proof_and_server_signature(
489 algorithm,
490 sasl.password(),
491 &salt,
492 iterations,
493 &auth_message,
494 )?;
495
496 let client_final = format!(
497 "{client_final_without_proof},p={}",
498 BASE64.encode(client_proof)
499 );
500 let auth_header_2 = RequestHeader::default()
501 .with_client_id(Some(StrBytes::from_string(DEFAULT_CLIENT_ID.to_owned())))
502 .with_request_api_key(ApiKey::SaslAuthenticate as i16)
503 .with_request_api_version(API_VERSION_SASL_AUTHENTICATE)
504 .with_correlation_id(correlation_id + 1);
505 let auth_request_2 =
506 SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(client_final));
507 send_kp_request_on_stream(
508 stream,
509 &auth_header_2,
510 &auth_request_2,
511 API_VERSION_SASL_AUTHENTICATE,
512 )?;
513 let auth_response_2: SaslAuthenticateResponse =
514 get_kp_response_from_stream(stream, API_VERSION_SASL_AUTHENTICATE)?;
515 if auth_response_2.error_code != 0 {
516 return Err(map_kafka_code_or_unknown(auth_response_2.error_code));
517 }
518
519 let server_final =
520 std::str::from_utf8(&auth_response_2.auth_bytes).map_err(|_| Error::codec())?;
521 let server_final_attrs = parse_scram_attributes(server_final)?;
522 if let Some(err_msg) = server_final_attrs.get("e") {
523 return Err(Error::Config(format!(
524 "SCRAM authentication failed: {err_msg}"
525 )));
526 }
527 let server_signature_b64 = server_final_attrs
528 .get("v")
529 .ok_or_else(|| Error::Config("SCRAM final message missing server signature".to_owned()))?;
530 let server_signature = BASE64
531 .decode(server_signature_b64)
532 .map_err(|e| Error::Config(format!("invalid SCRAM server signature encoding: {e}")))?;
533 if server_signature != expected_server_signature {
534 return Err(Error::Config(
535 "SCRAM server signature verification failed".to_owned(),
536 ));
537 }
538
539 Ok(())
540}
541
542#[cfg(feature = "security")]
543fn build_sasl_plain_auth_bytes(sasl: &SaslConfig) -> Bytes {
544 let mut payload = Vec::with_capacity(sasl.username().len() + sasl.password().len() + 2);
545 payload.push(0);
546 payload.extend_from_slice(sasl.username().as_bytes());
547 payload.push(0);
548 payload.extend_from_slice(sasl.password().as_bytes());
549 Bytes::from(payload)
550}
551
552#[cfg(feature = "security")]
553fn compute_scram_proof_and_server_signature(
554 algorithm: ScramAlgorithm,
555 password: &str,
556 salt: &[u8],
557 iterations: u32,
558 auth_message: &str,
559) -> Result<(Vec<u8>, Vec<u8>)> {
560 match algorithm {
561 ScramAlgorithm::Sha256 => compute_scram_sha256(password, salt, iterations, auth_message),
562 ScramAlgorithm::Sha512 => compute_scram_sha512(password, salt, iterations, auth_message),
563 }
564}
565
566#[cfg(feature = "security")]
567fn compute_scram_sha256(
568 password: &str,
569 salt: &[u8],
570 iterations: u32,
571 auth_message: &str,
572) -> Result<(Vec<u8>, Vec<u8>)> {
573 type HmacSha256 = Hmac<Sha256>;
574
575 let mut salted_password = [0u8; 32];
576 pbkdf2_hmac::<Sha256>(password.as_bytes(), salt, iterations, &mut salted_password);
577 let client_key = hmac_bytes::<HmacSha256>(&salted_password, b"Client Key")?;
578 let stored_key = Sha256::digest(&client_key).to_vec();
579 let client_signature = hmac_bytes::<HmacSha256>(&stored_key, auth_message.as_bytes())?;
580 let client_proof = xor_bytes(&client_key, &client_signature)?;
581 let server_key = hmac_bytes::<HmacSha256>(&salted_password, b"Server Key")?;
582 let server_signature = hmac_bytes::<HmacSha256>(&server_key, auth_message.as_bytes())?;
583 Ok((client_proof, server_signature))
584}
585
586#[cfg(feature = "security")]
587fn compute_scram_sha512(
588 password: &str,
589 salt: &[u8],
590 iterations: u32,
591 auth_message: &str,
592) -> Result<(Vec<u8>, Vec<u8>)> {
593 type HmacSha512 = Hmac<Sha512>;
594
595 let mut salted_password = [0u8; 64];
596 pbkdf2_hmac::<Sha512>(password.as_bytes(), salt, iterations, &mut salted_password);
597 let client_key = hmac_bytes::<HmacSha512>(&salted_password, b"Client Key")?;
598 let stored_key = Sha512::digest(&client_key).to_vec();
599 let client_signature = hmac_bytes::<HmacSha512>(&stored_key, auth_message.as_bytes())?;
600 let client_proof = xor_bytes(&client_key, &client_signature)?;
601 let server_key = hmac_bytes::<HmacSha512>(&salted_password, b"Server Key")?;
602 let server_signature = hmac_bytes::<HmacSha512>(&server_key, auth_message.as_bytes())?;
603 Ok((client_proof, server_signature))
604}
605
606#[cfg(feature = "security")]
607fn hmac_bytes<M>(key: &[u8], data: &[u8]) -> Result<Vec<u8>>
608where
609 M: Mac + hmac::digest::KeyInit,
610{
611 let mut mac = <M as hmac::digest::KeyInit>::new_from_slice(key)
612 .map_err(|e| Error::Config(format!("hmac init failed: {e}")))?;
613 mac.update(data);
614 Ok(mac.finalize().into_bytes().to_vec())
615}
616
617#[cfg(feature = "security")]
618fn xor_bytes(left: &[u8], right: &[u8]) -> Result<Vec<u8>> {
619 if left.len() != right.len() {
620 return Err(Error::Config(
621 "SCRAM proof construction failed: buffer length mismatch".to_owned(),
622 ));
623 }
624 Ok(left.iter().zip(right.iter()).map(|(a, b)| a ^ b).collect())
625}
626
627#[cfg(feature = "security")]
628fn parse_scram_attributes(input: &str) -> Result<HashMap<String, String>> {
629 let mut out = HashMap::new();
630 for part in input.split(',') {
631 if part.is_empty() {
632 continue;
633 }
634 let Some((k, v)) = part.split_once('=') else {
635 return Err(Error::Config(format!(
636 "invalid SCRAM attribute segment: {part}"
637 )));
638 };
639 out.insert(k.to_owned(), v.to_owned());
640 }
641 Ok(out)
642}
643
644#[cfg(feature = "security")]
645fn generate_scram_nonce() -> String {
646 Alphanumeric.sample_string(&mut rand::rng(), 24)
647}
648
649#[cfg(feature = "security")]
650fn scram_escape_username(username: &str) -> String {
651 username.replace('=', "=3D").replace(',', "=2C")
652}
653
654#[cfg(feature = "security")]
655fn send_kp_request_on_stream<T>(
656 stream: &mut KafkaStream,
657 header: &RequestHeader,
658 body: &T,
659 api_version: i16,
660) -> Result<()>
661where
662 T: Encodable + HeaderVersion,
663{
664 let header_version = T::header_version(api_version);
665
666 let mut header_buf = BytesMut::new();
667 header
668 .encode(&mut header_buf, header_version)
669 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
670
671 let mut body_buf = BytesMut::new();
672 body.encode(&mut body_buf, api_version)
673 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
674
675 let total_len = i32::try_from(header_buf.len() + body_buf.len())
676 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
677 let mut out = BytesMut::with_capacity(
678 4 + usize::try_from(total_len).map_err(|_| Error::Protocol(ProtocolError::Codec))?,
679 );
680 out.extend_from_slice(&total_len.to_be_bytes());
681 out.extend_from_slice(&header_buf);
682 out.extend_from_slice(&body_buf);
683
684 stream.write_all(&out).map_err(Error::from)?;
685 stream.flush().map_err(Error::from)
686}
687
688#[cfg(feature = "security")]
689fn get_kp_response_from_stream<R>(stream: &mut KafkaStream, api_version: i16) -> Result<R>
690where
691 R: Decodable + HeaderVersion,
692{
693 let mut size_buf = [0u8; 4];
694 stream.read_exact(&mut size_buf).map_err(Error::from)?;
695 let size = i32::from_be_bytes(size_buf);
696 if size < 0 {
697 return Err(Error::Protocol(ProtocolError::Codec));
698 }
699
700 let mut payload = vec![0u8; usize::try_from(size).map_err(|_| Error::codec())?];
701 stream.read_exact(&mut payload).map_err(Error::from)?;
702 let mut bytes = Bytes::from(payload);
703
704 let response_header_version = R::header_version(api_version);
705 let _resp_header = ResponseHeader::decode(&mut bytes, response_header_version)
706 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
707
708 R::decode(&mut bytes, api_version).map_err(|_| Error::Protocol(ProtocolError::Codec))
709}
710
711#[cfg(feature = "security")]
712fn map_kafka_code_or_unknown(code: i16) -> Error {
713 Error::from_protocol(code).unwrap_or(Error::Kafka(KafkaCode::Unknown))
714}
715
716impl KafkaConnection {
717 pub fn send(&mut self, msg: &[u8]) -> Result<usize> {
718 self.stream.write(msg).map_err(|e| {
719 self.state = ConnectionState::Terminated;
720 From::from(e)
721 })
722 }
723
724 pub(crate) fn is_terminated(&self) -> bool {
725 self.state == ConnectionState::Terminated
726 }
727
728 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
729 self.stream.read_exact(buf).map_err(|e| {
730 self.state = ConnectionState::Terminated;
731 From::from(e)
732 })
733 }
734
735 pub fn read_exact_alloc(&mut self, size: u64) -> Result<bytes::Bytes> {
736 let len = usize::try_from(size).expect("response size exceeds usize");
737 let mut buf = bytes::BytesMut::with_capacity(len);
738 buf.resize(len, 0);
739 self.read_exact(&mut buf)?;
740 Ok(buf.freeze())
741 }
742
743 pub(crate) fn shutdown(&mut self) -> Result<()> {
744 self.state = ConnectionState::Terminated;
745 let r = StreamOps::shutdown(&mut self.stream, Shutdown::Both);
746 debug!("Shut down: {:?} => {:?}", self, r);
747 r.map_err(From::from)
748 }
749
750 fn from_stream(
751 mut stream: KafkaStream,
752 id: u32,
753 host: &str,
754 rw_timeout: Option<Duration>,
755 ) -> Result<KafkaConnection> {
756 StreamOps::set_read_timeout(&mut stream, rw_timeout)?;
757 StreamOps::set_write_timeout(&mut stream, rw_timeout)?;
758 Ok(KafkaConnection {
759 id,
760 host: host.to_owned(),
761 stream,
762 state: ConnectionState::Connected,
763 })
764 }
765
766 fn new_tcp_stream(host: &str) -> std::io::Result<TcpStream> {
767 let mut last_err: Option<std::io::Error> = None;
768 for addr in host.to_socket_addrs()? {
769 let domain = match addr {
770 std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
771 std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
772 };
773 let socket =
774 socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
775
776 match socket.connect(&socket2::SockAddr::from(addr)) {
777 Ok(()) => {
778 configure_tcp_socket(&socket)?;
779 return Ok(socket.into());
780 }
781 Err(e) => last_err = Some(e),
782 }
783 }
784
785 Err(last_err.unwrap_or_else(|| {
786 std::io::Error::new(
787 std::io::ErrorKind::AddrNotAvailable,
788 format!("unable to resolve broker address: {host}"),
789 )
790 }))
791 }
792
793 #[cfg(not(feature = "security"))]
794 pub(crate) fn new(
795 id: u32,
796 host: &str,
797 rw_timeout: Option<Duration>,
798 ) -> Result<KafkaConnection> {
799 KafkaConnection::from_stream(Self::new_tcp_stream(host)?, id, host, rw_timeout)
800 }
801
802 #[cfg(feature = "security")]
803 pub(crate) fn new(
804 id: u32,
805 host: &str,
806 rw_timeout: Option<Duration>,
807 security: Option<&SecurityConfig>,
808 ) -> Result<KafkaConnection> {
809 let tcp_stream = Self::new_tcp_stream(host)?;
810
811 let mut stream = match security.map(SecurityConfig::tls_config) {
812 Some(config) => {
813 let domain = match host.rfind(':') {
814 None => host,
815 Some(i) => &host[..i],
816 };
817 let connector = RustlsConnector::new(config)?;
818 let tls_stream = connector.connect(domain, tcp_stream)?;
819 KafkaStream::Tls(tls_stream)
820 }
821 None => KafkaStream::Plain(tcp_stream),
822 };
823
824 if let Some(sasl) = security.and_then(SecurityConfig::sasl_config) {
825 perform_sasl_authentication(&mut stream, sasl)?;
826 }
827
828 KafkaConnection::from_stream(stream, id, host, rw_timeout)
829 }
830}