Source code for neugym.environment.gridworld

"""Base class for gridworld environment."""

import copy
import warnings

import networkx as nx
import numpy as np

import neugym as ng
from ._agent import _Agent
from ._object import _Object

__all__ = [
    "GridWorld"
]


[docs] class GridWorld: r"""Base class for gridworld environment. ``Gridworld`` environment consists of a ``world``, ``objects``, and one ``agent``. The world consists of multiple connected rectangle areas and each area is represented by a two-dimensional gird graph, which has each node connected to its four nearest neighbors. Each node represents a state in the world, and has an attribute ``altitude``. When the agent moves from one state toward another state, it will get a reward generated by the altitude change (if there is), i.e. .. math:: R_{move} = A_s - A_{s + 1} where $R_{move}$ is the movement reward and $A$ represents the altitude of current state $s$ and next state $s + 1$. At each position(state), the agent can choose from 5 ``actions`` to move towards **UP**, **DOWN**, **LEFT**, **RIGHT**, and **STAY** in the same state. When the performed movement would make the agent get out of the world, the agent would be forced to stay in the same state. Objects where the agent can get reward are placed at different states and each state can only obtain one object. Each object has its own adjustable probability (``prob``) of getting a ``reward`` when the agent reaches the state with this object, if the agent fails to get a reward, it will get a punishment (``punish``). .. math:: R_{object} = \left \{ \begin{aligned} & reward, P=p \\ & punish, P=1-p \end{aligned} \right. Under this situation, the total reward for this step will be the movement reward adding the object reward. .. math:: R_{total} = R_{move} + R_{object} So long as the agent gets to an object (no matter it was reward or punish that it got), this trial is finished and then the agent will be sent back to the start state of each trial. Parameters ---------- origin_shape : tuple of ints (optional, default: None) Shape of the world origin. If not provided, the origin will be initialized to be only one state (0, 0, 0), otherwise it will be a rectangular area of shape ``origin_shape``. Examples -------- Initialize a gridworld environment with only an origin state. >>> W = GridWorld() W can be grown in several aspects. **Areas:** Add one area of shape (2, 2). >>> W.add_area((2, 2)) Remove areas. >>> W.remove_area(1) Set area altitude. >>> W.add_area((3, 3)) >>> W.set_altitude(1, altitude_mat=np.random.randn(3, 3)) **Paths:** Add inter-area paths. >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.add_area((2, 2)) >>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1)) Remove paths. >>> W.remove_path((1, 1, 1), (2, 1, 0)) **Objects:** Add objects. >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_object((0, 0, 0), reward=1, prob=0.7) >>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10) Remove objects. >>> W.remove_object((1, 0, 0)) Update object attributes. >>> W.update_object((0, 0, 0), reward=10) >>> W.update_object((0, 0, 0), reward=1, prob=0.8) **Agent:** >>> W = GridWorld() >>> W.init_agent() One can also manually set the agent initial state. >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.init_agent(init_coord=(1, 1, 1)) When the agent is initialized, the agent can move in the world and get rewards. >>> next_state, reward, done = W.step(action=(0, 1)) **Reset:** To reset the environment, first set a reset checkpoint. >>> W.set_reset_checkpoint() Then the environment can be reset if needed. >>> W.reset() """
[docs] def __init__(self, origin_shape=None): """Initialize a gridworld environment. Parameters ---------- origin_shape : tuple of ints (optional, default: None) Shape of the world origin. If not provided, the origin will be initialized to be only one state ``(0, 0, 0)``, otherwise it will be a rectangular area of shape ``origin_shape``. Examples -------- Initialize a gridworld environment by default. >>> W = GridWorld() Manually set origin shape. >>> W = GridWorld((3, 4)) """ self._world = nx.Graph() self._time = 0 self._num_area = 0 self._area_alias = {} self._path_alias = {} self._objects = [] self._actions = ((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)) # Add origin. if origin_shape is None: origin_shape = (1, 1) self._world.add_node((0, 0, 0)) else: m, n = origin_shape origin = nx.grid_2d_graph(m, n) mapping = {} for coord in origin.nodes: mapping[coord] = tuple([0] + list(coord)) origin = nx.relabel_nodes(origin, mapping) self._world.update(origin) origin_altitude_mat = np.zeros(origin_shape) self.set_area_name(0, 'origin') self.set_altitude(0, origin_altitude_mat) nx.set_node_attributes(self._world, False, 'blocked') # Agent. self._agent = None # Reset state. self._has_reset_checkpoint = False self._reset_state = { "world": None, "time": None, "num_area": None, "area_alias": None, "path_alias": None, "objects": None, "agent": None }
[docs] def add_area(self, shape, name=None): """Add a new area to the world. .. note:: Using this function people can add a new dangling area to the world, i.e. for now, it cannot be accessed from any of the other areas. To make this new area accessible, use ``GridWorld.add_path()`` function to build inter-area bridges. Parameters ---------- shape : tuple of ints Shape of the new area. name : str (optional, default: None) Alias name of the area to be added. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 3)) >>> W.add_area((2, 2), name="Right") """ if name in self._area_alias.keys(): msg = "Alias name already exists, try another name" raise RuntimeError(msg) # Create new area. m, n = shape new_area = nx.grid_2d_graph(m, n) mapping = {} for coord in new_area.nodes: mapping[coord] = tuple([self._num_area + 1] + list(coord)) new_area = nx.relabel_nodes(new_area, mapping) nx.set_node_attributes(new_area, False, 'blocked') self._world.update(new_area) self._num_area += 1 if name is not None: self._area_alias[name] = self._num_area # Add area altitude. altitude_mat = np.zeros(shape) self.set_altitude(self._num_area, altitude_mat)
[docs] def remove_area(self, area): """Remove an area from the world. Index for all other areas left will be automatically reset. (Minus one if their original index are larger than the index of removed area.) .. note:: - All inter-area path and objects related to the to be removed area will be automatically removed. - Origin of the world is not allowed to be removed. Parameters ---------- area : int or str Index or name of the area to be removed. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.remove_area(1) >>> W.add_area((2, 2), name="example") >>> W.remove_area("example") """ new_world = copy.deepcopy(self._world) if type(area) == str: area_idx = self.get_area_index(area) elif type(area) == int: area_idx = area else: msg = "int for area index or str for area name " \ "expected, got '{}'".format(type(area)) raise TypeError(msg) if area_idx == 0: raise ng.NeuGymPermissionError("Not allowed to remove origin area") # Remove area node_list = list(new_world.nodes) for node in node_list: if node[0] == area_idx: new_world.remove_node(node) elif node[0] > area_idx: new_label = tuple([node[0] - 1] + list(node[1:])) new_world = nx.relabel_nodes(new_world, {node: new_label}) self._world = new_world self._num_area -= 1 # Remove invalid area alias. new_area_alias = {} for key, value in self._area_alias.items(): if value == area_idx: continue elif value > area_idx: new_area_alias[key] = value - 1 else: new_area_alias[key] = value self._area_alias = new_area_alias # Remove invalid path alias. new_path_alias = {} for key, value in self._path_alias.items(): if key[0] == area_idx: continue elif key[0] > area_idx: new_key = tuple([key[0] - 1] + list(key[1:])) else: new_key = key if value[0] == area_idx: continue elif value[0] > area_idx: new_value = tuple([value[0] - 1] + list(value[1:])) else: new_value = value new_path_alias[new_key] = new_value self._path_alias = new_path_alias # Remove objects in the area to be removed. new_objects = [] for i, obj in enumerate(self._objects): if obj.coord[0] < area_idx: new_objects.append(obj) elif obj.coord[0] == area_idx: continue else: obj.coord = tuple([obj.coord[0] - 1] + list(obj.coord[1:])) new_objects.append(obj) self._objects = new_objects
[docs] def add_path(self, coord_from, coord_to, register_action=None): """Add a new inter-area connection. .. note:: - Creating a path within the same area is not allowed. - When an inter-area path from ``coord_from`` to ``coord_to`` with action ``register_action`` is built, a reverse path is also build at the same time, i.e. the agent can also move from ``coord_to`` to ``coord_from`` with the reverse action of ``register_action`` (e.g. the reverse action of **UP(1, 0)** is **DOWN(-1, 0)**). Parameters ---------- coord_from : tuple of ints Coordinate of the path start state. coord_to : tuple of ints Coordinate of the path end state. register_action : tuple of ints (optional, default: None) Register an action to transport the agent from ``coord_from`` to ``coord_to``. If None, possible action to register will be searched in the order of [**UP(1, 0)**, **DOWN(-1, 0)**, **RIGHT(0, 1)**, **LEFT(0, -1)**], and the first possible action will be registered. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.add_area((3, 3)) >>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1)) """ if coord_from[0] == coord_to[0]: msg = "Not allowed to add path within an area" raise ng.NeuGymPermissionError(msg) if len(coord_from) != 3: msg = "Tuple of length 3 expected for argument " \ "'coord_from', got {}".format(len(coord_from)) raise ValueError(msg) if not self._world.has_node(coord_from): msg = "'coord_from' coordinate {} out of world".format(coord_from) raise ValueError(msg) if self._world.degree(coord_from) == 4: msg = "Maximum number of connections (4) for position " \ "{} reached, not allowed to access from it".format(coord_from) raise ng.NeuGymConnectivityError(msg) if len(coord_to) != 3: msg = "Tuple of length 3 expected for argument " \ "'coord_to', got {}".format(len(coord_to)) raise ValueError(msg) if not self._world.has_node(coord_to): msg = "'coord_to' coordinate {} out of world".format(coord_to) raise ValueError(msg) elif self._world.degree(coord_to) == 4: msg = "Maximum number of connections (4) for position " \ "{} reached, not allowed to access to it".format(coord_to) raise ng.NeuGymConnectivityError(msg) if (coord_from, coord_to) in self._world.edges: msg = "Path already exists between {} and {}".format(coord_from, coord_to) raise ng.NeuGymOverwriteError(msg) # Search for free actions that can be registered. free_actions = [] for action in self._actions: dx, dy = action alias_to = tuple([coord_from[0]] + [coord_from[1] + dx] + [coord_from[2] + dy]) alias_from = tuple([coord_to[0]] + [coord_to[1] - dx] + [coord_to[2] - dy]) if self._world.has_node(alias_to) or self._world.has_node(alias_from) or \ alias_to in self._path_alias.keys() or alias_from in self._path_alias.keys(): continue free_actions.append(action) if len(free_actions) == 0: msg = "Unable to connect two areas from 'coord_from' {} to 'coord_to' {}, " \ "all allowed actions allocated".format(coord_from, coord_to) raise ng.NeuGymConnectivityError(msg) if register_action is not None: if register_action not in self._actions: msg = "Illegal 'register_action' {}, " \ "expected one of {}".format(register_action, self._actions) raise ValueError(msg) if register_action not in free_actions: msg = "Unable to register action 'register_action' {}, " \ "already allocated".format(register_action) raise ng.NeuGymConnectivityError(msg) dx, dy = register_action else: dx, dy = free_actions[0] # Register action. self._path_alias[tuple([coord_from[0]] + [coord_from[1] + dx] + [coord_from[2] + dy])] = coord_to self._path_alias[tuple([coord_to[0]] + [coord_to[1] - dx] + [coord_to[2] - dy])] = coord_from self._world.add_edge(coord_from, coord_to)
[docs] def remove_path(self, coord_from, coord_to): """Remove one inter-area connection from the world. Parameters ---------- coord_from : tuple of ints Coordinate of the path start state. coord_to : tuple of ints Coordinate of the path end state. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_path((0, 0, 0), (1, 0, 0), register_action=(0, 1)) >>> W.remove_path((0, 0, 0), (1, 0, 0)) """ if coord_from[0] == coord_to[0]: msg = "Not allowed to remove path within an area, " \ "try using `GridWorld.block()` instead" raise ng.NeuGymPermissionError(msg) if len(coord_from) != 3 or len(coord_to) != 3: msg = "Tuple of length 3 expected for position coordinate" raise ValueError(msg) # Find alias to be removed. remove_list = [] for action in self._actions: dx, dy = action alias_to = tuple([coord_from[0]] + [coord_from[1] + dx] + [coord_from[2] + dy]) alias_from = tuple([coord_to[0]] + [coord_to[1] - dx] + [coord_to[2] - dy]) if self._path_alias.get(alias_to) == coord_to and \ self._path_alias.get(alias_from) == coord_from: remove_list.append(alias_to) remove_list.append(alias_from) if len(remove_list) == 0: msg = "Inter-area path not found between {} and {}, " \ "noting to do".format(coord_from, coord_to) warnings.warn(RuntimeWarning(msg)) else: assert len(remove_list) == 2 for key in remove_list: self._path_alias.pop(key) self._world.remove_edge(coord_from, coord_to)
[docs] def add_object(self, coord, reward, prob, punish=0): """Add one object to the world. Each state can only have one object. Parameters ---------- coord : tuple of ints Coordinate of the state to place the object. reward : int or float Reward that the object can generate. prob : float Probability for the object to generate a reward. punish : int or float (optional, default: 0) Punishment that the object will generate if failed to generate a reward. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_object((0, 0, 0), reward=1, prob=0.7) >>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10) """ if coord in self._world.nodes: self._objects.append(_Object(reward, punish, prob, coord)) else: msg = "Coordinate {} out of world".format(coord) raise ValueError(msg)
[docs] def remove_object(self, coord): """Remove one object from the world. Parameters ---------- coord : tuple of ints Coordinate of the state whose object will be removed. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_object((0, 0, 0), reward=1, prob=0.7) >>> W.remove_object((0, 0, 0)) """ pop_idx = None for i, obj in enumerate(self._objects): if coord == obj.coord: pop_idx = i break if pop_idx is not None: self._objects.pop(pop_idx) else: msg = "No object found at {}".format(coord) raise ValueError(msg)
[docs] def update_object(self, coord, **attr): """Reset object attributes. Except the object coordinate ``coord``, all other three attributes could be updated (``reward``, ``prob``, ``punish``). Parameters ---------- coord : tuple of ints Coordinate of the state whose object attribute will be updated. attr : keyword arguments \ {'reward': int or float, 'prob': int or float, 'punish': int or float} Attribute and new value to reset. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_object((0, 0, 0), reward=1, prob=0.7) >>> W.update_object((0, 0, 0), reward=10, prob=0.8, punish=-1) """ for obj in self._objects: if coord == obj.coord: for key, value in attr.items(): if hasattr(obj, key): setattr(obj, key, value) else: msg = "'Object' object doesn't have attribute " \ "'{}', ignored".format(key) warnings.warn(RuntimeWarning(msg)) return msg = "No object found at {}".format(coord) raise ValueError(msg)
[docs] def get_object_attribute(self, coord, attr): """Get the value of object attribute. Parameters ---------- coord : tuple of ints Coordinate of the state whose object attribute will be looked for. attr : str {'reward', 'prob', 'punish'} Object attribute to look for. Returns ------- attribute_value : int or float Value of object attribute ``attr``. Examples -------- >>> W = GridWorld() >>> W.add_object((0, 0, 0), reward=1, prob=0.7) >>> W.get_object_attribute((0, 0, 0), 'reward') 1 >>> W.get_object_attribute((0, 0, 0), 'prob') 0.7 """ for obj in self._objects: if coord == obj.coord: if hasattr(obj, attr): return getattr(obj, attr) else: msg = "'Object' object doesn't have attribute '{}'".format(attr) raise ValueError(msg) msg = "No object found at {}".format(coord) raise ValueError(msg)
[docs] def block(self, coord): """Block one state. When the agent tries to enter a blocked state, it will be forced to stay in the current state, i.e. no movements. Parameters ---------- coord : tuple of ints Coordinate of the state to block. Examples ------- >>> W = GridWorld() >>> W.add_area((1, 1)) >>> W.block((1, 0, 0)) """ if coord not in self._world.nodes: msg = "Coordinate {} out of world".format(coord) raise ValueError(msg) if self._agent is not None: agent_state = self.get_agent_state() if coord == agent_state: msg = "Unable to block state '{}', where the agent is currently in".format(coord) raise RuntimeError(msg) nx.set_node_attributes(self._world, {coord: True}, 'blocked')
[docs] def unblock(self, coord): """Unblock one state. Parameters ---------- coord : tuple of ints Coordinate of the state to block. Examples ------- >>> W = GridWorld() >>> W.add_area((1, 1)) >>> W.block((1, 0, 0)) >>> W.unblock((1, 0, 0)) """ if coord in self._world.nodes: nx.set_node_attributes(self._world, {coord: False}, 'blocked') else: msg = "Coordinate {} out of world".format(coord) raise ValueError(msg)
[docs] def set_altitude(self, area, altitude_mat): """Set the altitude of each state for one area. Parameters ---------- area : int or str Index or name of the area to set altitude. altitude_mat : numpy.ndarray An matrix of the same shape as the area. Each element in the matrix corresponds to the altitude of one state in the area. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 3)) >>> mat = np.random.randn(2, 3) >>> W.set_altitude(1, altitude_mat=mat) >>> W.add_area((2, 3), name="Right") >>> W.set_altitude("Right", altitude_mat=mat) """ if type(area) == str: area_idx = self.get_area_index(area) elif type(area) == int: area_idx = area else: msg = "int for area index or str for area name " \ "expected, got '{}'".format(type(area)) raise TypeError(msg) if area_idx > self._num_area or area_idx < 0: msg = "Area {} not found".format(area_idx) raise ValueError(msg) area_shape = self.get_area_shape(area_idx) if altitude_mat.shape != area_shape: msg = "Mismatch shape between Area({}) {} and " \ "altitude matrix {}".format(area_idx, area_shape, altitude_mat.shape) raise ValueError(msg) altitude_mapping = {} for x in range(area_shape[0]): for y in range(area_shape[1]): coord = (area_idx, x, y) altitude_mapping[coord] = altitude_mat[x, y] nx.set_node_attributes(self._world, altitude_mapping, 'altitude')
[docs] def set_area_name(self, area, name): """Set an alias name for an area. Can be used to reset the alias name of an area. Parameters ---------- area : int or str Index or name of the area to set name. name : str Name of the area to be set. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.set_area_name(1, "Up") """ if name in self._area_alias.keys(): msg = "Alias name already exists, try another name" raise RuntimeError(msg) if type(area) == str: area_idx = self.get_area_index(area) self._area_alias[name] = area_idx self._area_alias.pop(area) elif type(area) == int: if area > self._num_area or area < 0: msg = "Area with index '{}' not found".format(area) raise ValueError(msg) else: try: old_name = self.get_area_name(area) except RuntimeError: pass else: self._area_alias.pop(old_name) self._area_alias[name] = area else: msg = "int for area index or str for area name " \ "expected to find the area, got '{}'".format(type(area)) raise TypeError(msg)
[docs] def get_area_name(self, area_idx): """Get the alias name of an area using area index. Parameters ---------- area_idx : int Index of the area to get name. Returns ------- name : str Name of the area with index ``area_idx``. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2), name="Up") >>> W.get_area_name(1) Up """ if area_idx > self._num_area: msg = "Area index '{}' out of range".format(area_idx) raise ValueError(msg) for key, value in self._area_alias.items(): if value == area_idx: return key msg = "Area with index '{}' don't have an alias name".format(area_idx) raise RuntimeError(msg)
[docs] def get_area_index(self, area_name): """Get the index of an area with its alias name. Parameters ---------- area_name : str Name of the area to get index. Returns ------- index : int Index of the area with alias name ``area_name``. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2), name="Up") >>> W.get_area_index("Up") 1 """ for key, value in self._area_alias.items(): if key == area_name: return value msg = "Area with name '{}' not found".format(area_name) raise ValueError(msg)
[docs] def get_area_shape(self, area): """Get the shape of one area. Parameters ---------- area : int or str Index or name of the area to get its shape. Returns ------- shape : tuple of ints Shape of the area with index ``area_idx``. Examples -------- >>> W = GridWorld() >>> W.add_area((3, 10)) >>> W.get_area_shape(1) (3, 10) >>> W.add_area((5, 10), name="Right") >>> W.get_area_shape("Right") (5, 10) """ if type(area) == str: area_idx = self.get_area_index(area) elif type(area) == int: area_idx = area else: msg = "int for area index or str for area name " \ "expected, got '{}'".format(type(area)) raise TypeError(msg) if area_idx > self._num_area or area_idx < 0: msg = "Area {} not found".format(area_idx) raise ValueError(msg) max_x = 0 max_y = 0 for area, x, y in self._world.nodes: if area != area_idx: continue else: if x > max_x: max_x = x if y > max_y: max_y = y return max_x + 1, max_y + 1
[docs] def get_area_altitude(self, area): """Get the altitude of each state in one area. Parameters ---------- area : int or str Index or name of the area to get its state altitude. Returns ------- altitude_matrix : numpy.ndarray Altitude matrix of the area with index ``area_idx``. Each element in the matrix corresponds to the altitude of one state in the area. Examples -------- >>> W = GridWorld() >>> W.add_area((3, 5)) >>> W.add_area((3, 5), name="Right") >>> mat = np.ones((3, 5)) >>> W.set_altitude(1, altitude_mat=mat) >>> W.get_area_altitude(1) array([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]) >>> W.get_area_altitude("Right") array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]) """ if type(area) == str: area_idx = self.get_area_index(area) elif type(area) == int: area_idx = area else: msg = "int for area index or str for area name " \ "expected, got '{}'".format(type(area)) raise TypeError(msg) if area_idx > self._num_area or area_idx < 0: msg = "Area {} not found".format(area_idx) raise ValueError(msg) area_shape = self.get_area_shape(area_idx) altitude_mat = np.zeros(area_shape) for coord in self._world.nodes: if coord[0] != area_idx: continue else: altitude_mat[coord[1], coord[2]] = \ nx.get_node_attributes(self._world, 'altitude')[coord] return altitude_mat
[docs] def init_agent(self, init_coord=None, overwrite=False): """Initialize an agent in the world. Parameters ---------- init_coord : tuple of ints (optional, default: None) Coordinate of the agent initial state. If not provided, the agent will be initialized at ``(0, 0, 0)`` by default. overwrite : bool (default: False) Whether to overwrite the existing agent. Examples -------- >>> W = GridWorld() >>> W.init_agent() >>> W.add_area((2, 4)) >>> W.init_agent((1, 1, 3), overwrite=True) """ if init_coord is None: init_coord = (0, 0, 0) if not self._world.has_node(init_coord): msg = "Initial state coordinate {} out of world".format(init_coord) raise ValueError(msg) if nx.get_node_attributes(self._world, 'blocked')[init_coord]: msg = "Unable to initialize an agent at a blocked state '{}'".format(init_coord) raise RuntimeError(msg) if self._agent is None or overwrite: self._agent = _Agent(init_coord) else: raise ng.NeuGymOverwriteError("Agent already exists, " "set 'overwrite=True' to overwrite")
[docs] def get_agent_state(self, when="current"): """Get state of the agent. Parameters ---------- when : str {"current", "init"} (default: "current") Choose to get the initial ("init") or current ("current") state of the agent. Returns ------- agent_current_state : tuple of ints Coordinate of the state where the agent stays. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.init_agent() >>> W.step((1, 0)) ((1, 0, 0), 0.0, False) >>> W.get_agent_state() (1, 0, 0) >>> W.get_agent_state(when="init") (0, 0, 0) """ if when == "current": return self._agent.current_state elif when == "init": return self._agent.init_state else: msg = "Unrecognized parameter '{}', 'current' or 'init' expected".format(when) raise ValueError(msg)
@property def world(self): """A copy of ``world`` attribute of the gridworld environment. ``GridWorld.world`` is a NetworkX Graph object which represents here the areas, states and their connections in the gridworld environment. Each node in the graph is a state named by its global coordinate ``(area_idx, x, y)``, and it has an attribute ``altitude`` which represents the altitude of the state. Each edge in the graph denotes the connections between two states (including inter-area connections). .. note:: More information about NetworkX Graph object can be found at `networkx.Graph \ <https://networkx.org/documentation/stable/reference/classes/graph.html>`_ Returns ------- world : netwokx.Graph World attribute of the gridworld environment represented by a NetworkX Graph object. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> G = W.world >>> G.nodes NodeView(((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1))) References --------- .. [#] NetworkX Documentation: https://networkx.org/ """ return self._world.copy() @property def time(self): """Gridworld environment time. ``time`` attribute of gridworld environment represents the number of steps that the agent has moved. Returns ------- time : int Current time of gridworld environment. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 3)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.init_agent() >>> W.step((1, 0)) >>> W.step((0, 0)) >>> W.time 2 """ return self._time @property def num_area(self): """Number of areas in the ``world`` of gridworld environment. .. note:: Origin is not included when counting the number. Returns ------- num_area : int Number of areas in the world. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 2)) >>> W.add_area((3, 3)) >>> W.num_area 2 """ return self._num_area @property def actions(self): """Action space of the gridworld environment. Each action in the action space is represented with a tuple ``(dx, dy)``. Returns ------- actions : tuple Action space of the gridworld environment. Examples -------- >>> W = GridWorld() >>> W.actions ((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)) """ return self._actions @property def has_reset_checkpoint(self): """Whether there is a reset checkpoint for the gridworld environment. Returns ------- has_reset_checkpoint : bool Whether a reset checkpoint of the environment has been created. Examples -------- >>> W = GridWorld() >>> W.has_reset_checkpoint False >>> W.set_reset_checkpoint() >>> W.has_reset_checkpoint True """ return self._has_reset_checkpoint
[docs] def step(self, action): """Make the agent move toward direction given by ``action``. .. note:: - If one movement will cause the agent get out of the world, the agent will be forced to stay in the same position (state) instead. - If the agent reaches a state with an object, no matter whether the agent gets a reward or punishment from the object, this trial will end and the agent will be transported back to its initial state. Parameters ---------- action : tuple of ints \ {(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)} Direction of the agent movement. Returns ------- next_state : tuple of ints Next state of the agent after movement. reward : int or float Reward that the agent gets at through this movement. done : bool Whether this trial ends. Examples -------- >>> W = GridWorld() >>> W.add_area((2, 3)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.init_agent() >>> W.step((1, 0)) ((1, 0, 0), 0.0, False) """ if action not in self._actions: msg = "Illegal action {}, should be one of {}".format(action, self._actions) raise ValueError(msg) else: dx, dy = action done = False reward = 0 current_state = self._agent.current_state next_state = (current_state[0], current_state[1] + dx, current_state[2] + dy) if not self._world.has_node(next_state): if next_state in self._path_alias.keys(): next_state = self._path_alias[next_state] else: next_state = current_state if nx.get_node_attributes(self._world, 'blocked')[next_state]: next_state = current_state altitude = nx.get_node_attributes(self._world, 'altitude') reward += altitude[current_state] - altitude[next_state] for obj in self._objects: if obj.coord == next_state: reward += obj.get_reward() done = True break self._time += 1 if done: self._agent.reset() else: self._agent.update(current_state=next_state) return next_state, reward, done
[docs] def set_reset_checkpoint(self, overwrite=False): """Set environment checkpoint for reset. Parameters ---------- overwrite : bool (default: False) Whether to overwrite existing checkpoint. Examples -------- >>> W = GridWorld() >>> W.add_area((3, 3)) >>> W.add_object((1, 2, 1), reward=1, prob=0.7) >>> W.set_reset_checkpoint() >>> W.has_reset_checkpoint True """ if not self._has_reset_checkpoint or overwrite: for key in self._reset_state.keys(): self._reset_state[key] = copy.deepcopy(getattr(self, '_' + key)) self._has_reset_checkpoint = True else: raise ng.NeuGymOverwriteError("Reset state already exists, " "set 'overwrite=True' to overwrite")
[docs] def reset(self): """Reset the environment to the checkpoint state. Examples -------- >>> W = GridWorld() >>> W.add_area((3, 3)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.add_object((1, 2, 1), reward=1, prob=0.7) >>> W.set_reset_checkpoint() >>> W.init_agent() >>> W.step((1, 0)) ((1, 0, 0), 0.0, False) >>> W.reset() """ if not self._has_reset_checkpoint: raise ng.NeuGymCheckpointError( "Reset state not found, use 'set_reset_state()' " "to set the reset checkpoint first") for key, value in self._reset_state.items(): setattr(self, '_' + key, copy.deepcopy(value))
def __repr__(self): msg = "GridWorld:\n" msg += "".join(["=" for _ in range(10)]) msg += "\n" msg += "time: {}\n".format(self.time) msg += "areas: \n" for i in range(self._num_area + 1): alias = "" for key, value in self._area_alias.items(): if value == i: alias = key break msg += "\t[{}][{}] Area(shape={})\n".format(i, alias, self.get_area_shape(i)) if len(self._path_alias) == 0: msg += "inter-area connections: None\n" else: msg += "inter-area connections:\n" for u, v in self.world.edges(): if u[0] == v[0]: continue else: for a in self.actions: dx, dy = a alias = tuple([u[0]] + [u[1] + dx] + [u[2] + dy]) try: if self._path_alias[alias] == v: msg += "\t{} + {} -> {}\n".format(u, a, v) except KeyError: continue if len(self._objects) == 0: msg += "objects: None\n" else: msg += "objects:\n" for i, obj in enumerate(self._objects): msg += "\t[{}] {}\n".format(i, str(obj)) msg += "actions: {}\n".format(self._actions) msg += "agent: {}\n".format(str(self._agent)) msg += "has_reset_state: {}\n".format(self._has_reset_checkpoint) msg += "".join(["=" for _ in range(10)]) return msg