pacai.gridworld.gamestate

  1import logging
  2import math
  3import random
  4import typing
  5
  6import PIL.Image
  7import PIL.ImageDraw
  8
  9import pacai.core.action
 10import pacai.core.agentaction
 11import pacai.core.gamestate
 12import pacai.core.board
 13import pacai.core.font
 14import pacai.core.spritesheet
 15import pacai.gridworld.board
 16import pacai.gridworld.mdp
 17
 18AGENT_INDEX: int = 0
 19""" The fixed index of the only agent. """
 20
 21QVALUE_TRIANGLE_POINT_OFFSETS: list[tuple[tuple[float, float], tuple[float, float], tuple[float, float]]] = [
 22    ((0.0, 0.0), (1.0, 0.0), (0.5, 0.5)),
 23    ((1.0, 0.0), (1.0, 1.0), (0.5, 0.5)),
 24    ((1.0, 1.0), (0.0, 1.0), (0.5, 0.5)),
 25    ((0.0, 1.0), (0.0, 0.0), (0.5, 0.5)),
 26]
 27"""
 28Offsets (as position dimensions) of the points for Q-Value triangles.
 29Indexes line up with pacai.core.action.CARDINAL_DIRECTIONS.
 30"""
 31
 32TRIANGLE_WIDTH: int = 1
 33""" Width of the Q-Value triangle borders. """
 34
 35NON_SERIALIZED_FIELDS: list[str] = [
 36    '_mdp_state_values',
 37    '_minmax_mdp_state_values',
 38    '_policy',
 39    '_qvalues',
 40    '_minmax_qvalues',
 41]
 42
 43class GameState(pacai.core.gamestate.GameState):
 44    """ A game state specific to a standard GridWorld game. """
 45
 46    def __init__(self, **kwargs: typing.Any) -> None:
 47        super().__init__(**kwargs)
 48
 49        self._win: bool = False
 50        """ Keep track if the agent exited the game on a winning state. """
 51
 52        self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {}
 53        """
 54        The MDP state values computed by the agent.
 55        This member will not be serialized.
 56        """
 57
 58        self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0)
 59        """
 60        The min and max MDP state values computed by the agent.
 61        This member will not be serialized.
 62        """
 63
 64        self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {}
 65        """
 66        The policy computed by the agent.
 67        This member will not be serialized.
 68        """
 69
 70        self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {}
 71        """
 72        The Q-values computed by the agent.
 73        This member will not be serialized.
 74        """
 75
 76        self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0)
 77        """
 78        The min and max Q-values computed by the agent.
 79        This member will not be serialized.
 80        """
 81
 82    def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None:
 83        if (AGENT_INDEX not in agent_responses):
 84            return
 85
 86        agent_action = agent_responses[AGENT_INDEX].agent_action
 87        if (agent_action is None):
 88            return
 89
 90        if ('mdp_state_values' in agent_action.other_info):
 91            for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']:
 92                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
 93                self._mdp_state_values[mdp_state] = value
 94
 95                min_value = min(self._mdp_state_values.values())
 96                max_value = max(self._mdp_state_values.values())
 97                self._minmax_mdp_state_values = (min_value, max_value)
 98
 99        if ('policy' in agent_action.other_info):
