pacai.gridworld.gamestate
1import logging 2import math 3import random 4import typing 5 6import PIL.Image 7import PIL.ImageDraw 8 9import pacai.core.action 10import pacai.core.agentaction 11import pacai.core.gamestate 12import pacai.core.board 13import pacai.core.font 14import pacai.core.spritesheet 15import pacai.gridworld.board 16import pacai.gridworld.mdp 17 18AGENT_INDEX: int = 0 19""" The fixed index of the only agent. """ 20 21QVALUE_TRIANGLE_POINT_OFFSETS: list[tuple[tuple[float, float], tuple[float, float], tuple[float, float]]] = [ 22 ((0.0, 0.0), (1.0, 0.0), (0.5, 0.5)), 23 ((1.0, 0.0), (1.0, 1.0), (0.5, 0.5)), 24 ((1.0, 1.0), (0.0, 1.0), (0.5, 0.5)), 25 ((0.0, 1.0), (0.0, 0.0), (0.5, 0.5)), 26] 27""" 28Offsets (as position dimensions) of the points for Q-Value triangles. 29Indexes line up with pacai.core.action.CARDINAL_DIRECTIONS. 30""" 31 32TRIANGLE_WIDTH: int = 1 33""" Width of the Q-Value triangle borders. """ 34 35NON_SERIALIZED_FIELDS: list[str] = [ 36 '_mdp_state_values', 37 '_minmax_mdp_state_values', 38 '_policy', 39 '_qvalues', 40 '_minmax_qvalues', 41] 42 43class GameState(pacai.core.gamestate.GameState): 44 """ A game state specific to a standard GridWorld game. """ 45 46 def __init__(self, **kwargs: typing.Any) -> None: 47 super().__init__(**kwargs) 48 49 self._win: bool = False 50 """ Keep track if the agent exited the game on a winning state. """ 51 52 self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {} 53 """ 54 The MDP state values computed by the agent. 55 This member will not be serialized. 56 """ 57 58 self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0) 59 """ 60 The min and max MDP state values computed by the agent. 61 This member will not be serialized. 62 """ 63 64 self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {} 65 """ 66 The policy computed by the agent. 67 This member will not be serialized. 68 """ 69 70 self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {} 71 """ 72 The Q-values computed by the agent. 73 This member will not be serialized. 74 """ 75 76 self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0) 77 """ 78 The min and max Q-values computed by the agent. 79 This member will not be serialized. 80 """ 81 82 def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None: 83 if (AGENT_INDEX not in agent_responses): 84 return 85 86 agent_action = agent_responses[AGENT_INDEX].agent_action 87 if (agent_action is None): 88 return 89 90 if ('mdp_state_values' in agent_action.other_info): 91 for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']: 92 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 93 self._mdp_state_values[mdp_state] = value 94 95 min_value = min(self._mdp_state_values.values()) 96 max_value = max(self._mdp_state_values.values()) 97 self._minmax_mdp_state_values = (min_value, max_value) 98 99 if ('policy' in agent_action.other_info): 100 for (raw_mdp_state, raw_action) in agent_action.other_info['policy']: 101 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 102 action = pacai.core.action.Action(raw_action) 103 self._policy[mdp_state] = action 104 105 if ('qvalues' in agent_action.other_info): 106 values = [] 107 108 for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']: 109 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 110 action = pacai.core.action.Action(raw_action) 111 112 if (mdp_state not in self._qvalues): 113 self._qvalues[mdp_state] = {} 114 115 self._qvalues[mdp_state][action] = qvalue 116 values.append(qvalue) 117 118 self._minmax_qvalues = (min(values), max(values)) 119 120 def game_complete(self) -> list[int]: 121 # If the agent exited on a positive terminal position, they win. 122 if (self._win): 123 return [AGENT_INDEX] 124 125 return [] 126 127 def get_legal_actions(self, position: pacai.core.board.Position | None = None) -> list[pacai.core.action.Action]: 128 board = typing.cast(pacai.gridworld.board.Board, self.board) 129 130 if (position is None): 131 position = self.get_agent_position() 132 133 # If we are on a terminal position, we can only exit. 134 if ((position is not None) and board.is_terminal_position(position)): 135 return [pacai.core.mdp.ACTION_EXIT] 136 137 # Otherwise, all cardinal directions (and STOP) are legal actions. 138 return pacai.core.action.CARDINAL_DIRECTIONS + [pacai.core.action.STOP] 139 140 def sprite_lookup(self, 141 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 142 position: pacai.core.board.Position, 143 marker: pacai.core.board.Marker | None = None, 144 action: pacai.core.action.Action | None = None, 145 adjacency: pacai.core.board.AdjacencyString | None = None, 146 animation_key: str | None = None, 147 ) -> PIL.Image.Image: 148 board = typing.cast(pacai.gridworld.board.Board, self.board) 149 150 sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key) 151 152 if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE): 153 # Draw MDP state values and policies. 154 sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position) 155 elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE): 156 # Draw Q-values. 157 sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position) 158 elif ((marker == pacai.gridworld.board.MARKER_TERMINAL) 159 and ((position.row > board._original_height) or (position.col > board._original_width))): 160 # Draw terminal values on the extended q-display. 161 sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position) 162 163 return sprite 164 165 def skip_draw(self, 166 marker: pacai.core.board.Marker, 167 position: pacai.core.board.Position, 168 static: bool = False, 169 ) -> bool: 170 if (static): 171 return False 172 173 board = typing.cast(pacai.gridworld.board.Board, self.board) 174 return (position.row >= board._original_height) or (position.col >= board._original_width) 175 176 def get_static_positions(self) -> list[pacai.core.board.Position]: 177 board = typing.cast(pacai.gridworld.board.Board, self.board) 178 179 positions = [] 180 for row in range(board.height): 181 for col in range(board.width): 182 if (row >= board._original_height) or (col >= board._original_width): 183 positions.append(pacai.core.board.Position(row, col)) 184 185 return positions 186 187 def _add_terminal_sprite_info(self, 188 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 189 sprite: PIL.Image.Image, 190 position: pacai.core.board.Position) -> PIL.Image.Image: 191 """ Add coloring to the terminal positions. """ 192 193 board = typing.cast(pacai.gridworld.board.Board, self.board) 194 195 if (not board.is_terminal_position(position)): 196 return sprite 197 198 sprite = sprite.copy() 199 canvas = PIL.ImageDraw.Draw(sprite) 200 201 min_value = min(board._terminal_values.values()) 202 max_value = max(board._terminal_values.values()) 203 204 value = board.get_terminal_value(position) 205 color = self._red_green_gradient(value, min_value, max_value) 206 207 points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)] 208 canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1) 209 210 return sprite 211 212 def _add_mdp_state_value_sprite_info(self, 213 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 214 sprite: PIL.Image.Image, 215 position: pacai.core.board.Position) -> PIL.Image.Image: 216 """ Add the colored q-value triangles to the sprite. """ 217 218 board = typing.cast(pacai.gridworld.board.Board, self.board) 219 220 sprite = sprite.copy() 221 canvas = PIL.ImageDraw.Draw(sprite) 222 223 # The offset from the visualization position to the true board position. 224 # The MDP state values are to the right of the true board. 225 base_offset = pacai.core.board.Position(0, -(board._original_width + 1)) 226 base_position = position.add(base_offset) 227 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 228 229 value = self._mdp_state_values.get(mdp_state, 0.0) 230 color = self._red_green_gradient(value, self._minmax_mdp_state_values[0], self._minmax_mdp_state_values[1]) 231 232 points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)] 233 canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1) 234 235 return sprite 236 237 def _add_qvalue_sprite_info(self, 238 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 239 sprite: PIL.Image.Image, 240 position: pacai.core.board.Position) -> PIL.Image.Image: 241 """ Add the colored q-value triangles to the sprite. """ 242 243 board = typing.cast(pacai.gridworld.board.Board, self.board) 244 245 sprite = sprite.copy() 246 canvas = PIL.ImageDraw.Draw(sprite) 247 248 # The offset from the visualization position to the true board position. 249 # The Q-values are below the true board. 250 base_offset = pacai.core.board.Position(-(board._original_height + 1), 0) 251 base_position = position.add(base_offset) 252 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 253 254 for (direction_index, point_offsets) in enumerate(QVALUE_TRIANGLE_POINT_OFFSETS): 255 points = [] 256 257 for point_offset in point_offsets: 258 # Offset the outer points of the triangle towards the inside of the triangle to avoid border overlaps. 259 origin = [0, 0] 260 for (i, offset) in enumerate(point_offset): 261 if (math.isclose(offset, 0.0)): 262 origin[i] = TRIANGLE_WIDTH 263 elif (math.isclose(offset, 1.0)): 264 origin[i] = -TRIANGLE_WIDTH 265 266 point = ( 267 (origin[0] + (sprite_sheet.width * point_offset[0])), 268 (origin[1] + (sprite_sheet.height * point_offset[1])), 269 ) 270 points.append(tuple(point)) 271 272 qvalue = self._qvalues.get(mdp_state, {}).get(pacai.core.action.CARDINAL_DIRECTIONS[direction_index], 0.0) 273 color = self._red_green_gradient(qvalue, self._minmax_qvalues[0], self._minmax_qvalues[1]) 274 canvas.polygon(points, fill = color, outline = sprite_sheet.text, width = 1) 275 276 return sprite 277 278 def _red_green_gradient(self, value: float, 279 min_value: float, max_value: float, divider: float = 0.0, 280 blue_intensity: float = 0.00) -> tuple[int, int, int]: 281 """ 282 Get a color (RGB) between red (min) and green (max) based on the given value. 283 Values under the divider will always be red, and values over will always be green. 284 """ 285 286 if ((min_value > divider) or (divider > max_value)): 287 raise ValueError(("Gradient values are not in the correct order." 288 + f"Found: min = {min_value}, divider = {divider}, max = {max_value}.")) 289 290 red_intensity = 0.0 291 green_intensity = 0.0 292 293 red_mass = max(0.01, divider - min_value) 294 green_mass = max(0.01, max_value - divider) 295 296 value = min(max_value, max(min_value, value)) 297 298 if (math.isclose(value, divider)): 299 blue_intensity = 1.0 300 red_intensity = 0.25 301 green_intensity = 0.25 302 elif (value < divider): 303 red_intensity = 0.25 + (0.75 * (divider - value) / red_mass) 304 else: 305 green_intensity = 0.25 + (0.75 * (value - divider) / green_mass) 306 307 return (int(255 * red_intensity), int(255 * green_intensity), int(255 * blue_intensity)) 308 309 def process_turn(self, 310 action: pacai.core.action.Action, 311 rng: random.Random | None = None, 312 mdp: pacai.gridworld.mdp.GridWorldMDP | None = None, 313 **kwargs: typing.Any) -> None: 314 if (rng is None): 315 logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().") 316 rng = random.Random(4) 317 318 if (mdp is None): 319 raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().") 320 321 board = typing.cast(pacai.gridworld.board.Board, self.board) 322 323 # Get the possible transitions from the MDP. 324 transitions = mdp.get_transitions(mdp.get_starting_state(), action) 325 326 # If there is only an exit transition, exit. 327 if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)): 328 logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT) 329 self.game_over = True 330 return 331 332 # Choose a transition. 333 transition = self._choose_transition(transitions, rng) 334 335 # Apply the transition. 336 337 self.score += transition.reward 338 339 old_position = self.get_agent_position(AGENT_INDEX) 340 if (old_position is None): 341 raise ValueError("GridWorld agent was removed from board.") 342 343 new_position = transition.state.position 344 345 if (old_position != new_position): 346 board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position) 347 board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position) 348 349 # Check if we are going to "win". 350 # The reward for a terminal state is awarded on the action before EXIT. 351 if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)): 352 self._win = True 353 354 logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward) 355 356 def _choose_transition(self, 357 transitions: list[pacai.core.mdp.Transition], 358 rng: random.Random) -> pacai.core.mdp.Transition: 359 probability_sum = 0.0 360 point = rng.random() 361 362 for transition in transitions: 363 probability_sum += transition.probability 364 if (probability_sum > 1.0): 365 raise ValueError(f"Transition probabilities is over 1.0, found at least {probability_sum}.") 366 367 if (point < probability_sum): 368 return transition 369 370 raise ValueError(f"Transition probabilities is less than 1.0, found {probability_sum}.") 371 372 def get_static_text(self) -> list[pacai.core.font.BoardText]: 373 board = typing.cast(pacai.gridworld.board.Board, self.board) 374 375 texts = [] 376 377 # Add on terminal values. 378 for (position, value) in board._terminal_values.items(): 379 text_color = None 380 if ((position.row > board._original_height) or (position.col > board._original_width)): 381 text_color = (0, 0, 0) 382 383 texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color)) 384 385 # If we are using the extended display, fill in all the information. 386 if (board.display_qvalues()): 387 texts += self._get_qdisplay_static_text() 388 389 return texts 390 391 def _get_qdisplay_static_text(self) -> list[pacai.core.font.BoardText]: 392 texts = [] 393 394 board = typing.cast(pacai.gridworld.board.Board, self.board) 395 396 # Add labels on the separator. 397 row = board._original_height 398 texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, 1), ' ↓ Q-Values')) 399 texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, board._original_width + 3), ' ↑ Values & Policy')) 400 401 # Add text for MDP state values and policies. 402 texts += self._get_qdisplay_static_text_mdp_state_values() 403 404 # Add text for Q-values. 405 texts += self._get_qdisplay_static_text_qvalues() 406 407 return texts 408 409 def _get_qdisplay_static_text_mdp_state_values(self) -> list[pacai.core.font.BoardText]: 410 texts = [] 411 412 board = typing.cast(pacai.gridworld.board.Board, self.board) 413 414 # The offset from the visualization position to the true board position. 415 # The MDP state values are to the right of the true board. 416 base_offset = pacai.core.board.Position(0, -(board._original_width + 1)) 417 418 for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_VALUE): 419 base_position = position.add(base_offset) 420 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 421 422 value = self._mdp_state_values.get(mdp_state, None) 423 policy_action = self._policy.get(mdp_state, pacai.core.action.STOP) 424 425 value_text = '?' 426 if (value is not None): 427 value_text = f"{value:0.2f}" 428 429 policy_text = pacai.core.action.CARDINAL_DIRECTION_ARROWS.get(policy_action, '?') 430 431 text = f"{value_text}\n{policy_text}" 432 433 texts.append(pacai.core.font.BoardText(position, text, 434 size = pacai.core.font.FontSize.SMALL, 435 color = (0, 0, 0))) 436 437 return texts 438 439 def _get_qdisplay_static_text_qvalues(self) -> list[pacai.core.font.BoardText]: 440 texts = [] 441 442 board = typing.cast(pacai.gridworld.board.Board, self.board) 443 444 # The offset from the visualization position to the true board position. 445 # The Q-values are below the true board. 446 base_offset = pacai.core.board.Position(-(board._original_height + 1), 0) 447 448 for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_QVALUE): 449 base_position = position.add(base_offset) 450 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 451 452 # [(vertical alignment, horizontal alignment), ...] 453 alignments = [ 454 (pacai.core.font.TextVerticalAlign.TOP, pacai.core.font.TextHorizontalAlign.CENTER), 455 (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.RIGHT), 456 (pacai.core.font.TextVerticalAlign.BOTTOM, pacai.core.font.TextHorizontalAlign.CENTER), 457 (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.LEFT), 458 ] 459 460 for (i, alignment) in enumerate(alignments): 461 action = pacai.core.action.CARDINAL_DIRECTIONS[i] 462 qvalue = self._qvalues.get(mdp_state, {}).get(action, None) 463 464 text = '?' 465 if (qvalue is not None): 466 text = f"{qvalue:0.2f}" 467 468 vertical_align, horizontal_align = alignment 469 texts.append(pacai.core.font.BoardText( 470 position, text, 471 size = pacai.core.font.FontSize.TINY, 472 vertical_align = vertical_align, 473 horizontal_align = horizontal_align, 474 color = (0, 0, 0))) 475 476 return texts 477 478 def to_dict(self) -> dict[str, typing.Any]: 479 data = super().to_dict() 480 data['_win'] = self._win 481 482 for key in NON_SERIALIZED_FIELDS: 483 if (key in data): 484 del data[key] 485 486 return data 487 488 @classmethod 489 def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any: 490 game_state = super().from_dict(data) 491 game_state._win = data['_win'] 492 return game_state
The fixed index of the only agent.
Offsets (as position dimensions) of the points for Q-Value triangles. Indexes line up with pacai.core.action.CARDINAL_DIRECTIONS.
Width of the Q-Value triangle borders.
44class GameState(pacai.core.gamestate.GameState): 45 """ A game state specific to a standard GridWorld game. """ 46 47 def __init__(self, **kwargs: typing.Any) -> None: 48 super().__init__(**kwargs) 49 50 self._win: bool = False 51 """ Keep track if the agent exited the game on a winning state. """ 52 53 self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {} 54 """ 55 The MDP state values computed by the agent. 56 This member will not be serialized. 57 """ 58 59 self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0) 60 """ 61 The min and max MDP state values computed by the agent. 62 This member will not be serialized. 63 """ 64 65 self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {} 66 """ 67 The policy computed by the agent. 68 This member will not be serialized. 69 """ 70 71 self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {} 72 """ 73 The Q-values computed by the agent. 74 This member will not be serialized. 75 """ 76 77 self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0) 78 """ 79 The min and max Q-values computed by the agent. 80 This member will not be serialized. 81 """ 82 83 def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None: 84 if (AGENT_INDEX not in agent_responses): 85 return 86 87 agent_action = agent_responses[AGENT_INDEX].agent_action 88 if (agent_action is None): 89 return 90 91 if ('mdp_state_values' in agent_action.other_info): 92 for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']: 93 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 94 self._mdp_state_values[mdp_state] = value 95 96 min_value = min(self._mdp_state_values.values()) 97 max_value = max(self._mdp_state_values.values()) 98 self._minmax_mdp_state_values = (min_value, max_value) 99 100 if ('policy' in agent_action.other_info): 101 for (raw_mdp_state, raw_action) in agent_action.other_info['policy']: 102 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 103 action = pacai.core.action.Action(raw_action) 104 self._policy[mdp_state] = action 105 106 if ('qvalues' in agent_action.other_info): 107 values = [] 108 109 for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']: 110 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 111 action = pacai.core.action.Action(raw_action) 112 113 if (mdp_state not in self._qvalues): 114 self._qvalues[mdp_state] = {} 115 116 self._qvalues[mdp_state][action] = qvalue 117 values.append(qvalue) 118 119 self._minmax_qvalues = (min(values), max(values)) 120 121 def game_complete(self) -> list[int]: 122 # If the agent exited on a positive terminal position, they win. 123 if (self._win): 124 return [AGENT_INDEX] 125 126 return [] 127 128 def get_legal_actions(self, position: pacai.core.board.Position | None = None) -> list[pacai.core.action.Action]: 129 board = typing.cast(pacai.gridworld.board.Board, self.board) 130 131 if (position is None): 132 position = self.get_agent_position() 133 134 # If we are on a terminal position, we can only exit. 135 if ((position is not None) and board.is_terminal_position(position)): 136 return [pacai.core.mdp.ACTION_EXIT] 137 138 # Otherwise, all cardinal directions (and STOP) are legal actions. 139 return pacai.core.action.CARDINAL_DIRECTIONS + [pacai.core.action.STOP] 140 141 def sprite_lookup(self, 142 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 143 position: pacai.core.board.Position, 144 marker: pacai.core.board.Marker | None = None, 145 action: pacai.core.action.Action | None = None, 146 adjacency: pacai.core.board.AdjacencyString | None = None, 147 animation_key: str | None = None, 148 ) -> PIL.Image.Image: 149 board = typing.cast(pacai.gridworld.board.Board, self.board) 150 151 sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key) 152 153 if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE): 154 # Draw MDP state values and policies. 155 sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position) 156 elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE): 157 # Draw Q-values. 158 sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position) 159 elif ((marker == pacai.gridworld.board.MARKER_TERMINAL) 160 and ((position.row > board._original_height) or (position.col > board._original_width))): 161 # Draw terminal values on the extended q-display. 162 sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position) 163 164 return sprite 165 166 def skip_draw(self, 167 marker: pacai.core.board.Marker, 168 position: pacai.core.board.Position, 169 static: bool = False, 170 ) -> bool: 171 if (static): 172 return False 173 174 board = typing.cast(pacai.gridworld.board.Board, self.board) 175 return (position.row >= board._original_height) or (position.col >= board._original_width) 176 177 def get_static_positions(self) -> list[pacai.core.board.Position]: 178 board = typing.cast(pacai.gridworld.board.Board, self.board) 179 180 positions = [] 181 for row in range(board.height): 182 for col in range(board.width): 183 if (row >= board._original_height) or (col >= board._original_width): 184 positions.append(pacai.core.board.Position(row, col)) 185 186 return positions 187 188 def _add_terminal_sprite_info(self, 189 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 190 sprite: PIL.Image.Image, 191 position: pacai.core.board.Position) -> PIL.Image.Image: 192 """ Add coloring to the terminal positions. """ 193 194 board = typing.cast(pacai.gridworld.board.Board, self.board) 195 196 if (not board.is_terminal_position(position)): 197 return sprite 198 199 sprite = sprite.copy() 200 canvas = PIL.ImageDraw.Draw(sprite) 201 202 min_value = min(board._terminal_values.values()) 203 max_value = max(board._terminal_values.values()) 204 205 value = board.get_terminal_value(position) 206 color = self._red_green_gradient(value, min_value, max_value) 207 208 points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)] 209 canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1) 210 211 return sprite 212 213 def _add_mdp_state_value_sprite_info(self, 214 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 215 sprite: PIL.Image.Image, 216 position: pacai.core.board.Position) -> PIL.Image.Image: 217 """ Add the colored q-value triangles to the sprite. """ 218 219 board = typing.cast(pacai.gridworld.board.Board, self.board) 220 221 sprite = sprite.copy() 222 canvas = PIL.ImageDraw.Draw(sprite) 223 224 # The offset from the visualization position to the true board position. 225 # The MDP state values are to the right of the true board. 226 base_offset = pacai.core.board.Position(0, -(board._original_width + 1)) 227 base_position = position.add(base_offset) 228 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 229 230 value = self._mdp_state_values.get(mdp_state, 0.0) 231 color = self._red_green_gradient(value, self._minmax_mdp_state_values[0], self._minmax_mdp_state_values[1]) 232 233 points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)] 234 canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1) 235 236 return sprite 237 238 def _add_qvalue_sprite_info(self, 239 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 240 sprite: PIL.Image.Image, 241 position: pacai.core.board.Position) -> PIL.Image.Image: 242 """ Add the colored q-value triangles to the sprite. """ 243 244 board = typing.cast(pacai.gridworld.board.Board, self.board) 245 246 sprite = sprite.copy() 247 canvas = PIL.ImageDraw.Draw(sprite) 248 249 # The offset from the visualization position to the true board position. 250 # The Q-values are below the true board. 251 base_offset = pacai.core.board.Position(-(board._original_height + 1), 0) 252 base_position = position.add(base_offset) 253 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 254 255 for (direction_index, point_offsets) in enumerate(QVALUE_TRIANGLE_POINT_OFFSETS): 256 points = [] 257 258 for point_offset in point_offsets: 259 # Offset the outer points of the triangle towards the inside of the triangle to avoid border overlaps. 260 origin = [0, 0] 261 for (i, offset) in enumerate(point_offset): 262 if (math.isclose(offset, 0.0)): 263 origin[i] = TRIANGLE_WIDTH 264 elif (math.isclose(offset, 1.0)): 265 origin[i] = -TRIANGLE_WIDTH 266 267 point = ( 268 (origin[0] + (sprite_sheet.width * point_offset[0])), 269 (origin[1] + (sprite_sheet.height * point_offset[1])), 270 ) 271 points.append(tuple(point)) 272 273 qvalue = self._qvalues.get(mdp_state, {}).get(pacai.core.action.CARDINAL_DIRECTIONS[direction_index], 0.0) 274 color = self._red_green_gradient(qvalue, self._minmax_qvalues[0], self._minmax_qvalues[1]) 275 canvas.polygon(points, fill = color, outline = sprite_sheet.text, width = 1) 276 277 return sprite 278 279 def _red_green_gradient(self, value: float, 280 min_value: float, max_value: float, divider: float = 0.0, 281 blue_intensity: float = 0.00) -> tuple[int, int, int]: 282 """ 283 Get a color (RGB) between red (min) and green (max) based on the given value. 284 Values under the divider will always be red, and values over will always be green. 285 """ 286 287 if ((min_value > divider) or (divider > max_value)): 288 raise ValueError(("Gradient values are not in the correct order." 289 + f"Found: min = {min_value}, divider = {divider}, max = {max_value}.")) 290 291 red_intensity = 0.0 292 green_intensity = 0.0 293 294 red_mass = max(0.01, divider - min_value) 295 green_mass = max(0.01, max_value - divider) 296 297 value = min(max_value, max(min_value, value)) 298 299 if (math.isclose(value, divider)): 300 blue_intensity = 1.0 301 red_intensity = 0.25 302 green_intensity = 0.25 303 elif (value < divider): 304 red_intensity = 0.25 + (0.75 * (divider - value) / red_mass) 305 else: 306 green_intensity = 0.25 + (0.75 * (value - divider) / green_mass) 307 308 return (int(255 * red_intensity), int(255 * green_intensity), int(255 * blue_intensity)) 309 310 def process_turn(self, 311 action: pacai.core.action.Action, 312 rng: random.Random | None = None, 313 mdp: pacai.gridworld.mdp.GridWorldMDP | None = None, 314 **kwargs: typing.Any) -> None: 315 if (rng is None): 316 logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().") 317 rng = random.Random(4) 318 319 if (mdp is None): 320 raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().") 321 322 board = typing.cast(pacai.gridworld.board.Board, self.board) 323 324 # Get the possible transitions from the MDP. 325 transitions = mdp.get_transitions(mdp.get_starting_state(), action) 326 327 # If there is only an exit transition, exit. 328 if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)): 329 logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT) 330 self.game_over = True 331 return 332 333 # Choose a transition. 334 transition = self._choose_transition(transitions, rng) 335 336 # Apply the transition. 337 338 self.score += transition.reward 339 340 old_position = self.get_agent_position(AGENT_INDEX) 341 if (old_position is None): 342 raise ValueError("GridWorld agent was removed from board.") 343 344 new_position = transition.state.position 345 346 if (old_position != new_position): 347 board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position) 348 board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position) 349 350 # Check if we are going to "win". 351 # The reward for a terminal state is awarded on the action before EXIT. 352 if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)): 353 self._win = True 354 355 logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward) 356 357 def _choose_transition(self, 358 transitions: list[pacai.core.mdp.Transition], 359 rng: random.Random) -> pacai.core.mdp.Transition: 360 probability_sum = 0.0 361 point = rng.random() 362 363 for transition in transitions: 364 probability_sum += transition.probability 365 if (probability_sum > 1.0): 366 raise ValueError(f"Transition probabilities is over 1.0, found at least {probability_sum}.") 367 368 if (point < probability_sum): 369 return transition 370 371 raise ValueError(f"Transition probabilities is less than 1.0, found {probability_sum}.") 372 373 def get_static_text(self) -> list[pacai.core.font.BoardText]: 374 board = typing.cast(pacai.gridworld.board.Board, self.board) 375 376 texts = [] 377 378 # Add on terminal values. 379 for (position, value) in board._terminal_values.items(): 380 text_color = None 381 if ((position.row > board._original_height) or (position.col > board._original_width)): 382 text_color = (0, 0, 0) 383 384 texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color)) 385 386 # If we are using the extended display, fill in all the information. 387 if (board.display_qvalues()): 388 texts += self._get_qdisplay_static_text() 389 390 return texts 391 392 def _get_qdisplay_static_text(self) -> list[pacai.core.font.BoardText]: 393 texts = [] 394 395 board = typing.cast(pacai.gridworld.board.Board, self.board) 396 397 # Add labels on the separator. 398 row = board._original_height 399 texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, 1), ' ↓ Q-Values')) 400 texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, board._original_width + 3), ' ↑ Values & Policy')) 401 402 # Add text for MDP state values and policies. 403 texts += self._get_qdisplay_static_text_mdp_state_values() 404 405 # Add text for Q-values. 406 texts += self._get_qdisplay_static_text_qvalues() 407 408 return texts 409 410 def _get_qdisplay_static_text_mdp_state_values(self) -> list[pacai.core.font.BoardText]: 411 texts = [] 412 413 board = typing.cast(pacai.gridworld.board.Board, self.board) 414 415 # The offset from the visualization position to the true board position. 416 # The MDP state values are to the right of the true board. 417 base_offset = pacai.core.board.Position(0, -(board._original_width + 1)) 418 419 for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_VALUE): 420 base_position = position.add(base_offset) 421 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 422 423 value = self._mdp_state_values.get(mdp_state, None) 424 policy_action = self._policy.get(mdp_state, pacai.core.action.STOP) 425 426 value_text = '?' 427 if (value is not None): 428 value_text = f"{value:0.2f}" 429 430 policy_text = pacai.core.action.CARDINAL_DIRECTION_ARROWS.get(policy_action, '?') 431 432 text = f"{value_text}\n{policy_text}" 433 434 texts.append(pacai.core.font.BoardText(position, text, 435 size = pacai.core.font.FontSize.SMALL, 436 color = (0, 0, 0))) 437 438 return texts 439 440 def _get_qdisplay_static_text_qvalues(self) -> list[pacai.core.font.BoardText]: 441 texts = [] 442 443 board = typing.cast(pacai.gridworld.board.Board, self.board) 444 445 # The offset from the visualization position to the true board position. 446 # The Q-values are below the true board. 447 base_offset = pacai.core.board.Position(-(board._original_height + 1), 0) 448 449 for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_QVALUE): 450 base_position = position.add(base_offset) 451 mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position) 452 453 # [(vertical alignment, horizontal alignment), ...] 454 alignments = [ 455 (pacai.core.font.TextVerticalAlign.TOP, pacai.core.font.TextHorizontalAlign.CENTER), 456 (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.RIGHT), 457 (pacai.core.font.TextVerticalAlign.BOTTOM, pacai.core.font.TextHorizontalAlign.CENTER), 458 (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.LEFT), 459 ] 460 461 for (i, alignment) in enumerate(alignments): 462 action = pacai.core.action.CARDINAL_DIRECTIONS[i] 463 qvalue = self._qvalues.get(mdp_state, {}).get(action, None) 464 465 text = '?' 466 if (qvalue is not None): 467 text = f"{qvalue:0.2f}" 468 469 vertical_align, horizontal_align = alignment 470 texts.append(pacai.core.font.BoardText( 471 position, text, 472 size = pacai.core.font.FontSize.TINY, 473 vertical_align = vertical_align, 474 horizontal_align = horizontal_align, 475 color = (0, 0, 0))) 476 477 return texts 478 479 def to_dict(self) -> dict[str, typing.Any]: 480 data = super().to_dict() 481 data['_win'] = self._win 482 483 for key in NON_SERIALIZED_FIELDS: 484 if (key in data): 485 del data[key] 486 487 return data 488 489 @classmethod 490 def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any: 491 game_state = super().from_dict(data) 492 game_state._win = data['_win'] 493 return game_state
A game state specific to a standard GridWorld game.
47 def __init__(self, **kwargs: typing.Any) -> None: 48 super().__init__(**kwargs) 49 50 self._win: bool = False 51 """ Keep track if the agent exited the game on a winning state. """ 52 53 self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {} 54 """ 55 The MDP state values computed by the agent. 56 This member will not be serialized. 57 """ 58 59 self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0) 60 """ 61 The min and max MDP state values computed by the agent. 62 This member will not be serialized. 63 """ 64 65 self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {} 66 """ 67 The policy computed by the agent. 68 This member will not be serialized. 69 """ 70 71 self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {} 72 """ 73 The Q-values computed by the agent. 74 This member will not be serialized. 75 """ 76 77 self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0) 78 """ 79 The min and max Q-values computed by the agent. 80 This member will not be serialized. 81 """
83 def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None: 84 if (AGENT_INDEX not in agent_responses): 85 return 86 87 agent_action = agent_responses[AGENT_INDEX].agent_action 88 if (agent_action is None): 89 return 90 91 if ('mdp_state_values' in agent_action.other_info): 92 for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']: 93 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 94 self._mdp_state_values[mdp_state] = value 95 96 min_value = min(self._mdp_state_values.values()) 97 max_value = max(self._mdp_state_values.values()) 98 self._minmax_mdp_state_values = (min_value, max_value) 99 100 if ('policy' in agent_action.other_info): 101 for (raw_mdp_state, raw_action) in agent_action.other_info['policy']: 102 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 103 action = pacai.core.action.Action(raw_action) 104 self._policy[mdp_state] = action 105 106 if ('qvalues' in agent_action.other_info): 107 values = [] 108 109 for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']: 110 mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state) 111 action = pacai.core.action.Action(raw_action) 112 113 if (mdp_state not in self._qvalues): 114 self._qvalues[mdp_state] = {} 115 116 self._qvalues[mdp_state][action] = qvalue 117 values.append(qvalue) 118 119 self._minmax_qvalues = (min(values), max(values))
Indicate that agents have been started.
121 def game_complete(self) -> list[int]: 122 # If the agent exited on a positive terminal position, they win. 123 if (self._win): 124 return [AGENT_INDEX] 125 126 return []
Indicate that the game has ended. The state should take any final actions and return the indexes of the winning agents (if any).
128 def get_legal_actions(self, position: pacai.core.board.Position | None = None) -> list[pacai.core.action.Action]: 129 board = typing.cast(pacai.gridworld.board.Board, self.board) 130 131 if (position is None): 132 position = self.get_agent_position() 133 134 # If we are on a terminal position, we can only exit. 135 if ((position is not None) and board.is_terminal_position(position)): 136 return [pacai.core.mdp.ACTION_EXIT] 137 138 # Otherwise, all cardinal directions (and STOP) are legal actions. 139 return pacai.core.action.CARDINAL_DIRECTIONS + [pacai.core.action.STOP]
Get the moves that the current agent is allowed to make. Stopping is generally always considered a legal action (unless a game re-defines this behavior).
If a position is provided, it will override the current agent's position.
141 def sprite_lookup(self, 142 sprite_sheet: pacai.core.spritesheet.SpriteSheet, 143 position: pacai.core.board.Position, 144 marker: pacai.core.board.Marker | None = None, 145 action: pacai.core.action.Action | None = None, 146 adjacency: pacai.core.board.AdjacencyString | None = None, 147 animation_key: str | None = None, 148 ) -> PIL.Image.Image: 149 board = typing.cast(pacai.gridworld.board.Board, self.board) 150 151 sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key) 152 153 if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE): 154 # Draw MDP state values and policies. 155 sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position) 156 elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE): 157 # Draw Q-values. 158 sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position) 159 elif ((marker == pacai.gridworld.board.MARKER_TERMINAL) 160 and ((position.row > board._original_height) or (position.col > board._original_width))): 161 # Draw terminal values on the extended q-display. 162 sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position) 163 164 return sprite
Lookup the proper sprite for a situation. By default this just calls into the sprite sheet, but children may override for more expressive functionality.
166 def skip_draw(self, 167 marker: pacai.core.board.Marker, 168 position: pacai.core.board.Position, 169 static: bool = False, 170 ) -> bool: 171 if (static): 172 return False 173 174 board = typing.cast(pacai.gridworld.board.Board, self.board) 175 return (position.row >= board._original_height) or (position.col >= board._original_width)
Return true if this marker/position combination should not be drawn on the board.
177 def get_static_positions(self) -> list[pacai.core.board.Position]: 178 board = typing.cast(pacai.gridworld.board.Board, self.board) 179 180 positions = [] 181 for row in range(board.height): 182 for col in range(board.width): 183 if (row >= board._original_height) or (col >= board._original_width): 184 positions.append(pacai.core.board.Position(row, col)) 185 186 return positions
Get a list of positions to draw on the board statically.
310 def process_turn(self, 311 action: pacai.core.action.Action, 312 rng: random.Random | None = None, 313 mdp: pacai.gridworld.mdp.GridWorldMDP | None = None, 314 **kwargs: typing.Any) -> None: 315 if (rng is None): 316 logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().") 317 rng = random.Random(4) 318 319 if (mdp is None): 320 raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().") 321 322 board = typing.cast(pacai.gridworld.board.Board, self.board) 323 324 # Get the possible transitions from the MDP. 325 transitions = mdp.get_transitions(mdp.get_starting_state(), action) 326 327 # If there is only an exit transition, exit. 328 if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)): 329 logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT) 330 self.game_over = True 331 return 332 333 # Choose a transition. 334 transition = self._choose_transition(transitions, rng) 335 336 # Apply the transition. 337 338 self.score += transition.reward 339 340 old_position = self.get_agent_position(AGENT_INDEX) 341 if (old_position is None): 342 raise ValueError("GridWorld agent was removed from board.") 343 344 new_position = transition.state.position 345 346 if (old_position != new_position): 347 board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position) 348 board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position) 349 350 # Check if we are going to "win". 351 # The reward for a terminal state is awarded on the action before EXIT. 352 if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)): 353 self._win = True 354 355 logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward)
Process the current agent's turn with the given action. This may modify the current state. To get a copy of a potential successor state, use generate_successor().
373 def get_static_text(self) -> list[pacai.core.font.BoardText]: 374 board = typing.cast(pacai.gridworld.board.Board, self.board) 375 376 texts = [] 377 378 # Add on terminal values. 379 for (position, value) in board._terminal_values.items(): 380 text_color = None 381 if ((position.row > board._original_height) or (position.col > board._original_width)): 382 text_color = (0, 0, 0) 383 384 texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color)) 385 386 # If we are using the extended display, fill in all the information. 387 if (board.display_qvalues()): 388 texts += self._get_qdisplay_static_text() 389 390 return texts
Get any static text to display on board positions.
479 def to_dict(self) -> dict[str, typing.Any]: 480 data = super().to_dict() 481 data['_win'] = self._win 482 483 for key in NON_SERIALIZED_FIELDS: 484 if (key in data): 485 del data[key] 486 487 return data
Return a dict that can be used to represent this object. If the dict is passed to from_dict(), an identical object should be reconstructed.
489 @classmethod 490 def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any: 491 game_state = super().from_dict(data) 492 game_state._win = data['_win'] 493 return game_state
Return an instance of this subclass created using the given dict. If the dict came from to_dict(), the returned object should be identical to the original.
Inherited Members
- pacai.core.gamestate.GameState
- board
- seed
- agent_index
- last_agent_index
- game_over
- agent_actions
- score
- turn_count
- move_delays
- tickets
- copy
- game_start
- get_num_agents
- get_agent_indexes
- get_agent_positions
- get_agent_position
- get_agent_actions
- get_last_agent_action
- get_reverse_action
- generate_successor
- process_agent_timeout
- process_agent_crash
- process_game_timeout
- process_turn_full
- compute_move_delay
- get_next_agent_index
- get_nonstatic_text