spring_ai_rs/ai_interface/callback/group/
mod.rs1use std::{
2 collections::{HashMap, HashSet},
3 error::Error,
4 sync::RwLock,
5};
6
7use lazy_static::lazy_static;
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 ai_interface::{
12 callback::{
13 group::{command_info::GroupSupportedCommand, order_preview::GroupOrderPreview},
14 unit::Unit,
15 unit_def::UnitDef,
16 },
17 AIInterface,
18 },
19 get_callback,
20};
21
22mod command;
23pub mod command_info;
24pub mod order_preview;
25
26lazy_static! {
27 static ref GROUP_UNIT_DEFS: RwLock<HashMap<i32, HashMap<i32, HashMap<i32, i32>>>> =
28 RwLock::new(HashMap::new());
29 static ref GROUP_UNITS: RwLock<HashMap<i32, HashMap<i32, HashSet<i32>>>> =
30 RwLock::new(HashMap::new());
31}
32
33pub(crate) fn init_group_unit_defs(ai_id: i32) -> Result<(), Box<dyn Error>> {
34 GROUP_UNIT_DEFS.try_write()?.insert(ai_id, HashMap::new());
35 Ok(())
36}
37
38pub(crate) fn init_group_units(ai_id: i32) -> Result<(), Box<dyn Error>> {
39 GROUP_UNITS.try_write()?.insert(ai_id, HashMap::new());
40 Ok(())
41}
42
43#[derive(Debug, Copy, Clone)]
44pub struct GroupInterface {
45 ai_id: i32,
46}
47
48#[derive(Debug, Copy, Clone)]
49pub struct GroupInterfaceAll {}
50
51#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
52pub struct Group {
53 pub ai_id: i32,
54 pub group_id: i32,
55}
56
57#[derive(Debug, Clone)]
58pub struct GroupAll {
59 supported_commands: Vec<GroupSupportedCommand>,
60 order_preview: GroupOrderPreview,
61 is_selected: bool,
62}
63
64impl AIInterface {
65 pub fn group_interface(&self) -> GroupInterface {
66 GroupInterface { ai_id: self.ai_id }
67 }
68}
69
70const MAX_GROUPS: usize = 64;
71
72impl GroupInterface {
73 pub fn get_groups(&self) -> Result<HashMap<i32, Group>, Box<dyn Error>> {
74 let get_groups_func = get_callback!(self.ai_id, getGroups)?;
75
76 let mut groups = [-1_i32; MAX_GROUPS];
77 unsafe { get_groups_func(self.ai_id, groups.as_mut_ptr(), MAX_GROUPS as i32) };
78
79 Ok(groups
80 .iter()
81 .filter_map(|&group_id| {
82 if group_id == -1 {
83 None
84 } else {
85 Some((
86 group_id,
87 Group {
88 ai_id: self.ai_id,
89 group_id,
90 },
91 ))
92 }
93 })
94 .collect())
95 }
96}
97
98impl Group {
99 pub fn supported_commands(&self) -> Result<Vec<GroupSupportedCommand>, Box<dyn Error>> {
100 let get_supported_commands_func = get_callback!(self.ai_id, Group_getSupportedCommands)?;
101
102 let number_of_commands = unsafe { get_supported_commands_func(self.ai_id, self.group_id) };
103
104 Ok((0..number_of_commands)
105 .map(|supported_command_index| GroupSupportedCommand {
106 ai_id: self.ai_id,
107 group_id: self.group_id,
108 supported_command_index,
109 })
110 .collect())
111 }
112
113 pub fn unit_defs(&self) -> Result<Vec<UnitDef>, Box<dyn Error>> {
114 let mut group_unit_defs_lock = GROUP_UNIT_DEFS.try_write()?;
115 let ai_group_unit_defs = &mut group_unit_defs_lock.get_mut(&self.ai_id).unwrap();
116 let group_unit_defs = ai_group_unit_defs
117 .entry(self.group_id)
118 .or_insert(HashMap::new());
119 Ok(group_unit_defs
120 .keys()
121 .map(|&id| UnitDef {
122 ai_id: self.ai_id,
123 def_id: id,
124 })
125 .collect())
126 }
127
128 pub fn units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
129 let mut group_units_lock = GROUP_UNITS.try_write()?;
130 let ai_group_units = &mut group_units_lock.get_mut(&self.ai_id).unwrap();
131 let group_units = ai_group_units
132 .entry(self.group_id)
133 .or_insert(HashSet::new());
134 Ok(group_units
135 .iter()
136 .map(|&id| Unit {
137 ai_id: self.ai_id,
138 unit_id: id,
139 })
140 .collect())
141 }
142
143 pub fn order_preview(&self) -> Result<GroupOrderPreview, Box<dyn Error>> {
144 let get_group_order_preview_id_func = get_callback!(self.ai_id, Group_OrderPreview_getId)?;
145
146 Ok(GroupOrderPreview {
147 ai_id: self.ai_id,
148 group_id: self.group_id,
149 group_order_preview_id: unsafe {
150 get_group_order_preview_id_func(self.ai_id, self.group_id)
151 },
152 })
153 }
154
155 pub fn is_selected(&self) -> Result<bool, Box<dyn Error>> {
156 let is_selected_func = get_callback!(self.ai_id, Group_isSelected)?;
157 Ok(unsafe { is_selected_func(self.ai_id, self.group_id) })
158 }
159
160 pub fn all(&self) -> Result<GroupAll, Box<dyn Error>> {
161 Ok(GroupAll {
162 supported_commands: self.supported_commands()?,
163 order_preview: self.order_preview()?,
164 is_selected: self.is_selected()?,
165 })
166 }
167}