|
| 1 | +# クラスタリングモジュール |
| 2 | + |
| 3 | + |
| 4 | +## クラスタリングモジュールの目的 |
| 5 | + |
| 6 | +複数のエージェントを動かす場合は、それらのエージェントにどのように協調させるかが重要になります。RRSでは多くのチームが、エージェントに各々の担当地域を持たせ役割分担をおこなう協調を取り入れています(他の手段による協調も取り入れています)。担当地域を割り振るためには、地図上のオブジェクトをいくつかのグループに分ける必要があります。このようなグループ分けをしてそれらを管理する場合には、クラスタリングモジュールと呼ばれるモジュールを用います。 |
| 7 | + |
| 8 | +本資料では、多くの世界大会参加チームが使用しているアルゴリズムを用いたクラスタリングモジュールの実装をおこないます。 |
| 9 | + |
| 10 | +## 開発するクラスタリングモジュールの概要 |
| 11 | + |
| 12 | +本資料で開発するモジュールは下の画像のように、 |
| 13 | + |
| 14 | +1. k-means++アルゴリズムによって地図上のオブジェクトをエージェント数分の区画に分けます。 |
| 15 | +1. Hungarianアルゴリズムによってそれらの区画とエージェントを (間の距離の総和が最も小さくなるように)1対1で結びつけます。 |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | +## クラスタリングモジュールの実装 |
| 21 | + |
| 22 | +:::{note} |
| 23 | +以降の作業では、カレントディレクトリがプロジェクトのルートディレクトリであることを前提としています。 |
| 24 | +::: |
| 25 | + |
| 26 | +まず、クラスタリングモジュールを記述するためのファイルを作成します。 |
| 27 | + |
| 28 | +```bash |
| 29 | +mkdir -p src/<your_team_name>/module/algorithm |
| 30 | +touch src/<your_team_name>/module/algorithm/k_means_pp_clustering.py |
| 31 | +``` |
| 32 | + |
| 33 | +次に、クラスタリングモジュールの実装を行います。 以下のコードを `k_means_pp_clustering.py` に記述してください。 |
| 34 | + |
| 35 | +```python |
| 36 | +import numpy as np |
| 37 | +from adf_core_python.core.agent.develop.develop_data import DevelopData |
| 38 | +from adf_core_python.core.agent.info.agent_info import AgentInfo |
| 39 | +from adf_core_python.core.agent.info.scenario_info import ScenarioInfo, ScenarioInfoKeys |
| 40 | +from adf_core_python.core.agent.info.world_info import WorldInfo |
| 41 | +from adf_core_python.core.agent.module.module_manager import ModuleManager |
| 42 | +from adf_core_python.core.component.module.algorithm.clustering import Clustering |
| 43 | +from adf_core_python.core.logger.logger import get_logger |
| 44 | +from rcrs_core.connection.URN import Entity as EntityURN |
| 45 | +from rcrs_core.entities.ambulanceCenter import AmbulanceCentre |
| 46 | +from rcrs_core.entities.building import Building |
| 47 | +from rcrs_core.entities.entity import Entity |
| 48 | +from rcrs_core.entities.fireStation import FireStation |
| 49 | +from rcrs_core.entities.gasStation import GasStation |
| 50 | +from rcrs_core.entities.hydrant import Hydrant |
| 51 | +from rcrs_core.entities.policeOffice import PoliceOffice |
| 52 | +from rcrs_core.entities.refuge import Refuge |
| 53 | +from rcrs_core.entities.road import Road |
| 54 | +from rcrs_core.worldmodel.entityID import EntityID |
| 55 | +from scipy.optimize import linear_sum_assignment |
| 56 | +from sklearn.cluster import KMeans |
| 57 | + |
| 58 | +# クラスタリングのシード値 |
| 59 | +SEED = 42 |
| 60 | + |
| 61 | + |
| 62 | +class KMeansPPClustering(Clustering): |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + agent_info: AgentInfo, |
| 66 | + world_info: WorldInfo, |
| 67 | + scenario_info: ScenarioInfo, |
| 68 | + module_manager: ModuleManager, |
| 69 | + develop_data: DevelopData, |
| 70 | + ) -> None: |
| 71 | + super().__init__( |
| 72 | + agent_info, world_info, scenario_info, module_manager, develop_data |
| 73 | + ) |
| 74 | + self._logger = get_logger(f"{self.__class__.__name__}") |
| 75 | + |
| 76 | + # クラスター数の設定 |
| 77 | + self._cluster_number: int = 1 |
| 78 | + match agent_info.get_myself().get_urn(): |
| 79 | + case EntityURN.AMBULANCE_TEAM: |
| 80 | + self._cluster_number = scenario_info.get_value( |
| 81 | + ScenarioInfoKeys.SCENARIO_AGENTS_AT, |
| 82 | + 1, |
| 83 | + ) |
| 84 | + case EntityURN.POLICE_FORCE: |
| 85 | + self._cluster_number = scenario_info.get_value( |
| 86 | + ScenarioInfoKeys.SCENARIO_AGENTS_PF, |
| 87 | + 1, |
| 88 | + ) |
| 89 | + case EntityURN.FIRE_BRIGADE: |
| 90 | + self._cluster_number = scenario_info.get_value( |
| 91 | + ScenarioInfoKeys.SCENARIO_AGENTS_FB, |
| 92 | + 1, |
| 93 | + ) |
| 94 | + |
| 95 | + # 自分と同じクラスのエージェントのリストを取得 |
| 96 | + self._agents: list[Entity] = world_info.get_entities_of_types( |
| 97 | + [ |
| 98 | + agent_info.get_myself().__class__, |
| 99 | + ] |
| 100 | + ) |
| 101 | + |
| 102 | + # クラスタリング結果を保持する変数 |
| 103 | + self._cluster_entities: list[list[Entity]] = [] |
| 104 | + |
| 105 | + # クラスタリング対象のエンティティのリストを取得 |
| 106 | + self._entities: list[Entity] = world_info.get_entities_of_types( |
| 107 | + [ |
| 108 | + AmbulanceCentre, |
| 109 | + FireStation, |
| 110 | + GasStation, |
| 111 | + Hydrant, |
| 112 | + PoliceOffice, |
| 113 | + Refuge, |
| 114 | + Road, |
| 115 | + Building, |
| 116 | + ] |
| 117 | + ) |
| 118 | + |
| 119 | + def calculate(self) -> Clustering: |
| 120 | + return self |
| 121 | + |
| 122 | + def get_cluster_number(self) -> int: |
| 123 | + """ |
| 124 | + クラスター数を取得する |
| 125 | +
|
| 126 | + Returns |
| 127 | + ------- |
| 128 | + int |
| 129 | + クラスター数 |
| 130 | + """ |
| 131 | + return self._cluster_number |
| 132 | + |
| 133 | + def get_cluster_index(self, entity_id: EntityID) -> int: |
| 134 | + """ |
| 135 | + エージェントに割り当てられたクラスターのインデックスを取得する |
| 136 | +
|
| 137 | + Parameters |
| 138 | + ---------- |
| 139 | + entity_id : EntityID |
| 140 | + エージェントのID |
| 141 | +
|
| 142 | + Returns |
| 143 | + ------- |
| 144 | + int |
| 145 | + クラスターのインデックス |
| 146 | + """ |
| 147 | + return self._agent_cluster_indices.get(entity_id, 0) |
| 148 | + |
| 149 | + def get_cluster_entities(self, cluster_index: int) -> list[Entity]: |
| 150 | + """ |
| 151 | + クラスターのエンティティのリストを取得する |
| 152 | +
|
| 153 | + Parameters |
| 154 | + ---------- |
| 155 | + cluster_index : int |
| 156 | + クラスターのインデックス |
| 157 | +
|
| 158 | + Returns |
| 159 | + ------- |
| 160 | + list[Entity] |
| 161 | + クラスターのエンティティのリスト |
| 162 | + """ |
| 163 | + if cluster_index >= len(self._cluster_entities): |
| 164 | + return [] |
| 165 | + return self._cluster_entities[cluster_index] |
| 166 | + |
| 167 | + def get_cluster_entity_ids(self, cluster_index: int) -> list[EntityID]: |
| 168 | + """ |
| 169 | + クラスターのエンティティのIDのリストを取得する |
| 170 | +
|
| 171 | + Parameters |
| 172 | + ---------- |
| 173 | + cluster_index : int |
| 174 | + クラスターのインデックス |
| 175 | +
|
| 176 | + Returns |
| 177 | + ------- |
| 178 | + list[EntityID] |
| 179 | + クラスターのエンティティのIDのリスト |
| 180 | + """ |
| 181 | + if cluster_index >= len(self._cluster_entities): |
| 182 | + return [] |
| 183 | + return [entity.get_id() for entity in self._cluster_entities[cluster_index]] |
| 184 | + |
| 185 | + def prepare(self) -> Clustering: |
| 186 | + """ |
| 187 | + エージェントの起動時に一回のみ実行される処理 |
| 188 | + """ |
| 189 | + super().prepare() |
| 190 | + if self.get_count_prepare() > 1: |
| 191 | + return self |
| 192 | + |
| 193 | + # クラスタリングを実行 |
| 194 | + kmeans_pp = self._perform_kmeans_pp(self._entities, self._cluster_number) |
| 195 | + |
| 196 | + # クラスタリング結果を保持 |
| 197 | + self._cluster_entities = [[] for _ in range(self._cluster_number)] |
| 198 | + for entity, cluster_index in zip(self._entities, kmeans_pp.labels_): |
| 199 | + self._cluster_entities[cluster_index].append(entity) |
| 200 | + |
| 201 | + # エージェントとクラスターのエンティティの距離を計算し、最も全体の合計の距離が短くなるようにエージェントとクラスターを対応付ける |
| 202 | + agent_cluster_indices = self._agent_cluster_assignment( |
| 203 | + self._agents, kmeans_pp.cluster_centers_ |
| 204 | + ) |
| 205 | + |
| 206 | + # エージェントとクラスターの対応付け結果を保持 |
| 207 | + self._agent_cluster_indices = { |
| 208 | + entity.get_id(): cluster_index |
| 209 | + for entity, cluster_index in zip(self._agents, agent_cluster_indices) |
| 210 | + } |
| 211 | + |
| 212 | + # デバッグ用のログ出力 |
| 213 | + self._logger.info( |
| 214 | + f"Clustered entities: {[[entity.get_id().get_value() for entity in cluster] for cluster in self._cluster_entities]}" |
| 215 | + ) |
| 216 | + |
| 217 | + self._logger.info( |
| 218 | + f"Agent cluster indices: {[([self._world_info.get_entity(entity_id).get_x(), self._world_info.get_entity(entity_id).get_y()], int(cluster_index)) for entity_id, cluster_index in self._agent_cluster_indices.items()]}" |
| 219 | + ) |
| 220 | + |
| 221 | + return self |
| 222 | + |
| 223 | + def _perform_kmeans_pp(self, entities: list[Entity], n_clusters: int = 1) -> KMeans: |
| 224 | + """ |
| 225 | + K-means++法によるクラスタリングを実行する |
| 226 | +
|
| 227 | + Parameters |
| 228 | + ---------- |
| 229 | + entities : list[Entity] |
| 230 | + クラスタリング対象のエンティティのリスト |
| 231 | +
|
| 232 | + n_clusters : int, optional |
| 233 | + クラスター数, by default 1 |
| 234 | +
|
| 235 | + Returns |
| 236 | + ------- |
| 237 | + KMeans |
| 238 | + クラスタリング結果 |
| 239 | + """ |
| 240 | + entity_positions: np.ndarray = np.array( |
| 241 | + [ |
| 242 | + [entity.get_x(), entity.get_y()] |
| 243 | + for entity in entities |
| 244 | + if entity.get_x() is not None and entity.get_y() is not None |
| 245 | + ] |
| 246 | + ) |
| 247 | + |
| 248 | + entity_positions = entity_positions.reshape(-1, 2) |
| 249 | + kmeans_pp = KMeans( |
| 250 | + n_clusters=n_clusters, |
| 251 | + init="k-means++", |
| 252 | + random_state=SEED, |
| 253 | + ) |
| 254 | + kmeans_pp.fit(entity_positions) |
| 255 | + return kmeans_pp |
| 256 | + |
| 257 | + def _agent_cluster_assignment( |
| 258 | + self, agents: list[Entity], cluster_positions: np.ndarray |
| 259 | + ) -> np.ndarray: |
| 260 | + """ |
| 261 | + エージェントとクラスターの対応付けを行う |
| 262 | +
|
| 263 | + Parameters |
| 264 | + ---------- |
| 265 | + agents : list[Entity] |
| 266 | + エージェントのリスト |
| 267 | +
|
| 268 | + cluster_positions : np.ndarray |
| 269 | + クラスターの位置のリスト |
| 270 | +
|
| 271 | + Returns |
| 272 | + ------- |
| 273 | + np.ndarray |
| 274 | + エージェントとクラスターの対応付け結果 |
| 275 | + """ |
| 276 | + agent_positions = np.array( |
| 277 | + [ |
| 278 | + [agent.get_x(), agent.get_y()] |
| 279 | + for agent in agents |
| 280 | + if agent.get_x() is not None and agent.get_y() is not None |
| 281 | + ] |
| 282 | + ) |
| 283 | + |
| 284 | + agent_positions = agent_positions.reshape(-1, 2) |
| 285 | + cost_matrix = np.linalg.norm( |
| 286 | + agent_positions[:, np.newaxis] - cluster_positions, axis=2 |
| 287 | + ) |
| 288 | + _, col_ind = linear_sum_assignment(cost_matrix) |
| 289 | + return col_ind |
| 290 | +``` |
| 291 | + |
| 292 | +次に、作成したモジュールを登録します。`config/module.yaml` を以下のように編集してください。 |
| 293 | + |
| 294 | +```yaml |
| 295 | +SampleSearch: |
| 296 | + PathPlanning: adf_core_python.implement.module.algorithm.a_star_path_planning.AStarPathPlanning |
| 297 | + Clustering: src.test-agent.module.algorithm.k_means_pp_clustering.KMeansPPClustering |
| 298 | + |
| 299 | +SampleHumanDetector: |
| 300 | + Clustering: src.test-agent.module.algorithm.k_means_pp_clustering.KMeansPPClustering |
| 301 | +``` |
| 302 | +
|
| 303 | +
|
0 commit comments