100            for (raw_mdp_state, raw_action) in agent_action.other_info['policy']:
101                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
102                action = pacai.core.action.Action(raw_action)
103                self._policy[mdp_state] = action
104
105        if ('qvalues' in agent_action.other_info):
106            values = []
107
108            for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']:
109                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
110                action = pacai.core.action.Action(raw_action)
111
112                if (mdp_state not in self._qvalues):
113                    self._qvalues[mdp_state] = {}
114
115                self._qvalues[mdp_state][action] = qvalue
116                values.append(qvalue)
117
118            self._minmax_qvalues = (min(values), max(values))
119
120    def game_complete(self) -> list[int]:
121        # If the agent exited on a positive terminal position, they win.
122        if (self._win):
123            return [AGENT_INDEX]
124
125        return []
126
127    def get_legal_actions(self, position: pacai.core.board.Position | None = None) -> list[pacai.core.action.Action]:
128        board = typing.cast(pacai.gridworld.board.Board, self.board)
129
130        if (position is None):
131            position = self.get_agent_position()
132
133        # If we are on a terminal position, we can only exit.
134        if ((position is not None) and board.is_terminal_position(position)):
135            return [pacai.core.mdp.ACTION_EXIT]
136
137        # Otherwise, all cardinal directions (and STOP) are legal actions.
138        return pacai.core.action.CARDINAL_DIRECTIONS + [pacai.core.action.STOP]
139
140    def sprite_lookup(self,
141            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
142            position: pacai.core.board.Position,
143            marker: pacai.core.board.Marker | None = None,
144            action: pacai.core.action.Action | None = None,
145            adjacency: pacai.core.board.AdjacencyString | None = None,
146            animation_key: str | None = None,
147            ) -> PIL.Image.Image:
148        board = typing.cast(pacai.gridworld.board.Board, self.board)
149
150        sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key)
151
152        if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE):
153            # Draw MDP state values and policies.
154            sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position)
155        elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE):
156            # Draw Q-values.
157            sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position)
158        elif ((marker == pacai.gridworld.board.MARKER_TERMINAL)
159                and ((position.row > board._original_height) or (position.col > board._original_width))):
160            # Draw terminal values on the extended q-display.
161            sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position)
162
163        return sprite
164
165    def skip_draw(self,
166            marker: pacai.core.board.Marker,
167            position: pacai.core.board.Position,
168            static: bool = False,
169            ) -> bool:
170        if (static):
171            return False
172
173        board = typing.cast(pacai.gridworld.board.Board, self.board)
174        return (position.row >= board._original_height) or (position.col >= board._original_width)
175
176    def get_static_positions(self) -> list[pacai.core.board.Position]:
177        board = typing.cast(pacai.gridworld.board.Board, self.board)
178
179        positions = []
180        for row in range(board.height):
181            for col in range(board.width):
182                if (row >= board._original_height) or (col >= board._original_width):
183                    positions.append(pacai.core.board.Position(row, col))
184
185        return positions
186
187    def _add_terminal_sprite_info(self,
188            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
189            sprite: PIL.Image.Image,
190            position: pacai.core.board.Position) -> PIL.Image.Image:
191        """ Add coloring to the terminal positions. """
192
193        board = typing.cast(pacai.gridworld.board.Board, self.board)
194
195        if (not board.is_terminal_position(position)):
196            return sprite
197
198        sprite = sprite.copy()
199        canvas = PIL.ImageDraw.Draw(sprite)
200
201        min_value = min(board._terminal_values.values())
202        max_value = max(board._terminal_values.values())
203
204        value = board.get_terminal_value(position)
205        color = self._red_green_gradient(value, min_value, max_value)
206
207        points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)]
208        canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1)
209
210        return sprite
211
212    def _add_mdp_state_value_sprite_info(self,
213            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
214            sprite: PIL.Image.Image,
215            position: pacai.core.board.Position) -> PIL.Image.Image:
216        """ Add the colored q-value triangles to the sprite. """
217
218        board = typing.cast(pacai.gridworld.board.Board, self.board)
219
220        sprite = sprite.copy()
221        canvas = PIL.ImageDraw.Draw(sprite)
222
223        # The offset from the visualization position to the true board position.
224        # The MDP state values are to the right of the true board.
225        base_offset = pacai.core.board.Position(0, -(board._original_width + 1))
226        base_position = position.add(base_offset)
227        mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
228
229        value = self._mdp_state_values.get(mdp_state, 0.0)
230        color = self._red_green_gradient(value, self._minmax_mdp_state_values[0], self._minmax_mdp_state_values[1])
231
232        points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)]
233        canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1)
234
235        return sprite
236
237    def _add_qvalue_sprite_info(self,
238            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
239            sprite: PIL.Image.Image,
240            position: pacai.core.board.Position) -> PIL.Image.Image:
241        """ Add the colored q-value triangles to the sprite. """
242
243        board = typing.cast(pacai.gridworld.board.Board, self.board)
244
245        sprite = sprite.copy()
246        canvas = PIL.ImageDraw.Draw(sprite)
247
248        # The offset from the visualization position to the true board position.
249        # The Q-values are below the true board.
250        base_offset = pacai.core.board.Position(-(board._original_height + 1), 0)
251        base_position = position.add(base_offset)
252        mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
253
254        for (direction_index, point_offsets) in enumerate(QVALUE_TRIANGLE_POINT_OFFSETS):
255            points = []
256
257            for point_offset in point_offsets:
258                # Offset the outer points of the triangle towards the inside of the triangle to avoid border overlaps.
259                origin = [0, 0]
260                for (i, offset) in enumerate(point_offset):
261                    if (math.isclose(offset, 0.0)):
262                        origin[i] = TRIANGLE_WIDTH
263                    elif (math.isclose(offset, 1.0)):
264                        origin[i] = -TRIANGLE_WIDTH
265
266                point = (
267                    (origin[0] + (sprite_sheet.width * point_offset[0])),
268                    (origin[1] + (sprite_sheet.height * point_offset[1])),
269                )
270                points.append(tuple(point))
271
272            qvalue = self._qvalues.get(mdp_state, {}).get(pacai.core.action.CARDINAL_DIRECTIONS[direction_index], 0.0)
273            color = self._red_green_gradient(qvalue, self._minmax_qvalues[0], self._minmax_qvalues[1])
274            canvas.polygon(points, fill = color, outline = sprite_sheet.text, width = 1)
275
276        return sprite
277
278    def _red_green_gradient(self, value: float,
279            min_value: float, max_value: float, divider: float = 0.0,
280            blue_intensity: float = 0.00) -> tuple[int, int, int]:
281        """
282        Get a color (RGB) between red (min) and green (max) based on the given value.
283        Values under the divider will always be red, and values over will always be green.
284        """
285
286        if ((min_value > divider) or (divider > max_value)):
287            raise ValueError(("Gradient values are not in the correct order."
288                    + f"Found: min = {min_value}, divider = {divider}, max = {max_value}."))
289
290        red_intensity = 0.0
291        green_intensity = 0.0
292
293        red_mass = max(0.01, divider - min_value)
294        green_mass = max(0.01, max_value - divider)
295
296        value = min(max_value, max(min_value, value))
297
298        if (math.isclose(value, divider)):
299            blue_intensity = 1.0
300            red_intensity = 0.25
301            green_intensity = 0.25
302        elif (value < divider):
303            red_intensity = 0.25 + (0.75 * (divider - value) / red_mass)
304        else:
305            green_intensity = 0.25 + (0.75 * (value - divider) / green_mass)
306
307        return (int(255 * red_intensity), int(255 * green_intensity), int(255 * blue_intensity))
308
309    def process_turn(self,
310            action: pacai.core.action.Action,
311            rng: random.Random | None = None,
312            mdp: pacai.gridworld.mdp.GridWorldMDP | None = None,
313            **kwargs: typing.Any) -> None:
314        if (rng is None):
315            logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().")
316            rng = random.Random(4)
317
318        if (mdp is None):
319            raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().")
320
321        board = typing.cast(pacai.gridworld.board.Board, self.board)
322
323        # Get the possible transitions from the MDP.
324        transitions = mdp.get_transitions(mdp.get_starting_state(), action)
325
326        # If there is only an exit transition, exit.
327        if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)):
328            logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT)
329            self.game_over = True
330            return
331
332        # Choose a transition.
333        transition = self._choose_transition(transitions, rng)
334
335        # Apply the transition.
336
337        self.score += transition.reward
338
339        old_position = self.get_agent_position(AGENT_INDEX)
340        if (old_position is None):
341            raise ValueError("GridWorld agent was removed from board.")
342
343        new_position = transition.state.position
344
345        if (old_position != new_position):
346            board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position)
347            board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position)
348
349        # Check if we are going to "win".
350        # The reward for a terminal state is awarded on the action before EXIT.
351        if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)):
352            self._win = True
353
354        logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward)
355
356    def _choose_transition(self,
357            transitions: list[pacai.core.mdp.Transition],
358            rng: random.Random) -> pacai.core.mdp.Transition:
359        probability_sum = 0.0
360        point = rng.random()
361
362        for transition in transitions:
363            probability_sum += transition.probability
364            if (probability_sum > 1.0):
365                raise ValueError(f"Transition probabilities is over 1.0, found at least {probability_sum}.")
366
367            if (point < probability_sum):
368                return transition
369
370        raise ValueError(f"Transition probabilities is less than 1.0, found {probability_sum}.")
371
372    def get_static_text(self) -> list[pacai.core.font.BoardText]:
373        board = typing.cast(pacai.gridworld.board.Board, self.board)
374
375        texts = []
376
377        # Add on terminal values.
378        for (position, value) in board._terminal_values.items():
379            text_color = None
380            if ((position.row > board._original_height) or (position.col > board._original_width)):
381                text_color = (0, 0, 0)
382
383            texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color))
384
385        # If we are using the extended display, fill in all the information.
386        if (board.display_qvalues()):
387            texts += self._get_qdisplay_static_text()
388
389        return texts
390
391    def _get_qdisplay_static_text(self) -> list[pacai.core.font.BoardText]:
392        texts = []
393
394        board = typing.cast(pacai.gridworld.board.Board, self.board)
395
396        # Add labels on the separator.
397        row = board._original_height
398        texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, 1), ' ↓ Q-Values'))
399        texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, board._original_width + 3), '  ↑ Values & Policy'))
400
401        # Add text for MDP state values and policies.
402        texts += self._get_qdisplay_static_text_mdp_state_values()
403
404        # Add text for Q-values.
405        texts += self._get_qdisplay_static_text_qvalues()
406
407        return texts
408
409    def _get_qdisplay_static_text_mdp_state_values(self) -> list[pacai.core.font.BoardText]:
410        texts = []
411
412        board = typing.cast(pacai.gridworld.board.Board, self.board)
413
414        # The offset from the visualization position to the true board position.
415        # The MDP state values are to the right of the true board.
416        base_offset = pacai.core.board.Position(0, -(board._original_width + 1))
417
418        for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_VALUE):
419            base_position = position.add(base_offset)
420            mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
421
422            value = self._mdp_state_values.get(mdp_state, None)
423            policy_action = self._policy.get(mdp_state, pacai.core.action.STOP)
424
425            value_text = '?'
426            if (value is not None):
427                value_text = f"{value:0.2f}"
428
429            policy_text = pacai.core.action.CARDINAL_DIRECTION_ARROWS.get(policy_action, '?')
430
431            text = f"{value_text}\n{policy_text}"
432
433            texts.append(pacai.core.font.BoardText(position, text,
434                    size = pacai.core.font.FontSize.SMALL,
435                    color = (0, 0, 0)))
436
437        return texts
438
439    def _get_qdisplay_static_text_qvalues(self) -> list[pacai.core.font.BoardText]:
440        texts = []
441
442        board = typing.cast(pacai.gridworld.board.Board, self.board)
443
444        # The offset from the visualization position to the true board position.
445        # The Q-values are below the true board.
446        base_offset = pacai.core.board.Position(-(board._original_height + 1), 0)
447
448        for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_QVALUE):
449            base_position = position.add(base_offset)
450            mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
451
452            # [(vertical alignment, horizontal alignment), ...]
453            alignments = [
454                (pacai.core.font.TextVerticalAlign.TOP, pacai.core.font.TextHorizontalAlign.CENTER),
455                (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.RIGHT),
456                (pacai.core.font.TextVerticalAlign.BOTTOM, pacai.core.font.TextHorizontalAlign.CENTER),
457                (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.LEFT),
458            ]
459
460            for (i, alignment) in enumerate(alignments):
461                action = pacai.core.action.CARDINAL_DIRECTIONS[i]
462                qvalue = self._qvalues.get(mdp_state, {}).get(action, None)
463
464                text = '?'
465                if (qvalue is not None):
466                    text = f"{qvalue:0.2f}"
467
468                vertical_align, horizontal_align = alignment
469                texts.append(pacai.core.font.BoardText(
470                        position, text,
471                        size = pacai.core.font.FontSize.TINY,
472                        vertical_align = vertical_align,
473                        horizontal_align = horizontal_align,
474                        color = (0, 0, 0)))
475
476        return texts
477
478    def to_dict(self) -> dict[str, typing.Any]:
479        data = super().to_dict()
480        data['_win'] = self._win
481
482        for key in NON_SERIALIZED_FIELDS:
483            if (key in data):
484                del data[key]
485
486        return data
487
488    @classmethod
489    def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any:
490        game_state = super().from_dict(data)
491        game_state._win = data['_win']
492        return game_state
AGENT_INDEX: int = 0

