spring_ai_rs/ai_interface/callback/group/
mod.rs

1use std::{
2    collections::{HashMap, HashSet},
3    error::Error,
4    sync::RwLock,
5};
6
7use lazy_static::lazy_static;
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    ai_interface::{
12        callback::{
13            group::{command_info::GroupSupportedCommand, order_preview::GroupOrderPreview},
14            unit::Unit,
15            unit_def::UnitDef,
16        },
17        AIInterface,
18    },
19    get_callback,
20};
21
22mod command;
23pub mod command_info;
24pub mod order_preview;
25
26lazy_static! {
27    static ref GROUP_UNIT_DEFS: RwLock<HashMap<i32, HashMap<i32, HashMap<i32, i32>>>> =
28        RwLock::new(HashMap::new());
29    static ref GROUP_UNITS: RwLock<HashMap<i32, HashMap<i32, HashSet<i32>>>> =
30        RwLock::new(HashMap::new());
31}
32
33pub(crate) fn init_group_unit_defs(ai_id: i32) -> Result<(), Box<dyn Error>> {
34    GROUP_UNIT_DEFS.try_write()?.insert(ai_id, HashMap::new());
35    Ok(())
36}
37
38pub(crate) fn init_group_units(ai_id: i32) -> Result<(), Box<dyn Error>> {
39    GROUP_UNITS.try_write()?.insert(ai_id, HashMap::new());
40    Ok(())
41}
42
43#[derive(Debug, Copy, Clone)]
44pub struct GroupInterface {
45    ai_id: i32,
46}
47
48#[derive(Debug, Copy, Clone)]
49pub struct GroupInterfaceAll {}
50
51#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
52pub struct Group {
53    pub ai_id: i32,
54    pub group_id: i32,
55}
56
57#[derive(Debug, Clone)]
58pub struct GroupAll {
59    supported_commands: Vec<GroupSupportedCommand>,
60    order_preview: GroupOrderPreview,
61    is_selected: bool,
62}
63
64impl AIInterface {
65    pub fn group_interface(&self) -> GroupInterface {
66        GroupInterface { ai_id: self.ai_id }
67    }
68}
69
70const MAX_GROUPS: usize = 64;
71
72impl GroupInterface {
73    pub fn get_groups(&self) -> Result<HashMap<i32, Group>, Box<dyn Error>> {
74        let get_groups_func = get_callback!(self.ai_id, getGroups)?;
75
76        let mut groups = [-1_i32; MAX_GROUPS];
77        unsafe { get_groups_func(self.ai_id, groups.as_mut_ptr(), MAX_GROUPS as i32) };
78
79        Ok(groups
80            .iter()
81            .filter_map(|&group_id| {
82                if group_id == -1 {
83                    None
84                } else {
85                    Some((
86                        group_id,
87                        Group {
88                            ai_id: self.ai_id,
89                            group_id,
90                        },
91                    ))
92                }
93            })
94            .collect())
95    }
96}
97
98impl Group {
99    pub fn supported_commands(&self) -> Result<Vec<GroupSupportedCommand>, Box<dyn Error>> {
100        let get_supported_commands_func = get_callback!(self.ai_id, Group_getSupportedCommands)?;
101
102        let number_of_commands = unsafe { get_supported_commands_func(self.ai_id, self.group_id) };
103
104        Ok((0..number_of_commands)
105            .map(|supported_command_index| GroupSupportedCommand {
106                ai_id: self.ai_id,
107                group_id: self.group_id,
108                supported_command_index,
109            })
110            .collect())
111    }
112
113    pub fn unit_defs(&self) -> Result<Vec<UnitDef>, Box<dyn Error>> {
114        let mut group_unit_defs_lock = GROUP_UNIT_DEFS.try_write()?;
115        let ai_group_unit_defs = &mut group_unit_defs_lock.get_mut(&self.ai_id).unwrap();
116        let group_unit_defs = ai_group_unit_defs
117            .entry(self.group_id)
118            .or_insert(HashMap::new());
119        Ok(group_unit_defs
120            .keys()
121            .map(|&id| UnitDef {
122                ai_id: self.ai_id,
123                def_id: id,
124            })
125            .collect())
126    }
127
128    pub fn units(&self) -> Result<Vec<Unit>, Box<dyn Error>> {
129        let mut group_units_lock = GROUP_UNITS.try_write()?;
130        let ai_group_units = &mut group_units_lock.get_mut(&self.ai_id).unwrap();
131        let group_units = ai_group_units
132            .entry(self.group_id)
133            .or_insert(HashSet::new());
134        Ok(group_units
135            .iter()
136            .map(|&id| Unit {
137                ai_id: self.ai_id,
138                unit_id: id,
139            })
140            .collect())
141    }
142
143    pub fn order_preview(&self) -> Result<GroupOrderPreview, Box<dyn Error>> {
144        let get_group_order_preview_id_func = get_callback!(self.ai_id, Group_OrderPreview_getId)?;
145
146        Ok(GroupOrderPreview {
147            ai_id: self.ai_id,
148            group_id: self.group_id,
149            group_order_preview_id: unsafe {
150                get_group_order_preview_id_func(self.ai_id, self.group_id)
151            },
152        })
153    }
154
155    pub fn is_selected(&self) -> Result<bool, Box<dyn Error>> {
156        let is_selected_func = get_callback!(self.ai_id, Group_isSelected)?;
157        Ok(unsafe { is_selected_func(self.ai_id, self.group_id) })
158    }
159
160    pub fn all(&self) -> Result<GroupAll, Box<dyn Error>> {
161        Ok(GroupAll {
162            supported_commands: self.supported_commands()?,
163            order_preview: self.order_preview()?,
164            is_selected: self.is_selected()?,
165        })
166    }
167}