Source code for multigrid.core.agent

from __future__ import annotations

import numpy as np

from gymnasium import spaces
from numpy.typing import ArrayLike, NDArray as ndarray

from .actions import Action
from .constants import Color, Direction, Type
from .mission import Mission, MissionSpace
from .world_object import WorldObj

from ..utils.misc import front_pos, PropertyAlias
from ..utils.rendering import (
    fill_coords,
    point_in_triangle,
    rotate_fn,
)



[docs]class Agent: """ Class representing an agent in the environment. :Observation Space: Observations are dictionaries with the following entries: * image : ndarray[int] of shape (view_size, view_size, :attr:`.WorldObj.dim`) Encoding of the agent's view of the environment * direction : int Agent's direction (0: right, 1: down, 2: left, 3: up) * mission : Mission Task string corresponding to the current environment configuration :Action Space: Actions are discrete integers, as enumerated in :class:`.Action`. Attributes ---------- index : int Index of the agent in the environment state : AgentState State of the agent mission : Mission Current mission string for the agent action_space : gym.spaces.Discrete Action space for the agent observation_space : gym.spaces.Dict Observation space for the agent """
[docs] def __init__( self, index: int, mission_space: MissionSpace = MissionSpace.from_string('maximize reward'), view_size: int = 7, see_through_walls: bool = False): """ Parameters ---------- index : int Index of the agent in the environment mission_space : MissionSpace The mission space for the agent view_size : int The size of the agent's view (must be odd) see_through_walls : bool Whether the agent can see through walls """ self.index: int = index self.state: AgentState = AgentState() self.mission: Mission = None # Number of cells (width and height) in the agent view assert view_size % 2 == 1 assert view_size >= 3 self.view_size = view_size self.see_through_walls = see_through_walls # Observations are dictionaries containing an # encoding of the grid and a textual 'mission' string self.observation_space = spaces.Dict({ 'image': spaces.Box( low=0, high=255, shape=(view_size, view_size, WorldObj.dim), dtype=int, ), 'direction': spaces.Discrete(len(Direction)), 'mission': mission_space, }) # Actions are discrete integer values self.action_space = spaces.Discrete(len(Action))
# AgentState Properties color = PropertyAlias( 'state', 'color', doc='Alias for :attr:`AgentState.color`.') dir = PropertyAlias( 'state', 'dir', doc='Alias for :attr:`AgentState.dir`.') pos = PropertyAlias( 'state', 'pos', doc='Alias for :attr:`AgentState.pos`.') terminated = PropertyAlias( 'state', 'terminated', doc='Alias for :attr:`AgentState.terminated`.') carrying = PropertyAlias( 'state', 'carrying', doc='Alias for :attr:`AgentState.carrying`.') @property def front_pos(self) -> tuple[int, int]: """ Get the position of the cell that is directly in front of the agent. """ agent_dir = self.state._view[AgentState.DIR] agent_pos = self.state._view[AgentState.POS] return front_pos(*agent_pos, agent_dir)
[docs] def reset(self, mission: Mission = Mission('maximize reward')): """ Reset the agent to an initial state. Parameters ---------- mission : Mission Mission string to use for the new episode """ self.mission = mission self.state.pos = (-1, -1) self.state.dir = -1 self.state.terminated = False self.state.carrying = None
[docs] def encode(self) -> tuple[int, int, int]: """ Encode a description of this agent as a 3-tuple of integers. Returns ------- type_idx : int The index of the agent type color_idx : int The index of the agent color agent_dir : int The direction of the agent (0: right, 1: down, 2: left, 3: up) """ return (Type.agent.to_index(), self.state.color.to_index(), self.state.dir)
[docs] def render(self, img: ndarray[np.uint8]): """ Draw the agent. Parameters ---------- img : ndarray[int] of shape (width, height, 3) RGB image array to render agent on """ tri_fn = point_in_triangle( (0.12, 0.19), (0.87, 0.50), (0.12, 0.81), ) # Rotate the agent based on its direction tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * np.pi * self.state.dir) fill_coords(img, tri_fn, self.state.color.rgb())
[docs]class AgentState(np.ndarray): """ State for an :class:`.Agent` object. ``AgentState`` objects also support vectorized operations, in which case the ``AgentState`` object represents the states of multiple agents. Attributes ---------- color : Color or ndarray[str] Agent color dir : Direction or ndarray[int] Agent direction (0: right, 1: down, 2: left, 3: up) pos : tuple[int, int] or ndarray[int] Agent (x, y) position terminated : bool or ndarray[bool] Whether the agent has terminated carrying : WorldObj or None or ndarray[object] Object the agent is carrying Examples -------- Create a vectorized agent state for 3 agents: >>> agent_state = AgentState(3) >>> agent_state AgentState(3) Access and set state attributes for one agent at a time: >>> a = agent_state[0] >>> a AgentState() >>> a.color 'red' >>> a.color = 'yellow' The underlying vectorized state is automatically updated as well: >>> agent_state.color array(['yellow', 'green', 'blue']) Access and set state attributes all at once: >>> agent_state.dir array([-1, -1, -1]) >>> agent_state.dir = np.random.randint(4, size=(len(agent_state))) >>> agent_state.dir array([2, 3, 0]) >>> a.dir 2 """ # State vector indices TYPE = 0 COLOR = 1 DIR = 2 ENCODING = slice(0, 3) POS = slice(3, 5) TERMINATED = 5 CARRYING = slice(6, 6 + WorldObj.dim) # State vector dimension dim = 6 + WorldObj.dim
[docs] def __new__(cls, *dims: int): """ Parameters ---------- dims : int, optional Shape of vectorized agent state """ obj = np.zeros(dims + (cls.dim,), dtype=int).view(cls) # Set default values obj[..., AgentState.TYPE] = Type.agent obj[..., AgentState.COLOR].flat = Color.cycle(np.prod(dims)) obj[..., AgentState.DIR] = -1 obj[..., AgentState.POS] = (-1, -1) # Other attributes obj._carried_obj = np.empty(dims, dtype=object) # object references obj._terminated = np.zeros(dims, dtype=bool) # cache for faster access obj._view = obj.view(np.ndarray) # view of the underlying array (faster indexing) return obj
def __repr__(self): shape = str(self.shape[:-1]).replace(",)", ")") return f'{self.__class__.__name__}{shape}' def __getitem__(self, idx): out = super().__getitem__(idx) if out.shape and out.shape[-1] == self.dim: out._view = self._view[idx, ...] out._carried_obj = self._carried_obj[idx, ...] # set carried object reference out._terminated = self._terminated[idx, ...] # set terminated cache return out @property def color(self) -> Color | ndarray[np.str]: """ Return the agent color. """ return Color.from_index(self._view[..., AgentState.COLOR]) @color.setter def color(self, value: str | ArrayLike[str]): """ Set the agent color. """ self[..., AgentState.COLOR] = np.vectorize(lambda c: Color(c).to_index())(value) @property def dir(self) -> Direction | ndarray[np.int]: """ Return the agent direction. """ out = self._view[..., AgentState.DIR] return Direction(out.item()) if out.ndim == 0 else out @dir.setter def dir(self, value: int | ArrayLike[int]): """ Set the agent direction. """ self[..., AgentState.DIR] = value @property def pos(self) -> tuple[int, int] | ndarray[np.int]: """ Return the agent's (x, y) position. """ out = self._view[..., AgentState.POS] return tuple(out) if out.ndim == 1 else out @pos.setter def pos(self, value: ArrayLike[int] | ArrayLike[ArrayLike[int]]): """ Set the agent's (x, y) position. """ self[..., AgentState.POS] = value @property def terminated(self) -> bool | ndarray[np.bool]: """ Return whether the agent has terminated. """ out = self._terminated return out.item() if out.ndim == 0 else out @terminated.setter def terminated(self, value: bool | ArrayLike[bool]): """ Set whether the agent has terminated. """ self[..., AgentState.TERMINATED] = value self._terminated[...] = value @property def carrying(self) -> WorldObj | None | ndarray[np.object]: """ Return the object the agent is carrying. """ out = self._carried_obj return out.item() if out.ndim == 0 else out @carrying.setter def carrying(self, obj: WorldObj | None | ArrayLike[object]): """ Set the object the agent is carrying. """ self[..., AgentState.CARRYING] = WorldObj.empty() if obj is None else obj if isinstance(obj, (WorldObj, type(None))): self._carried_obj[...].fill(obj) else: self._carried_obj[...] = obj