1use crate::filter::RemoteTarget;
10use crate::graph::NodeId;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15#[serde(tag = "type")]
16#[non_exhaustive]
17pub enum TrainingStrategy {
18 #[default]
20 Local,
21
22 DataParallel {
25 num_replicas: usize,
26 aggregation: GradientAggregation,
27 },
28
29 ModelParallel {
32 partitions: Vec<Partition>,
33 communication: CommunicationProtocol,
34 },
35
36 Federated {
39 num_clients: usize,
40 rounds: usize,
41 aggregation: FederatedAggregation,
42 client_selection: ClientSelection,
43 },
44
45 PopulationBased {
48 population_size: usize,
49 generations: usize,
50 exploit: ExploitStrategy,
51 explore: ExploreStrategy,
52 },
53
54 Custom {
56 coordinator: String,
57 config: serde_json::Value,
58 },
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(tag = "method")]
64#[non_exhaustive]
65pub enum GradientAggregation {
66 AllReduce,
68 ParameterServer,
70 Decentralized { topology: String },
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Partition {
80 pub node_ids: Vec<NodeId>,
81 pub target: RemoteTarget,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(tag = "protocol")]
87#[non_exhaustive]
88pub enum CommunicationProtocol {
89 DataStore,
91 Direct,
93 Pipeline { micro_batch_size: usize },
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(tag = "method")]
100#[non_exhaustive]
101pub enum FederatedAggregation {
102 FedAvg,
104 FedProx { mu: f64 },
106 FedYogi { beta1: f64, beta2: f64, tau: f64 },
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112#[serde(tag = "method")]
113#[non_exhaustive]
114pub enum ClientSelection {
115 All,
117 Random { fraction: f64 },
119 ByCapability { required_tags: Vec<String> },
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "method")]
126#[non_exhaustive]
127pub enum ExploitStrategy {
128 Truncation { fraction: f64 },
130 Binary { threshold: f64 },
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(tag = "method")]
137#[non_exhaustive]
138pub enum ExploreStrategy {
139 Perturbation { factor: f64 },
141 Resample,
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn default_is_local() {
151 assert!(matches!(
152 TrainingStrategy::default(),
153 TrainingStrategy::Local
154 ));
155 }
156
157 #[test]
158 fn serde_roundtrip_data_parallel() {
159 let strategy = TrainingStrategy::DataParallel {
160 num_replicas: 4,
161 aggregation: GradientAggregation::AllReduce,
162 };
163 let json = serde_json::to_string(&strategy).unwrap();
164 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
165 assert!(matches!(
166 parsed,
167 TrainingStrategy::DataParallel {
168 num_replicas: 4,
169 ..
170 }
171 ));
172 }
173
174 #[test]
175 fn serde_roundtrip_model_parallel() {
176 let strategy = TrainingStrategy::ModelParallel {
177 partitions: vec![
178 Partition {
179 node_ids: vec!["embed".into(), "backbone".into()],
180 target: RemoteTarget::Tag("gpu-0".into()),
181 },
182 Partition {
183 node_ids: vec!["head_a".into()],
184 target: RemoteTarget::Tag("gpu-1".into()),
185 },
186 ],
187 communication: CommunicationProtocol::Pipeline {
188 micro_batch_size: 4,
189 },
190 };
191 let json = serde_json::to_string(&strategy).unwrap();
192 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
193 assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
194 }
195
196 #[test]
197 fn serde_roundtrip_federated() {
198 let strategy = TrainingStrategy::Federated {
199 num_clients: 10,
200 rounds: 50,
201 aggregation: FederatedAggregation::FedProx { mu: 0.01 },
202 client_selection: ClientSelection::Random { fraction: 0.3 },
203 };
204 let json = serde_json::to_string(&strategy).unwrap();
205 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
206 assert!(matches!(
207 parsed,
208 TrainingStrategy::Federated {
209 num_clients: 10,
210 rounds: 50,
211 ..
212 }
213 ));
214 }
215
216 #[test]
217 fn serde_roundtrip_pbt() {
218 let strategy = TrainingStrategy::PopulationBased {
219 population_size: 20,
220 generations: 50,
221 exploit: ExploitStrategy::Truncation { fraction: 0.2 },
222 explore: ExploreStrategy::Perturbation { factor: 0.2 },
223 };
224 let json = serde_json::to_string(&strategy).unwrap();
225 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
226 assert!(matches!(
227 parsed,
228 TrainingStrategy::PopulationBased {
229 population_size: 20,
230 ..
231 }
232 ));
233 }
234}