Skip to main content

schema_registry_client/serdes/
protobuf.rs

1use crate::DESCRIPTOR_POOL;
2use crate::rest::models::{Kind, Mode, Phase};
3use crate::rest::models::{Schema, SchemaReference};
4use crate::rest::schema_registry_client::Client;
5use crate::serdes::config::{DeserializerConfig, SerializerConfig};
6use crate::serdes::rule_registry::RuleRegistry;
7use crate::serdes::serde::SerdeError::Serialization;
8use crate::serdes::serde::{
9    BaseDeserializer, BaseSerializer, FALLBACK_TYPE_CONFIG, FieldTransformer, FieldType,
10    KAFKA_CLUSTER_ID_CONFIG, RuleContext, SchemaId, Serde, SerdeError, SerdeFormat, SerdeSchema,
11    SerdeType, SerdeValue, SerializationContext, SubjectCacheKey, SubjectNameStrategyType,
12    get_executor, get_executors, load_associated_subject, parse_subject_name_strategy_type,
13    topic_name_strategy,
14};
15use async_recursion::async_recursion;
16use base64::Engine;
17use base64::prelude::BASE64_STANDARD;
18use dashmap::DashMap;
19use futures::future::FutureExt;
20use prost::{Message, bytes};
21use prost_reflect::prost_types::{DescriptorProto, FileDescriptorProto};
22use prost_reflect::{
23    DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, MessageDescriptor,
24    ReflectMessage, SerializeOptions, Value,
25};
26use prost_types::FileDescriptorSet;
27use std::collections::{HashMap, HashSet, VecDeque};
28use std::io::{BufReader, Cursor};
29use std::sync::Arc;
30
31pub mod confluent {
32    include!("../codegen/confluent.rs");
33    pub mod r#type {
34        include!("../codegen/confluent.r#type.rs");
35    }
36}
37
38#[derive(Clone, Debug)]
39pub(crate) struct ProtobufSerde {
40    parsed_schemas: DashMap<Schema, (FileDescriptor, DescriptorPool)>,
41    subject_cache: DashMap<SubjectCacheKey, Option<String>>,
42}
43
44#[derive(Clone)]
45pub struct ProtobufSerializer<'a, T: Client> {
46    reference_subject_name_strategy: ReferenceSubjectNameStrategy,
47    base: BaseSerializer<'a, T>,
48    serde: ProtobufSerde,
49    subject_name_strategy_type: SubjectNameStrategyType,
50}
51
52impl<'a, T: Client> std::fmt::Debug for ProtobufSerializer<'a, T> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("ProtobufSerializer")
55            .field("reference_subject_name_strategy", &"<function>")
56            .field("serde", &self.serde)
57            .field(
58                "subject_name_strategy_type",
59                &self.subject_name_strategy_type,
60            )
61            .finish_non_exhaustive()
62    }
63}
64
65pub type ReferenceSubjectNameStrategy = fn(ref_name: &str, serde_type: &SerdeType) -> String;
66
67pub fn default_reference_subject_name_strategy(ref_name: &str, serde_type: &SerdeType) -> String {
68    ref_name.to_string()
69}
70
71impl<'a, T: Client + Sync> ProtobufSerializer<'a, T> {
72    pub fn new(
73        client: &'a T,
74        rule_registry: Option<RuleRegistry>,
75        serializer_config: SerializerConfig,
76    ) -> Result<ProtobufSerializer<'a, T>, SerdeError> {
77        Self::with_reference_subject_name_strategy(
78            client,
79            default_reference_subject_name_strategy,
80            rule_registry,
81            serializer_config,
82        )
83    }
84
85    pub fn with_reference_subject_name_strategy(
86        client: &'a T,
87        reference_subject_name_strategy: ReferenceSubjectNameStrategy,
88        rule_registry: Option<RuleRegistry>,
89        serializer_config: SerializerConfig,
90    ) -> Result<ProtobufSerializer<'a, T>, SerdeError> {
91        for executor in get_executors(rule_registry.as_ref()) {
92            executor.configure(client.config(), &serializer_config.rule_config)?;
93        }
94        Ok(ProtobufSerializer {
95            reference_subject_name_strategy,
96            base: BaseSerializer::new(Serde::new(client, rule_registry), serializer_config.clone()),
97            serde: ProtobufSerde {
98                parsed_schemas: DashMap::new(),
99                subject_cache: DashMap::new(),
100            },
101            subject_name_strategy_type: serializer_config.subject_name_strategy_type,
102        })
103    }
104
105    pub async fn serialize_with_file_desc_set<M: Message>(
106        &self,
107        ctx: &SerializationContext,
108        value: &M,
109        message_type_name: &str,
110        fds: FileDescriptorSet,
111    ) -> Result<Vec<u8>, SerdeError> {
112        let pool = DescriptorPool::from_file_descriptor_set(fds)?;
113        let md = pool
114            .get_message_by_name(message_type_name)
115            .ok_or(Serialization(format!(
116                "message descriptor {message_type_name} not found"
117            )))?;
118        self.serialize_with_message_desc(ctx, value, &md).await
119    }
120
121    pub async fn serialize<M: ReflectMessage>(
122        &self,
123        ctx: &SerializationContext,
124        value: &M,
125    ) -> Result<Vec<u8>, SerdeError> {
126        self.serialize_with_message_desc(ctx, value, &value.descriptor())
127            .await
128    }
129
130    pub async fn serialize_with_message_desc<M: Message>(
131        &self,
132        ctx: &SerializationContext,
133        value: &M,
134        md: &MessageDescriptor,
135    ) -> Result<Vec<u8>, SerdeError> {
136        // TODO pass schema (instead of None) to strategy later for Record/TopicRecord strategies
137        let subject = self.get_subject(&ctx.topic, &ctx.serde_type, None).await?;
138        let latest_schema = if let Some(ref subj) = subject {
139            self.base
140                .serde
141                .get_reader_schema(subj, Some("serialized"), &self.base.config.use_schema)
142                .await?
143        } else {
144            None
145        };
146        let subject = subject.ok_or_else(|| {
147            Serialization("Could not determine subject for serialization".to_string())
148        })?;
149
150        let mut schema_id;
151        if let Some(ref schema) = latest_schema {
152            schema_id = SchemaId::new(SerdeFormat::Protobuf, schema.id, schema.guid.clone(), None)?;
153        } else {
154            let references = self.resolve_dependencies(ctx, &md.parent_file()).await?;
155            let schema = Schema {
156                schema_type: Some("PROTOBUF".to_string()),
157                references: Some(references),
158                metadata: None,
159                rule_set: None,
160                schema: schema_to_str(&md.parent_file())?,
161            };
162            if self.base.config.auto_register_schemas {
163                let rs = self
164                    .base
165                    .serde
166                    .client
167                    .register_schema(&subject, &schema, self.base.config.normalize_schemas)
168                    .await?;
169                schema_id = SchemaId::new(SerdeFormat::Protobuf, rs.id, rs.guid.clone(), None)?;
170            } else {
171                let rs = self
172                    .base
173                    .serde
174                    .client
175                    .get_by_schema(&subject, &schema, self.base.config.normalize_schemas, false)
176                    .await?;
177                schema_id = SchemaId::new(SerdeFormat::Protobuf, rs.id, rs.guid.clone(), None)?;
178            }
179        }
180
181        let fd;
182        let pool;
183        let mut encoded_bytes = Vec::new();
184        if let Some(ref latest_schema) = latest_schema {
185            let schema = latest_schema.to_schema();
186            (fd, pool) = self.get_parsed_schema(&schema).await?;
187            let field_transformer: FieldTransformer =
188                Box::new(|ctx, value| transform_fields(ctx, value).boxed());
189            let mut msg = DynamicMessage::new(md.clone());
190            msg.transcode_from(value)?;
191            let serde_value = self
192                .base
193                .serde
194                .execute_rules(
195                    ctx,
196                    &subject,
197                    Mode::Write,
198                    None,
199                    Some(&schema),
200                    Some(&SerdeSchema::Protobuf(fd.clone())),
201                    &SerdeValue::Protobuf(Value::Message(msg)),
202                    Some(Arc::new(field_transformer)),
203                )
204                .await?;
205            msg = match serde_value {
206                SerdeValue::Protobuf(Value::Message(msg)) => msg,
207                _ => return Err(Serialization("unexpected serde value".to_string())),
208            };
209            msg.encode(&mut encoded_bytes)?;
210            if let Some(ref rule_set) = schema.rule_set
211                && rule_set.encoding_rules.is_some()
212            {
213                encoded_bytes = self
214                    .base
215                    .serde
216                    .execute_rules_with_phase(
217                        ctx,
218                        &subject,
219                        Phase::Encoding,
220                        Mode::Write,
221                        None,
222                        Some(&schema),
223                        None,
224                        &SerdeValue::new_bytes(SerdeFormat::Protobuf, &encoded_bytes),
225                        None,
226                    )
227                    .await?
228                    .as_bytes();
229            }
230        } else {
231            value.encode(&mut encoded_bytes)?;
232        }
233
234        schema_id.message_indexes = Some(self.to_index_array(md)?);
235        let id_ser = self.base.config.schema_id_serializer;
236        id_ser(&encoded_bytes, ctx, &schema_id)
237    }
238
239    #[async_recursion]
240    async fn resolve_dependencies(
241        &self,
242        ctx: &SerializationContext,
243        file_desc: &FileDescriptor,
244    ) -> Result<Vec<SchemaReference>, SerdeError> {
245        let mut references = Vec::new();
246        for dep in file_desc.dependencies() {
247            if is_builtin(dep.name()) {
248                continue;
249            }
250            let dep_refs = self.resolve_dependencies(ctx, &dep).await?;
251            let strategy = self.reference_subject_name_strategy;
252            let subject = strategy(dep.name(), &ctx.serde_type);
253            let schema = Schema {
254                schema_type: Some("PROTOBUF".to_string()),
255                references: Some(dep_refs),
256                metadata: None,
257                rule_set: None,
258                schema: schema_to_str(&dep)?,
259            };
260            if self.base.config.auto_register_schemas {
261                self.base
262                    .serde
263                    .client
264                    .register_schema(&subject, &schema, self.base.config.normalize_schemas)
265                    .await?;
266            }
267
268            let reference = self
269                .base
270                .serde
271                .client
272                .get_by_schema(&subject, &schema, self.base.config.normalize_schemas, false)
273                .await?;
274            references.push(SchemaReference {
275                name: Some(dep.name().to_string()),
276                subject: Some(subject),
277                version: reference.version,
278            });
279        }
280        Ok(references)
281    }
282
283    fn to_index_array(&self, desc: &MessageDescriptor) -> Result<Vec<i32>, SerdeError> {
284        let mut msg_idx = VecDeque::new();
285
286        // Walk the nested MessageDescriptor tree up to the root.
287        let mut previous = desc.clone();
288        let mut current = previous;
289        let mut found = false;
290        while current.parent_message().is_some() {
291            previous = current;
292            current = previous
293                .parent_message()
294                .ok_or(Serialization("parent message not found".to_string()))?;
295            for (idx, node) in current.child_messages().enumerate() {
296                if node == previous {
297                    msg_idx.push_front(idx as i32);
298                    found = true;
299                    break;
300                }
301            }
302            if !found {
303                return Err(Serialization(
304                    "nested MessageDescriptor not found".to_string(),
305                ));
306            }
307        }
308
309        // Add the index of the root MessageDescriptor in the FileDescriptor.
310        found = false;
311        for (idx, node) in current.parent_file().messages().enumerate() {
312            if node == current {
313                msg_idx.push_front(idx as i32);
314                found = true;
315                break;
316            }
317        }
318        if !found {
319            return Err(Serialization(
320                "MessageDescriptor not found in file".to_string(),
321            ));
322        }
323        Ok(msg_idx.into_iter().collect())
324    }
325
326    async fn get_parsed_schema(
327        &self,
328        schema: &Schema,
329    ) -> Result<(FileDescriptor, DescriptorPool), SerdeError> {
330        let result = self.serde.parsed_schemas.get(schema);
331        if let Some((parsed_schema, pool)) = result.as_deref() {
332            return Ok((parsed_schema.clone(), pool.clone()));
333        }
334        // Get a copy of the global descriptor pool with the Google well-known types.
335        let mut pool = DESCRIPTOR_POOL.clone();
336        resolve_named_schema(
337            schema,
338            self.base.serde.client,
339            &mut pool,
340            &mut HashSet::new(),
341        )
342        .await?;
343        let fd_proto = str_to_proto(&mut pool, "default", &schema.schema)?;
344        self.serde
345            .parsed_schemas
346            .insert(schema.clone(), (fd_proto.clone(), pool.clone()));
347        Ok((fd_proto, pool))
348    }
349
350    pub async fn get_record_name(&self, schema: &Schema) -> Result<String, SerdeError> {
351        let (file_descriptor, _) = self.get_parsed_schema(schema).await?;
352        file_descriptor
353            .messages()
354            .next()
355            .map(|m| m.full_name().to_string())
356            .ok_or_else(|| {
357                Serialization("Could not find message name in protobuf schema".to_string())
358            })
359    }
360
361    async fn get_subject(
362        &self,
363        topic: &str,
364        serde_type: &SerdeType,
365        schema: Option<&Schema>,
366    ) -> Result<Option<String>, SerdeError> {
367        self.get_subject_for_type(self.subject_name_strategy_type, topic, serde_type, schema)
368            .await
369    }
370
371    #[async_recursion]
372    async fn get_subject_for_type(
373        &self,
374        strategy_type: SubjectNameStrategyType,
375        topic: &str,
376        serde_type: &SerdeType,
377        schema: Option<&Schema>,
378    ) -> Result<Option<String>, SerdeError> {
379        match strategy_type {
380            SubjectNameStrategyType::Record => {
381                if let Some(schema) = schema {
382                    Ok(Some(self.get_record_name(schema).await?))
383                } else {
384                    Ok(None)
385                }
386            }
387            SubjectNameStrategyType::TopicRecord => {
388                if let Some(schema) = schema {
389                    let name = self.get_record_name(schema).await?;
390                    Ok(Some(format!("{topic}-{name}")))
391                } else {
392                    Ok(None)
393                }
394            }
395            SubjectNameStrategyType::Associated => {
396                match load_associated_subject(
397                    self.base.serde.client,
398                    &self.serde.subject_cache,
399                    &self.base.config.strategy_config,
400                    topic,
401                    serde_type,
402                    schema,
403                )
404                .await?
405                {
406                    Some(s) => Ok(Some(s)),
407                    None => {
408                        let fallback_type = self
409                            .base
410                            .config
411                            .strategy_config
412                            .get(FALLBACK_TYPE_CONFIG)
413                            .map(|s| parse_subject_name_strategy_type(s))
414                            .transpose()?
415                            .unwrap_or(SubjectNameStrategyType::Topic);
416                        match fallback_type {
417                            SubjectNameStrategyType::None | SubjectNameStrategyType::Associated => {
418                                Ok(None)
419                            }
420                            other => {
421                                self.get_subject_for_type(other, topic, serde_type, schema)
422                                    .await
423                            }
424                        }
425                    }
426                }
427            }
428            _ => Ok(Some(
429                topic_name_strategy(topic, serde_type, schema).unwrap(),
430            )),
431        }
432    }
433
434    fn close(&mut self) {}
435}
436
437fn schema_to_str(fd: &FileDescriptor) -> Result<String, SerdeError> {
438    let mut buf = Vec::new();
439    fd.encode(&mut buf)?;
440    Ok(BASE64_STANDARD.encode(&buf))
441}
442
443fn str_to_proto(
444    pool: &mut DescriptorPool,
445    name: &str,
446    s: &str,
447) -> Result<FileDescriptor, SerdeError> {
448    let result = BASE64_STANDARD.decode(s).map_err(|e| {
449        Serialization(format!(
450            "failed to decode base64 schema string for {name}: {e}"
451        ))
452    });
453    decode_file_descriptor_proto_with_name(pool, name, result.unwrap())?;
454    pool.get_file_by_name(name)
455        .ok_or(Serialization("file descriptor not found".to_string()))
456}
457
458// See https://github.com/andrewhickman/prost-reflect/issues/152 and
459// https://github.com/protocolbuffers/protobuf/issues/9257
460fn decode_file_descriptor_proto_with_name(
461    pool: &mut DescriptorPool,
462    name: &str,
463    mut bytes: Vec<u8>,
464) -> Result<(), SerdeError> {
465    FileDescriptorProto {
466        name: Some(name.to_owned()),
467        ..Default::default()
468    }
469    .encode(&mut bytes)?;
470    pool.decode_file_descriptor_proto(bytes.as_slice())?;
471    Ok(())
472}
473
474async fn transform_fields(
475    ctx: &mut RuleContext,
476    value: &SerdeValue,
477) -> Result<SerdeValue, SerdeError> {
478    if let Some(SerdeSchema::Protobuf(s)) = ctx.parsed_target.clone()
479        && let SerdeValue::Protobuf(v) = value
480        && let Value::Message(message) = v
481    {
482        let pool = s.parent_pool();
483        let desc = pool
484            .get_message_by_name(message.descriptor().full_name())
485            .ok_or(Serialization(format!(
486                "message descriptor {} not found",
487                message.descriptor().full_name()
488            )))?;
489        let value = transform(ctx, &desc, v).await?;
490        return Ok(SerdeValue::Protobuf(value));
491    }
492    Ok(value.clone())
493}
494
495#[derive(Clone)]
496pub struct ProtobufDeserializer<'a, T: Client> {
497    base: BaseDeserializer<'a, T>,
498    serde: ProtobufSerde,
499    subject_name_strategy_type: SubjectNameStrategyType,
500}
501
502impl<'a, T: Client> std::fmt::Debug for ProtobufDeserializer<'a, T> {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        f.debug_struct("ProtobufDeserializer")
505            .field("serde", &self.serde)
506            .field(
507                "subject_name_strategy_type",
508                &self.subject_name_strategy_type,
509            )
510            .finish_non_exhaustive()
511    }
512}
513
514impl<'a, T: Client + Sync> ProtobufDeserializer<'a, T> {
515    pub fn new(
516        client: &'a T,
517        rule_registry: Option<RuleRegistry>,
518        deserializer_config: DeserializerConfig,
519    ) -> Result<ProtobufDeserializer<'a, T>, SerdeError> {
520        for executor in get_executors(rule_registry.as_ref()) {
521            executor.configure(client.config(), &deserializer_config.rule_config)?;
522        }
523        Ok(ProtobufDeserializer {
524            base: BaseDeserializer::new(
525                Serde::new(client, rule_registry),
526                deserializer_config.clone(),
527            ),
528            serde: ProtobufSerde {
529                parsed_schemas: DashMap::new(),
530                subject_cache: DashMap::new(),
531            },
532            subject_name_strategy_type: deserializer_config.subject_name_strategy_type,
533        })
534    }
535
536    pub async fn deserialize<M: Message + Default>(
537        &self,
538        ctx: &SerializationContext,
539        data: &[u8],
540    ) -> Result<M, SerdeError> {
541        // Get initial subject using configured subject name strategy (without schema)
542        let initial_subject = self
543            .get_subject(&ctx.topic, &ctx.serde_type, None)
544            .await
545            .ok()
546            .flatten();
547        let mut latest_schema = None;
548
549        // Try to get reader schema with initial subject
550        if let Some(ref init_subj) = initial_subject {
551            latest_schema = self
552                .base
553                .serde
554                .get_reader_schema(init_subj, Some("serialized"), &self.base.config.use_schema)
555                .await
556                .ok()
557                .flatten();
558        }
559
560        let mut schema_id = SchemaId::new(SerdeFormat::Protobuf, None, None, None)?;
561        let id_deser = self.base.config.schema_id_deserializer;
562        let bytes_read = id_deser(data, ctx, &mut schema_id)?;
563        let msg_index = schema_id.message_indexes.clone().unwrap_or_default();
564        let mut data = &data[bytes_read..];
565
566        let writer_schema_raw = self
567            .base
568            .get_writer_schema(&schema_id, initial_subject.as_deref(), Some("serialized"))
569            .await?;
570        let (writer_schema, mut pool) = self.get_parsed_schema(&writer_schema_raw).await?;
571        let writer_desc = self.get_message_desc(&pool, &writer_schema, &msg_index)?;
572
573        // Recompute subject with writer schema (needed for Record/TopicRecord strategies)
574        let subject = self
575            .get_subject(&ctx.topic, &ctx.serde_type, Some(&writer_schema_raw))
576            .await?;
577
578        // If subject changed, try to get reader schema again
579        if subject != initial_subject {
580            if let Some(ref subj) = subject {
581                if let Ok(Some(schema)) = self
582                    .base
583                    .serde
584                    .get_reader_schema(subj, Some("serialized"), &self.base.config.use_schema)
585                    .await
586                {
587                    latest_schema = Some(schema);
588                }
589            }
590        }
591
592        let subject = subject.ok_or_else(|| {
593            Serialization("Could not determine subject for deserialization".to_string())
594        })?;
595        let serde_value;
596        if let Some(ref rule_set) = writer_schema_raw.rule_set
597            && rule_set.encoding_rules.is_some()
598        {
599            serde_value = self
600                .base
601                .serde
602                .execute_rules_with_phase(
603                    ctx,
604                    &subject,
605                    Phase::Encoding,
606                    Mode::Read,
607                    None,
608                    Some(&writer_schema_raw),
609                    None,
610                    &SerdeValue::new_bytes(SerdeFormat::Protobuf, data),
611                    None,
612                )
613                .await?
614                .as_bytes();
615            data = &serde_value;
616        }
617
618        let migrations;
619        let reader_schema_raw;
620        let reader_schema;
621        if let Some(ref latest_schema) = latest_schema {
622            migrations = self
623                .base
624                .serde
625                .get_migrations(&subject, &writer_schema_raw, latest_schema, None)
626                .await?;
627            reader_schema_raw = latest_schema.to_schema();
628            (reader_schema, pool) = self.get_parsed_schema(&reader_schema_raw).await?;
629        } else {
630            migrations = Vec::new();
631            reader_schema_raw = writer_schema_raw.clone();
632            reader_schema = writer_schema.clone();
633        }
634
635        // Initialize reader desc to first message in file
636        let mut reader_desc = self.get_message_desc(&pool, &reader_schema, &[0i32])?;
637        // Attempt to find a reader desc with the same name as the writer
638        reader_desc = pool
639            .get_message_by_name(writer_desc.full_name())
640            .unwrap_or(reader_desc);
641
642        let mut msg: DynamicMessage;
643        if !migrations.is_empty() {
644            let reader = Cursor::new(data);
645            msg = DynamicMessage::decode(writer_desc, reader)?;
646            let mut serializer = serde_json::Serializer::new(vec![]);
647            let options = SerializeOptions::new().skip_default_fields(false);
648            msg.serialize_with_options(&mut serializer, &options)?;
649            let json: serde_json::Value = serde_json::from_slice(&serializer.into_inner())?;
650            let mut serde_value = SerdeValue::Json(json);
651            serde_value = self
652                .base
653                .serde
654                .execute_migrations(ctx, &subject, &migrations, &serde_value)
655                .await?;
656            let json_str = match serde_value {
657                SerdeValue::Json(v) => serde_json::to_string(&v)?,
658                _ => return Err(Serialization("unexpected serde value".to_string())),
659            };
660            let mut deserializer = serde_json::de::Deserializer::from_str(&json_str);
661            msg = DynamicMessage::deserialize(reader_desc, &mut deserializer)?;
662            deserializer.end()?;
663        } else {
664            let mut reader = Cursor::new(&data);
665            msg = DynamicMessage::decode(reader_desc, &mut reader)?;
666        }
667
668        let field_transformer: FieldTransformer =
669            Box::new(|ctx, value| transform_fields(ctx, value).boxed());
670        let serde_value = self
671            .base
672            .serde
673            .execute_rules(
674                ctx,
675                &subject,
676                Mode::Read,
677                None,
678                Some(&reader_schema_raw),
679                Some(&SerdeSchema::Protobuf(reader_schema.clone())),
680                &SerdeValue::Protobuf(Value::Message(msg)),
681                Some(Arc::new(field_transformer)),
682            )
683            .await?;
684        msg = match serde_value {
685            SerdeValue::Protobuf(Value::Message(msg)) => msg,
686            _ => return Err(Serialization("unexpected serde value".to_string())),
687        };
688
689        let result: M = msg.transcode_to()?;
690        Ok(result)
691    }
692
693    fn get_message_desc(
694        &self,
695        pool: &DescriptorPool,
696        fd: &FileDescriptor,
697        msg_index: &[i32],
698    ) -> Result<MessageDescriptor, SerdeError> {
699        let file_desc_proto = fd.file_descriptor_proto();
700        let (full_name, desc_proto) =
701            self.get_message_desc_proto_file("", file_desc_proto, msg_index)?;
702        let package = file_desc_proto.package.clone().unwrap_or_default();
703        let qualified_name = if package.is_empty() {
704            full_name
705        } else {
706            package + "." + &full_name
707        };
708        let msg = pool
709            .get_message_by_name(&qualified_name)
710            .ok_or(Serialization(format!(
711                "message descriptor {qualified_name} not found"
712            )))?;
713        Ok(msg)
714    }
715
716    fn get_message_desc_proto_file(
717        &self,
718        path: &str,
719        desc: &FileDescriptorProto,
720        msg_index: &[i32],
721    ) -> Result<(String, DescriptorProto), SerdeError> {
722        let index = msg_index[0] as usize;
723        let msg = desc.message_type.get(index).ok_or(Serialization(format!(
724            "message descriptor not found at index {index}"
725        )))?;
726        let name = msg.name.clone().unwrap_or(String::new());
727        let path = if !path.is_empty() && !name.is_empty() {
728            path.to_string() + "." + &name
729        } else {
730            name
731        };
732        if msg_index.len() == 1 {
733            Ok((path, msg.clone()))
734        } else {
735            self.get_message_desc_proto_nested(&path, msg, &msg_index[1..])
736        }
737    }
738
739    fn get_message_desc_proto_nested(
740        &self,
741        path: &str,
742        desc: &DescriptorProto,
743        msg_index: &[i32],
744    ) -> Result<(String, DescriptorProto), SerdeError> {
745        let index = msg_index[0] as usize;
746        let msg = desc.nested_type.get(index).ok_or(Serialization(format!(
747            "message descriptor not found at index {index}"
748        )))?;
749        let name = msg.name.clone().unwrap_or(String::new());
750        let path = if !path.is_empty() && !name.is_empty() {
751            path.to_string() + "." + &name
752        } else {
753            name
754        };
755        if msg_index.len() == 1 {
756            Ok((path, msg.clone()))
757        } else {
758            self.get_message_desc_proto_nested(&path, msg, &msg_index[1..])
759        }
760    }
761
762    async fn get_parsed_schema(
763        &self,
764        schema: &Schema,
765    ) -> Result<(FileDescriptor, DescriptorPool), SerdeError> {
766        let result = self.serde.parsed_schemas.get(schema);
767        if let Some((parsed_schema, pool)) = result.as_deref() {
768            return Ok((parsed_schema.clone(), pool.clone()));
769        }
770        // Get a copy of the global descriptor pool with the Google well-known types.
771        let mut pool = DESCRIPTOR_POOL.clone();
772        resolve_named_schema(
773            schema,
774            self.base.serde.client,
775            &mut pool,
776            &mut HashSet::new(),
777        )
778        .await?;
779        let fd_proto = str_to_proto(&mut pool, "default", &schema.schema)?;
780        self.serde
781            .parsed_schemas
782            .insert(schema.clone(), (fd_proto.clone(), pool.clone()));
783        Ok((fd_proto, pool))
784    }
785
786    pub async fn get_record_name(&self, schema: &Schema) -> Result<String, SerdeError> {
787        let (file_descriptor, _) = self.get_parsed_schema(schema).await?;
788        file_descriptor
789            .messages()
790            .next()
791            .map(|m| m.full_name().to_string())
792            .ok_or_else(|| {
793                Serialization("Could not find message name in protobuf schema".to_string())
794            })
795    }
796
797    async fn get_subject(
798        &self,
799        topic: &str,
800        serde_type: &SerdeType,
801        schema: Option<&Schema>,
802    ) -> Result<Option<String>, SerdeError> {
803        self.get_subject_for_type(self.subject_name_strategy_type, topic, serde_type, schema)
804            .await
805    }
806
807    #[async_recursion]
808    async fn get_subject_for_type(
809        &self,
810        strategy_type: SubjectNameStrategyType,
811        topic: &str,
812        serde_type: &SerdeType,
813        schema: Option<&Schema>,
814    ) -> Result<Option<String>, SerdeError> {
815        match strategy_type {
816            SubjectNameStrategyType::Record => {
817                if let Some(schema) = schema {
818                    Ok(Some(self.get_record_name(schema).await?))
819                } else {
820                    Ok(None)
821                }
822            }
823            SubjectNameStrategyType::TopicRecord => {
824                if let Some(schema) = schema {
825                    let name = self.get_record_name(schema).await?;
826                    Ok(Some(format!("{topic}-{name}")))
827                } else {
828                    Ok(None)
829                }
830            }
831            SubjectNameStrategyType::Associated => {
832                match load_associated_subject(
833                    self.base.serde.client,
834                    &self.serde.subject_cache,
835                    &self.base.config.strategy_config,
836                    topic,
837                    serde_type,
838                    schema,
839                )
840                .await?
841                {
842                    Some(s) => Ok(Some(s)),
843                    None => {
844                        let fallback_type = self
845                            .base
846                            .config
847                            .strategy_config
848                            .get(FALLBACK_TYPE_CONFIG)
849                            .map(|s| parse_subject_name_strategy_type(s))
850                            .transpose()?
851                            .unwrap_or(SubjectNameStrategyType::Topic);
852                        match fallback_type {
853                            SubjectNameStrategyType::None | SubjectNameStrategyType::Associated => {
854                                Ok(None)
855                            }
856                            other => {
857                                self.get_subject_for_type(other, topic, serde_type, schema)
858                                    .await
859                            }
860                        }
861                    }
862                }
863            }
864            _ => Ok(Some(
865                topic_name_strategy(topic, serde_type, schema).unwrap(),
866            )),
867        }
868    }
869}
870
871#[async_recursion]
872async fn resolve_named_schema<'a, T>(
873    schema: &Schema,
874    client: &'a T,
875    pool: &mut DescriptorPool,
876    visited: &mut HashSet<String>,
877) -> Result<(), SerdeError>
878where
879    T: Client + Sync,
880{
881    if let Some(refs) = schema.references.as_ref() {
882        for r in refs {
883            let name = r.name.clone().unwrap_or_default();
884            if is_builtin(&name) || visited.contains(&name) {
885                continue;
886            }
887            visited.insert(name.clone());
888            let ref_schema = client
889                .get_version(
890                    &r.subject.clone().unwrap_or_default(),
891                    r.version.unwrap_or(-1),
892                    true,
893                    Some("serialized"),
894                )
895                .await?;
896            resolve_named_schema(&ref_schema.to_schema(), client, pool, visited).await?;
897            str_to_proto(pool, &name, &ref_schema.schema.clone().unwrap_or_default())?;
898        }
899    }
900    Ok(())
901}
902
903#[async_recursion]
904async fn transform(
905    ctx: &mut RuleContext,
906    descriptor: &MessageDescriptor,
907    message: &Value,
908) -> Result<Value, SerdeError> {
909    match message {
910        Value::List(items) => {
911            let mut result = Vec::with_capacity(items.len());
912            for item in items {
913                let item = transform(ctx, descriptor, item).await?;
914                result.push(item);
915            }
916            return Ok(Value::List(result));
917        }
918        Value::Map(map) => {
919            let mut result = HashMap::new();
920            for (key, value) in map {
921                let value = transform(ctx, descriptor, value).await?;
922                result.insert(key.clone(), value);
923            }
924            return Ok(Value::Map(result));
925        }
926        Value::Message(message) => {
927            let mut result = message.clone();
928            for fd in descriptor.fields() {
929                let field = transform_field_with_ctx(ctx, &fd, descriptor, message).await?;
930                if let Some(field) = field {
931                    result.set_field(&fd, field);
932                }
933            }
934            return Ok(Value::Message(result));
935        }
936        _ => {
937            if let Some(field_ctx) = ctx.current_field() {
938                let rule_tags = ctx
939                    .rule
940                    .tags
941                    .clone()
942                    .map(|v| HashSet::from_iter(v.into_iter()));
943                if rule_tags.is_none_or(|tags| !tags.is_disjoint(&field_ctx.tags)) {
944                    let message_value = SerdeValue::Protobuf(message.clone());
945                    let field_executor_type = ctx.rule.r#type.clone();
946                    let executor = get_executor(ctx.rule_registry.as_ref(), &field_executor_type);
947                    if let Some(executor) = executor {
948                        let field_executor =
949                            executor
950                                .as_field_rule_executor()
951                                .ok_or(SerdeError::Rule(format!(
952                                    "executor {field_executor_type} is not a field rule executor"
953                                )))?;
954                        let new_value = field_executor.transform_field(ctx, &message_value).await?;
955                        if let SerdeValue::Protobuf(v) = new_value {
956                            return Ok(v);
957                        }
958                    }
959                }
960            }
961        }
962    }
963    Ok(message.clone())
964}
965
966async fn transform_field_with_ctx(
967    ctx: &mut RuleContext,
968    fd: &FieldDescriptor,
969    desc: &MessageDescriptor,
970    message: &DynamicMessage,
971) -> Result<Option<Value>, SerdeError> {
972    let message_value = SerdeValue::Protobuf(Value::Message(message.clone()));
973    ctx.enter_field(
974        message_value,
975        fd.full_name().to_string(),
976        fd.name().to_string(),
977        get_type(fd),
978        get_inline_tags(fd),
979    );
980    if fd.containing_oneof().is_some() && !message.has_field(fd) {
981        // skip oneof fields that are not set
982        return Ok(None);
983    }
984    let value = message.get_field(fd);
985    let new_value = transform(ctx, desc, &value).await?;
986    if let Some(Kind::Condition) = ctx.rule.kind
987        && let Value::Bool(b) = new_value
988        && !b
989    {
990        return Err(SerdeError::RuleCondition(Box::new(ctx.rule.clone())));
991    }
992    ctx.exit_field();
993    Ok(Some(new_value))
994}
995
996fn get_type(fd: &FieldDescriptor) -> FieldType {
997    if fd.is_map() {
998        return FieldType::Map;
999    }
1000    match fd.kind() {
1001        prost_reflect::Kind::Message(_) => FieldType::Record,
1002        prost_reflect::Kind::Enum(_) => FieldType::Enum,
1003        prost_reflect::Kind::String => FieldType::String,
1004        prost_reflect::Kind::Bytes => FieldType::Bytes,
1005        prost_reflect::Kind::Int32 => FieldType::Int,
1006        prost_reflect::Kind::Sint32 => FieldType::Int,
1007        prost_reflect::Kind::Uint32 => FieldType::Int,
1008        prost_reflect::Kind::Fixed32 => FieldType::Int,
1009        prost_reflect::Kind::Sfixed32 => FieldType::Int,
1010        prost_reflect::Kind::Int64 => FieldType::Long,
1011        prost_reflect::Kind::Sint64 => FieldType::Long,
1012        prost_reflect::Kind::Uint64 => FieldType::Long,
1013        prost_reflect::Kind::Fixed64 => FieldType::Long,
1014        prost_reflect::Kind::Sfixed64 => FieldType::Long,
1015        prost_reflect::Kind::Float => FieldType::Float,
1016        prost_reflect::Kind::Double => FieldType::Double,
1017        prost_reflect::Kind::Bool => FieldType::Boolean,
1018    }
1019}
1020
1021fn get_inline_tags(fd: &FieldDescriptor) -> HashSet<String> {
1022    let mut tag_set = HashSet::new();
1023    let field_ext = DESCRIPTOR_POOL
1024        .get_extension_by_name("confluent.field_meta")
1025        .unwrap();
1026    if fd.options().has_extension(&field_ext)
1027        && let Some(v) = fd.options().get_extension(&field_ext).as_message()
1028        && let Some(tags) = v.get_field_by_name("tags")
1029        && let Some(tags) = tags.as_list()
1030    {
1031        for tag in tags {
1032            if let Some(tag) = tag.as_str() {
1033                tag_set.insert(tag.to_string());
1034            }
1035        }
1036    }
1037    tag_set
1038}
1039
1040fn is_builtin(name: &str) -> bool {
1041    name.starts_with("confluent/")
1042        || name.starts_with("google/protobuf/")
1043        || name.starts_with("google/type/")
1044}
1045
1046#[cfg(test)]
1047#[cfg(feature = "rules")]
1048mod tests {
1049    use super::*;
1050    use crate::TEST_FILE_DESCRIPTOR_SET;
1051    use crate::rest::client_config::ClientConfig;
1052    use crate::rest::mock_dek_registry_client::MockDekRegistryClient;
1053    use crate::rest::mock_schema_registry_client::MockSchemaRegistryClient;
1054    use crate::rest::models::{Rule, RuleSet};
1055    use crate::rest::schema_registry_client::SchemaRegistryClient;
1056    use crate::rules::cel::cel_field_executor::CelFieldExecutor;
1057    use crate::rules::encryption::encrypt_executor::{
1058        EncryptionExecutor, FakeClock, FieldEncryptionExecutor,
1059    };
1060    use crate::rules::encryption::localkms::local_driver::LocalKmsDriver;
1061    use crate::serdes::config::SchemaSelector;
1062    use crate::serdes::protobuf::tests::test::Author;
1063    use crate::serdes::protobuf::tests::test::DependencyMessage;
1064    use crate::serdes::protobuf::tests::test::TestMessage;
1065    use crate::serdes::protobuf::tests::test::author::PiiOneof;
1066    use crate::serdes::serde::{SerdeFormat, SerdeHeaders, header_schema_id_serializer};
1067    use std::collections::BTreeMap;
1068
1069    pub(crate) mod test {
1070        include!("../codegen/test/test.rs");
1071    }
1072
1073    #[tokio::test]
1074    async fn test_basic_serialization() {
1075        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1076        let client = MockSchemaRegistryClient::new(client_conf);
1077        let ser_conf = SerializerConfig::default();
1078        let obj = Author {
1079            name: "Kafka".to_string(),
1080            id: 123,
1081            picture: vec![1u8, 2u8, 3u8],
1082            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1083            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1084        };
1085        let rule_registry = RuleRegistry::new();
1086        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1087            &client,
1088            default_reference_subject_name_strategy,
1089            Some(rule_registry.clone()),
1090            ser_conf,
1091        )
1092        .unwrap();
1093        let ser_ctx = SerializationContext {
1094            topic: "test".to_string(),
1095            serde_type: SerdeType::Value,
1096            serde_format: SerdeFormat::Protobuf,
1097            headers: None,
1098        };
1099        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1100
1101        let deser = ProtobufDeserializer::new(
1102            &client,
1103            Some(rule_registry.clone()),
1104            DeserializerConfig::default(),
1105        )
1106        .unwrap();
1107        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1108        assert_eq!(obj2, obj);
1109    }
1110
1111    #[tokio::test]
1112    async fn test_guid_in_header() {
1113        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1114        let client = MockSchemaRegistryClient::new(client_conf);
1115        let mut ser_conf = SerializerConfig::default();
1116        ser_conf.schema_id_serializer = header_schema_id_serializer;
1117        let obj = Author {
1118            name: "Kafka".to_string(),
1119            id: 123,
1120            picture: vec![1u8, 2u8, 3u8],
1121            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1122            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1123        };
1124        let rule_registry = RuleRegistry::new();
1125        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1126            &client,
1127            default_reference_subject_name_strategy,
1128            Some(rule_registry.clone()),
1129            ser_conf,
1130        )
1131        .unwrap();
1132        let ser_ctx = SerializationContext {
1133            topic: "test".to_string(),
1134            serde_type: SerdeType::Value,
1135            serde_format: SerdeFormat::Protobuf,
1136            headers: Some(SerdeHeaders::default()),
1137        };
1138        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1139
1140        let deser = ProtobufDeserializer::new(
1141            &client,
1142            Some(rule_registry.clone()),
1143            DeserializerConfig::default(),
1144        )
1145        .unwrap();
1146        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1147        assert_eq!(obj2, obj);
1148    }
1149
1150    #[tokio::test]
1151    async fn test_basic_serialization_with_file_desc_set() {
1152        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1153        let client = MockSchemaRegistryClient::new(client_conf);
1154        let ser_conf = SerializerConfig::default();
1155        let obj = Author {
1156            name: "Kafka".to_string(),
1157            id: 123,
1158            picture: vec![1u8, 2u8, 3u8],
1159            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1160            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1161        };
1162        let rule_registry = RuleRegistry::new();
1163        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1164            &client,
1165            default_reference_subject_name_strategy,
1166            Some(rule_registry.clone()),
1167            ser_conf,
1168        )
1169        .unwrap();
1170        let ser_ctx = SerializationContext {
1171            topic: "test".to_string(),
1172            serde_type: SerdeType::Value,
1173            serde_format: SerdeFormat::Protobuf,
1174            headers: None,
1175        };
1176        let fds: FileDescriptorSet = FileDescriptorSet::decode(TEST_FILE_DESCRIPTOR_SET).unwrap();
1177        let bytes = ser
1178            .serialize_with_file_desc_set(&ser_ctx, &obj, "test.Author", fds)
1179            .await
1180            .unwrap();
1181
1182        let deser = ProtobufDeserializer::new(
1183            &client,
1184            Some(rule_registry.clone()),
1185            DeserializerConfig::default(),
1186        )
1187        .unwrap();
1188        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1189        assert_eq!(obj2, obj);
1190    }
1191
1192    #[tokio::test]
1193    async fn test_serialize_reference() {
1194        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1195        let client = MockSchemaRegistryClient::new(client_conf);
1196        let ser_conf = SerializerConfig::default();
1197        let msg = TestMessage {
1198            test_string: "hi".to_string(),
1199            test_bool: true,
1200            test_bytes: vec![1u8, 2u8, 3u8],
1201            test_double: 1.23,
1202            test_float: 3.45,
1203            test_fixed32: 67,
1204            test_fixed64: 89,
1205            test_int32: 100,
1206            test_int64: 200,
1207            test_sfixed32: 300,
1208            test_sfixed64: 400,
1209            test_sint32: 500,
1210            test_sint64: 600,
1211            test_uint32: 700,
1212            test_uint64: 800,
1213        };
1214        let obj = DependencyMessage {
1215            is_active: true,
1216            test_message: Some(msg),
1217        };
1218        let rule_registry = RuleRegistry::new();
1219        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1220            &client,
1221            default_reference_subject_name_strategy,
1222            Some(rule_registry.clone()),
1223            ser_conf,
1224        )
1225        .unwrap();
1226        let ser_ctx = SerializationContext {
1227            topic: "test".to_string(),
1228            serde_type: SerdeType::Value,
1229            serde_format: SerdeFormat::Protobuf,
1230            headers: None,
1231        };
1232        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1233
1234        let deser = ProtobufDeserializer::new(
1235            &client,
1236            Some(rule_registry.clone()),
1237            DeserializerConfig::default(),
1238        )
1239        .unwrap();
1240        let obj2: DependencyMessage = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1241        assert_eq!(obj2, obj);
1242    }
1243
1244    #[tokio::test]
1245    async fn test_cel_field() {
1246        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1247        let client = MockSchemaRegistryClient::new(client_conf);
1248        let ser_conf = SerializerConfig::new(
1249            false,
1250            Some(SchemaSelector::LatestVersion),
1251            true,
1252            false,
1253            HashMap::new(),
1254        );
1255        let mut obj = Author {
1256            name: "Kafka".to_string(),
1257            id: 123,
1258            picture: vec![1u8, 2u8, 3u8],
1259            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1260            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1261        };
1262        let rule = Rule {
1263            name: "test-cel".to_string(),
1264            doc: None,
1265            kind: Some(Kind::Transform),
1266            mode: Some(Mode::Write),
1267            r#type: "CEL_FIELD".to_string(),
1268            tags: None,
1269            params: None,
1270            expr: Some("typeName == 'STRING' ; value + '-suffix'".to_string()),
1271            on_success: None,
1272            on_failure: None,
1273            disabled: None,
1274        };
1275        let rule_set = RuleSet {
1276            migration_rules: None,
1277            domain_rules: Some(vec![rule]),
1278            encoding_rules: None,
1279            enable_at: None,
1280        };
1281        let schema = Schema {
1282            schema_type: Some("PROTOBUF".to_string()),
1283            references: None,
1284            metadata: None,
1285            rule_set: Some(Box::new(rule_set)),
1286            schema: schema_to_str(&obj.descriptor().parent_file()).unwrap(),
1287        };
1288        client
1289            .register_schema("test-value", &schema, false)
1290            .await
1291            .unwrap();
1292        let rule_registry = RuleRegistry::new();
1293        rule_registry.register_executor(CelFieldExecutor::new());
1294        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1295            &client,
1296            default_reference_subject_name_strategy,
1297            Some(rule_registry.clone()),
1298            ser_conf,
1299        )
1300        .unwrap();
1301        let ser_ctx = SerializationContext {
1302            topic: "test".to_string(),
1303            serde_type: SerdeType::Value,
1304            serde_format: SerdeFormat::Protobuf,
1305            headers: None,
1306        };
1307        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1308
1309        let deser = ProtobufDeserializer::new(
1310            &client,
1311            Some(rule_registry.clone()),
1312            DeserializerConfig::default(),
1313        )
1314        .unwrap();
1315
1316        obj = Author {
1317            name: "Kafka-suffix".to_string(),
1318            id: 123,
1319            picture: vec![1u8, 2u8, 3u8],
1320            works: vec![
1321                "Metamorphosis-suffix".to_string(),
1322                "The Trial-suffix".to_string(),
1323            ],
1324            pii_oneof: Some(PiiOneof::OneofString("oneof-suffix".to_string())),
1325        };
1326        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1327        assert_eq!(obj2, obj);
1328    }
1329
1330    #[tokio::test]
1331    async fn test_encryption() {
1332        LocalKmsDriver::register();
1333
1334        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1335        let client = MockSchemaRegistryClient::new(client_conf);
1336        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1337        let ser_conf = SerializerConfig::new(false, None, true, false, rule_conf.clone());
1338        let deser_conf = DeserializerConfig::new(None, false, rule_conf);
1339        let mut obj = Author {
1340            name: "Kafka".to_string(),
1341            id: 123,
1342            picture: vec![1u8, 2u8, 3u8],
1343            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1344            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1345        };
1346        let rule = Rule {
1347            name: "test-encrypt".to_string(),
1348            doc: None,
1349            kind: Some(Kind::Transform),
1350            mode: Some(Mode::WriteRead),
1351            r#type: "ENCRYPT".to_string(),
1352            tags: Some(vec!["PII".to_string()]),
1353            params: Some(BTreeMap::from([
1354                ("encrypt.kek.name".to_string(), "kek1".to_string()),
1355                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1356                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1357            ])),
1358            expr: None,
1359            on_success: None,
1360            on_failure: Some("ERROR,NONE".to_string()),
1361            disabled: None,
1362        };
1363        let rule_set = RuleSet {
1364            migration_rules: None,
1365            domain_rules: Some(vec![rule]),
1366            encoding_rules: None,
1367            enable_at: None,
1368        };
1369        let schema = Schema {
1370            schema_type: Some("PROTOBUF".to_string()),
1371            references: None,
1372            metadata: None,
1373            rule_set: Some(Box::new(rule_set)),
1374            schema: schema_to_str(&obj.descriptor().parent_file()).unwrap(),
1375        };
1376        client
1377            .register_schema("test-value", &schema, false)
1378            .await
1379            .unwrap();
1380        let rule_registry = RuleRegistry::new();
1381        rule_registry.register_executor(FieldEncryptionExecutor::<MockDekRegistryClient>::new(
1382            FakeClock::new(0),
1383        ));
1384        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1385            &client,
1386            default_reference_subject_name_strategy,
1387            Some(rule_registry.clone()),
1388            ser_conf,
1389        )
1390        .unwrap();
1391        let ser_ctx = SerializationContext {
1392            topic: "test".to_string(),
1393            serde_type: SerdeType::Value,
1394            serde_format: SerdeFormat::Protobuf,
1395            headers: None,
1396        };
1397        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1398
1399        let deser =
1400            ProtobufDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1401
1402        obj = Author {
1403            name: "Kafka".to_string(),
1404            id: 123,
1405            picture: vec![1u8, 2u8, 3u8],
1406            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1407            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1408        };
1409        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1410        assert_eq!(obj2, obj);
1411    }
1412
1413    #[tokio::test]
1414    async fn test_payload_encryption() {
1415        LocalKmsDriver::register();
1416
1417        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1418        let client = MockSchemaRegistryClient::new(client_conf);
1419        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1420        let ser_conf = SerializerConfig::new(false, None, true, false, rule_conf.clone());
1421        let deser_conf = DeserializerConfig::new(None, false, rule_conf);
1422        let mut obj = Author {
1423            name: "Kafka".to_string(),
1424            id: 123,
1425            picture: vec![1u8, 2u8, 3u8],
1426            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1427            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1428        };
1429        let rule = Rule {
1430            name: "test-encrypt".to_string(),
1431            doc: None,
1432            kind: Some(Kind::Transform),
1433            mode: Some(Mode::WriteRead),
1434            r#type: "ENCRYPT_PAYLOAD".to_string(),
1435            tags: None,
1436            params: Some(BTreeMap::from([
1437                ("encrypt.kek.name".to_string(), "kek1".to_string()),
1438                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1439                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1440            ])),
1441            expr: None,
1442            on_success: None,
1443            on_failure: Some("ERROR,NONE".to_string()),
1444            disabled: None,
1445        };
1446        let rule_set = RuleSet {
1447            migration_rules: None,
1448            domain_rules: None,
1449            encoding_rules: Some(vec![rule]),
1450            enable_at: None,
1451        };
1452        let schema = Schema {
1453            schema_type: Some("PROTOBUF".to_string()),
1454            references: None,
1455            metadata: None,
1456            rule_set: Some(Box::new(rule_set)),
1457            schema: schema_to_str(&obj.descriptor().parent_file()).unwrap(),
1458        };
1459        client
1460            .register_schema("test-value", &schema, false)
1461            .await
1462            .unwrap();
1463        let rule_registry = RuleRegistry::new();
1464        rule_registry.register_executor(EncryptionExecutor::<MockDekRegistryClient>::new(
1465            FakeClock::new(0),
1466        ));
1467        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1468            &client,
1469            default_reference_subject_name_strategy,
1470            Some(rule_registry.clone()),
1471            ser_conf,
1472        )
1473        .unwrap();
1474        let ser_ctx = SerializationContext {
1475            topic: "test".to_string(),
1476            serde_type: SerdeType::Value,
1477            serde_format: SerdeFormat::Protobuf,
1478            headers: None,
1479        };
1480        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1481
1482        let deser =
1483            ProtobufDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1484
1485        obj = Author {
1486            name: "Kafka".to_string(),
1487            id: 123,
1488            picture: vec![1u8, 2u8, 3u8],
1489            works: vec!["Metamorphosis".to_string(), "The Trial".to_string()],
1490            pii_oneof: Some(PiiOneof::OneofString("oneof".to_string())),
1491        };
1492        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1493        assert_eq!(obj2, obj);
1494    }
1495
1496    fn proto_demo_obj() -> Author {
1497        Author {
1498            name: "Kafka".to_string(),
1499            id: 123,
1500            picture: vec![1u8, 2u8],
1501            works: vec!["Metamorphosis".to_string()],
1502            pii_oneof: None,
1503        }
1504    }
1505
1506    fn proto_demo_schema(obj: &Author) -> Schema {
1507        Schema {
1508            schema_type: Some("PROTOBUF".to_string()),
1509            references: None,
1510            metadata: None,
1511            rule_set: None,
1512            schema: schema_to_str(&obj.descriptor().parent_file()).unwrap(),
1513        }
1514    }
1515
1516    fn proto_make_association(
1517        resource_id: &str,
1518        subject: &str,
1519        association_type: &str,
1520    ) -> crate::rest::models::AssociationCreateOrUpdateRequest {
1521        use crate::rest::models::{
1522            AssociationCreateOrUpdateInfo, AssociationCreateOrUpdateRequest,
1523        };
1524        AssociationCreateOrUpdateRequest {
1525            resource_name: Some("topic1".to_string()),
1526            resource_namespace: Some("-".to_string()),
1527            resource_id: Some(resource_id.to_string()),
1528            resource_type: Some("topic".to_string()),
1529            associations: Some(vec![AssociationCreateOrUpdateInfo {
1530                subject: Some(subject.to_string()),
1531                association_type: Some(association_type.to_string()),
1532                lifecycle: None,
1533                frozen: None,
1534                schema: None,
1535                normalize: None,
1536            }]),
1537        }
1538    }
1539
1540    #[tokio::test]
1541    async fn test_protobuf_serde_with_associated_name_strategy() {
1542        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1543        let client = MockSchemaRegistryClient::new(client_conf);
1544
1545        let obj = proto_demo_obj();
1546        let schema = proto_demo_schema(&obj);
1547        client
1548            .register_schema("my-custom-subject", &schema, false)
1549            .await
1550            .unwrap();
1551        client
1552            .create_association(&proto_make_association(
1553                "lkc-123:topic1",
1554                "my-custom-subject",
1555                "value",
1556            ))
1557            .await
1558            .unwrap();
1559
1560        let rule_registry = RuleRegistry::new();
1561        let mut ser_conf = SerializerConfig::new(
1562            false,
1563            Some(SchemaSelector::LatestVersion),
1564            false,
1565            false,
1566            HashMap::new(),
1567        );
1568        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1569        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1570            &client,
1571            default_reference_subject_name_strategy,
1572            Some(rule_registry.clone()),
1573            ser_conf,
1574        )
1575        .unwrap();
1576        let ser_ctx = SerializationContext {
1577            topic: "topic1".to_string(),
1578            serde_type: SerdeType::Value,
1579            serde_format: SerdeFormat::Protobuf,
1580            headers: None,
1581        };
1582        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1583
1584        let mut deser_conf = DeserializerConfig::default();
1585        deser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1586        let deser =
1587            ProtobufDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1588        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1589        assert_eq!(obj2, obj);
1590    }
1591
1592    #[tokio::test]
1593    async fn test_protobuf_serde_with_associated_name_strategy_fallback_to_topic() {
1594        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1595        let client = MockSchemaRegistryClient::new(client_conf);
1596
1597        let obj = proto_demo_obj();
1598        let schema = proto_demo_schema(&obj);
1599        // Register under topic name strategy subject (no association)
1600        client
1601            .register_schema("topic1-value", &schema, false)
1602            .await
1603            .unwrap();
1604
1605        let rule_registry = RuleRegistry::new();
1606        let mut ser_conf = SerializerConfig::new(
1607            false,
1608            Some(SchemaSelector::LatestVersion),
1609            false,
1610            false,
1611            HashMap::new(),
1612        );
1613        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1614        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1615            &client,
1616            default_reference_subject_name_strategy,
1617            Some(rule_registry.clone()),
1618            ser_conf,
1619        )
1620        .unwrap();
1621        let ser_ctx = SerializationContext {
1622            topic: "topic1".to_string(),
1623            serde_type: SerdeType::Value,
1624            serde_format: SerdeFormat::Protobuf,
1625            headers: None,
1626        };
1627        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1628
1629        let deser = ProtobufDeserializer::new(
1630            &client,
1631            Some(rule_registry.clone()),
1632            DeserializerConfig::default(),
1633        )
1634        .unwrap();
1635        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1636        assert_eq!(obj2, obj);
1637    }
1638
1639    #[tokio::test]
1640    async fn test_protobuf_serde_with_associated_name_strategy_fallback_none() {
1641        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1642        let client = MockSchemaRegistryClient::new(client_conf);
1643
1644        let rule_registry = RuleRegistry::new();
1645        let mut ser_conf = SerializerConfig::new(
1646            false,
1647            Some(SchemaSelector::LatestVersion),
1648            false,
1649            false,
1650            HashMap::new(),
1651        );
1652        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1653        ser_conf.strategy_config =
1654            HashMap::from([(FALLBACK_TYPE_CONFIG.to_string(), "NONE".to_string())]);
1655        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1656            &client,
1657            default_reference_subject_name_strategy,
1658            Some(rule_registry.clone()),
1659            ser_conf,
1660        )
1661        .unwrap();
1662        let ser_ctx = SerializationContext {
1663            topic: "topic1".to_string(),
1664            serde_type: SerdeType::Value,
1665            serde_format: SerdeFormat::Protobuf,
1666            headers: None,
1667        };
1668        let result = ser.serialize(&ser_ctx, &proto_demo_obj()).await;
1669        assert!(result.is_err());
1670        assert!(
1671            result
1672                .unwrap_err()
1673                .to_string()
1674                .contains("Could not determine subject")
1675        );
1676    }
1677
1678    #[tokio::test]
1679    async fn test_protobuf_serde_with_associated_name_strategy_multiple_associations() {
1680        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1681        let client = MockSchemaRegistryClient::new(client_conf);
1682
1683        let obj = proto_demo_obj();
1684        let schema = proto_demo_schema(&obj);
1685        client
1686            .register_schema("subject1", &schema, false)
1687            .await
1688            .unwrap();
1689        client
1690            .register_schema("subject2", &schema, false)
1691            .await
1692            .unwrap();
1693        client
1694            .create_association(&proto_make_association(
1695                "lkc-123:topic1",
1696                "subject1",
1697                "value",
1698            ))
1699            .await
1700            .unwrap();
1701        client
1702            .create_association(&proto_make_association(
1703                "lkc-456:topic1",
1704                "subject2",
1705                "value",
1706            ))
1707            .await
1708            .unwrap();
1709
1710        let rule_registry = RuleRegistry::new();
1711        let mut ser_conf = SerializerConfig::new(
1712            false,
1713            Some(SchemaSelector::LatestVersion),
1714            false,
1715            false,
1716            HashMap::new(),
1717        );
1718        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1719        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1720            &client,
1721            default_reference_subject_name_strategy,
1722            Some(rule_registry.clone()),
1723            ser_conf,
1724        )
1725        .unwrap();
1726        let ser_ctx = SerializationContext {
1727            topic: "topic1".to_string(),
1728            serde_type: SerdeType::Value,
1729            serde_format: SerdeFormat::Protobuf,
1730            headers: None,
1731        };
1732        let result = ser.serialize(&ser_ctx, &proto_demo_obj()).await;
1733        assert!(result.is_err());
1734        assert!(
1735            result
1736                .unwrap_err()
1737                .to_string()
1738                .contains("multiple associated subjects found")
1739        );
1740    }
1741
1742    #[tokio::test]
1743    async fn test_protobuf_serde_with_associated_name_strategy_with_kafka_cluster_id() {
1744        use crate::rest::models::{
1745            AssociationCreateOrUpdateInfo, AssociationCreateOrUpdateRequest,
1746        };
1747
1748        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1749        let client = MockSchemaRegistryClient::new(client_conf);
1750
1751        let obj = proto_demo_obj();
1752        let schema = proto_demo_schema(&obj);
1753        client
1754            .register_schema("my-custom-subject", &schema, false)
1755            .await
1756            .unwrap();
1757
1758        let request = AssociationCreateOrUpdateRequest {
1759            resource_name: Some("topic1".to_string()),
1760            resource_namespace: Some("lkc-my-cluster".to_string()),
1761            resource_id: Some("lkc-my-cluster:topic1".to_string()),
1762            resource_type: Some("topic".to_string()),
1763            associations: Some(vec![AssociationCreateOrUpdateInfo {
1764                subject: Some("my-custom-subject".to_string()),
1765                association_type: Some("value".to_string()),
1766                lifecycle: None,
1767                frozen: None,
1768                schema: None,
1769                normalize: None,
1770            }]),
1771        };
1772        client.create_association(&request).await.unwrap();
1773
1774        let rule_registry = RuleRegistry::new();
1775        let mut ser_conf = SerializerConfig::new(
1776            false,
1777            Some(SchemaSelector::LatestVersion),
1778            false,
1779            false,
1780            HashMap::new(),
1781        );
1782        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1783        ser_conf.strategy_config = HashMap::from([(
1784            KAFKA_CLUSTER_ID_CONFIG.to_string(),
1785            "lkc-my-cluster".to_string(),
1786        )]);
1787        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1788            &client,
1789            default_reference_subject_name_strategy,
1790            Some(rule_registry.clone()),
1791            ser_conf,
1792        )
1793        .unwrap();
1794        let ser_ctx = SerializationContext {
1795            topic: "topic1".to_string(),
1796            serde_type: SerdeType::Value,
1797            serde_format: SerdeFormat::Protobuf,
1798            headers: None,
1799        };
1800        let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1801
1802        let mut deser_conf = DeserializerConfig::default();
1803        deser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1804        deser_conf.strategy_config = HashMap::from([(
1805            KAFKA_CLUSTER_ID_CONFIG.to_string(),
1806            "lkc-my-cluster".to_string(),
1807        )]);
1808        let deser =
1809            ProtobufDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1810        let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1811        assert_eq!(obj2, obj);
1812    }
1813
1814    #[tokio::test]
1815    async fn test_protobuf_serde_with_associated_name_strategy_caching() {
1816        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1817        let client = MockSchemaRegistryClient::new(client_conf);
1818
1819        let obj = proto_demo_obj();
1820        let schema = proto_demo_schema(&obj);
1821        client
1822            .register_schema("my-cached-subject", &schema, false)
1823            .await
1824            .unwrap();
1825        client
1826            .create_association(&proto_make_association(
1827                "lkc-123:topic1",
1828                "my-cached-subject",
1829                "value",
1830            ))
1831            .await
1832            .unwrap();
1833
1834        let rule_registry = RuleRegistry::new();
1835        let mut ser_conf = SerializerConfig::new(
1836            false,
1837            Some(SchemaSelector::LatestVersion),
1838            false,
1839            false,
1840            HashMap::new(),
1841        );
1842        ser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1843        let ser = ProtobufSerializer::with_reference_subject_name_strategy(
1844            &client,
1845            default_reference_subject_name_strategy,
1846            Some(rule_registry.clone()),
1847            ser_conf,
1848        )
1849        .unwrap();
1850        let ser_ctx = SerializationContext {
1851            topic: "topic1".to_string(),
1852            serde_type: SerdeType::Value,
1853            serde_format: SerdeFormat::Protobuf,
1854            headers: None,
1855        };
1856
1857        let mut deser_conf = DeserializerConfig::default();
1858        deser_conf.subject_name_strategy_type = SubjectNameStrategyType::Associated;
1859        let deser =
1860            ProtobufDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1861
1862        for _ in 0..5 {
1863            let bytes = ser.serialize(&ser_ctx, &obj).await.unwrap();
1864            let obj2: Author = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1865            assert_eq!(obj2, obj);
1866        }
1867    }
1868}