1use bytes::{Bytes, BytesMut};
4use kafka_protocol::messages::{
5 ApiKey, MetadataRequest, MetadataResponse, ProduceRequest, ProduceResponse, RequestHeader,
6 ResponseHeader, TopicName, metadata_request::MetadataRequestTopic,
7};
8use kafka_protocol::protocol::{Decodable, Encodable, StrBytes};
9use kafka_protocol::records::{
10 Record as KpRecord, RecordBatchEncoder, RecordEncodeOptions, TimestampType,
11};
12use rustfs_kafka::client::{Compression, RequiredAcks, SecurityConfig};
13use rustfs_kafka::error::{ConnectionError, Error, KafkaCode, ProtocolError, Result};
14use rustfs_kafka::producer::{AsBytes, Record};
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::atomic::{AtomicI32, Ordering};
18use std::time::Duration;
19use tokio::sync::Mutex;
20use tracing::debug;
21
22use crate::AsyncKafkaClient;
23use crate::connection::AsyncConnection;
24
25const API_VERSION_PRODUCE: i16 = 9;
26const API_VERSION_METADATA: i16 = 1;
27
28struct NativeProducer {
29 client: Mutex<AsyncKafkaClient>,
30 state: Mutex<NativeProducerState>,
31 required_acks: i16,
32 ack_timeout_ms: i32,
33 compression: Compression,
34 correlation: AtomicI32,
35}
36
37#[derive(Default)]
38struct NativeProducerState {
39 brokers: HashMap<i32, String>,
40 topics: HashMap<String, TopicRoute>,
41 round_robin: HashMap<String, usize>,
42}
43
44#[derive(Default)]
45struct TopicRoute {
46 partitions: HashMap<i32, i32>, available_partitions: Vec<i32>,
48}
49
50enum AsyncProducerMode {
51 Native(Box<NativeProducer>),
52}
53
54pub struct AsyncProducer {
58 mode: AsyncProducerMode,
59}
60
61pub struct AsyncProducerConfig {
63 required_acks: RequiredAcks,
64 ack_timeout: Duration,
65 compression: Compression,
66 security: Option<SecurityConfig>,
67}
68
69impl AsyncProducerConfig {
70 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 required_acks: RequiredAcks::One,
74 ack_timeout: Duration::from_secs(30),
75 compression: Compression::NONE,
76 security: None,
77 }
78 }
79
80 #[must_use]
81 pub fn with_required_acks(mut self, required_acks: RequiredAcks) -> Self {
82 self.required_acks = required_acks;
83 self
84 }
85
86 #[must_use]
87 pub fn with_ack_timeout(mut self, ack_timeout: Duration) -> Self {
88 self.ack_timeout = ack_timeout;
89 self
90 }
91
92 #[must_use]
93 pub fn with_compression(mut self, compression: Compression) -> Self {
94 self.compression = compression;
95 self
96 }
97
98 #[must_use]
99 pub fn with_security(mut self, security: SecurityConfig) -> Self {
100 self.security = Some(security);
101 self
102 }
103}
104
105impl Default for AsyncProducerConfig {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111pub struct AsyncProducerBuilder {
113 hosts: Vec<String>,
114 client_id: String,
115 config: AsyncProducerConfig,
116 channel_capacity: usize,
117 native_async: bool,
118}
119
120impl AsyncProducerBuilder {
121 #[must_use]
123 pub fn new(hosts: Vec<String>) -> Self {
124 Self {
125 hosts,
126 client_id: "rustfs-kafka-async".to_owned(),
127 config: AsyncProducerConfig::default(),
128 channel_capacity: 256,
129 native_async: true,
130 }
131 }
132
133 #[must_use]
135 pub fn with_client_id(mut self, client_id: String) -> Self {
136 self.client_id = client_id;
137 self
138 }
139
140 #[must_use]
142 pub fn with_required_acks(mut self, required_acks: RequiredAcks) -> Self {
143 self.config = self.config.with_required_acks(required_acks);
144 self
145 }
146
147 #[must_use]
149 pub fn with_ack_timeout(mut self, ack_timeout: Duration) -> Self {
150 self.config = self.config.with_ack_timeout(ack_timeout);
151 self
152 }
153
154 #[must_use]
156 pub fn with_compression(mut self, compression: Compression) -> Self {
157 self.config = self.config.with_compression(compression);
158 self
159 }
160
161 #[must_use]
163 pub fn with_security(mut self, security: SecurityConfig) -> Self {
164 self.config = self.config.with_security(security);
165 self
166 }
167
168 #[must_use]
170 pub fn with_channel_capacity(mut self, channel_capacity: usize) -> Self {
171 self.channel_capacity = channel_capacity.max(1);
172 self
173 }
174
175 #[must_use]
177 pub fn with_native_async(mut self, native_async: bool) -> Self {
178 self.native_async = native_async;
179 self
180 }
181
182 pub async fn build(self) -> Result<AsyncProducer> {
184 let AsyncProducerBuilder {
185 hosts,
186 client_id,
187 config,
188 channel_capacity,
189 native_async,
190 } = self;
191
192 if !native_async {
193 debug!(
194 "AsyncProducerBuilder::with_native_async(false) is ignored: producer always uses native async I/O"
195 );
196 }
197 let _ = channel_capacity;
198
199 let client = AsyncKafkaClient::with_client_id_and_security(
200 hosts,
201 client_id,
202 config.security.clone(),
203 )
204 .await?;
205 AsyncProducer::from_native(client, config)
206 }
207}
208
209impl AsyncProducer {
210 #[must_use]
212 pub fn builder(hosts: Vec<String>) -> AsyncProducerBuilder {
213 AsyncProducerBuilder::new(hosts)
214 }
215
216 pub async fn new(client: AsyncKafkaClient) -> Result<Self> {
218 Self::new_with_config(client, AsyncProducerConfig::default()).await
219 }
220
221 pub async fn new_with_config(
223 client: AsyncKafkaClient,
224 config: AsyncProducerConfig,
225 ) -> Result<Self> {
226 if config.security.is_some() && client.security().is_none() {
227 return Self::builder(client.bootstrap_hosts().to_vec())
228 .with_client_id(client.client_id().to_owned())
229 .with_required_acks(config.required_acks)
230 .with_ack_timeout(config.ack_timeout)
231 .with_compression(config.compression)
232 .build_with_optional_security(config.security)
233 .await;
234 }
235
236 Self::from_native(client, config)
237 }
238
239 pub async fn from_hosts(hosts: Vec<String>) -> Result<Self> {
241 Self::builder(hosts).build().await
242 }
243
244 pub async fn from_hosts_with_config(
246 hosts: Vec<String>,
247 config: AsyncProducerConfig,
248 ) -> Result<Self> {
249 Self::builder(hosts)
250 .with_required_acks(config.required_acks)
251 .with_ack_timeout(config.ack_timeout)
252 .with_compression(config.compression)
253 .build_with_optional_security(config.security)
254 .await
255 }
256
257 pub async fn send<K, V>(&self, record: &Record<'_, K, V>) -> Result<()>
259 where
260 K: AsBytes,
261 V: AsBytes,
262 {
263 match &self.mode {
264 AsyncProducerMode::Native(native) => native.send(record).await,
265 }
266 }
267
268 pub async fn flush(&self) -> Result<()> {
270 Ok(())
271 }
272
273 pub async fn close(self) -> Result<()> {
275 Ok(())
276 }
277
278 fn from_native(client: AsyncKafkaClient, config: AsyncProducerConfig) -> Result<Self> {
279 if client.bootstrap_hosts().is_empty() {
280 return Err(no_host_reachable_error());
281 }
282
283 let ack_timeout_ms = to_millis_i32(config.ack_timeout)?;
284 Ok(Self {
285 mode: AsyncProducerMode::Native(
286 NativeProducer {
287 client: Mutex::new(client),
288 state: Mutex::new(NativeProducerState::default()),
289 required_acks: config.required_acks as i16,
290 ack_timeout_ms,
291 compression: config.compression,
292 correlation: AtomicI32::new(1),
293 }
294 .into(),
295 ),
296 })
297 }
298}
299
300impl AsyncProducerBuilder {
301 async fn build_with_optional_security(
302 self,
303 security: Option<SecurityConfig>,
304 ) -> Result<AsyncProducer> {
305 if let Some(security) = security {
306 self.with_security(security).build().await
307 } else {
308 self.build().await
309 }
310 }
311}
312
313impl NativeProducer {
314 async fn send<K, V>(&self, record: &Record<'_, K, V>) -> Result<()>
315 where
316 K: AsBytes,
317 V: AsBytes,
318 {
319 let topic = record.topic.to_owned();
320 let requested_partition = record.partition;
321 let key = Bytes::copy_from_slice(record.key.as_bytes());
322 let value = Bytes::copy_from_slice(record.value.as_bytes());
323 let headers: Vec<(String, Bytes)> = record.headers.iter().cloned().collect();
324
325 let correlation_id = self.correlation.fetch_add(1, Ordering::Relaxed);
326 let mut client = self.client.lock().await;
327 let mut state = self.state.lock().await;
328 client.ensure_connected().await?;
329
330 let (partition, leader_host) = resolve_partition_and_leader(
331 &mut client,
332 &mut state,
333 &topic,
334 requested_partition,
335 correlation_id,
336 )
337 .await?;
338 let client_id = client.client_id().to_owned();
339 let conn = client.get_connection(&leader_host).await?;
340
341 let (header, request) = build_single_produce_request(
342 correlation_id,
343 &client_id,
344 self.required_acks,
345 self.ack_timeout_ms,
346 self.compression,
347 &topic,
348 partition,
349 key.as_ref(),
350 value.as_ref(),
351 &headers,
352 );
353
354 send_kp_request(conn, &header, &request, API_VERSION_PRODUCE).await?;
355 if self.required_acks == 0 {
356 return Ok(());
357 }
358
359 let response = get_kp_response::<ProduceResponse>(conn, API_VERSION_PRODUCE).await?;
360 for topic_resp in response.responses {
361 for part in topic_resp.partition_responses {
362 if part.error_code != 0 {
363 if let Some(code) = map_kafka_code(part.error_code) {
364 return Err(Error::Kafka(code));
365 }
366 return Err(Error::Kafka(KafkaCode::Unknown));
367 }
368 }
369 }
370
371 Ok(())
372 }
373}
374
375async fn resolve_partition_and_leader(
376 client: &mut AsyncKafkaClient,
377 state: &mut NativeProducerState,
378 topic: &str,
379 requested_partition: i32,
380 correlation_id: i32,
381) -> Result<(i32, String)> {
382 for _ in 0..2 {
383 if let Some((partition, leader_host)) =
384 try_resolve_from_cache(state, topic, requested_partition)
385 {
386 return Ok((partition, leader_host));
387 }
388
389 refresh_topic_metadata(client, state, topic, correlation_id).await?;
390 }
391
392 Err(Error::Kafka(KafkaCode::UnknownTopicOrPartition))
393}
394
395fn try_resolve_from_cache(
396 state: &mut NativeProducerState,
397 topic: &str,
398 requested_partition: i32,
399) -> Option<(i32, String)> {
400 let route = state.topics.get(topic)?;
401 let partitions = route.partitions.clone();
402 let available_partitions = route.available_partitions.clone();
403 let partition = if requested_partition >= 0 {
404 requested_partition
405 } else {
406 pick_round_robin_partition(state, topic, &available_partitions)?
407 };
408
409 let leader_id = *partitions.get(&partition)?;
410 if leader_id < 0 {
411 return None;
412 }
413 let leader_host = state.brokers.get(&leader_id)?.clone();
414 Some((partition, leader_host))
415}
416
417fn pick_round_robin_partition(
418 state: &mut NativeProducerState,
419 topic: &str,
420 available_partitions: &[i32],
421) -> Option<i32> {
422 if available_partitions.is_empty() {
423 return None;
424 }
425
426 let len = available_partitions.len();
427 let idx = match state.round_robin.entry(topic.to_owned()) {
428 Entry::Occupied(mut occupied) => {
429 let idx = *occupied.get() % len;
430 *occupied.get_mut() = occupied.get().wrapping_add(1);
431 idx
432 }
433 Entry::Vacant(vacant) => {
434 vacant.insert(1);
435 0
436 }
437 };
438 available_partitions.get(idx).copied()
439}
440
441async fn refresh_topic_metadata(
442 client: &mut AsyncKafkaClient,
443 state: &mut NativeProducerState,
444 topic: &str,
445 correlation_id: i32,
446) -> Result<()> {
447 let request_host = pick_request_host(client).ok_or_else(no_host_reachable_error)?;
448 let client_id = client.client_id().to_owned();
449 let conn = client.get_connection(&request_host).await?;
450 let (header, request) = build_metadata_request(correlation_id, &client_id, topic);
451
452 send_kp_request(conn, &header, &request, API_VERSION_METADATA).await?;
453 let response = get_kp_response::<MetadataResponse>(conn, API_VERSION_METADATA).await?;
454
455 for broker in response.brokers {
456 state.brokers.insert(
457 i32::from(broker.node_id),
458 format!("{}:{}", broker.host, broker.port),
459 );
460 }
461
462 for topic_meta in response.topics {
463 let Some(name) = topic_meta.name else {
464 continue;
465 };
466 if name.as_str() != topic {
467 continue;
468 }
469
470 let mut route = TopicRoute::default();
471 for part in topic_meta.partitions {
472 let partition = part.partition_index;
473 let leader = i32::from(part.leader_id);
474 route.partitions.insert(partition, leader);
475 if leader >= 0 {
476 route.available_partitions.push(partition);
477 }
478 }
479
480 route.available_partitions.sort_unstable();
481 route.available_partitions.dedup();
482 state.topics.insert(topic.to_owned(), route);
483 return Ok(());
484 }
485
486 Err(Error::Kafka(KafkaCode::UnknownTopicOrPartition))
487}
488
489fn pick_request_host(client: &AsyncKafkaClient) -> Option<String> {
490 if let Some(connected) = client.connected_hosts().first() {
491 return Some((*connected).to_owned());
492 }
493 client.bootstrap_hosts().first().cloned()
494}
495
496fn build_metadata_request(
497 correlation_id: i32,
498 client_id: &str,
499 topic: &str,
500) -> (RequestHeader, MetadataRequest) {
501 let header = RequestHeader::default()
502 .with_client_id(Some(StrBytes::from_string(client_id.to_owned())))
503 .with_request_api_key(ApiKey::Metadata as i16)
504 .with_request_api_version(API_VERSION_METADATA)
505 .with_correlation_id(correlation_id);
506
507 let request = MetadataRequest::default().with_topics(Some(vec![
508 MetadataRequestTopic::default().with_name(Some(TopicName::from(StrBytes::from_string(
509 topic.to_owned(),
510 )))),
511 ]));
512
513 (header, request)
514}
515
516#[allow(clippy::too_many_arguments)]
517fn build_single_produce_request(
518 correlation_id: i32,
519 client_id: &str,
520 required_acks: i16,
521 timeout_ms: i32,
522 compression: Compression,
523 topic: &str,
524 partition: i32,
525 key: &[u8],
526 value: &[u8],
527 headers: &[(String, Bytes)],
528) -> (RequestHeader, ProduceRequest) {
529 let header = RequestHeader::default()
530 .with_client_id(Some(StrBytes::from_string(client_id.to_owned())))
531 .with_request_api_key(ApiKey::Produce as i16)
532 .with_request_api_version(API_VERSION_PRODUCE)
533 .with_correlation_id(correlation_id);
534
535 let kp_headers = headers
536 .iter()
537 .map(|(k, v)| (StrBytes::from_string(k.clone()), Some(v.clone())))
538 .collect();
539
540 let record = KpRecord {
541 transactional: false,
542 control: false,
543 partition_leader_epoch: -1,
544 producer_id: -1,
545 producer_epoch: -1,
546 timestamp_type: TimestampType::Creation,
547 offset: 0,
548 sequence: -1,
549 timestamp: 0,
550 key: if key.is_empty() {
551 None
552 } else {
553 Some(Bytes::copy_from_slice(key))
554 },
555 value: if value.is_empty() {
556 None
557 } else {
558 Some(Bytes::copy_from_slice(value))
559 },
560 headers: kp_headers,
561 };
562
563 let mut buf = BytesMut::new();
564 let options = RecordEncodeOptions {
565 version: 2,
566 compression: to_kp_compression(compression),
567 };
568 RecordBatchEncoder::encode(&mut buf, &[record], &options)
569 .expect("failed to encode record batch");
570
571 let partition_data = kafka_protocol::messages::produce_request::PartitionProduceData::default()
572 .with_index(partition)
573 .with_records(Some(buf.freeze()));
574
575 let topic_data = kafka_protocol::messages::produce_request::TopicProduceData::default()
576 .with_name(TopicName::from(StrBytes::from_string(topic.to_owned())))
577 .with_partition_data(vec![partition_data]);
578
579 let request = ProduceRequest::default()
580 .with_transactional_id(None)
581 .with_acks(required_acks)
582 .with_timeout_ms(timeout_ms)
583 .with_topic_data(vec![topic_data]);
584
585 (header, request)
586}
587
588fn to_kp_compression(c: Compression) -> kafka_protocol::records::Compression {
589 match c {
590 Compression::NONE => kafka_protocol::records::Compression::None,
591 Compression::GZIP => kafka_protocol::records::Compression::Gzip,
592 Compression::SNAPPY => kafka_protocol::records::Compression::Snappy,
593 Compression::LZ4 => kafka_protocol::records::Compression::Lz4,
594 Compression::ZSTD => kafka_protocol::records::Compression::Zstd,
595 }
596}
597
598fn map_kafka_code(code: i16) -> Option<KafkaCode> {
599 match code {
600 0 => None,
601 1 => Some(KafkaCode::OffsetOutOfRange),
602 2 => Some(KafkaCode::CorruptMessage),
603 3 => Some(KafkaCode::UnknownTopicOrPartition),
604 4 => Some(KafkaCode::InvalidMessageSize),
605 5 => Some(KafkaCode::LeaderNotAvailable),
606 6 => Some(KafkaCode::NotLeaderForPartition),
607 7 => Some(KafkaCode::RequestTimedOut),
608 8 => Some(KafkaCode::BrokerNotAvailable),
609 9 => Some(KafkaCode::ReplicaNotAvailable),
610 10 => Some(KafkaCode::MessageSizeTooLarge),
611 11 => Some(KafkaCode::StaleControllerEpoch),
612 12 => Some(KafkaCode::OffsetMetadataTooLarge),
613 13 => Some(KafkaCode::NetworkException),
614 14 => Some(KafkaCode::GroupLoadInProgress),
615 15 => Some(KafkaCode::GroupCoordinatorNotAvailable),
616 16 => Some(KafkaCode::NotCoordinatorForGroup),
617 17 => Some(KafkaCode::InvalidTopic),
618 22 => Some(KafkaCode::IllegalGeneration),
619 23 => Some(KafkaCode::InconsistentGroupProtocol),
620 24 => Some(KafkaCode::InvalidGroupId),
621 25 => Some(KafkaCode::UnknownMemberId),
622 26 => Some(KafkaCode::InvalidSessionTimeout),
623 27 => Some(KafkaCode::RebalanceInProgress),
624 28 => Some(KafkaCode::InvalidCommitOffsetSize),
625 29 => Some(KafkaCode::TopicAuthorizationFailed),
626 30 => Some(KafkaCode::GroupAuthorizationFailed),
627 31 => Some(KafkaCode::ClusterAuthorizationFailed),
628 32 => Some(KafkaCode::InvalidTimestamp),
629 33 => Some(KafkaCode::UnsupportedSaslMechanism),
630 34 => Some(KafkaCode::IllegalSaslState),
631 35 => Some(KafkaCode::UnsupportedVersion),
632 _ => Some(KafkaCode::Unknown),
633 }
634}
635
636async fn send_kp_request<T>(
637 conn: &mut AsyncConnection,
638 header: &RequestHeader,
639 body: &T,
640 api_version: i16,
641) -> Result<()>
642where
643 T: Encodable + kafka_protocol::protocol::HeaderVersion,
644{
645 let header_version = T::header_version(api_version);
646
647 let mut header_buf = BytesMut::new();
648 header
649 .encode(&mut header_buf, header_version)
650 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
651
652 let mut body_buf = BytesMut::new();
653 body.encode(&mut body_buf, api_version)
654 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
655
656 let total_len = usize_to_i32(header_buf.len() + body_buf.len())?;
657 let mut out = BytesMut::with_capacity(4 + non_negative_i32_to_usize(total_len)?);
658 out.extend_from_slice(&total_len.to_be_bytes());
659 out.extend_from_slice(&header_buf);
660 out.extend_from_slice(&body_buf);
661
662 conn.send(&out).await
663}
664
665async fn get_kp_response<R>(conn: &mut AsyncConnection, api_version: i16) -> Result<R>
666where
667 R: Decodable + kafka_protocol::protocol::HeaderVersion,
668{
669 let size_bytes = conn.read_exact(4).await?;
670 let size = i32::from_be_bytes(
671 <[u8; 4]>::try_from(size_bytes.as_ref())
672 .map_err(|_| Error::Protocol(ProtocolError::Codec))?,
673 );
674 let mut bytes = conn.read_exact(non_negative_i32_to_u64(size)?).await?;
675
676 let response_header_version = R::header_version(api_version);
677 let _resp_header = ResponseHeader::decode(&mut bytes, response_header_version)
678 .map_err(|_| Error::Protocol(ProtocolError::Codec))?;
679
680 R::decode(&mut bytes, api_version).map_err(|_| Error::Protocol(ProtocolError::Codec))
681}
682
683fn to_millis_i32(d: Duration) -> Result<i32> {
684 let m = d
685 .as_secs()
686 .saturating_mul(1_000)
687 .saturating_add(u64::from(d.subsec_millis()));
688 if m > i32::MAX as u64 {
689 Err(Error::Protocol(ProtocolError::InvalidDuration))
690 } else {
691 i32::try_from(m).map_err(|_| Error::Protocol(ProtocolError::InvalidDuration))
692 }
693}
694
695fn usize_to_i32(value: usize) -> Result<i32> {
696 i32::try_from(value).map_err(|_| Error::Protocol(ProtocolError::Codec))
697}
698
699fn non_negative_i32_to_usize(value: i32) -> Result<usize> {
700 usize::try_from(value).map_err(|_| Error::Protocol(ProtocolError::Codec))
701}
702
703fn non_negative_i32_to_u64(value: i32) -> Result<u64> {
704 u64::try_from(value).map_err(|_| Error::Protocol(ProtocolError::Codec))
705}
706
707fn no_host_reachable_error() -> Error {
708 Error::Connection(ConnectionError::NoHostReachable)
709}
710
711#[cfg(test)]
712mod tests {
713 use rustfs_kafka::error::{ConnectionError, Error};
714
715 use super::*;
716
717 #[tokio::test]
718 async fn from_hosts_fails_with_unreachable_hosts() {
719 let result = AsyncProducer::from_hosts(vec!["127.0.0.1:1".to_owned()]).await;
720 assert!(matches!(
721 result,
722 Err(Error::Connection(ConnectionError::NoHostReachable))
723 ));
724 }
725
726 #[tokio::test]
727 async fn new_fails_with_empty_hosts() {
728 let client = AsyncKafkaClient::new(vec![]).await.unwrap();
729 let result = AsyncProducer::new(client).await;
730 assert!(matches!(
731 result,
732 Err(Error::Connection(ConnectionError::NoHostReachable))
733 ));
734 }
735}