1use std::ops::{Add, Div};
2
3use crate::node::Node;
4use candid::utils::ArgumentEncoder;
5use ic_cdk::export::{
6 candid::{CandidType, Deserialize},
7 Principal,
8};
9
10use ic_kit::candid::{Decode, Encode};
11use ic_kit::ic;
12use ic_kit::interfaces::management::{self, CanisterSettings};
13use ic_kit::interfaces::Method;
14use serde::de::DeserializeOwned;
15
16#[derive(Clone, Debug, CandidType, Deserialize)]
17pub enum CanisterManagerEvent {
18 NodeCreated(Principal),
19 NodeDeleted(Principal),
20 Migrate(MigrateArgs),
21}
22
23#[derive(Clone, Debug, CandidType, Deserialize)]
24pub struct InstallArgs {
25 pub all_nodes: Vec<Principal>,
26}
27
28#[derive(Clone, Debug, CandidType, Deserialize)]
29pub struct WasmInitArgs {
30 pub position: usize, pub wasm_chunk: Vec<u8>,
32}
33
34#[derive(Clone, Debug, CandidType, Deserialize)]
35pub struct InitCanisterManagerParam {
36 pub args: Option<InstallArgs>,
37}
38
39#[derive(Clone, Debug, CandidType, Deserialize)]
40pub enum NodeStatus {
41 Initialized,
42 Ready,
43 Error(NodeError),
44 ShutDown,
45 Migrating,
46 ScaleUp,
47 ScaleDown,
48}
49
50#[derive(Clone, Debug, CandidType, Deserialize)]
51pub enum NodeError {
52 Migration(String),
53 ScaleUp(String),
54 Initialize(String),
55 Broadcast(String),
56}
57
58#[derive(Clone, Debug, CandidType, Deserialize)]
59pub struct NodeInfo {
60 pub all_nodes: Vec<String>,
61 pub prev_node_id: Option<Principal>,
62 pub next_node_id: Option<Principal>,
63 pub status: NodeStatus,
64 pub cycles_balance: u64,
65}
66
67#[derive(CandidType, Deserialize)]
68struct DataChunk<Data>
69where
70 Data: CandidType,
71{
72 data: Vec<(String, Data)>,
73}
74
75impl<Data> DataChunk<Data>
76where
77 Data: CandidType + DeserializeOwned,
78{
79 fn new(data: Vec<(String, Data)>) -> Self {
80 Self { data }
81 }
82
83 fn encode(self) -> Result<Vec<u8>, String> {
84 Encode!(&self).map_err(|e| e.to_string())
86 }
87
88 fn decode(data: &Vec<u8>) -> Result<Self, String> {
89 Decode!(data, DataChunk<Data>).map_err(|e| e.to_string())
90 }
91}
92
93impl<Data> From<&Vec<u8>> for DataChunk<Data>
94where
95 Data: CandidType + DeserializeOwned,
96{
97 fn from(data: &Vec<u8>) -> Self {
98 Decode!(data, DataChunk<Data>).unwrap()
99 }
100}
101
102#[derive(CandidType, Deserialize, Debug, Clone)]
103pub struct MigrateArgs {
104 #[serde(with = "serde_bytes")]
105 data: Vec<u8>,
106}
107
108type Canister<Data> = Node<Principal, Data>;
109
110pub struct CanisterManager<Data: Default + Clone> {
111 status: NodeStatus,
112 pub canister: Canister<Data>,
113 wasm_binary: Option<Vec<u8>>,
114 should_upgrade_func: fn(usize) -> bool,
115}
116
117impl<Data: Default + Clone + CandidType + DeserializeOwned> CanisterManager<Data> {
118 pub fn new(node_id: Principal, should_upgrade_func: fn(usize) -> bool) -> Self {
119 let mut new_canister: Node<Principal, Data> =
120 Node::new(node_id.clone(), Default::default());
121
122 new_canister.add_node(node_id);
123
124 Self {
125 status: NodeStatus::Initialized,
126 canister: new_canister,
127 wasm_binary: None, should_upgrade_func,
129 }
130 }
131
132 pub async fn forward_request<R, M, A>(
133 node_id: Principal,
134 method: M,
135 args: A,
136 ) -> Result<R, String>
137 where
138 M: Into<String>,
139 A: ArgumentEncoder,
140 R: CandidType + DeserializeOwned,
141 {
142 let result = ic::call::<_, (R,), _>(node_id, method, args).await;
143 match result {
144 Ok((result,)) => Ok(result),
145 Err((_, error)) => Err(error.to_string()),
146 }
147 }
148
149 fn get_status(&self) -> &NodeStatus {
173 &self.status
174 }
175
176 fn should_scale_up(&self) -> bool {
177 (self.should_upgrade_func)(self.canister.size())
178 && self.canister.next_node_id.is_none()
179 && matches!(self.status, NodeStatus::Ready)
180 }
181
182 fn should_scale_down(&self) -> bool {
183 false
184 }
185
186 pub fn lifecycle_init_wasm(&mut self, args: WasmInitArgs) -> bool {
187 match args.position {
188 0 => {
189 self.wasm_binary = Some(args.wasm_chunk);
190 true
191 }
192 1 | 2 => match self.wasm_binary.as_mut() {
193 Some(wasm_binary) => {
194 wasm_binary.extend_from_slice(&args.wasm_chunk);
195 if args.position == 2 {
196 self.status = NodeStatus::Ready;
197 }
198 true
199 }
200
201 None => false,
202 },
203 _ => false,
204 }
205 }
206
207 pub async fn lifecyle_init_node(&mut self, all_nodes: Option<Vec<Principal>>) -> () {
208 let node_id = self.canister.id;
209 let mut new_canister: Node<Principal, Data> = Node::new(node_id, Default::default());
210
211 if let Some(mut all_nodes) = all_nodes {
212 if all_nodes.len() > 1 {
213 let prev_node_id = all_nodes[all_nodes.len() - 2].clone(); new_canister.prev_node_id = Some(prev_node_id);
215 all_nodes.push(node_id);
216 for principal_id in all_nodes {
217 new_canister.add_node(principal_id);
218 }
219 }
220 }
221
222 self.canister = new_canister;
223
224 self.broadcast_event(CanisterManagerEvent::NodeCreated(self.canister.id))
225 .await;
226 }
227
228 pub async fn lifecyle_heartbeat_node(&mut self) -> () {
229 if self.should_scale_up() {
230 self.status = NodeStatus::ScaleUp;
231 let create_node_result = self.create_node().await;
232
233 match create_node_result {
234 Some(new_node_id) => {
235 self.canister.add_node(new_node_id.clone());
236 let result = self.initialize_node(new_node_id.clone()).await;
237 if !result {
238 self.canister.remove_node(&new_node_id);
239 self.status = NodeStatus::Error(NodeError::Initialize(format!(
240 "Failed to initialize node {}",
241 new_node_id
242 )));
243
244 return;
245 }
246 self.status = NodeStatus::Migrating;
247 let result = self.migrate_data(new_node_id).await;
248
249 if !result {
250 self.canister.remove_node(&new_node_id);
251 self.status = NodeStatus::Error(NodeError::Migration(format!(
252 "Failed to migrate data to node {}",
253 new_node_id
254 )));
255 return;
256 }
257
258 self.status = NodeStatus::Ready;
259 self.canister.next_node_id = Some(new_node_id);
260 self.broadcast_event(CanisterManagerEvent::NodeCreated(new_node_id))
261 .await;
262 }
263 None => {
264 self.status =
265 NodeStatus::Error(NodeError::ScaleUp("Failed to create node".to_string()));
266 }
267 }
268 } else if self.should_scale_down() {
269 }
272 }
273
274 async fn create_node(&mut self) -> Option<Principal> {
275 let arg = management::CreateCanisterArgument {
276 settings: Some(CanisterSettings {
277 compute_allocation: None,
278 controllers: Some(vec![self.canister.id]),
279 freezing_threshold: None,
280 memory_allocation: None, }),
282 };
283
284 let result = management::CreateCanister::perform_with_payment(
285 Principal::management_canister(),
286 (arg,),
287 ic::balance().div(self.canister.all_nodes().len().add(1) as u64),
288 )
289 .await;
290
291 match result {
292 Ok((result,)) => Some(result.canister_id),
293 Err(_) => None,
294 }
295 }
296
297 async fn initialize_node(&mut self, canister_id: Principal) -> bool {
298 let wasm_code = self.wasm_binary.clone().unwrap();
301
302 let install_args = management::InstallCodeArgument {
303 canister_id,
304 mode: management::InstallMode::Install,
305 wasm_module: wasm_code,
306 arg: Vec::<u8>::new(),
307 };
308
309 let result = management::InstallCode::perform_with_payment(
310 Principal::management_canister(),
311 (install_args,),
312 10_000_000,
313 )
314 .await;
315
316 if result.is_err() {
317 self.status = NodeStatus::Error(NodeError::Initialize(format!(
318 "Failed to initialize node {}",
319 canister_id
320 )));
321
322 return false;
323 }
324
325 let args = InitCanisterManagerParam {
326 args: Some(InstallArgs {
327 all_nodes: self.canister.all_nodes().into_iter().cloned().collect(),
328 }),
329 };
330
331 let result = ic::call::<_, (), _>(canister_id, "init_canister_manager", (args,)).await;
332
333 if result.is_err() {
334 self.status = NodeStatus::Error(NodeError::Initialize(format!(
335 "Failed to initialize node {}",
336 canister_id
337 )));
338
339 return false;
340 }
341
342 if !self.init_wasm(canister_id).await {
343 self.status = NodeStatus::Error(NodeError::Initialize(format!(
344 "Failed to initialize wasm {}",
345 canister_id
346 )));
347 return false;
348 }
349
350 true
351 }
352
353 async fn init_wasm(&self, canister_id: Principal) -> bool {
354 #[derive(CandidType, Deserialize)]
355 pub struct WasmInitArgs {
356 position: usize,
357 wasm_chunk: Vec<u8>,
358 }
359
360 async fn send_wasm(args: WasmInitArgs, canister_id: Principal) -> bool {
361 let result = ic::call::<_, (bool,), _>(canister_id, "init_wasm", (args,)).await;
362 result.is_ok()
363 }
364
365 let mut byte_iterator = self
366 .wasm_binary
367 .as_ref()
368 .unwrap()
369 .chunks(1024 * 1024)
370 .into_iter();
371
372 if !send_wasm(
373 WasmInitArgs {
374 position: 0,
375 wasm_chunk: byte_iterator.next().unwrap().to_vec(),
376 },
377 canister_id,
378 )
379 .await
380 {
381 return false;
382 }
383
384 while let Some(wasm_chunk) = byte_iterator.next() {
385 if !send_wasm(
386 WasmInitArgs {
387 position: 1,
388 wasm_chunk: wasm_chunk.to_vec(),
389 },
390 canister_id,
391 )
392 .await
393 {
394 return false;
395 }
396 }
397
398 if !send_wasm(
399 WasmInitArgs {
400 position: 2,
401 wasm_chunk: vec![],
402 },
403 canister_id,
404 )
405 .await
406 {
407 return false;
408 }
409
410 true
411 }
412
413 fn delete_node(&mut self) -> () {
414 }
417
418 async fn migrate_to_node(&mut self, canister_id: Principal, data: Vec<(String, Data)>) -> bool {
419 #[derive(CandidType, Deserialize)]
420 struct Response {
421 result: bool,
422 }
423
424 let call_migrate = |args: MigrateArgs| async {
425 ic::call::<_, (), _>(
426 canister_id,
427 "handle_event",
428 (CanisterManagerEvent::Migrate(args),),
429 )
430 .await
431 .map(|_| true)
432 .map_err(|e| e.1)
433 };
434
435 let encode_data_chunk = |data_chunk: DataChunk<Data>| -> Result<MigrateArgs, String> {
436 data_chunk.encode().map(|data| MigrateArgs { data })
437 };
438
439 for data_chunk in data.chunks(100) {
440 let result = match encode_data_chunk(DataChunk::new(data_chunk.to_vec())) {
441 Ok(args) => call_migrate(args).await,
442 Err(error) => Err(error),
443 };
444
445 match result {
446 Ok(response) => {
447 if !response {
448 self.status = NodeStatus::Error(NodeError::Migration(format!(
449 "Failed to migrate data to node {}",
450 canister_id
451 )));
452 return false;
453 }
454 }
455 Err(error) => {
456 self.status = NodeStatus::Error(NodeError::Migration(error));
457 return false;
458 }
459 }
460 }
461
462 true
463 }
464
465 fn handle_migrate(&mut self, args: MigrateArgs) -> bool {
466 match DataChunk::<Data>::decode(&args.data) {
467 Ok(data_chunk) => {
468 let data_chunk = data_chunk.data;
469 for (key, value) in data_chunk {
470 self.canister.insert_data(key, value);
471 }
472 true
473 }
474 Err(e) => {
475 self.status = NodeStatus::Error(NodeError::Migration(
476 "Failed to handle migrate data to node".to_string(),
477 ));
478 false
479 }
480 }
481 }
482
483 pub async fn lifecycle_handle_event(&mut self, event: CanisterManagerEvent) -> () {
484 match event {
485 CanisterManagerEvent::NodeCreated(node_id) => {
486 if node_id != self.canister.id {
487 self.canister.add_node(node_id);
488 self.migrate_data(node_id).await;
489 }
490 }
491 CanisterManagerEvent::NodeDeleted(node_id) => {
492 if node_id != self.canister.id {
493 self.canister.remove_node(&node_id);
494 self.migrate_data(node_id).await;
495 }
496 }
497 CanisterManagerEvent::Migrate(migrate_args) => {
498 self.handle_migrate(migrate_args);
499 }
500 }
501 }
502
503 async fn migrate_data(&mut self, node_id: Principal) -> bool {
504 let data_for_migration = self.canister.get_data_to_migrate();
505 let result = self.migrate_to_node(node_id, data_for_migration).await;
506 result
507 }
508
509 async fn broadcast_event(&mut self, event: CanisterManagerEvent) -> () {
510 let all_canisters = self.canister.all_nodes();
511 for &canister_id in all_canisters {
512 if self.canister.id != canister_id {
513 let result =
514 ic::call::<_, (), _>(canister_id, "handle_event", (event.clone(),)).await;
515
516 if let Err(e) = result {
517 self.status = NodeStatus::Error(NodeError::Broadcast(format!(
518 "Failed to broadcast event, error {} to node {}",
519 e.1, canister_id
520 )));
521 }
522 }
523 }
524 }
525
526 pub fn node_info(&self) -> NodeInfo {
527 NodeInfo {
528 all_nodes: self
529 .canister
530 .all_nodes()
531 .iter()
532 .map(|&principal| principal.to_string())
533 .collect(),
534 next_node_id: self.canister.next_node_id,
535 prev_node_id: self.canister.prev_node_id,
536 status: self.status.clone(),
537 cycles_balance: ic::balance(),
538 }
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use crate::node_manager::NodeStatus;
545
546 use super::CanisterManager;
547 use super::WasmInitArgs;
548 use async_std::test as async_test;
549 use ic_kit::mock_principals;
550 use ic_kit::MockContext;
551 use ic_kit::Principal;
552
553 #[test]
554 fn new_node() {
555 let node_id = Principal::anonymous();
556 let cm = CanisterManager::<String>::new(node_id, |size| size > 10);
557 let node_info = cm.node_info();
558
559 assert_eq!(node_info.all_nodes, vec![node_id.to_string()]);
560 }
561
562 #[async_test]
563 async fn node_initialized_properly() {
564 let node_id = mock_principals::alice();
565 let previous_node = mock_principals::bob();
566
567 MockContext::new()
568 .with_caller(previous_node.clone())
569 .with_id(node_id.clone())
570 .with_constant_return_handler(())
571 .inject();
572
573 let mut cm = CanisterManager::<String>::new(node_id.clone(), |size| size > 10);
574 let all_nodes = vec![previous_node.clone()];
575
576 cm.lifecyle_init_node(Some(all_nodes)).await;
577 let node_info = cm.node_info();
578
579 assert_eq!(
580 node_info.all_nodes,
581 vec![previous_node.to_string(), node_id.to_string()]
582 );
583
584 assert_eq!(cm.canister.prev_node_id, Some(previous_node));
585 matches!(cm.get_status(), NodeStatus::Initialized);
586 }
587
588 #[test]
589 fn node_wasm_initialized_properly() {
590 let node_id = mock_principals::alice();
591 let mut cm = CanisterManager::<String>::new(node_id.clone(), |size| size > 10);
592
593 assert!(cm.lifecycle_init_wasm(WasmInitArgs {
594 position: 0,
595 wasm_chunk: Vec::<u8>::default(),
596 }));
597
598 assert!(cm.lifecycle_init_wasm(WasmInitArgs {
599 position: 1,
600 wasm_chunk: Vec::<u8>::default(),
601 }));
602
603 assert!(cm.lifecycle_init_wasm(WasmInitArgs {
604 position: 2,
605 wasm_chunk: Vec::<u8>::default(),
606 }));
607
608 matches!(cm.get_status(), NodeStatus::Ready);
609 }
610}
611
612