The fixed index of the only agent.

QVALUE_TRIANGLE_POINT_OFFSETS: list[tuple[tuple[float, float], tuple[float, float], tuple[float, float]]] = [((0.0, 0.0), (1.0, 0.0), (0.5, 0.5)), ((1.0, 0.0), (1.0, 1.0), (0.5, 0.5)), ((1.0, 1.0), (0.0, 1.0), (0.5, 0.5)), ((0.0, 1.0), (0.0, 0.0), (0.5, 0.5))]

Offsets (as position dimensions) of the points for Q-Value triangles. Indexes line up with pacai.core.action.CARDINAL_DIRECTIONS.

TRIANGLE_WIDTH: int = 1

Width of the Q-Value triangle borders.

NON_SERIALIZED_FIELDS: list[str] = ['_mdp_state_values', '_minmax_mdp_state_values', '_policy', '_qvalues', '_minmax_qvalues']
class GameState(pacai.core.gamestate.GameState):
 44class GameState(pacai.core.gamestate.GameState):
 45    """ A game state specific to a standard GridWorld game. """
 46
 47    def __init__(self, **kwargs: typing.Any) -> None:
 48        super().__init__(**kwargs)
 49
 50        self._win: bool = False
 51        """ Keep track if the agent exited the game on a winning state. """
 52
 53        self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {}
 54        """
 55        The MDP state values computed by the agent.
 56        This member will not be serialized.
 57        """
 58
 59        self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0)
 60        """
 61        The min and max MDP state values computed by the agent.
 62        This member will not be serialized.
 63        """
 64
 65        self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {}
 66        """
 67        The policy computed by the agent.
 68        This member will not be serialized.
 69        """
 70
 71        self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {}
 72        """
 73        The Q-values computed by the agent.
 74        This member will not be serialized.
 75        """
 76
 77        self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0)
 78        """
 79        The min and max Q-values computed by the agent.
 80        This member will not be serialized.
 81        """
 82
 83    def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None:
 84        if (AGENT_INDEX not in agent_responses):
 85            return
 86
 87        agent_action = agent_responses[AGENT_INDEX].agent_action
 88        if (agent_action is None):
 89            return
 90
 91        if ('mdp_state_values' in agent_action.other_info):
 92            for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']:
 93                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
 94                self._mdp_state_values[mdp_state] = value
 95
 96                min_value = min(self._mdp_state_values.values())
 97                max_value = max(self._mdp_state_values.values())
 98                self._minmax_mdp_state_values = (min_value, max_value)
 99
