1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//! Masking, aggregation and unmasking of models.
//!
//! # Models
//! A [`Model`] is a collection of weights/parameters which are represented as finite numerical
//! values (i.e. rational numbers) of arbitrary precision. As such, a model in itself is not bound
//! to any particular primitive data type, but it can be created from those and converted back into
//! them.
//!
//! Currently, the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported and
//! this might be extended in the future.
//!
//! ```
//! # use xaynet::mask::{FromPrimitives, IntoPrimitives, Model};
//! let weights = vec![0_f32; 10];
//! let model = Model::from_primitives_bounded(weights.into_iter());
//! assert_eq!(
//!     model.into_primitives_unchecked().collect::<Vec<f32>>(),
//!     vec![0_f32; 10],
//! );
//! ```
//!
//! # Masking configurations
//! The masking, aggregation and unmasking of models requires certain information about the models
//! to guarantee that no information is lost during the process, which is configured via the
//! [`MaskConfig`]. Each masking configuration consists of the group type, data type, bound type and
//! model type. Usually, a masking configuration is decided on and configured depending on the
//! specific machine learning use case as part of the setup for the XayNet federated learning
//! platform.
//!
//! Currently, those choices are catalogued for certain fixed variants for each type, but we aim
//! to generalize this in the future to more flexible masking configurations to allow for a more
//! fine-grained tradeoff between representability and performance.
//!
//! ## Group type
//! The [`GroupType`] describes the order of the finite group in which the masked model weights are
//! embedded. The smaller the gap between the maximum possible embedded weights and the group order
//! is, the less theoretically possible information flow about the masks may be observed. Specific
//! group orders provide potentially higher performance on the other hand, which always makes this
//! a tradeoff between security and performance. The group type variants are:
//! - Integer: no gap but potentially slowest performance.
//! - Prime: usually small gap with higher performance.
//! - Power2: usually higher gap with potentially highest performance.
//!
//! ## Data type
//! The [`DataType`] describes the original primitive data type of the model weights. This in
//! combination with the bound type influences the preserved decimal places of the model weights
//! during the masking, aggregation and unmasking process, which are:
//! - F32: 10 decimal places for bounded model weights and 45 decimal places for unbounded.
//! - F64: 20 decimal places for bounded model weights and 324 decimal places for unbounded.
//! - I32 and I64: 10 decimal places (required for scaled aggregation).
//!
//! Currently the primitive data types [`f32`], [`f64`], [`i32`] and [`i64`] are supported via the
//! data type variants.
//!
//! ## Bound type
//! The [`BoundType`] describes the absolute bounds on all model weights. The smaller the bounds of
//! the model weights, the less bytes are required to represent the masked model weights. These
//! bounds are enforced on the model weights before masking them to prevent information loss during
//! the masking, aggregation and unmasking process. The bound type variants are:
//! - B0: all model weights are absolutely bounded by 1.
//! - B2: all model weights are absolutely bounded by 100.
//! - B4: all model weights are absolutely bounded by 10,000.
//! - B6: all model weights are absolutely bounded by 1,000,000.
//! - Bmax: all model weights are absolutely bounded by their primitive data type's absolute
//!   maximum value.
//!
//! ## Model type
//! The [`ModelType`] describes the maximum number of masked models that can be aggregated without
//! information loss. The smaller the number of masked models, the less bytes are required to
//! represent masked model weights. The model type variants are:
//! - M3: at most 1,000 masked models may be aggregated.
//! - M6: at most 1,000,000 masked models may be aggregated.
//! - M9: at most 1,000,000,000 masked models may be aggregated.
//! - M12: at most 1,000,000,000,000 masked models may be aggregated.
//!
//! # Masking, aggregation and unmasking
//! Local models should be masked (i.e. encrypted) before they are communicated somewhere else to
//! protect the possibly sensitive information learned from local data. The masking should allow
//! for masked models to be aggregated while they are still masked (i.e. homomorphic encryption).
//! Then the aggregated masked model can safely be unmasked without jeopardizing the secrecy of
//! personal information if the model is generalized enough.
//!
//! ## Masking
//! A [`Model`] can be masked with a [`Masker`], which requires a [`MaskConfig`]. During the
//! masking, the model weights are scaled, then embedded as elements of the chosen finite group and
//! finally masked by randomly generated elements from that very same finite group. The scalar
//! provides the necessary means to perform different aggregation strategies, for example federated
//! averaging. The masked model is returned as a [`MaskObject`] and the mask used to mask the model
//! can be generated via the additionally returned [`MaskSeed`].
//!
//! ```
//! # use xaynet::mask::{BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, Model, ModelType};
//! // create local models and a fitting masking configuration
//! let number_weights = 10;
//! let scalar = 0.5;
//! let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! let config = MaskConfig {
//!     group_type: GroupType::Prime,
//!     data_type: DataType::F32,
//!     bound_type: BoundType::B0,
//!     model_type: ModelType::M3,
//! };
//!
//! // mask the local models
//! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//!
//! // derive the masks of the local masked models
//! let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
//! let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);
//! ```
//!
//! ## Aggregation
//! Masked models can be aggregated via an [`Aggregation`]. Masks themselves can be aggregated via
//! an [`Aggregation`] as well. An aggregated masked model can only be unmasked by the aggregation
//! of masks for each model. Aggregation should always be validated beforehand so that it may be
//! safely performed wrt the chosen masking configuration without possible loss of information.
//!
//! ```
//! # use xaynet::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType};
//! # let number_weights = 10;
//! # let scalar = 0.5;
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
//! # let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);
//! // aggregate the local masks
//! let mut mask_aggregator = Aggregation::new(config, number_weights);
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_1) {
//!     mask_aggregator.aggregate(local_mask_1);
//! };
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_2) {
//!     mask_aggregator.aggregate(local_mask_2);
//! };
//! let global_mask: MaskObject = mask_aggregator.into();
//!
//! // aggregate the local masked models
//! let mut model_aggregator = Aggregation::new(config, number_weights);
//! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) {
//!     model_aggregator.aggregate(masked_local_model_1);
//! };
//! if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) {
//!     model_aggregator.aggregate(masked_local_model_2);
//! };
//! ```
//!
//! ## Unmasking
//! A masked model can be unmasked by the corresponding mask via an [`Aggregation`]. Unmasking
//! should always be validated beforehand so that it may be safely performed wrt the chosen mask
//! configuration without possible loss of information.
//!
//! ```
//! # use xaynet::mask::{Aggregation, BoundType, DataType, FromPrimitives, GroupType, MaskConfig, Masker, MaskObject, Model, ModelType};
//! # let number_weights = 10;
//! # let scalar = 0.5;
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
//! # let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);
//! # let mut mask_aggregator = Aggregation::new(config, number_weights);
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_1) { mask_aggregator.aggregate(local_mask_1); };
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_2) { mask_aggregator.aggregate(local_mask_2); };
//! # let global_mask: MaskObject = mask_aggregator.into();
//! # let mut model_aggregator = Aggregation::new(config, number_weights);
//! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) { model_aggregator.aggregate(masked_local_model_1); };
//! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_2) { model_aggregator.aggregate(masked_local_model_2); };
//! // unmask the aggregated masked model with the aggregated mask
//! if let Ok(_) = model_aggregator.validate_unmasking(&global_mask) {
//!     let global_model = model_aggregator.unmask(global_mask);
//!     assert_eq!(
//!         global_model,
//!         Model::from_primitives_bounded(vec![0.5_f32; number_weights].into_iter()),
//!     );
//! };
//! ```

pub(crate) mod config;
pub(crate) mod masking;
pub(crate) mod model;
pub(crate) mod object;
pub(crate) mod seed;

pub use self::{
    config::{
        serialization::MaskConfigBuffer,
        BoundType,
        DataType,
        GroupType,
        InvalidMaskConfigError,
        MaskConfig,
        ModelType,
    },
    masking::{Aggregation, AggregationError, Masker, UnmaskingError},
    model::{FromPrimitives, IntoPrimitives, Model, ModelCastError, PrimitiveCastError},
    object::{serialization::MaskObjectBuffer, InvalidMaskObjectError, MaskObject},
    seed::{EncryptedMaskSeed, MaskSeed},
};