Skip to main content

schema_registry_client/serdes/
avro.rs

1use crate::rest::models::{Kind, Mode};
2use crate::rest::models::{Phase, Schema};
3use crate::rest::schema_registry_client::Client;
4use crate::serdes::config::{DeserializerConfig, SerializerConfig};
5use crate::serdes::rule_registry::RuleRegistry;
6use crate::serdes::serde::SerdeError::Serialization;
7use crate::serdes::serde::{
8    BaseDeserializer, BaseSerializer, FieldTransformer, FieldType, RuleContext, SchemaId, Serde,
9    SerdeError, SerdeFormat, SerdeSchema, SerdeType, SerdeValue, SerializationContext,
10    get_executor, get_executors,
11};
12use apache_avro::schema::{Name, RecordField, RecordSchema, UnionSchema};
13use apache_avro::types::Value;
14use async_recursion::async_recursion;
15use dashmap::DashMap;
16use futures::StreamExt;
17use futures::future::FutureExt;
18use serde::Serialize;
19use std::collections::{HashMap, HashSet};
20use std::io::Cursor;
21use std::sync::Arc;
22use uuid::Uuid;
23
24#[derive(Clone, Debug)]
25pub(crate) struct AvroSerde {
26    parsed_schemas: DashMap<Schema, (apache_avro::Schema, Vec<apache_avro::Schema>)>,
27}
28
29#[derive(Clone, Debug)]
30pub struct AvroSerializer<'a, T: Client> {
31    schema: Option<&'a Schema>,
32    base: BaseSerializer<'a, T>,
33    serde: AvroSerde,
34}
35
36impl<'a, T: Client + Sync> AvroSerializer<'a, T> {
37    pub fn new(
38        client: &'a T,
39        schema: Option<&'a Schema>,
40        rule_registry: Option<RuleRegistry>,
41        serializer_config: SerializerConfig,
42    ) -> Result<AvroSerializer<'a, T>, SerdeError> {
43        for executor in get_executors(rule_registry.as_ref()) {
44            executor.configure(client.config(), &serializer_config.rule_config)?;
45        }
46        Ok(AvroSerializer {
47            schema,
48            base: BaseSerializer::new(Serde::new(client, rule_registry), serializer_config),
49            serde: AvroSerde {
50                parsed_schemas: DashMap::new(),
51            },
52        })
53    }
54
55    pub async fn serialize_ser(
56        &self,
57        ctx: &SerializationContext,
58        value: impl Serialize,
59    ) -> Result<Vec<u8>, SerdeError> {
60        let v = apache_avro::to_value(value)?;
61        self.serialize(ctx, v).await
62    }
63
64    pub async fn serialize(
65        &self,
66        ctx: &SerializationContext,
67        value: Value,
68    ) -> Result<Vec<u8>, SerdeError> {
69        let mut value = value;
70        let strategy = self.base.config.subject_name_strategy;
71        let subject = strategy(&ctx.topic, &ctx.serde_type, self.schema);
72        let subject = subject.ok_or(Serialization(
73            "subject name strategy returned None".to_string(),
74        ))?;
75        let latest_schema = self
76            .base
77            .serde
78            .get_reader_schema(&subject, None, &self.base.config.use_schema)
79            .await?;
80
81        let schema_id;
82        if let Some(ref schema) = latest_schema {
83            schema_id = SchemaId::new(SerdeFormat::Avro, schema.id, schema.guid.clone(), None)?;
84        } else {
85            let schema = self
86                .schema
87                .ok_or(Serialization("schema needs to be set".to_string()))?;
88            if self.base.config.auto_register_schemas {
89                let rs = self
90                    .base
91                    .serde
92                    .client
93                    .register_schema(&subject, schema, self.base.config.normalize_schemas)
94                    .await?;
95                schema_id = SchemaId::new(SerdeFormat::Avro, rs.id, rs.guid.clone(), None)?;
96            } else {
97                let rs = self
98                    .base
99                    .serde
100                    .client
101                    .get_by_schema(&subject, schema, self.base.config.normalize_schemas, false)
102                    .await?;
103                schema_id = SchemaId::new(SerdeFormat::Avro, rs.id, rs.guid.clone(), None)?;
104            }
105        }
106
107        let schema_tuple;
108        if let Some(ref latest_schema) = latest_schema {
109            let schema = latest_schema.to_schema();
110            schema_tuple = self.get_parsed_schema(&schema).await?;
111            let field_transformer: FieldTransformer =
112                Box::new(|ctx, value| transform_fields(ctx, value).boxed());
113            let serde_value = self
114                .base
115                .serde
116                .execute_rules(
117                    ctx,
118                    &subject,
119                    Mode::Write,
120                    None,
121                    Some(&schema),
122                    Some(&SerdeSchema::Avro(schema_tuple.clone())),
123                    &SerdeValue::Avro(value),
124                    Some(Arc::new(field_transformer)),
125                )
126                .await?;
127            value = match serde_value {
128                SerdeValue::Avro(value) => value,
129                _ => return Err(Serialization("unexpected serde value".to_string())),
130            }
131        } else {
132            let schema = self
133                .schema
134                .ok_or(Serialization("schema needs to be set".to_string()))?;
135            schema_tuple = self.get_parsed_schema(schema).await?;
136        }
137
138        let mut encoded_bytes = apache_avro::to_avro_datum_schemata(
139            &schema_tuple.0,
140            schema_tuple.1.iter().collect(),
141            value,
142        )?;
143        if let Some(ref latest_schema) = latest_schema {
144            let schema = latest_schema.to_schema();
145            if let Some(ref rule_set) = schema.rule_set {
146                if rule_set.encoding_rules.is_some() {
147                    encoded_bytes = self
148                        .base
149                        .serde
150                        .execute_rules_with_phase(
151                            ctx,
152                            &subject,
153                            Phase::Encoding,
154                            Mode::Write,
155                            None,
156                            Some(&schema),
157                            None,
158                            &SerdeValue::new_bytes(SerdeFormat::Avro, &encoded_bytes),
159                            None,
160                        )
161                        .await?
162                        .as_bytes();
163                }
164            }
165        }
166
167        let id_ser = self.base.config.schema_id_serializer;
168        id_ser(&encoded_bytes, ctx, &schema_id)
169    }
170
171    async fn get_parsed_schema(
172        &self,
173        schema: &Schema,
174    ) -> Result<(apache_avro::Schema, Vec<apache_avro::Schema>), SerdeError> {
175        let parsed_schema = self.serde.parsed_schemas.get(schema);
176        if let Some(parsed_schema) = parsed_schema {
177            return Ok(parsed_schema.clone());
178        }
179        let mut schemas = Vec::new();
180        resolve_named_schema(
181            schema,
182            self.base.serde.client,
183            &mut schemas,
184            &mut HashSet::new(),
185        )
186        .await?;
187        let parsed_schema = apache_avro::Schema::parse_str_with_list(
188            &schema.schema,
189            &schemas.iter().map(|s| s.as_str()).collect::<Vec<&str>>(),
190        )?;
191        self.serde
192            .parsed_schemas
193            .insert(schema.clone(), parsed_schema.clone());
194        Ok(parsed_schema)
195    }
196
197    fn close(&mut self) {}
198}
199
200async fn transform_fields(
201    ctx: &mut RuleContext,
202    value: &SerdeValue,
203) -> Result<SerdeValue, SerdeError> {
204    if let Some(SerdeSchema::Avro((s, named))) = ctx.parsed_target.clone() {
205        if let SerdeValue::Avro(v) = value {
206            let value = transform(ctx, &s, &named, v).await?;
207            return Ok(SerdeValue::Avro(value));
208        }
209    }
210    Ok(value.clone())
211}
212
213#[derive(Clone, Debug, PartialEq)]
214pub struct NamedValue {
215    pub name: Option<Name>,
216    pub value: Value,
217}
218
219#[derive(Clone, Debug)]
220pub struct AvroDeserializer<'a, T: Client> {
221    base: BaseDeserializer<'a, T>,
222    serde: AvroSerde,
223}
224
225impl<'a, T: Client + Sync> AvroDeserializer<'a, T> {
226    pub fn new(
227        client: &'a T,
228        rule_registry: Option<RuleRegistry>,
229        deserializer_config: DeserializerConfig,
230    ) -> Result<AvroDeserializer<'a, T>, SerdeError> {
231        for executor in get_executors(rule_registry.as_ref()) {
232            executor.configure(client.config(), &deserializer_config.rule_config)?;
233        }
234        Ok(AvroDeserializer {
235            base: BaseDeserializer::new(Serde::new(client, rule_registry), deserializer_config),
236            serde: AvroSerde {
237                parsed_schemas: DashMap::new(),
238            },
239        })
240    }
241
242    pub async fn deserialize(
243        &self,
244        ctx: &SerializationContext,
245        data: &[u8],
246    ) -> Result<NamedValue, SerdeError> {
247        let strategy = self.base.config.subject_name_strategy;
248        let mut subject = strategy(&ctx.topic, &ctx.serde_type, None);
249        let mut latest_schema = None;
250        let has_subject = subject.is_some();
251        if has_subject {
252            latest_schema = self
253                .base
254                .serde
255                .get_reader_schema(
256                    subject.as_ref().unwrap(),
257                    None,
258                    &self.base.config.use_schema,
259                )
260                .await?;
261        }
262
263        let mut schema_id = SchemaId::new(SerdeFormat::Avro, None, None, None)?;
264        let id_deser = self.base.config.schema_id_deserializer;
265        let bytes_read = id_deser(data, ctx, &mut schema_id)?;
266        let mut data = &data[bytes_read..];
267
268        let writer_schema_raw = self
269            .base
270            .get_writer_schema(&schema_id, subject.as_deref(), None)
271            .await?;
272        let (writer_schema, writer_named) = self.get_parsed_schema(&writer_schema_raw).await?;
273
274        if !has_subject {
275            subject = strategy(&ctx.topic, &ctx.serde_type, Some(&writer_schema_raw));
276            if let Some(subject) = subject.as_ref() {
277                latest_schema = self
278                    .base
279                    .serde
280                    .get_reader_schema(subject, None, &self.base.config.use_schema)
281                    .await?;
282            }
283        }
284        let subject = subject.unwrap();
285        let serde_value;
286        if let Some(ref rule_set) = writer_schema_raw.rule_set {
287            if rule_set.encoding_rules.is_some() {
288                serde_value = self
289                    .base
290                    .serde
291                    .execute_rules_with_phase(
292                        ctx,
293                        &subject,
294                        Phase::Encoding,
295                        Mode::Read,
296                        None,
297                        Some(&writer_schema_raw),
298                        None,
299                        &SerdeValue::new_bytes(SerdeFormat::Avro, data),
300                        None,
301                    )
302                    .await?
303                    .as_bytes();
304                data = &serde_value;
305            }
306        }
307
308        let migrations;
309        let reader_schema_raw;
310        let reader_schema;
311        let reader_named;
312        if let Some(ref latest_schema) = latest_schema {
313            migrations = self
314                .base
315                .serde
316                .get_migrations(&subject, &writer_schema_raw, latest_schema, None)
317                .await?;
318            reader_schema_raw = latest_schema.to_schema();
319            (reader_schema, reader_named) = self.get_parsed_schema(&reader_schema_raw).await?;
320        } else {
321            migrations = Vec::new();
322            reader_schema_raw = writer_schema_raw.clone();
323            reader_schema = writer_schema.clone();
324            reader_named = writer_named.clone();
325        }
326
327        let mut reader = Cursor::new(data);
328        let mut value;
329        if let Some(ref latest_schema) = latest_schema {
330            value = apache_avro::from_avro_datum_schemata(
331                &writer_schema,
332                writer_named.iter().collect(),
333                &mut reader,
334                None,
335            )?;
336            let json = from_avro_value(value.clone())?;
337            let mut serde_value = SerdeValue::Json(json);
338            serde_value = self
339                .base
340                .serde
341                .execute_migrations(ctx, &subject, &migrations, &serde_value)
342                .await?;
343            value = match serde_value {
344                SerdeValue::Json(v) => to_avro_value(&value, &v)?,
345                _ => return Err(Serialization("unexpected serde value".to_string())),
346            }
347        } else {
348            value = apache_avro::from_avro_datum_reader_schemata(
349                &writer_schema,
350                writer_named.iter().collect(),
351                &mut reader,
352                Some(&reader_schema),
353                reader_named.iter().collect(),
354            )?;
355        }
356
357        let field_transformer: FieldTransformer =
358            Box::new(|ctx, value| transform_fields(ctx, value).boxed());
359        let serde_value = self
360            .base
361            .serde
362            .execute_rules(
363                ctx,
364                &subject,
365                Mode::Read,
366                None,
367                Some(&reader_schema_raw),
368                Some(&SerdeSchema::Avro((
369                    reader_schema.clone(),
370                    reader_named.clone(),
371                ))),
372                &SerdeValue::Avro(value),
373                Some(Arc::new(field_transformer)),
374            )
375            .await?;
376        value = match serde_value {
377            SerdeValue::Avro(value) => value,
378            _ => return Err(Serialization("unexpected serde value".to_string())),
379        };
380
381        Ok(NamedValue {
382            name: self.get_name(&reader_schema),
383            value,
384        })
385    }
386
387    fn get_name(&self, schema: &apache_avro::Schema) -> Option<Name> {
388        match schema {
389            apache_avro::Schema::Record(schema) => Some(schema.name.clone()),
390            _ => None,
391        }
392    }
393
394    async fn get_parsed_schema(
395        &self,
396        schema: &Schema,
397    ) -> Result<(apache_avro::Schema, Vec<apache_avro::Schema>), SerdeError> {
398        let parsed_schema = self.serde.parsed_schemas.get(schema);
399        if let Some(parsed_schema) = parsed_schema {
400            return Ok(parsed_schema.clone());
401        }
402        let mut schemas = Vec::new();
403        resolve_named_schema(
404            schema,
405            self.base.serde.client,
406            &mut schemas,
407            &mut HashSet::new(),
408        )
409        .await?;
410        let parsed_schema = apache_avro::Schema::parse_str_with_list(
411            &schema.schema,
412            &schemas.iter().map(|s| s.as_str()).collect::<Vec<&str>>(),
413        )?;
414        self.serde
415            .parsed_schemas
416            .insert(schema.clone(), parsed_schema.clone());
417        Ok(parsed_schema)
418    }
419}
420
421#[async_recursion]
422async fn resolve_named_schema<'a, T>(
423    schema: &Schema,
424    client: &'a T,
425    schemas: &mut Vec<String>,
426    visited: &mut HashSet<String>,
427) -> Result<(), SerdeError>
428where
429    T: Client + Sync,
430{
431    if let Some(refs) = schema.references.as_ref() {
432        for r in refs {
433            let name = r.name.clone().unwrap_or_default();
434            if visited.contains(&name) {
435                continue;
436            }
437            visited.insert(name);
438            let ref_schema = client
439                .get_version(
440                    &r.subject.clone().unwrap_or_default(),
441                    r.version.unwrap_or(-1),
442                    true,
443                    None,
444                )
445                .await?;
446            resolve_named_schema(&ref_schema.to_schema(), client, schemas, visited).await?;
447            schemas.push(ref_schema.schema.clone().unwrap_or_default());
448        }
449    }
450    Ok(())
451}
452
453#[async_recursion]
454async fn transform(
455    ctx: &mut RuleContext,
456    schema: &apache_avro::Schema,
457    named_schemas: &[apache_avro::Schema],
458    message: &Value,
459) -> Result<Value, SerdeError> {
460    match schema {
461        apache_avro::Schema::Union(union) => {
462            let subschema = resolve_union(union, message);
463            if subschema.is_none() {
464                return Ok(message.clone());
465            }
466            let result = transform(ctx, subschema.unwrap().1, named_schemas, message).await?;
467            return Ok(result);
468        }
469        apache_avro::Schema::Array(array) => {
470            if let Value::Array(items) = message {
471                let mut result = Vec::with_capacity(items.len());
472                for item in items {
473                    let item = transform(ctx, &array.items, named_schemas, item).await?;
474                    result.push(item);
475                }
476                return Ok(Value::Array(result));
477            }
478        }
479        apache_avro::Schema::Map(map) => {
480            if let Value::Map(values) = message {
481                let mut result: HashMap<String, Value> = HashMap::with_capacity(values.len());
482                for (key, value) in values {
483                    let value = transform(ctx, &map.types, named_schemas, value).await?;
484                    result.insert(key.clone(), value);
485                }
486                return Ok(Value::Map(result));
487            }
488        }
489        apache_avro::Schema::Record(record) => {
490            if let Value::Record(fields) = message {
491                let mut result = Vec::with_capacity(fields.len());
492                for field in fields {
493                    let field =
494                        transform_field_with_ctx(ctx, record, named_schemas, field, fields).await?;
495                    result.push(field);
496                }
497                return Ok(Value::Record(result));
498            }
499        }
500        _ => {}
501    }
502    if let Some(field_ctx) = ctx.current_field() {
503        field_ctx.set_field_type(get_type(schema));
504        let rule_tags = ctx
505            .rule
506            .tags
507            .clone()
508            .map(|v| HashSet::from_iter(v.into_iter()));
509        if rule_tags.is_none_or(|tags| !tags.is_disjoint(&field_ctx.tags)) {
510            let message_value = SerdeValue::Avro(message.clone());
511            let field_executor_type = ctx.rule.r#type.clone();
512            let executor = get_executor(ctx.rule_registry.as_ref(), &field_executor_type);
513            if let Some(executor) = executor {
514                let field_executor =
515                    executor
516                        .as_field_rule_executor()
517                        .ok_or(SerdeError::Rule(format!(
518                            "executor {field_executor_type} is not a field rule executor"
519                        )))?;
520                let new_value = field_executor.transform_field(ctx, &message_value).await?;
521                if let SerdeValue::Avro(v) = new_value {
522                    return Ok(v);
523                }
524            }
525        }
526    }
527    Ok(message.clone())
528}
529
530async fn transform_field_with_ctx(
531    ctx: &mut RuleContext,
532    schema: &RecordSchema,
533    named_schemas: &[apache_avro::Schema],
534    field: &(String, Value),
535    message: &[(String, Value)],
536) -> Result<(String, Value), SerdeError> {
537    let field_schema = schema
538        .fields
539        .iter()
540        .find(|f| f.name == field.0)
541        .ok_or(SerdeError::Rule(format!(
542            "field {} not found in schema {}",
543            field.0, schema.name
544        )))?;
545    let field_type = get_type(&field_schema.schema);
546    let name = field.0.to_string();
547    let full_name = schema.name.to_string() + "." + &name;
548    let message_value = SerdeValue::Avro(Value::Record(message.to_vec()));
549    ctx.enter_field(
550        message_value,
551        full_name,
552        name,
553        field_type,
554        get_inline_tags(field_schema),
555    );
556    let new_value = transform(ctx, &field_schema.schema, named_schemas, &field.1).await?;
557    if let Some(Kind::Condition) = ctx.rule.kind {
558        if let Value::Boolean(b) = new_value {
559            if !b {
560                return Err(SerdeError::RuleCondition(Box::new(ctx.rule.clone())));
561            }
562        }
563    }
564    ctx.exit_field();
565    Ok((field.0.clone(), new_value))
566}
567
568fn get_type(schema: &apache_avro::Schema) -> FieldType {
569    match schema {
570        apache_avro::Schema::Null => FieldType::Null,
571        apache_avro::Schema::Boolean => FieldType::Boolean,
572        apache_avro::Schema::Int => FieldType::Int,
573        apache_avro::Schema::Long => FieldType::Long,
574        apache_avro::Schema::Float => FieldType::Float,
575        apache_avro::Schema::Double => FieldType::Double,
576        apache_avro::Schema::Bytes => FieldType::Bytes,
577        apache_avro::Schema::String => FieldType::String,
578        apache_avro::Schema::Fixed(_) => FieldType::Fixed,
579        apache_avro::Schema::Enum(_) => FieldType::Enum,
580        apache_avro::Schema::Array(_) => FieldType::Array,
581        apache_avro::Schema::Map(_) => FieldType::Map,
582        apache_avro::Schema::Union(_) => FieldType::Combined,
583        apache_avro::Schema::Record(_) => FieldType::Record,
584        apache_avro::Schema::Decimal(_) => FieldType::Bytes,
585        apache_avro::Schema::BigDecimal => FieldType::Bytes,
586        apache_avro::Schema::Uuid => FieldType::String,
587        apache_avro::Schema::Date => FieldType::Int,
588        apache_avro::Schema::TimeMillis => FieldType::Int,
589        apache_avro::Schema::TimeMicros => FieldType::Long,
590        apache_avro::Schema::TimestampMillis => FieldType::Long,
591        apache_avro::Schema::TimestampMicros => FieldType::Long,
592        apache_avro::Schema::TimestampNanos => FieldType::Long,
593        apache_avro::Schema::LocalTimestampMillis => FieldType::Long,
594        apache_avro::Schema::LocalTimestampMicros => FieldType::Long,
595        apache_avro::Schema::LocalTimestampNanos => FieldType::Long,
596        apache_avro::Schema::Duration => FieldType::Fixed,
597        // TODO assume Ref is a record, is this correct?
598        apache_avro::Schema::Ref { name: _ } => FieldType::Record,
599    }
600}
601
602fn get_inline_tags(field: &RecordField) -> HashSet<String> {
603    let tags = field.custom_attributes.get("confluent:tags");
604    if let Some(serde_json::Value::Array(tags)) = tags {
605        return tags
606            .iter()
607            .filter_map(|v| v.as_str().map(|s| s.to_string()))
608            .collect();
609    }
610    HashSet::new()
611}
612
613fn resolve_union<'a>(
614    union: &'a UnionSchema,
615    message: &Value,
616) -> Option<(usize, &'a apache_avro::Schema)> {
617    union.find_schema_with_known_schemata::<apache_avro::Schema>(message, None, &None)
618}
619
620fn from_avro_value(value: Value) -> Result<serde_json::Value, SerdeError> {
621    Ok(serde_json::Value::try_from(value)?)
622}
623
624fn to_avro_value(input: &Value, value: &serde_json::Value) -> Result<Value, SerdeError> {
625    let result = match value {
626        serde_json::Value::Null => Value::Null,
627        serde_json::Value::Bool(b) => (*b).into(),
628        serde_json::Value::Number(n) => match input {
629            Value::Long(_l) => Value::Long(n.as_i64().unwrap()),
630            Value::Float(_f) => Value::Float(n.as_f64().unwrap() as f32),
631            Value::Double(_d) => Value::Double(n.as_f64().unwrap()),
632            Value::Date(_d) => Value::Date(n.as_i64().unwrap() as i32),
633            Value::TimeMillis(_t) => Value::TimeMillis(n.as_i64().unwrap() as i32),
634            Value::TimeMicros(_t) => Value::TimeMicros(n.as_i64().unwrap()),
635            Value::TimestampMillis(_t) => Value::TimestampMillis(n.as_i64().unwrap()),
636            Value::TimestampMicros(_t) => Value::TimestampMicros(n.as_i64().unwrap()),
637            Value::TimestampNanos(_t) => Value::TimestampNanos(n.as_i64().unwrap()),
638            Value::LocalTimestampMillis(_t) => Value::LocalTimestampMillis(n.as_i64().unwrap()),
639            Value::LocalTimestampMicros(_t) => Value::LocalTimestampMicros(n.as_i64().unwrap()),
640            Value::LocalTimestampNanos(_t) => Value::LocalTimestampNanos(n.as_i64().unwrap()),
641            _ => Value::Int(n.as_i64().unwrap() as i32),
642        },
643        serde_json::Value::String(s) => match input {
644            Value::Enum(i, _s) => Value::Enum(*i, s.to_string()),
645            Value::Uuid(_uuid) => Value::Uuid(Uuid::parse_str(s)?),
646            _ => s.as_str().into(),
647        },
648        serde_json::Value::Array(items) => match input {
649            Value::Bytes(_bytes) => {
650                Value::Bytes(items.iter().map(|v| v.as_u64().unwrap() as u8).collect())
651            }
652            Value::Fixed(size, _items) => Value::Fixed(
653                *size,
654                items.iter().map(|v| v.as_u64().unwrap() as u8).collect(),
655            ),
656            Value::Decimal(_d) => {
657                let items: Vec<u8> = items.iter().map(|v| v.as_u64().unwrap() as u8).collect();
658                Value::Decimal(items.into())
659            }
660            // TODO BigDecimal
661            _ => Value::Array(
662                items
663                    .iter()
664                    .map(|v| to_avro_value(input, v))
665                    .collect::<Result<Vec<Value>, SerdeError>>()?,
666            ),
667        },
668        serde_json::Value::Object(props) => match input {
669            Value::Record(fields) => {
670                let mut result = Vec::new();
671                // use the order of the input fields
672                for (k, _v) in fields {
673                    let v = props
674                        .get(k)
675                        .ok_or(Serialization(format!("missing field {k}")))?;
676                    result.push((k.to_string(), to_avro_value(input, v)?));
677                }
678                Value::Record(result)
679            }
680            _ => {
681                let mut result = HashMap::new();
682                for (k, v) in props {
683                    result.insert(k.to_string(), to_avro_value(input, v)?);
684                }
685                Value::Map(result)
686            }
687        },
688    };
689    Ok(result)
690}
691
692impl From<uuid::Error> for SerdeError {
693    fn from(value: uuid::Error) -> Self {
694        Serialization(format!("UUID error: {value}"))
695    }
696}
697
698#[cfg(test)]
699#[cfg(feature = "rules")]
700mod tests {
701    use super::*;
702    use crate::rest::client_config::ClientConfig;
703    use crate::rest::dek_registry_client::Client as DekClient;
704    use crate::rest::mock_dek_registry_client::MockDekRegistryClient;
705    use crate::rest::mock_schema_registry_client::MockSchemaRegistryClient;
706    use crate::rest::models::dek::Algorithm;
707    use crate::rest::models::{
708        CreateDekRequest, CreateKekRequest, Metadata, Rule, RuleSet, SchemaReference, ServerConfig,
709    };
710    use crate::rest::schema_registry_client::Client;
711    use crate::rules::cel::cel_executor::CelExecutor;
712    use crate::rules::cel::cel_field_executor::CelFieldExecutor;
713    use crate::rules::encryption::encrypt_executor::{
714        EncryptionExecutor, FakeClock, FieldEncryptionExecutor,
715    };
716    use crate::rules::encryption::localkms::local_driver::LocalKmsDriver;
717    use crate::rules::jsonata::jsonata_executor::JsonataExecutor;
718    use crate::serdes::config::SchemaSelector;
719    use crate::serdes::serde::{SerdeFormat, SerdeHeaders, header_schema_id_serializer};
720    use apache_avro::types::Value::{Record, Union};
721    use std::collections::BTreeMap;
722
723    #[tokio::test]
724    async fn test_basic_serialization() {
725        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
726        let client = MockSchemaRegistryClient::new(client_conf);
727        let ser_conf = SerializerConfig::default();
728        let schema_str = r#"
729        {
730            "type": "record",
731            "name": "test",
732            "fields": [
733                {"name": "intField", "type": "int"},
734                {"name": "doubleField", "type": "double"},
735                {"name": "stringField", "type": "string"},
736                {"name": "booleanField", "type": "boolean"},
737                {"name": "bytesField", "type": "bytes"}
738            ]
739        }
740        "#;
741        let schema = Schema {
742            schema_type: Some("AVRO".to_string()),
743            references: None,
744            metadata: None,
745            rule_set: None,
746            schema: schema_str.to_string(),
747        };
748        let fields = vec![
749            ("intField".to_string(), Value::Int(123)),
750            ("doubleField".to_string(), Value::Double(45.67)),
751            ("stringField".to_string(), Value::String("hi".to_string())),
752            ("booleanField".to_string(), Value::Boolean(true)),
753            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
754        ];
755        let obj = Record(fields.clone());
756        let rule_registry = RuleRegistry::new();
757        let ser = AvroSerializer::new(
758            &client,
759            Some(&schema),
760            Some(rule_registry.clone()),
761            ser_conf,
762        )
763        .unwrap();
764        let ser_ctx = SerializationContext {
765            topic: "test".to_string(),
766            serde_type: SerdeType::Value,
767            serde_format: SerdeFormat::Avro,
768            headers: None,
769        };
770        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
771
772        let deser = AvroDeserializer::new(
773            &client,
774            Some(rule_registry.clone()),
775            DeserializerConfig::default(),
776        )
777        .unwrap();
778        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
779        if let Record(v) = obj2.value {
780            assert_eq!(v, fields);
781        } else {
782            unreachable!();
783        }
784    }
785
786    #[tokio::test]
787    async fn test_guid_in_header() {
788        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
789        let client = MockSchemaRegistryClient::new(client_conf);
790        let mut ser_conf = SerializerConfig::default();
791        ser_conf.schema_id_serializer = header_schema_id_serializer;
792        let schema_str = r#"
793        {
794            "type": "record",
795            "name": "test",
796            "fields": [
797                {"name": "intField", "type": "int"},
798                {"name": "doubleField", "type": "double"},
799                {"name": "stringField", "type": "string"},
800                {"name": "booleanField", "type": "boolean"},
801                {"name": "bytesField", "type": "bytes"}
802            ]
803        }
804        "#;
805        let schema = Schema {
806            schema_type: Some("AVRO".to_string()),
807            references: None,
808            metadata: None,
809            rule_set: None,
810            schema: schema_str.to_string(),
811        };
812        let fields = vec![
813            ("intField".to_string(), Value::Int(123)),
814            ("doubleField".to_string(), Value::Double(45.67)),
815            ("stringField".to_string(), Value::String("hi".to_string())),
816            ("booleanField".to_string(), Value::Boolean(true)),
817            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
818        ];
819        let obj = Record(fields.clone());
820        let rule_registry = RuleRegistry::new();
821        let ser = AvroSerializer::new(
822            &client,
823            Some(&schema),
824            Some(rule_registry.clone()),
825            ser_conf,
826        )
827        .unwrap();
828        let ser_ctx = SerializationContext {
829            topic: "test".to_string(),
830            serde_type: SerdeType::Value,
831            serde_format: SerdeFormat::Avro,
832            headers: Some(SerdeHeaders::default()),
833        };
834        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
835
836        let deser = AvroDeserializer::new(
837            &client,
838            Some(rule_registry.clone()),
839            DeserializerConfig::default(),
840        )
841        .unwrap();
842        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
843        if let Record(v) = obj2.value {
844            assert_eq!(v, fields);
845        } else {
846            unreachable!();
847        }
848    }
849
850    #[tokio::test]
851    async fn test_union_with_references() {
852        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
853        let client = MockSchemaRegistryClient::new(client_conf);
854        let ser_conf = SerializerConfig::new(
855            false,
856            Some(SchemaSelector::LatestVersion),
857            true,
858            false,
859            HashMap::new(),
860        );
861        let ref_schema_str = r#"
862        {
863            "type": "record",
864            "name": "ref",
865            "fields": [
866                {"name": "intField", "type": "int"},
867                {"name": "doubleField", "type": "double"},
868                {"name": "stringField", "type": "string", "confluent:tags": ["PII"]},
869                {"name": "booleanField", "type": "boolean"},
870                {"name": "bytesField", "type": "bytes", "confluent:tags": ["PII"]}
871            ]
872        }
873        "#;
874        let ref_schema = Schema {
875            schema_type: Some("AVRO".to_string()),
876            references: None,
877            metadata: None,
878            rule_set: None,
879            schema: ref_schema_str.to_string(),
880        };
881        client
882            .register_schema("ref", &ref_schema, false)
883            .await
884            .unwrap();
885        let ref2_schema_str = r#"
886        {
887            "type": "record",
888            "name": "ref2",
889            "fields": [
890                {"name": "otherField", "type": "string"}
891            ]
892        }
893        "#;
894        let ref2_schema = Schema {
895            schema_type: Some("AVRO".to_string()),
896            references: None,
897            metadata: None,
898            rule_set: None,
899            schema: ref2_schema_str.to_string(),
900        };
901        client
902            .register_schema("ref2", &ref2_schema, false)
903            .await
904            .unwrap();
905        let schema_str = r#"["ref", "ref2"]"#;
906        let refs = vec![
907            SchemaReference {
908                name: Some("ref".to_string()),
909                subject: Some("ref".to_string()),
910                version: Some(1),
911            },
912            SchemaReference {
913                name: Some("ref2".to_string()),
914                subject: Some("ref2".to_string()),
915                version: Some(1),
916            },
917        ];
918        let schema = Schema {
919            schema_type: Some("AVRO".to_string()),
920            references: Some(refs),
921            metadata: None,
922            rule_set: None,
923            schema: schema_str.to_string(),
924        };
925        client
926            .register_schema("test-value", &schema, false)
927            .await
928            .unwrap();
929
930        let fields = vec![
931            ("intField".to_string(), Value::Int(123)),
932            ("doubleField".to_string(), Value::Double(45.67)),
933            ("stringField".to_string(), Value::String("hi".to_string())),
934            ("booleanField".to_string(), Value::Boolean(true)),
935            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
936        ];
937        let obj = Record(fields.clone());
938        let rule_registry = RuleRegistry::new();
939        let ser =
940            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
941        let ser_ctx = SerializationContext {
942            topic: "test".to_string(),
943            serde_type: SerdeType::Value,
944            serde_format: SerdeFormat::Avro,
945            headers: None,
946        };
947        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
948        let deser = AvroDeserializer::new(
949            &client,
950            Some(rule_registry.clone()),
951            DeserializerConfig::default(),
952        )
953        .unwrap();
954
955        let fields2 = vec![
956            ("intField".to_string(), Value::Int(123)),
957            ("doubleField".to_string(), Value::Double(45.67)),
958            ("stringField".to_string(), Value::String("hi".to_string())),
959            ("booleanField".to_string(), Value::Boolean(true)),
960            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
961        ];
962        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
963        if let Union(_, v) = obj2.value {
964            assert_eq!(*v, Record(fields2));
965        } else {
966            unreachable!();
967        }
968    }
969
970    #[tokio::test]
971    async fn test_cel_condition() {
972        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
973        let client = MockSchemaRegistryClient::new(client_conf);
974        let ser_conf = SerializerConfig::new(
975            false,
976            Some(SchemaSelector::LatestVersion),
977            true,
978            false,
979            HashMap::new(),
980        );
981        let schema_str = r#"
982        {
983            "type": "record",
984            "name": "test",
985            "fields": [
986                {"name": "intField", "type": "int"},
987                {"name": "doubleField", "type": "double"},
988                {"name": "stringField", "type": "string"},
989                {"name": "booleanField", "type": "boolean"},
990                {"name": "bytesField", "type": "bytes"}
991            ]
992        }
993        "#;
994        let rule = Rule {
995            name: "test-cel".to_string(),
996            doc: None,
997            kind: Some(Kind::Condition),
998            mode: Some(Mode::Write),
999            r#type: "CEL".to_string(),
1000            tags: None,
1001            params: None,
1002            expr: Some("message.stringField == 'hi'".to_string()),
1003            on_success: None,
1004            on_failure: None,
1005            disabled: None,
1006        };
1007        let rule_set = RuleSet {
1008            migration_rules: None,
1009            domain_rules: Some(vec![rule]),
1010            encoding_rules: None,
1011        };
1012        let schema = Schema {
1013            schema_type: Some("AVRO".to_string()),
1014            references: None,
1015            metadata: None,
1016            rule_set: Some(Box::new(rule_set)),
1017            schema: schema_str.to_string(),
1018        };
1019        client
1020            .register_schema("test-value", &schema, false)
1021            .await
1022            .unwrap();
1023        let fields = vec![
1024            ("intField".to_string(), Value::Int(123)),
1025            ("doubleField".to_string(), Value::Double(45.67)),
1026            ("stringField".to_string(), Value::String("hi".to_string())),
1027            ("booleanField".to_string(), Value::Boolean(true)),
1028            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1029        ];
1030        let obj = Record(fields.clone());
1031        let rule_registry = RuleRegistry::new();
1032        rule_registry.register_executor(CelExecutor::new());
1033        let ser =
1034            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
1035        let ser_ctx = SerializationContext {
1036            topic: "test".to_string(),
1037            serde_type: SerdeType::Value,
1038            serde_format: SerdeFormat::Avro,
1039            headers: None,
1040        };
1041        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
1042
1043        let deser = AvroDeserializer::new(
1044            &client,
1045            Some(rule_registry.clone()),
1046            DeserializerConfig::default(),
1047        )
1048        .unwrap();
1049
1050        let fields2 = vec![
1051            ("intField".to_string(), Value::Int(123)),
1052            ("doubleField".to_string(), Value::Double(45.67)),
1053            ("stringField".to_string(), Value::String("hi".to_string())),
1054            ("booleanField".to_string(), Value::Boolean(true)),
1055            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1056        ];
1057        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1058        if let Record(v) = obj2.value {
1059            assert_eq!(v, fields2);
1060        } else {
1061            unreachable!();
1062        }
1063    }
1064
1065    #[tokio::test]
1066    async fn test_cel_field() {
1067        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1068        let client = MockSchemaRegistryClient::new(client_conf);
1069        let ser_conf = SerializerConfig::new(
1070            false,
1071            Some(SchemaSelector::LatestVersion),
1072            true,
1073            false,
1074            HashMap::new(),
1075        );
1076        let schema_str = r#"
1077        {
1078            "type": "record",
1079            "name": "test",
1080            "fields": [
1081                {"name": "intField", "type": "int"},
1082                {"name": "doubleField", "type": "double"},
1083                {"name": "stringField", "type": "string"},
1084                {"name": "booleanField", "type": "boolean"},
1085                {"name": "bytesField", "type": "bytes"}
1086            ]
1087        }
1088        "#;
1089        let rule = Rule {
1090            name: "test-cel".to_string(),
1091            doc: None,
1092            kind: Some(Kind::Transform),
1093            mode: Some(Mode::Write),
1094            r#type: "CEL_FIELD".to_string(),
1095            tags: None,
1096            params: None,
1097            expr: Some("name == 'stringField' ; value + '-suffix'".to_string()),
1098            on_success: None,
1099            on_failure: None,
1100            disabled: None,
1101        };
1102        let rule_set = RuleSet {
1103            migration_rules: None,
1104            domain_rules: Some(vec![rule]),
1105            encoding_rules: None,
1106        };
1107        let schema = Schema {
1108            schema_type: Some("AVRO".to_string()),
1109            references: None,
1110            metadata: None,
1111            rule_set: Some(Box::new(rule_set)),
1112            schema: schema_str.to_string(),
1113        };
1114        client
1115            .register_schema("test-value", &schema, false)
1116            .await
1117            .unwrap();
1118        let fields = vec![
1119            ("intField".to_string(), Value::Int(123)),
1120            ("doubleField".to_string(), Value::Double(45.67)),
1121            ("stringField".to_string(), Value::String("hi".to_string())),
1122            ("booleanField".to_string(), Value::Boolean(true)),
1123            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1124        ];
1125        let obj = Record(fields.clone());
1126        let rule_registry = RuleRegistry::new();
1127        rule_registry.register_executor(CelFieldExecutor::new());
1128        let ser =
1129            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
1130        let ser_ctx = SerializationContext {
1131            topic: "test".to_string(),
1132            serde_type: SerdeType::Value,
1133            serde_format: SerdeFormat::Avro,
1134            headers: None,
1135        };
1136        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
1137
1138        let deser = AvroDeserializer::new(
1139            &client,
1140            Some(rule_registry.clone()),
1141            DeserializerConfig::default(),
1142        )
1143        .unwrap();
1144
1145        let fields2 = vec![
1146            ("intField".to_string(), Value::Int(123)),
1147            ("doubleField".to_string(), Value::Double(45.67)),
1148            (
1149                "stringField".to_string(),
1150                Value::String("hi-suffix".to_string()),
1151            ),
1152            ("booleanField".to_string(), Value::Boolean(true)),
1153            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1154        ];
1155        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1156        if let Record(v) = obj2.value {
1157            assert_eq!(v, fields2);
1158        } else {
1159            unreachable!();
1160        }
1161    }
1162
1163    #[tokio::test]
1164    async fn test_jsonata_with_cel_field() {
1165        let rule1_to_2 =
1166            "$merge([$sift($, function($v, $k) {$k != 'size'}), {'height': $.'size'}])";
1167
1168        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1169        let client = MockSchemaRegistryClient::new(client_conf);
1170        let server_config = ServerConfig {
1171            compatibility_group: Some("application.version".to_string()),
1172            ..Default::default()
1173        };
1174        client
1175            .update_config("test-value", &server_config)
1176            .await
1177            .unwrap();
1178
1179        let schema_str = r#"
1180        {
1181            "type": "record",
1182            "name": "old",
1183            "fields": [
1184                {"name": "name", "type": "string"},
1185                {"name": "size", "type": "int"},
1186                {"name": "version", "type": "int"}
1187            ]
1188        }
1189        "#;
1190        let metadata = Metadata {
1191            tags: None,
1192            properties: Some(BTreeMap::from([(
1193                "application.version".to_string(),
1194                "v1".to_string(),
1195            )])),
1196            sensitive: None,
1197        };
1198        let schema = Schema {
1199            schema_type: Some("AVRO".to_string()),
1200            references: None,
1201            metadata: Some(Box::new(metadata)),
1202            rule_set: None,
1203            schema: schema_str.to_string(),
1204        };
1205        client
1206            .register_schema("test-value", &schema, false)
1207            .await
1208            .unwrap();
1209        let schema_str = r#"
1210        {
1211            "type": "record",
1212            "name": "new",
1213            "fields": [
1214                {"name": "name", "type": "string"},
1215                {"name": "height", "type": "int"},
1216                {"name": "version", "type": "int"}
1217            ]
1218        }
1219        "#;
1220        let rule1 = Rule {
1221            name: "test-jsonata".to_string(),
1222            doc: None,
1223            kind: Some(Kind::Transform),
1224            mode: Some(Mode::Upgrade),
1225            r#type: "JSONATA".to_string(),
1226            tags: None,
1227            params: None,
1228            expr: Some(rule1_to_2.to_string()),
1229            on_success: None,
1230            on_failure: None,
1231            disabled: None,
1232        };
1233        let rule2 = Rule {
1234            name: "test-cel".to_string(),
1235            doc: None,
1236            kind: Some(Kind::Transform),
1237            mode: Some(Mode::Read),
1238            r#type: "CEL_FIELD".to_string(),
1239            tags: None,
1240            params: None,
1241            expr: Some("name == 'name' ; value + '-suffix'".to_string()),
1242            on_success: None,
1243            on_failure: None,
1244            disabled: None,
1245        };
1246        let rule_set = RuleSet {
1247            migration_rules: Some(vec![rule1]),
1248            domain_rules: Some(vec![rule2]),
1249            encoding_rules: None,
1250        };
1251        let metadata = Metadata {
1252            tags: None,
1253            properties: Some(BTreeMap::from([(
1254                "application.version".to_string(),
1255                "v2".to_string(),
1256            )])),
1257            sensitive: None,
1258        };
1259        let schema = Schema {
1260            schema_type: Some("AVRO".to_string()),
1261            references: None,
1262            metadata: Some(Box::new(metadata)),
1263            rule_set: Some(Box::new(rule_set)),
1264            schema: schema_str.to_string(),
1265        };
1266        client
1267            .register_schema("test-value", &schema, false)
1268            .await
1269            .unwrap();
1270        let fields = vec![
1271            ("name".to_string(), Value::String("alice".to_string())),
1272            ("size".to_string(), Value::Int(123)),
1273            ("version".to_string(), Value::Int(1)),
1274        ];
1275        let obj = Record(fields.clone());
1276        let rule_registry = RuleRegistry::new();
1277        rule_registry.register_executor(CelFieldExecutor::new());
1278        rule_registry.register_executor(JsonataExecutor::new());
1279        let ser_conf = SerializerConfig::new(
1280            false,
1281            Some(SchemaSelector::LatestWithMetadata(HashMap::from([(
1282                "application.version".to_string(),
1283                "v1".to_string(),
1284            )]))),
1285            false,
1286            false,
1287            HashMap::new(),
1288        );
1289        let ser =
1290            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
1291        let ser_ctx = SerializationContext {
1292            topic: "test".to_string(),
1293            serde_type: SerdeType::Value,
1294            serde_format: SerdeFormat::Avro,
1295            headers: None,
1296        };
1297        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
1298
1299        let deser_conf = DeserializerConfig::new(
1300            Some(SchemaSelector::LatestWithMetadata(HashMap::from([(
1301                "application.version".to_string(),
1302                "v2".to_string(),
1303            )]))),
1304            false,
1305            HashMap::new(),
1306        );
1307        let deser =
1308            AvroDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1309
1310        let fields2 = vec![
1311            (
1312                "name".to_string(),
1313                Value::String("alice-suffix".to_string()),
1314            ),
1315            ("height".to_string(), Value::Int(123)),
1316            ("version".to_string(), Value::Int(1)),
1317        ];
1318        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1319        if let Record(v) = obj2.value {
1320            assert_eq!(v, fields2);
1321        } else {
1322            unreachable!();
1323        }
1324    }
1325
1326    #[tokio::test]
1327    async fn test_encryption() {
1328        LocalKmsDriver::register();
1329
1330        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1331        let client = MockSchemaRegistryClient::new(client_conf);
1332        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1333        let ser_conf = SerializerConfig::new(
1334            false,
1335            Some(SchemaSelector::LatestVersion),
1336            false,
1337            false,
1338            rule_conf,
1339        );
1340        let schema_str = r#"
1341        {
1342            "type": "record",
1343            "name": "test",
1344            "fields": [
1345                {"name": "intField", "type": "int"},
1346                {"name": "doubleField", "type": "double"},
1347                {"name": "stringField", "type": "string", "confluent:tags": ["PII"]},
1348                {"name": "booleanField", "type": "boolean"},
1349                {"name": "bytesField", "type": "bytes", "confluent:tags": ["PII"]}
1350            ]
1351        }
1352        "#;
1353        let rule = Rule {
1354            name: "test-encrypt".to_string(),
1355            doc: None,
1356            kind: Some(Kind::Transform),
1357            mode: Some(Mode::WriteRead),
1358            r#type: "ENCRYPT".to_string(),
1359            tags: Some(vec!["PII".to_string()]),
1360            params: Some(BTreeMap::from([
1361                ("encrypt.kek.name".to_string(), "kek1".to_string()),
1362                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1363                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1364            ])),
1365            expr: None,
1366            on_success: None,
1367            on_failure: Some("ERROR,NONE".to_string()),
1368            disabled: None,
1369        };
1370        let rule_set = RuleSet {
1371            migration_rules: None,
1372            domain_rules: Some(vec![rule]),
1373            encoding_rules: None,
1374        };
1375        let schema = Schema {
1376            schema_type: Some("AVRO".to_string()),
1377            references: None,
1378            metadata: None,
1379            rule_set: Some(Box::new(rule_set)),
1380            schema: schema_str.to_string(),
1381        };
1382        client
1383            .register_schema("test-value", &schema, false)
1384            .await
1385            .unwrap();
1386        let fields = vec![
1387            ("intField".to_string(), Value::Int(123)),
1388            ("doubleField".to_string(), Value::Double(45.67)),
1389            ("stringField".to_string(), Value::String("hi".to_string())),
1390            ("booleanField".to_string(), Value::Boolean(true)),
1391            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1392        ];
1393        let obj = Record(fields.clone());
1394        let rule_registry = RuleRegistry::new();
1395        rule_registry.register_executor(FieldEncryptionExecutor::<MockDekRegistryClient>::new(
1396            FakeClock::new(0),
1397        ));
1398        let ser =
1399            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
1400        let ser_ctx = SerializationContext {
1401            topic: "test".to_string(),
1402            serde_type: SerdeType::Value,
1403            serde_format: SerdeFormat::Avro,
1404            headers: None,
1405        };
1406        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
1407        let deser = AvroDeserializer::new(
1408            &client,
1409            Some(rule_registry.clone()),
1410            DeserializerConfig::default(),
1411        )
1412        .unwrap();
1413
1414        let fields2 = vec![
1415            ("intField".to_string(), Value::Int(123)),
1416            ("doubleField".to_string(), Value::Double(45.67)),
1417            ("stringField".to_string(), Value::String("hi".to_string())),
1418            ("booleanField".to_string(), Value::Boolean(true)),
1419            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1420        ];
1421        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1422        if let Record(v) = obj2.value {
1423            assert_eq!(v, fields2);
1424        } else {
1425            unreachable!();
1426        }
1427    }
1428
1429    #[tokio::test]
1430    async fn test_payload_encryption() {
1431        LocalKmsDriver::register();
1432
1433        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1434        let client = MockSchemaRegistryClient::new(client_conf);
1435        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1436        let ser_conf = SerializerConfig::new(
1437            false,
1438            Some(SchemaSelector::LatestVersion),
1439            false,
1440            false,
1441            rule_conf,
1442        );
1443        let schema_str = r#"
1444        {
1445            "type": "record",
1446            "name": "test",
1447            "fields": [
1448                {"name": "intField", "type": "int"},
1449                {"name": "doubleField", "type": "double"},
1450                {"name": "stringField", "type": "string", "confluent:tags": ["PII"]},
1451                {"name": "booleanField", "type": "boolean"},
1452                {"name": "bytesField", "type": "bytes", "confluent:tags": ["PII"]}
1453            ]
1454        }
1455        "#;
1456        let rule = Rule {
1457            name: "test-encrypt".to_string(),
1458            doc: None,
1459            kind: Some(Kind::Transform),
1460            mode: Some(Mode::WriteRead),
1461            r#type: "ENCRYPT_PAYLOAD".to_string(),
1462            tags: None,
1463            params: Some(BTreeMap::from([
1464                ("encrypt.kek.name".to_string(), "kek1".to_string()),
1465                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1466                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1467            ])),
1468            expr: None,
1469            on_success: None,
1470            on_failure: Some("ERROR,NONE".to_string()),
1471            disabled: None,
1472        };
1473        let rule_set = RuleSet {
1474            migration_rules: None,
1475            domain_rules: None,
1476            encoding_rules: Some(vec![rule]),
1477        };
1478        let schema = Schema {
1479            schema_type: Some("AVRO".to_string()),
1480            references: None,
1481            metadata: None,
1482            rule_set: Some(Box::new(rule_set)),
1483            schema: schema_str.to_string(),
1484        };
1485        client
1486            .register_schema("test-value", &schema, false)
1487            .await
1488            .unwrap();
1489        let fields = vec![
1490            ("intField".to_string(), Value::Int(123)),
1491            ("doubleField".to_string(), Value::Double(45.67)),
1492            ("stringField".to_string(), Value::String("hi".to_string())),
1493            ("booleanField".to_string(), Value::Boolean(true)),
1494            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1495        ];
1496        let obj = Record(fields.clone());
1497        let rule_registry = RuleRegistry::new();
1498        rule_registry.register_executor(EncryptionExecutor::<MockDekRegistryClient>::new(
1499            FakeClock::new(0),
1500        ));
1501        let ser =
1502            AvroSerializer::new(&client, None, Some(rule_registry.clone()), ser_conf).unwrap();
1503        let ser_ctx = SerializationContext {
1504            topic: "test".to_string(),
1505            serde_type: SerdeType::Value,
1506            serde_format: SerdeFormat::Avro,
1507            headers: None,
1508        };
1509        let bytes = ser.serialize(&ser_ctx, obj).await.unwrap();
1510        let deser = AvroDeserializer::new(
1511            &client,
1512            Some(rule_registry.clone()),
1513            DeserializerConfig::default(),
1514        )
1515        .unwrap();
1516
1517        let fields2 = vec![
1518            ("intField".to_string(), Value::Int(123)),
1519            ("doubleField".to_string(), Value::Double(45.67)),
1520            ("stringField".to_string(), Value::String("hi".to_string())),
1521            ("booleanField".to_string(), Value::Boolean(true)),
1522            ("bytesField".to_string(), Value::Bytes(vec![1, 2, 3])),
1523        ];
1524        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1525        if let Record(v) = obj2.value {
1526            assert_eq!(v, fields2);
1527        } else {
1528            unreachable!();
1529        }
1530    }
1531
1532    #[tokio::test]
1533    async fn test_encryption_f1_preserialized() {
1534        LocalKmsDriver::register();
1535
1536        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1537        let client = MockSchemaRegistryClient::new(client_conf);
1538        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1539        let schema_str = r#"
1540        {
1541            "type": "record",
1542            "name": "f1Schema",
1543            "fields": [
1544                {"name": "f1", "type": "string", "confluent:tags": ["PII"]}
1545            ]
1546        }
1547        "#;
1548        let rule = Rule {
1549            name: "test-encrypt".to_string(),
1550            doc: None,
1551            kind: Some(Kind::Transform),
1552            mode: Some(Mode::WriteRead),
1553            r#type: "ENCRYPT".to_string(),
1554            tags: Some(vec!["PII".to_string()]),
1555            params: Some(BTreeMap::from([
1556                ("encrypt.kek.name".to_string(), "kek1-f1".to_string()),
1557                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1558                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1559            ])),
1560            expr: None,
1561            on_success: None,
1562            on_failure: Some("ERROR,ERROR".to_string()),
1563            disabled: None,
1564        };
1565        let rule_set = RuleSet {
1566            migration_rules: None,
1567            domain_rules: Some(vec![rule]),
1568            encoding_rules: None,
1569        };
1570        let schema = Schema {
1571            schema_type: Some("AVRO".to_string()),
1572            references: None,
1573            metadata: None,
1574            rule_set: Some(Box::new(rule_set)),
1575            schema: schema_str.to_string(),
1576        };
1577        client
1578            .register_schema("test-value", &schema, false)
1579            .await
1580            .unwrap();
1581        let fields = vec![("f1".to_string(), Value::String("hello world".to_string()))];
1582        let obj = Record(fields.clone());
1583        let rule_registry = RuleRegistry::new();
1584        rule_registry.register_executor(FieldEncryptionExecutor::<MockDekRegistryClient>::new(
1585            FakeClock::new(0),
1586        ));
1587
1588        let ser_ctx = SerializationContext {
1589            topic: "test".to_string(),
1590            serde_type: SerdeType::Value,
1591            serde_format: SerdeFormat::Avro,
1592            headers: None,
1593        };
1594        let deser_conf = DeserializerConfig::new(None, false, rule_conf);
1595        let deser =
1596            AvroDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1597
1598        let executor = rule_registry.get_executor("ENCRYPT").unwrap();
1599        let field_executor = executor
1600            .as_any()
1601            .downcast_ref::<FieldEncryptionExecutor<MockDekRegistryClient>>()
1602            .unwrap();
1603        let dek_client = field_executor.executor.client().unwrap();
1604        let kek_req = CreateKekRequest {
1605            name: "kek1-f1".to_string(),
1606            kms_type: "local-kms".to_string(),
1607            kms_key_id: "mykey".to_string(),
1608            kms_props: None,
1609            doc: None,
1610            shared: false,
1611        };
1612        dek_client.register_kek(kek_req).await.unwrap();
1613
1614        let encrypted_dek =
1615            "07V2ndh02DA73p+dTybwZFm7DKQSZN1tEwQh+FoX1DZLk4Yj2LLu4omYjp/84tAg3BYlkfGSz+zZacJHIE4=";
1616        let dek_req = CreateDekRequest {
1617            subject: "test-value".to_string(),
1618            version: None,
1619            algorithm: None,
1620            encrypted_key_material: Some(encrypted_dek.to_string()),
1621        };
1622        dek_client.register_dek("kek1-f1", dek_req).await.unwrap();
1623
1624        let bytes = [
1625            0, 0, 0, 0, 1, 104, 122, 103, 121, 47, 106, 70, 78, 77, 86, 47, 101, 70, 105, 108, 97,
1626            72, 114, 77, 121, 101, 66, 103, 100, 97, 86, 122, 114, 82, 48, 117, 100, 71, 101, 111,
1627            116, 87, 56, 99, 65, 47, 74, 97, 108, 55, 117, 107, 114, 43, 77, 47, 121, 122,
1628        ];
1629
1630        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1631        if let Record(v) = obj2.value {
1632            assert_eq!(v, fields);
1633        } else {
1634            unreachable!();
1635        }
1636    }
1637
1638    #[tokio::test]
1639    async fn test_encryption_deterministic_f1_preserialized() {
1640        LocalKmsDriver::register();
1641
1642        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1643        let client = MockSchemaRegistryClient::new(client_conf);
1644        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1645        let schema_str = r#"
1646        {
1647            "type": "record",
1648            "name": "f1Schema",
1649            "fields": [
1650                {"name": "f1", "type": "string", "confluent:tags": ["PII"]}
1651            ]
1652        }
1653        "#;
1654        let rule = Rule {
1655            name: "test-encrypt".to_string(),
1656            doc: None,
1657            kind: Some(Kind::Transform),
1658            mode: Some(Mode::WriteRead),
1659            r#type: "ENCRYPT".to_string(),
1660            tags: Some(vec!["PII".to_string()]),
1661            params: Some(BTreeMap::from([
1662                ("encrypt.kek.name".to_string(), "kek1-det-f1".to_string()),
1663                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1664                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1665                (
1666                    "encrypt.dek.algorithm".to_string(),
1667                    "AES256_SIV".to_string(),
1668                ),
1669            ])),
1670            expr: None,
1671            on_success: None,
1672            on_failure: Some("ERROR,ERROR".to_string()),
1673            disabled: None,
1674        };
1675        let rule_set = RuleSet {
1676            migration_rules: None,
1677            domain_rules: Some(vec![rule]),
1678            encoding_rules: None,
1679        };
1680        let schema = Schema {
1681            schema_type: Some("AVRO".to_string()),
1682            references: None,
1683            metadata: None,
1684            rule_set: Some(Box::new(rule_set)),
1685            schema: schema_str.to_string(),
1686        };
1687        client
1688            .register_schema("test-value", &schema, false)
1689            .await
1690            .unwrap();
1691        let fields = vec![("f1".to_string(), Value::String("hello world".to_string()))];
1692        let obj = Record(fields.clone());
1693        let rule_registry = RuleRegistry::new();
1694        rule_registry.register_executor(FieldEncryptionExecutor::<MockDekRegistryClient>::new(
1695            FakeClock::new(0),
1696        ));
1697
1698        let ser_ctx = SerializationContext {
1699            topic: "test".to_string(),
1700            serde_type: SerdeType::Value,
1701            serde_format: SerdeFormat::Avro,
1702            headers: None,
1703        };
1704        let deser_conf = DeserializerConfig::new(None, false, rule_conf);
1705        let deser =
1706            AvroDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1707
1708        let executor = rule_registry.get_executor("ENCRYPT").unwrap();
1709        let field_executor = executor
1710            .as_any()
1711            .downcast_ref::<FieldEncryptionExecutor<MockDekRegistryClient>>()
1712            .unwrap();
1713        let dek_client = field_executor.executor.client().unwrap();
1714        let kek_req = CreateKekRequest {
1715            name: "kek1-det-f1".to_string(),
1716            kms_type: "local-kms".to_string(),
1717            kms_key_id: "mykey".to_string(),
1718            kms_props: None,
1719            doc: None,
1720            shared: false,
1721        };
1722        dek_client.register_kek(kek_req).await.unwrap();
1723
1724        let encrypted_dek = "YSx3DTlAHrmpoDChquJMifmPntBzxgRVdMzgYL82rgWBKn7aUSnG+WIu9ozBNS3y2vXd++mBtK07w4/W/G6w0da39X9hfOVZsGnkSvry/QRht84V8yz3dqKxGMOK5A==";
1725        let dek_req = CreateDekRequest {
1726            subject: "test-value".to_string(),
1727            version: None,
1728            algorithm: Some(Algorithm::Aes256Siv),
1729            encrypted_key_material: Some(encrypted_dek.to_string()),
1730        };
1731        dek_client
1732            .register_dek("kek1-det-f1", dek_req)
1733            .await
1734            .unwrap();
1735
1736        let bytes = [
1737            0, 0, 0, 0, 1, 72, 68, 54, 89, 116, 120, 114, 108, 66, 110, 107, 84, 87, 87, 57, 78,
1738            54, 86, 98, 107, 51, 73, 73, 110, 106, 87, 72, 56, 49, 120, 109, 89, 104, 51, 107, 52,
1739            100,
1740        ];
1741
1742        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1743        if let Record(v) = obj2.value {
1744            assert_eq!(v, fields);
1745        } else {
1746            unreachable!();
1747        }
1748    }
1749
1750    #[tokio::test]
1751    async fn test_encryption_dek_rotation_f1_preserialized() {
1752        LocalKmsDriver::register();
1753
1754        let client_conf = ClientConfig::new(vec!["mock://".to_string()]);
1755        let client = MockSchemaRegistryClient::new(client_conf);
1756        let rule_conf = HashMap::from([("secret".to_string(), "mysecret".to_string())]);
1757        let schema_str = r#"
1758        {
1759            "type": "record",
1760            "name": "f1Schema",
1761            "fields": [
1762                {"name": "f1", "type": "string", "confluent:tags": ["PII"]}
1763            ]
1764        }
1765        "#;
1766        let rule = Rule {
1767            name: "test-encrypt".to_string(),
1768            doc: None,
1769            kind: Some(Kind::Transform),
1770            mode: Some(Mode::WriteRead),
1771            r#type: "ENCRYPT".to_string(),
1772            tags: Some(vec!["PII".to_string()]),
1773            params: Some(BTreeMap::from([
1774                ("encrypt.kek.name".to_string(), "kek1-rot-f1".to_string()),
1775                ("encrypt.kms.type".to_string(), "local-kms".to_string()),
1776                ("encrypt.kms.key.id".to_string(), "mykey".to_string()),
1777                ("encrypt.dek.expiry.days".to_string(), "1".to_string()),
1778            ])),
1779            expr: None,
1780            on_success: None,
1781            on_failure: Some("ERROR,ERROR".to_string()),
1782            disabled: None,
1783        };
1784        let rule_set = RuleSet {
1785            migration_rules: None,
1786            domain_rules: Some(vec![rule]),
1787            encoding_rules: None,
1788        };
1789        let schema = Schema {
1790            schema_type: Some("AVRO".to_string()),
1791            references: None,
1792            metadata: None,
1793            rule_set: Some(Box::new(rule_set)),
1794            schema: schema_str.to_string(),
1795        };
1796        client
1797            .register_schema("test-value", &schema, false)
1798            .await
1799            .unwrap();
1800        let fields = vec![("f1".to_string(), Value::String("hello world".to_string()))];
1801        let obj = Record(fields.clone());
1802        let rule_registry = RuleRegistry::new();
1803        rule_registry.register_executor(FieldEncryptionExecutor::<MockDekRegistryClient>::new(
1804            FakeClock::new(0),
1805        ));
1806
1807        let ser_ctx = SerializationContext {
1808            topic: "test".to_string(),
1809            serde_type: SerdeType::Value,
1810            serde_format: SerdeFormat::Avro,
1811            headers: None,
1812        };
1813        let deser_conf = DeserializerConfig::new(None, false, rule_conf);
1814        let deser =
1815            AvroDeserializer::new(&client, Some(rule_registry.clone()), deser_conf).unwrap();
1816
1817        let executor = rule_registry.get_executor("ENCRYPT").unwrap();
1818        let field_executor = executor
1819            .as_any()
1820            .downcast_ref::<FieldEncryptionExecutor<MockDekRegistryClient>>()
1821            .unwrap();
1822        let dek_client = field_executor.executor.client().unwrap();
1823        let kek_req = CreateKekRequest {
1824            name: "kek1-rot-f1".to_string(),
1825            kms_type: "local-kms".to_string(),
1826            kms_key_id: "mykey".to_string(),
1827            kms_props: None,
1828            doc: None,
1829            shared: false,
1830        };
1831        dek_client.register_kek(kek_req).await.unwrap();
1832
1833        let encrypted_dek =
1834            "W/v6hOQYq1idVAcs1pPWz9UUONMVZW4IrglTnG88TsWjeCjxmtRQ4VaNe/I5dCfm2zyY9Cu0nqdvqImtUk4=";
1835        let dek_req = CreateDekRequest {
1836            subject: "test-value".to_string(),
1837            version: None,
1838            algorithm: Some(Algorithm::Aes256Gcm),
1839            encrypted_key_material: Some(encrypted_dek.to_string()),
1840        };
1841        dek_client
1842            .register_dek("kek1-rot-f1", dek_req)
1843            .await
1844            .unwrap();
1845
1846        let bytes = [
1847            0, 0, 0, 0, 1, 120, 65, 65, 65, 65, 65, 65, 71, 52, 72, 73, 54, 98, 49, 110, 88, 80,
1848            88, 113, 76, 121, 71, 56, 99, 73, 73, 51, 53, 78, 72, 81, 115, 101, 113, 113, 85, 67,
1849            100, 43, 73, 101, 76, 101, 70, 86, 65, 101, 78, 112, 83, 83, 51, 102, 120, 80, 110, 74,
1850            51, 50, 65, 61,
1851        ];
1852
1853        let obj2 = deser.deserialize(&ser_ctx, &bytes).await.unwrap();
1854        if let Record(v) = obj2.value {
1855            assert_eq!(v, fields);
1856        } else {
1857            unreachable!();
1858        }
1859    }
1860}