100        if ('policy' in agent_action.other_info):
101            for (raw_mdp_state, raw_action) in agent_action.other_info['policy']:
102                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
103                action = pacai.core.action.Action(raw_action)
104                self._policy[mdp_state] = action
105
106        if ('qvalues' in agent_action.other_info):
107            values = []
108
109            for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']:
110                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
111                action = pacai.core.action.Action(raw_action)
112
113                if (mdp_state not in self._qvalues):
114                    self._qvalues[mdp_state] = {}
115
116                self._qvalues[mdp_state][action] = qvalue
117                values.append(qvalue)
118
119            self._minmax_qvalues = (min(values), max(values))
120
121    def game_complete(self) -> list[int]:
122        # If the agent exited on a positive terminal position, they win.
123        if (self._win):
124            return [AGENT_INDEX]
125
126        return []
127
128    def get_legal_actions(self, position: pacai.core.board.Position | None = None) -> list[pacai.core.action.Action]:
129        board = typing.cast(pacai.gridworld.board.Board, self.board)
130
131        if (position is None):
132            position = self.get_agent_position()
133
134        # If we are on a terminal position, we can only exit.
135        if ((position is not None) and board.is_terminal_position(position)):
136            return [pacai.core.mdp.ACTION_EXIT]
137
138        # Otherwise, all cardinal directions (and STOP) are legal actions.
139        return pacai.core.action.CARDINAL_DIRECTIONS + [pacai.core.action.STOP]
140
141    def sprite_lookup(self,
142            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
143            position: pacai.core.board.Position,
144            marker: pacai.core.board.Marker | None = None,
145            action: pacai.core.action.Action | None = None,
146            adjacency: pacai.core.board.AdjacencyString | None = None,
147            animation_key: str | None = None,
148            ) -> PIL.Image.Image:
149        board = typing.cast(pacai.gridworld.board.Board, self.board)
150
151        sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key)
152
153        if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE):
154            # Draw MDP state values and policies.
155            sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position)
156        elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE):
157            # Draw Q-values.
158            sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position)
159        elif ((marker == pacai.gridworld.board.MARKER_TERMINAL)
160                and ((position.row > board._original_height) or (position.col > board._original_width))):
161            # Draw terminal values on the extended q-display.
162            sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position)
163
164        return sprite
165
166    def skip_draw(self,
167            marker: pacai.core.board.Marker,
168            position: pacai.core.board.Position,
169            static: bool = False,
170            ) -> bool:
171        if (static):
172            return False
173
174        board = typing.cast(pacai.gridworld.board.Board, self.board)
175        return (position.row >= board._original_height) or (position.col >= board._original_width)
176
177    def get_static_positions(self) -> list[pacai.core.board.Position]:
178        board = typing.cast(pacai.gridworld.board.Board, self.board)
179
180        positions = []
181        for row in range(board.height):
182            for col in range(board.width):
183                if (row >= board._original_height) or (col >= board._original_width):
184                    positions.append(pacai.core.board.Position(row, col))
185
186        return positions
187
188    def _add_terminal_sprite_info(self,
189            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
190            sprite: PIL.Image.Image,
191            position: pacai.core.board.Position) -> PIL.Image.Image:
192        """ Add coloring to the terminal positions. """
193
194        board = typing.cast(pacai.gridworld.board.Board, self.board)
195
196        if (not board.is_terminal_position(position)):
197            return sprite
198
199        sprite = sprite.copy()
200        canvas = PIL.ImageDraw.Draw(sprite)
201
202        min_value = min(board._terminal_values.values())
203        max_value = max(board._terminal_values.values())
204
205        value = board.get_terminal_value(position)
206        color = self._red_green_gradient(value, min_value, max_value)
207
208        points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)]
209        canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1)
210
211        return sprite
212
213    def _add_mdp_state_value_sprite_info(self,
214            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
215            sprite: PIL.Image.Image,
216            position: pacai.core.board.Position) -> PIL.Image.Image:
217        """ Add the colored q-value triangles to the sprite. """
218
219        board = typing.cast(pacai.gridworld.board.Board, self.board)
220
221        sprite = sprite.copy()
222        canvas = PIL.ImageDraw.Draw(sprite)
223
224        # The offset from the visualization position to the true board position.
225        # The MDP state values are to the right of the true board.
226        base_offset = pacai.core.board.Position(0, -(board._original_width + 1))
227        base_position = position.add(base_offset)
228        mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
229
230        value = self._mdp_state_values.get(mdp_state, 0.0)
231        color = self._red_green_gradient(value, self._minmax_mdp_state_values[0], self._minmax_mdp_state_values[1])
232
233        points = [(1, 1), (sprite_sheet.width - 1, sprite_sheet.height - 1)]
234        canvas.rectangle(points, fill = color, outline = sprite_sheet.text, width = 1)
235
236        return sprite
237
238    def _add_qvalue_sprite_info(self,
239            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
240            sprite: PIL.Image.Image,
241            position: pacai.core.board.Position) -> PIL.Image.Image:
242        """ Add the colored q-value triangles to the sprite. """
243
244        board = typing.cast(pacai.gridworld.board.Board, self.board)
245
246        sprite = sprite.copy()
247        canvas = PIL.ImageDraw.Draw(sprite)
248
249        # The offset from the visualization position to the true board position.
250        # The Q-values are below the true board.
251        base_offset = pacai.core.board.Position(-(board._original_height + 1), 0)
252        base_position = position.add(base_offset)
253        mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
254
255        for (direction_index, point_offsets) in enumerate(QVALUE_TRIANGLE_POINT_OFFSETS):
256            points = []
257
258            for point_offset in point_offsets:
259                # Offset the outer points of the triangle towards the inside of the triangle to avoid border overlaps.
260                origin = [0, 0]
261                for (i, offset) in enumerate(point_offset):
262                    if (math.isclose(offset, 0.0)):
263                        origin[i] = TRIANGLE_WIDTH
264                    elif (math.isclose(offset, 1.0)):
265                        origin[i] = -TRIANGLE_WIDTH
266
267                point = (
268                    (origin[0] + (sprite_sheet.width * point_offset[0])),
269                    (origin[1] + (sprite_sheet.height * point_offset[1])),
270                )
271                points.append(tuple(point))
272
273            qvalue = self._qvalues.get(mdp_state, {}).get(pacai.core.action.CARDINAL_DIRECTIONS[direction_index], 0.0)
274            color = self._red_green_gradient(qvalue, self._minmax_qvalues[0], self._minmax_qvalues[1])
275            canvas.polygon(points, fill = color, outline = sprite_sheet.text, width = 1)
276
277        return sprite
278
279    def _red_green_gradient(self, value: float,
280            min_value: float, max_value: float, divider: float = 0.0,
281            blue_intensity: float = 0.00) -> tuple[int, int, int]:
282        """
283        Get a color (RGB) between red (min) and green (max) based on the given value.
284        Values under the divider will always be red, and values over will always be green.
285        """
286
287        if ((min_value > divider) or (divider > max_value)):
288            raise ValueError(("Gradient values are not in the correct order."
289                    + f"Found: min = {min_value}, divider = {divider}, max = {max_value}."))
290
291        red_intensity = 0.0
292        green_intensity = 0.0
293
294        red_mass = max(0.01, divider - min_value)
295        green_mass = max(0.01, max_value - divider)
296
297        value = min(max_value, max(min_value, value))
298
299        if (math.isclose(value, divider)):
300            blue_intensity = 1.0
301            red_intensity = 0.25
302            green_intensity = 0.25
303        elif (value < divider):
304            red_intensity = 0.25 + (0.75 * (divider - value) / red_mass)
305        else:
306            green_intensity = 0.25 + (0.75 * (value - divider) / green_mass)
307
308        return (int(255 * red_intensity), int(255 * green_intensity), int(255 * blue_intensity))
309
310    def process_turn(self,
311            action: pacai.core.action.Action,
312            rng: random.Random | None = None,
313            mdp: pacai.gridworld.mdp.GridWorldMDP | None = None,
314            **kwargs: typing.Any) -> None:
315        if (rng is None):
316            logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().")
317            rng = random.Random(4)
318
319        if (mdp is None):
320            raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().")
321
322        board = typing.cast(pacai.gridworld.board.Board, self.board)
323
324        # Get the possible transitions from the MDP.
325        transitions = mdp.get_transitions(mdp.get_starting_state(), action)
326
327        # If there is only an exit transition, exit.
328        if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)):
329            logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT)
330            self.game_over = True
331            return
332
333        # Choose a transition.
334        transition = self._choose_transition(transitions, rng)
335
336        # Apply the transition.
337
338        self.score += transition.reward
339
340        old_position = self.get_agent_position(AGENT_INDEX)
341        if (old_position is None):
342            raise ValueError("GridWorld agent was removed from board.")
343
344        new_position = transition.state.position
345
346        if (old_position != new_position):
347            board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position)
348            board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position)
349
350        # Check if we are going to "win".
351        # The reward for a terminal state is awarded on the action before EXIT.
352        if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)):
353            self._win = True
354
355        logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward)
356
357    def _choose_transition(self,
358            transitions: list[pacai.core.mdp.Transition],
359            rng: random.Random) -> pacai.core.mdp.Transition:
360        probability_sum = 0.0
361        point = rng.random()
362
363        for transition in transitions:
364            probability_sum += transition.probability
365            if (probability_sum > 1.0):
366                raise ValueError(f"Transition probabilities is over 1.0, found at least {probability_sum}.")
367
368            if (point < probability_sum):
369                return transition
370
371        raise ValueError(f"Transition probabilities is less than 1.0, found {probability_sum}.")
372
373    def get_static_text(self) -> list[pacai.core.font.BoardText]:
374        board = typing.cast(pacai.gridworld.board.Board, self.board)
375
376        texts = []
377
378        # Add on terminal values.
379        for (position, value) in board._terminal_values.items():
380            text_color = None
381            if ((position.row > board._original_height) or (position.col > board._original_width)):
382                text_color = (0, 0, 0)
383
384            texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color))
385
386        # If we are using the extended display, fill in all the information.
387        if (board.display_qvalues()):
388            texts += self._get_qdisplay_static_text()
389
390        return texts
391
392    def _get_qdisplay_static_text(self) -> list[pacai.core.font.BoardText]:
393        texts = []
394
395        board = typing.cast(pacai.gridworld.board.Board, self.board)
396
397        # Add labels on the separator.
398        row = board._original_height
399        texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, 1), ' ↓ Q-Values'))
400        texts.append(pacai.core.font.BoardText(pacai.core.board.Position(row, board._original_width + 3), '  ↑ Values & Policy'))
401
402        # Add text for MDP state values and policies.
403        texts += self._get_qdisplay_static_text_mdp_state_values()
404
405        # Add text for Q-values.
406        texts += self._get_qdisplay_static_text_qvalues()
407
408        return texts
409
410    def _get_qdisplay_static_text_mdp_state_values(self) -> list[pacai.core.font.BoardText]:
411        texts = []
412
413        board = typing.cast(pacai.gridworld.board.Board, self.board)
414
415        # The offset from the visualization position to the true board position.
416        # The MDP state values are to the right of the true board.
417        base_offset = pacai.core.board.Position(0, -(board._original_width + 1))
418
419        for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_VALUE):
420            base_position = position.add(base_offset)
421            mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
422
423            value = self._mdp_state_values.get(mdp_state, None)
424            policy_action = self._policy.get(mdp_state, pacai.core.action.STOP)
425
426            value_text = '?'
427            if (value is not None):
428                value_text = f"{value:0.2f}"
429
430            policy_text = pacai.core.action.CARDINAL_DIRECTION_ARROWS.get(policy_action, '?')
431
432            text = f"{value_text}\n{policy_text}"
433
434            texts.append(pacai.core.font.BoardText(position, text,
435                    size = pacai.core.font.FontSize.SMALL,
436                    color = (0, 0, 0)))
437
438        return texts
439
440    def _get_qdisplay_static_text_qvalues(self) -> list[pacai.core.font.BoardText]:
441        texts = []
442
443        board = typing.cast(pacai.gridworld.board.Board, self.board)
444
445        # The offset from the visualization position to the true board position.
446        # The Q-values are below the true board.
447        base_offset = pacai.core.board.Position(-(board._original_height + 1), 0)
448
449        for position in self.board.get_marker_positions(pacai.gridworld.board.MARKER_DISPLAY_QVALUE):
450            base_position = position.add(base_offset)
451            mdp_state = pacai.core.mdp.MDPStatePosition(position = base_position)
452
453            # [(vertical alignment, horizontal alignment), ...]
454            alignments = [
455                (pacai.core.font.TextVerticalAlign.TOP, pacai.core.font.TextHorizontalAlign.CENTER),
456                (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.RIGHT),
457                (pacai.core.font.TextVerticalAlign.BOTTOM, pacai.core.font.TextHorizontalAlign.CENTER),
458                (pacai.core.font.TextVerticalAlign.MIDDLE, pacai.core.font.TextHorizontalAlign.LEFT),
459            ]
460
461            for (i, alignment) in enumerate(alignments):
462                action = pacai.core.action.CARDINAL_DIRECTIONS[i]
463                qvalue = self._qvalues.get(mdp_state, {}).get(action, None)
464
465                text = '?'
466                if (qvalue is not None):
467                    text = f"{qvalue:0.2f}"
468
469                vertical_align, horizontal_align = alignment
470                texts.append(pacai.core.font.BoardText(
471                        position, text,
472                        size = pacai.core.font.FontSize.TINY,
473                        vertical_align = vertical_align,
474                        horizontal_align = horizontal_align,
475                        color = (0, 0, 0)))
476
477        return texts
478
479    def to_dict(self) -> dict[str, typing.Any]:
480        data = super().to_dict()
481        data['_win'] = self._win
482
483        for key in NON_SERIALIZED_FIELDS:
484            if (key in data):
485                del data[key]
486
487        return data
488
489    @classmethod
490    def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any:
491        game_state = super().from_dict(data)
492        game_state._win = data['_win']
493        return game_state

A game state specific to a standard GridWorld game.

GameState(**kwargs: Any)
47    def __init__(self, **kwargs: typing.Any) -> None:
48        super().__init__(**kwargs)
49
50        self._win: bool = False
51        """ Keep track if the agent exited the game on a winning state. """
52
53        self._mdp_state_values: dict[pacai.core.mdp.MDPStatePosition, float] = {}
54        """
55        The MDP state values computed by the agent.
56        This member will not be serialized.
57        """
58
59        self._minmax_mdp_state_values: tuple[float, float] = (-1.0, 1.0)
60        """
61        The min and max MDP state values computed by the agent.
62        This member will not be serialized.
63        """
64
65        self._policy: dict[pacai.core.mdp.MDPStatePosition, pacai.core.action.Action] = {}
66        """
67        The policy computed by the agent.
68        This member will not be serialized.
69        """
70
71        self._qvalues: dict[pacai.core.mdp.MDPStatePosition, dict[pacai.core.action.Action, float]] = {}
72        """
73        The Q-values computed by the agent.
74        This member will not be serialized.
75        """
76
77        self._minmax_qvalues: tuple[float, float] = (-1.0, 1.0)
78        """
79        The min and max Q-values computed by the agent.
80        This member will not be serialized.
81        """
def agents_game_start( self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None:
 83    def agents_game_start(self, agent_responses: dict[int, pacai.core.agentaction.AgentActionRecord]) -> None:
 84        if (AGENT_INDEX not in agent_responses):
 85            return
 86
 87        agent_action = agent_responses[AGENT_INDEX].agent_action
 88        if (agent_action is None):
 89            return
 90
 91        if ('mdp_state_values' in agent_action.other_info):
 92            for (raw_mdp_state, value) in agent_action.other_info['mdp_state_values']:
 93                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
 94                self._mdp_state_values[mdp_state] = value
 95
 96                min_value = min(self._mdp_state_values.values())
 97                max_value = max(self._mdp_state_values.values())
 98                self._minmax_mdp_state_values = (min_value, max_value)
 99
100        if ('policy' in agent_action.other_info):
101            for (raw_mdp_state, raw_action) in agent_action.other_info['policy']:
102                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
103                action = pacai.core.action.Action(raw_action)
104                self._policy[mdp_state] = action
105
106        if ('qvalues' in agent_action.other_info):
107            values = []
108
109            for (raw_mdp_state, raw_action, qvalue) in agent_action.other_info['qvalues']:
110                mdp_state = pacai.core.mdp.MDPStatePosition.from_dict(raw_mdp_state)
111                action = pacai.core.action.Action(raw_action)
112
113                if (mdp_state not in self._qvalues):
114                    self._qvalues[mdp_state] = {}
115
116                self._qvalues[mdp_state][action] = qvalue
117                values.append(qvalue)
118
119            self._minmax_qvalues = (min(values), max(values))

Indicate that agents have been started.

def game_complete(self) -> list[int]:
121    def game_complete(self) -> list[int]:
122        # If the agent exited on a positive terminal position, they win.
123        if (self._win):
124            return [AGENT_INDEX]
125
126        return []

Indicate that the game has ended. The state should take any final actions and return the indexes of the winning agents (if any).

def sprite_lookup( self, sprite_sheet: pacai.core.spritesheet.SpriteSheet, position: pacai.core.board.Position, marker: pacai.core.board.Marker | None = None, action: pacai.core.action.Action | None = None, adjacency: pacai.core.board.AdjacencyString | None = None, animation_key: str | None = None) -> PIL.Image.Image:
141    def sprite_lookup(self,
142            sprite_sheet: pacai.core.spritesheet.SpriteSheet,
143            position: pacai.core.board.Position,
144            marker: pacai.core.board.Marker | None = None,
145            action: pacai.core.action.Action | None = None,
146            adjacency: pacai.core.board.AdjacencyString | None = None,
147            animation_key: str | None = None,
148            ) -> PIL.Image.Image:
149        board = typing.cast(pacai.gridworld.board.Board, self.board)
150
151        sprite = super().sprite_lookup(sprite_sheet, position, marker = marker, action = action, adjacency = adjacency, animation_key = animation_key)
152
153        if (marker == pacai.gridworld.board.MARKER_DISPLAY_VALUE):
154            # Draw MDP state values and policies.
155            sprite = self._add_mdp_state_value_sprite_info(sprite_sheet, sprite, position)
156        elif (marker == pacai.gridworld.board.MARKER_DISPLAY_QVALUE):
157            # Draw Q-values.
158            sprite = self._add_qvalue_sprite_info(sprite_sheet, sprite, position)
159        elif ((marker == pacai.gridworld.board.MARKER_TERMINAL)
160                and ((position.row > board._original_height) or (position.col > board._original_width))):
161            # Draw terminal values on the extended q-display.
162            sprite = self._add_terminal_sprite_info(sprite_sheet, sprite, position)
163
164        return sprite

Lookup the proper sprite for a situation. By default this just calls into the sprite sheet, but children may override for more expressive functionality.

def skip_draw( self, marker: pacai.core.board.Marker, position: pacai.core.board.Position, static: bool = False) -> bool:
166    def skip_draw(self,
167            marker: pacai.core.board.Marker,
168            position: pacai.core.board.Position,
169            static: bool = False,
170            ) -> bool:
171        if (static):
172            return False
173
174        board = typing.cast(pacai.gridworld.board.Board, self.board)
175        return (position.row >= board._original_height) or (position.col >= board._original_width)

Return true if this marker/position combination should not be drawn on the board.

def get_static_positions(self) -> list[pacai.core.board.Position]:
177    def get_static_positions(self) -> list[pacai.core.board.Position]:
178        board = typing.cast(pacai.gridworld.board.Board, self.board)
179
180        positions = []
181        for row in range(board.height):
182            for col in range(board.width):
183                if (row >= board._original_height) or (col >= board._original_width):
184                    positions.append(pacai.core.board.Position(row, col))
185
186        return positions

Get a list of positions to draw on the board statically.

def process_turn( self, action: pacai.core.action.Action, rng: random.Random | None = None, mdp: pacai.gridworld.mdp.GridWorldMDP | None = None, **kwargs: Any) -> None:
310    def process_turn(self,
311            action: pacai.core.action.Action,
312            rng: random.Random | None = None,
313            mdp: pacai.gridworld.mdp.GridWorldMDP | None = None,
314            **kwargs: typing.Any) -> None:
315        if (rng is None):
316            logging.warning("No RNG passed to pacai.gridworld.gamestate.GameState.process_turn().")
317            rng = random.Random(4)
318
319        if (mdp is None):
320            raise ValueError("No MDP passed to pacai.gridworld.gamestate.GameState.process_turn().")
321
322        board = typing.cast(pacai.gridworld.board.Board, self.board)
323
324        # Get the possible transitions from the MDP.
325        transitions = mdp.get_transitions(mdp.get_starting_state(), action)
326
327        # If there is only an exit transition, exit.
328        if ((len(transitions) == 1) and (transitions[0].action == pacai.core.mdp.ACTION_EXIT)):
329            logging.debug("Got the %s action, game is over.", pacai.core.mdp.ACTION_EXIT)
330            self.game_over = True
331            return
332
333        # Choose a transition.
334        transition = self._choose_transition(transitions, rng)
335
336        # Apply the transition.
337
338        self.score += transition.reward
339
340        old_position = self.get_agent_position(AGENT_INDEX)
341        if (old_position is None):
342            raise ValueError("GridWorld agent was removed from board.")
343
344        new_position = transition.state.position
345
346        if (old_position != new_position):
347            board.remove_marker(pacai.gridworld.board.AGENT_MARKER, old_position)
348            board.place_marker(pacai.gridworld.board.AGENT_MARKER, new_position)
349
350        # Check if we are going to "win".
351        # The reward for a terminal state is awarded on the action before EXIT.
352        if (board.is_terminal_position(new_position) and (board.get_terminal_value(new_position) > 0)):
353            self._win = True
354
355        logging.debug("Requested Action: '%s', Actual Action: '%s', Reward: %0.2f.", action, transition.action, transition.reward)

Process the current agent's turn with the given action. This may modify the current state. To get a copy of a potential successor state, use generate_successor().

def get_static_text(self) -> list[pacai.core.font.BoardText]:
373    def get_static_text(self) -> list[pacai.core.font.BoardText]:
374        board = typing.cast(pacai.gridworld.board.Board, self.board)
375
376        texts = []
377
378        # Add on terminal values.
379        for (position, value) in board._terminal_values.items():
380            text_color = None
381            if ((position.row > board._original_height) or (position.col > board._original_width)):
382                text_color = (0, 0, 0)
383
384            texts.append(pacai.core.font.BoardText(position, str(value), size = pacai.core.font.FontSize.SMALL, color = text_color))
385
386        # If we are using the extended display, fill in all the information.
387        if (board.display_qvalues()):
388            texts += self._get_qdisplay_static_text()
389
390        return texts

Get any static text to display on board positions.

def to_dict(self) -> dict[str, typing.Any]:
479    def to_dict(self) -> dict[str, typing.Any]:
480        data = super().to_dict()
481        data['_win'] = self._win
482
483        for key in NON_SERIALIZED_FIELDS:
484            if (key in data):
485                del data[key]
486
487        return data

Return a dict that can be used to represent this object. If the dict is passed to from_dict(), an identical object should be reconstructed.

@classmethod
def from_dict(cls, data: dict[str, typing.Any]) -> Any:
489    @classmethod
490    def from_dict(cls, data: dict[str, typing.Any]) -> typing.Any:
491        game_state = super().from_dict(data)
492        game_state._win = data['_win']
493        return game_state

Return an instance of this subclass created using the given dict. If the dict came from to_dict(), the returned object should be identical to the original.