"""Base class for gridworld environment."""
import copy
import warnings
import networkx as nx
import numpy as np
import neugym as ng
from ._agent import _Agent
from ._object import _Object
__all__ = [
"GridWorld"
]
[docs]
class GridWorld:
r"""Base class for gridworld environment.
``Gridworld`` environment consists of a ``world``, ``objects``, and one ``agent``.
The world consists of multiple connected rectangle areas and
each area is represented by a two-dimensional gird graph,
which has each node connected to its four nearest neighbors.
Each node represents a state in the world, and has an attribute
``altitude``. When the agent moves from one state toward another state,
it will get a reward generated by the altitude change (if there is),
i.e.
.. math::
R_{move} = A_s - A_{s + 1}
where $R_{move}$ is the movement reward and $A$
represents the altitude of current state $s$ and next state $s + 1$.
At each position(state), the agent can choose from 5 ``actions`` to move towards
**UP**, **DOWN**, **LEFT**, **RIGHT**, and **STAY** in the same state.
When the performed movement would make the agent get out of the world,
the agent would be forced to stay in the same state.
Objects where the agent can get reward are placed at different states
and each state can only obtain one object. Each object has its own adjustable
probability (``prob``) of getting a ``reward`` when the agent reaches
the state with this object, if the agent fails to get a reward,
it will get a punishment (``punish``).
.. math::
R_{object} =
\left \{
\begin{aligned}
& reward, P=p \\
& punish, P=1-p
\end{aligned}
\right.
Under this situation, the total reward for this step will be
the movement reward adding the object reward.
.. math::
R_{total} = R_{move} + R_{object}
So long as the agent gets to an object
(no matter it was reward or punish that it got),
this trial is finished and then the agent will be
sent back to the start state of each trial.
Parameters
----------
origin_shape : tuple of ints (optional, default: None)
Shape of the world origin. If not provided, the origin will be
initialized to be only one state (0, 0, 0), otherwise it will
be a rectangular area of shape ``origin_shape``.
Examples
--------
Initialize a gridworld environment with only an origin state.
>>> W = GridWorld()
W can be grown in several aspects.
**Areas:**
Add one area of shape (2, 2).
>>> W.add_area((2, 2))
Remove areas.
>>> W.remove_area(1)
Set area altitude.
>>> W.add_area((3, 3))
>>> W.set_altitude(1, altitude_mat=np.random.randn(3, 3))
**Paths:**
Add inter-area paths.
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.add_area((2, 2))
>>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1))
Remove paths.
>>> W.remove_path((1, 1, 1), (2, 1, 0))
**Objects:**
Add objects.
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_object((0, 0, 0), reward=1, prob=0.7)
>>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10)
Remove objects.
>>> W.remove_object((1, 0, 0))
Update object attributes.
>>> W.update_object((0, 0, 0), reward=10)
>>> W.update_object((0, 0, 0), reward=1, prob=0.8)
**Agent:**
>>> W = GridWorld()
>>> W.init_agent()
One can also manually set the agent initial state.
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.init_agent(init_coord=(1, 1, 1))
When the agent is initialized, the agent can move in the
world and get rewards.
>>> next_state, reward, done = W.step(action=(0, 1))
**Reset:**
To reset the environment, first set a reset checkpoint.
>>> W.set_reset_checkpoint()
Then the environment can be reset if needed.
>>> W.reset()
"""
[docs]
def __init__(self, origin_shape=None):
"""Initialize a gridworld environment.
Parameters
----------
origin_shape : tuple of ints (optional, default: None)
Shape of the world origin. If not provided, the origin will be
initialized to be only one state ``(0, 0, 0)``, otherwise it will
be a rectangular area of shape ``origin_shape``.
Examples
--------
Initialize a gridworld environment by default.
>>> W = GridWorld()
Manually set origin shape.
>>> W = GridWorld((3, 4))
"""
self._world = nx.Graph()
self._time = 0
self._num_area = 0
self._area_alias = {}
self._path_alias = {}
self._objects = []
self._actions = ((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1))
# Add origin.
if origin_shape is None:
origin_shape = (1, 1)
self._world.add_node((0, 0, 0))
else:
m, n = origin_shape
origin = nx.grid_2d_graph(m, n)
mapping = {}
for coord in origin.nodes:
mapping[coord] = tuple([0] + list(coord))
origin = nx.relabel_nodes(origin, mapping)
self._world.update(origin)
origin_altitude_mat = np.zeros(origin_shape)
self.set_area_name(0, 'origin')
self.set_altitude(0, origin_altitude_mat)
nx.set_node_attributes(self._world, False, 'blocked')
# Agent.
self._agent = None
# Reset state.
self._has_reset_checkpoint = False
self._reset_state = {
"world": None,
"time": None,
"num_area": None,
"area_alias": None,
"path_alias": None,
"objects": None,
"agent": None
}
[docs]
def add_area(self, shape, name=None):
"""Add a new area to the world.
.. note::
Using this function people can add a new dangling area to the world,
i.e. for now, it cannot be accessed from any of the other areas.
To make this new area accessible, use ``GridWorld.add_path()`` function
to build inter-area bridges.
Parameters
----------
shape : tuple of ints
Shape of the new area.
name : str (optional, default: None)
Alias name of the area to be added.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 3))
>>> W.add_area((2, 2), name="Right")
"""
if name in self._area_alias.keys():
msg = "Alias name already exists, try another name"
raise RuntimeError(msg)
# Create new area.
m, n = shape
new_area = nx.grid_2d_graph(m, n)
mapping = {}
for coord in new_area.nodes:
mapping[coord] = tuple([self._num_area + 1] + list(coord))
new_area = nx.relabel_nodes(new_area, mapping)
nx.set_node_attributes(new_area, False, 'blocked')
self._world.update(new_area)
self._num_area += 1
if name is not None:
self._area_alias[name] = self._num_area
# Add area altitude.
altitude_mat = np.zeros(shape)
self.set_altitude(self._num_area, altitude_mat)
[docs]
def remove_area(self, area):
"""Remove an area from the world.
Index for all other areas left will be automatically reset.
(Minus one if their original index are larger than the index of removed
area.)
.. note::
- All inter-area path and objects related to the to be removed area
will be automatically removed.
- Origin of the world is not allowed to be removed.
Parameters
----------
area : int or str
Index or name of the area to be removed.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.remove_area(1)
>>> W.add_area((2, 2), name="example")
>>> W.remove_area("example")
"""
new_world = copy.deepcopy(self._world)
if type(area) == str:
area_idx = self.get_area_index(area)
elif type(area) == int:
area_idx = area
else:
msg = "int for area index or str for area name " \
"expected, got '{}'".format(type(area))
raise TypeError(msg)
if area_idx == 0:
raise ng.NeuGymPermissionError("Not allowed to remove origin area")
# Remove area
node_list = list(new_world.nodes)
for node in node_list:
if node[0] == area_idx:
new_world.remove_node(node)
elif node[0] > area_idx:
new_label = tuple([node[0] - 1] + list(node[1:]))
new_world = nx.relabel_nodes(new_world, {node: new_label})
self._world = new_world
self._num_area -= 1
# Remove invalid area alias.
new_area_alias = {}
for key, value in self._area_alias.items():
if value == area_idx:
continue
elif value > area_idx:
new_area_alias[key] = value - 1
else:
new_area_alias[key] = value
self._area_alias = new_area_alias
# Remove invalid path alias.
new_path_alias = {}
for key, value in self._path_alias.items():
if key[0] == area_idx:
continue
elif key[0] > area_idx:
new_key = tuple([key[0] - 1] + list(key[1:]))
else:
new_key = key
if value[0] == area_idx:
continue
elif value[0] > area_idx:
new_value = tuple([value[0] - 1] + list(value[1:]))
else:
new_value = value
new_path_alias[new_key] = new_value
self._path_alias = new_path_alias
# Remove objects in the area to be removed.
new_objects = []
for i, obj in enumerate(self._objects):
if obj.coord[0] < area_idx:
new_objects.append(obj)
elif obj.coord[0] == area_idx:
continue
else:
obj.coord = tuple([obj.coord[0] - 1] + list(obj.coord[1:]))
new_objects.append(obj)
self._objects = new_objects
[docs]
def add_path(self, coord_from, coord_to, register_action=None):
"""Add a new inter-area connection.
.. note::
- Creating a path within the same area is not allowed.
- When an inter-area path from ``coord_from`` to ``coord_to`` with
action ``register_action`` is built, a reverse path is also build at
the same time, i.e. the agent can also move from ``coord_to`` to
``coord_from`` with the reverse action of ``register_action`` (e.g.
the reverse action of **UP(1, 0)** is **DOWN(-1, 0)**).
Parameters
----------
coord_from : tuple of ints
Coordinate of the path start state.
coord_to : tuple of ints
Coordinate of the path end state.
register_action : tuple of ints (optional, default: None)
Register an action to transport the agent from ``coord_from`` to
``coord_to``. If None, possible action to register will
be searched in the order of
[**UP(1, 0)**, **DOWN(-1, 0)**, **RIGHT(0, 1)**, **LEFT(0, -1)**],
and the first possible action will be registered.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.add_area((3, 3))
>>> W.add_path((1, 1, 1), (2, 1, 0), register_action=(0, 1))
"""
if coord_from[0] == coord_to[0]:
msg = "Not allowed to add path within an area"
raise ng.NeuGymPermissionError(msg)
if len(coord_from) != 3:
msg = "Tuple of length 3 expected for argument " \
"'coord_from', got {}".format(len(coord_from))
raise ValueError(msg)
if not self._world.has_node(coord_from):
msg = "'coord_from' coordinate {} out of world".format(coord_from)
raise ValueError(msg)
if self._world.degree(coord_from) == 4:
msg = "Maximum number of connections (4) for position " \
"{} reached, not allowed to access from it".format(coord_from)
raise ng.NeuGymConnectivityError(msg)
if len(coord_to) != 3:
msg = "Tuple of length 3 expected for argument " \
"'coord_to', got {}".format(len(coord_to))
raise ValueError(msg)
if not self._world.has_node(coord_to):
msg = "'coord_to' coordinate {} out of world".format(coord_to)
raise ValueError(msg)
elif self._world.degree(coord_to) == 4:
msg = "Maximum number of connections (4) for position " \
"{} reached, not allowed to access to it".format(coord_to)
raise ng.NeuGymConnectivityError(msg)
if (coord_from, coord_to) in self._world.edges:
msg = "Path already exists between {} and {}".format(coord_from, coord_to)
raise ng.NeuGymOverwriteError(msg)
# Search for free actions that can be registered.
free_actions = []
for action in self._actions:
dx, dy = action
alias_to = tuple([coord_from[0]] + [coord_from[1] + dx] + [coord_from[2] + dy])
alias_from = tuple([coord_to[0]] + [coord_to[1] - dx] + [coord_to[2] - dy])
if self._world.has_node(alias_to) or self._world.has_node(alias_from) or \
alias_to in self._path_alias.keys() or alias_from in self._path_alias.keys():
continue
free_actions.append(action)
if len(free_actions) == 0:
msg = "Unable to connect two areas from 'coord_from' {} to 'coord_to' {}, " \
"all allowed actions allocated".format(coord_from, coord_to)
raise ng.NeuGymConnectivityError(msg)
if register_action is not None:
if register_action not in self._actions:
msg = "Illegal 'register_action' {}, " \
"expected one of {}".format(register_action, self._actions)
raise ValueError(msg)
if register_action not in free_actions:
msg = "Unable to register action 'register_action' {}, " \
"already allocated".format(register_action)
raise ng.NeuGymConnectivityError(msg)
dx, dy = register_action
else:
dx, dy = free_actions[0]
# Register action.
self._path_alias[tuple([coord_from[0]] +
[coord_from[1] + dx] +
[coord_from[2] + dy])] = coord_to
self._path_alias[tuple([coord_to[0]] +
[coord_to[1] - dx] +
[coord_to[2] - dy])] = coord_from
self._world.add_edge(coord_from, coord_to)
[docs]
def remove_path(self, coord_from, coord_to):
"""Remove one inter-area connection from the world.
Parameters
----------
coord_from : tuple of ints
Coordinate of the path start state.
coord_to : tuple of ints
Coordinate of the path end state.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_path((0, 0, 0), (1, 0, 0), register_action=(0, 1))
>>> W.remove_path((0, 0, 0), (1, 0, 0))
"""
if coord_from[0] == coord_to[0]:
msg = "Not allowed to remove path within an area, " \
"try using `GridWorld.block()` instead"
raise ng.NeuGymPermissionError(msg)
if len(coord_from) != 3 or len(coord_to) != 3:
msg = "Tuple of length 3 expected for position coordinate"
raise ValueError(msg)
# Find alias to be removed.
remove_list = []
for action in self._actions:
dx, dy = action
alias_to = tuple([coord_from[0]] +
[coord_from[1] + dx] +
[coord_from[2] + dy])
alias_from = tuple([coord_to[0]] +
[coord_to[1] - dx] +
[coord_to[2] - dy])
if self._path_alias.get(alias_to) == coord_to and \
self._path_alias.get(alias_from) == coord_from:
remove_list.append(alias_to)
remove_list.append(alias_from)
if len(remove_list) == 0:
msg = "Inter-area path not found between {} and {}, " \
"noting to do".format(coord_from, coord_to)
warnings.warn(RuntimeWarning(msg))
else:
assert len(remove_list) == 2
for key in remove_list:
self._path_alias.pop(key)
self._world.remove_edge(coord_from, coord_to)
[docs]
def add_object(self, coord, reward, prob, punish=0):
"""Add one object to the world.
Each state can only have one object.
Parameters
----------
coord : tuple of ints
Coordinate of the state to place the object.
reward : int or float
Reward that the object can generate.
prob : float
Probability for the object to generate a reward.
punish : int or float (optional, default: 0)
Punishment that the object will generate if failed
to generate a reward.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_object((0, 0, 0), reward=1, prob=0.7)
>>> W.add_object((1, 0, 0), reward=1, prob=0.3, punish=-10)
"""
if coord in self._world.nodes:
self._objects.append(_Object(reward, punish, prob, coord))
else:
msg = "Coordinate {} out of world".format(coord)
raise ValueError(msg)
[docs]
def remove_object(self, coord):
"""Remove one object from the world.
Parameters
----------
coord : tuple of ints
Coordinate of the state whose object will be removed.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_object((0, 0, 0), reward=1, prob=0.7)
>>> W.remove_object((0, 0, 0))
"""
pop_idx = None
for i, obj in enumerate(self._objects):
if coord == obj.coord:
pop_idx = i
break
if pop_idx is not None:
self._objects.pop(pop_idx)
else:
msg = "No object found at {}".format(coord)
raise ValueError(msg)
[docs]
def update_object(self, coord, **attr):
"""Reset object attributes.
Except the object coordinate ``coord``, all other
three attributes could be updated (``reward``,
``prob``, ``punish``).
Parameters
----------
coord : tuple of ints
Coordinate of the state whose object attribute will be updated.
attr : keyword arguments \
{'reward': int or float, 'prob': int or float, 'punish': int or float}
Attribute and new value to reset.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_object((0, 0, 0), reward=1, prob=0.7)
>>> W.update_object((0, 0, 0), reward=10, prob=0.8, punish=-1)
"""
for obj in self._objects:
if coord == obj.coord:
for key, value in attr.items():
if hasattr(obj, key):
setattr(obj, key, value)
else:
msg = "'Object' object doesn't have attribute " \
"'{}', ignored".format(key)
warnings.warn(RuntimeWarning(msg))
return
msg = "No object found at {}".format(coord)
raise ValueError(msg)
[docs]
def get_object_attribute(self, coord, attr):
"""Get the value of object attribute.
Parameters
----------
coord : tuple of ints
Coordinate of the state whose object attribute will be looked for.
attr : str {'reward', 'prob', 'punish'}
Object attribute to look for.
Returns
-------
attribute_value : int or float
Value of object attribute ``attr``.
Examples
--------
>>> W = GridWorld()
>>> W.add_object((0, 0, 0), reward=1, prob=0.7)
>>> W.get_object_attribute((0, 0, 0), 'reward')
1
>>> W.get_object_attribute((0, 0, 0), 'prob')
0.7
"""
for obj in self._objects:
if coord == obj.coord:
if hasattr(obj, attr):
return getattr(obj, attr)
else:
msg = "'Object' object doesn't have attribute '{}'".format(attr)
raise ValueError(msg)
msg = "No object found at {}".format(coord)
raise ValueError(msg)
[docs]
def block(self, coord):
"""Block one state.
When the agent tries to enter a blocked state,
it will be forced to stay in the current state,
i.e. no movements.
Parameters
----------
coord : tuple of ints
Coordinate of the state to block.
Examples
-------
>>> W = GridWorld()
>>> W.add_area((1, 1))
>>> W.block((1, 0, 0))
"""
if coord not in self._world.nodes:
msg = "Coordinate {} out of world".format(coord)
raise ValueError(msg)
if self._agent is not None:
agent_state = self.get_agent_state()
if coord == agent_state:
msg = "Unable to block state '{}', where the agent is currently in".format(coord)
raise RuntimeError(msg)
nx.set_node_attributes(self._world, {coord: True}, 'blocked')
[docs]
def unblock(self, coord):
"""Unblock one state.
Parameters
----------
coord : tuple of ints
Coordinate of the state to block.
Examples
-------
>>> W = GridWorld()
>>> W.add_area((1, 1))
>>> W.block((1, 0, 0))
>>> W.unblock((1, 0, 0))
"""
if coord in self._world.nodes:
nx.set_node_attributes(self._world, {coord: False}, 'blocked')
else:
msg = "Coordinate {} out of world".format(coord)
raise ValueError(msg)
[docs]
def set_altitude(self, area, altitude_mat):
"""Set the altitude of each state for one area.
Parameters
----------
area : int or str
Index or name of the area to set altitude.
altitude_mat : numpy.ndarray
An matrix of the same shape as the area.
Each element in the matrix corresponds to the altitude of one state
in the area.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 3))
>>> mat = np.random.randn(2, 3)
>>> W.set_altitude(1, altitude_mat=mat)
>>> W.add_area((2, 3), name="Right")
>>> W.set_altitude("Right", altitude_mat=mat)
"""
if type(area) == str:
area_idx = self.get_area_index(area)
elif type(area) == int:
area_idx = area
else:
msg = "int for area index or str for area name " \
"expected, got '{}'".format(type(area))
raise TypeError(msg)
if area_idx > self._num_area or area_idx < 0:
msg = "Area {} not found".format(area_idx)
raise ValueError(msg)
area_shape = self.get_area_shape(area_idx)
if altitude_mat.shape != area_shape:
msg = "Mismatch shape between Area({}) {} and " \
"altitude matrix {}".format(area_idx,
area_shape,
altitude_mat.shape)
raise ValueError(msg)
altitude_mapping = {}
for x in range(area_shape[0]):
for y in range(area_shape[1]):
coord = (area_idx, x, y)
altitude_mapping[coord] = altitude_mat[x, y]
nx.set_node_attributes(self._world, altitude_mapping, 'altitude')
[docs]
def set_area_name(self, area, name):
"""Set an alias name for an area.
Can be used to reset the alias name of an area.
Parameters
----------
area : int or str
Index or name of the area to set name.
name : str
Name of the area to be set.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.set_area_name(1, "Up")
"""
if name in self._area_alias.keys():
msg = "Alias name already exists, try another name"
raise RuntimeError(msg)
if type(area) == str:
area_idx = self.get_area_index(area)
self._area_alias[name] = area_idx
self._area_alias.pop(area)
elif type(area) == int:
if area > self._num_area or area < 0:
msg = "Area with index '{}' not found".format(area)
raise ValueError(msg)
else:
try:
old_name = self.get_area_name(area)
except RuntimeError:
pass
else:
self._area_alias.pop(old_name)
self._area_alias[name] = area
else:
msg = "int for area index or str for area name " \
"expected to find the area, got '{}'".format(type(area))
raise TypeError(msg)
[docs]
def get_area_name(self, area_idx):
"""Get the alias name of an area using area index.
Parameters
----------
area_idx : int
Index of the area to get name.
Returns
-------
name : str
Name of the area with index ``area_idx``.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2), name="Up")
>>> W.get_area_name(1)
Up
"""
if area_idx > self._num_area:
msg = "Area index '{}' out of range".format(area_idx)
raise ValueError(msg)
for key, value in self._area_alias.items():
if value == area_idx:
return key
msg = "Area with index '{}' don't have an alias name".format(area_idx)
raise RuntimeError(msg)
[docs]
def get_area_index(self, area_name):
"""Get the index of an area with its alias name.
Parameters
----------
area_name : str
Name of the area to get index.
Returns
-------
index : int
Index of the area with alias name ``area_name``.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2), name="Up")
>>> W.get_area_index("Up")
1
"""
for key, value in self._area_alias.items():
if key == area_name:
return value
msg = "Area with name '{}' not found".format(area_name)
raise ValueError(msg)
[docs]
def get_area_shape(self, area):
"""Get the shape of one area.
Parameters
----------
area : int or str
Index or name of the area to get its shape.
Returns
-------
shape : tuple of ints
Shape of the area with index ``area_idx``.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((3, 10))
>>> W.get_area_shape(1)
(3, 10)
>>> W.add_area((5, 10), name="Right")
>>> W.get_area_shape("Right")
(5, 10)
"""
if type(area) == str:
area_idx = self.get_area_index(area)
elif type(area) == int:
area_idx = area
else:
msg = "int for area index or str for area name " \
"expected, got '{}'".format(type(area))
raise TypeError(msg)
if area_idx > self._num_area or area_idx < 0:
msg = "Area {} not found".format(area_idx)
raise ValueError(msg)
max_x = 0
max_y = 0
for area, x, y in self._world.nodes:
if area != area_idx:
continue
else:
if x > max_x:
max_x = x
if y > max_y:
max_y = y
return max_x + 1, max_y + 1
[docs]
def get_area_altitude(self, area):
"""Get the altitude of each state in one area.
Parameters
----------
area : int or str
Index or name of the area to get its state altitude.
Returns
-------
altitude_matrix : numpy.ndarray
Altitude matrix of the area with index ``area_idx``.
Each element in the matrix corresponds to the altitude of one state
in the area.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((3, 5))
>>> W.add_area((3, 5), name="Right")
>>> mat = np.ones((3, 5))
>>> W.set_altitude(1, altitude_mat=mat)
>>> W.get_area_altitude(1)
array([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> W.get_area_altitude("Right")
array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
"""
if type(area) == str:
area_idx = self.get_area_index(area)
elif type(area) == int:
area_idx = area
else:
msg = "int for area index or str for area name " \
"expected, got '{}'".format(type(area))
raise TypeError(msg)
if area_idx > self._num_area or area_idx < 0:
msg = "Area {} not found".format(area_idx)
raise ValueError(msg)
area_shape = self.get_area_shape(area_idx)
altitude_mat = np.zeros(area_shape)
for coord in self._world.nodes:
if coord[0] != area_idx:
continue
else:
altitude_mat[coord[1], coord[2]] = \
nx.get_node_attributes(self._world, 'altitude')[coord]
return altitude_mat
[docs]
def init_agent(self, init_coord=None, overwrite=False):
"""Initialize an agent in the world.
Parameters
----------
init_coord : tuple of ints (optional, default: None)
Coordinate of the agent initial state. If not
provided, the agent will be initialized at ``(0, 0, 0)``
by default.
overwrite : bool (default: False)
Whether to overwrite the existing agent.
Examples
--------
>>> W = GridWorld()
>>> W.init_agent()
>>> W.add_area((2, 4))
>>> W.init_agent((1, 1, 3), overwrite=True)
"""
if init_coord is None:
init_coord = (0, 0, 0)
if not self._world.has_node(init_coord):
msg = "Initial state coordinate {} out of world".format(init_coord)
raise ValueError(msg)
if nx.get_node_attributes(self._world, 'blocked')[init_coord]:
msg = "Unable to initialize an agent at a blocked state '{}'".format(init_coord)
raise RuntimeError(msg)
if self._agent is None or overwrite:
self._agent = _Agent(init_coord)
else:
raise ng.NeuGymOverwriteError("Agent already exists, "
"set 'overwrite=True' to overwrite")
[docs]
def get_agent_state(self, when="current"):
"""Get state of the agent.
Parameters
----------
when : str {"current", "init"} (default: "current")
Choose to get the initial ("init") or current ("current") state of the agent.
Returns
-------
agent_current_state : tuple of ints
Coordinate of the state where the agent stays.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.init_agent()
>>> W.step((1, 0))
((1, 0, 0), 0.0, False)
>>> W.get_agent_state()
(1, 0, 0)
>>> W.get_agent_state(when="init")
(0, 0, 0)
"""
if when == "current":
return self._agent.current_state
elif when == "init":
return self._agent.init_state
else:
msg = "Unrecognized parameter '{}', 'current' or 'init' expected".format(when)
raise ValueError(msg)
@property
def world(self):
"""A copy of ``world`` attribute of the gridworld environment.
``GridWorld.world`` is a NetworkX Graph object which represents
here the areas, states and their connections in the gridworld
environment. Each node in the graph is a state named by its
global coordinate ``(area_idx, x, y)``, and it has an attribute
``altitude`` which represents the altitude of the state.
Each edge in the graph denotes the connections between two
states (including inter-area connections).
.. note::
More information about NetworkX Graph object can be found at
`networkx.Graph \
<https://networkx.org/documentation/stable/reference/classes/graph.html>`_
Returns
-------
world : netwokx.Graph
World attribute of the gridworld environment represented by
a NetworkX Graph object.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> G = W.world
>>> G.nodes
NodeView(((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)))
References
---------
.. [#] NetworkX Documentation: https://networkx.org/
"""
return self._world.copy()
@property
def time(self):
"""Gridworld environment time.
``time`` attribute of gridworld environment represents the
number of steps that the agent has moved.
Returns
-------
time : int
Current time of gridworld environment.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 3))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.init_agent()
>>> W.step((1, 0))
>>> W.step((0, 0))
>>> W.time
2
"""
return self._time
@property
def num_area(self):
"""Number of areas in the ``world`` of gridworld environment.
.. note::
Origin is not included when counting the number.
Returns
-------
num_area : int
Number of areas in the world.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 2))
>>> W.add_area((3, 3))
>>> W.num_area
2
"""
return self._num_area
@property
def actions(self):
"""Action space of the gridworld environment.
Each action in the action space is represented
with a tuple ``(dx, dy)``.
Returns
-------
actions : tuple
Action space of the gridworld environment.
Examples
--------
>>> W = GridWorld()
>>> W.actions
((0, 0), (1, 0), (-1, 0), (0, 1), (0, -1))
"""
return self._actions
@property
def has_reset_checkpoint(self):
"""Whether there is a reset checkpoint for the gridworld environment.
Returns
-------
has_reset_checkpoint : bool
Whether a reset checkpoint of the environment has
been created.
Examples
--------
>>> W = GridWorld()
>>> W.has_reset_checkpoint
False
>>> W.set_reset_checkpoint()
>>> W.has_reset_checkpoint
True
"""
return self._has_reset_checkpoint
[docs]
def step(self, action):
"""Make the agent move toward direction given by ``action``.
.. note::
- If one movement will cause the agent get out of the world,
the agent will be forced to stay in the same position (state) instead.
- If the agent reaches a state with an object, no matter whether the agent
gets a reward or punishment from the object, this trial will end and the
agent will be transported back to its initial state.
Parameters
----------
action : tuple of ints \
{(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)}
Direction of the agent movement.
Returns
-------
next_state : tuple of ints
Next state of the agent after movement.
reward : int or float
Reward that the agent gets at through this movement.
done : bool
Whether this trial ends.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((2, 3))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.init_agent()
>>> W.step((1, 0))
((1, 0, 0), 0.0, False)
"""
if action not in self._actions:
msg = "Illegal action {}, should be one of {}".format(action, self._actions)
raise ValueError(msg)
else:
dx, dy = action
done = False
reward = 0
current_state = self._agent.current_state
next_state = (current_state[0], current_state[1] + dx, current_state[2] + dy)
if not self._world.has_node(next_state):
if next_state in self._path_alias.keys():
next_state = self._path_alias[next_state]
else:
next_state = current_state
if nx.get_node_attributes(self._world, 'blocked')[next_state]:
next_state = current_state
altitude = nx.get_node_attributes(self._world, 'altitude')
reward += altitude[current_state] - altitude[next_state]
for obj in self._objects:
if obj.coord == next_state:
reward += obj.get_reward()
done = True
break
self._time += 1
if done:
self._agent.reset()
else:
self._agent.update(current_state=next_state)
return next_state, reward, done
[docs]
def set_reset_checkpoint(self, overwrite=False):
"""Set environment checkpoint for reset.
Parameters
----------
overwrite : bool (default: False)
Whether to overwrite existing checkpoint.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((3, 3))
>>> W.add_object((1, 2, 1), reward=1, prob=0.7)
>>> W.set_reset_checkpoint()
>>> W.has_reset_checkpoint
True
"""
if not self._has_reset_checkpoint or overwrite:
for key in self._reset_state.keys():
self._reset_state[key] = copy.deepcopy(getattr(self, '_' + key))
self._has_reset_checkpoint = True
else:
raise ng.NeuGymOverwriteError("Reset state already exists, "
"set 'overwrite=True' to overwrite")
[docs]
def reset(self):
"""Reset the environment to the checkpoint state.
Examples
--------
>>> W = GridWorld()
>>> W.add_area((3, 3))
>>> W.add_path((0, 0, 0), (1, 0, 0))
>>> W.add_object((1, 2, 1), reward=1, prob=0.7)
>>> W.set_reset_checkpoint()
>>> W.init_agent()
>>> W.step((1, 0))
((1, 0, 0), 0.0, False)
>>> W.reset()
"""
if not self._has_reset_checkpoint:
raise ng.NeuGymCheckpointError(
"Reset state not found, use 'set_reset_state()' "
"to set the reset checkpoint first")
for key, value in self._reset_state.items():
setattr(self, '_' + key, copy.deepcopy(value))
def __repr__(self):
msg = "GridWorld:\n"
msg += "".join(["=" for _ in range(10)])
msg += "\n"
msg += "time: {}\n".format(self.time)
msg += "areas: \n"
for i in range(self._num_area + 1):
alias = ""
for key, value in self._area_alias.items():
if value == i:
alias = key
break
msg += "\t[{}][{}] Area(shape={})\n".format(i, alias, self.get_area_shape(i))
if len(self._path_alias) == 0:
msg += "inter-area connections: None\n"
else:
msg += "inter-area connections:\n"
for u, v in self.world.edges():
if u[0] == v[0]:
continue
else:
for a in self.actions:
dx, dy = a
alias = tuple([u[0]] + [u[1] + dx] + [u[2] + dy])
try:
if self._path_alias[alias] == v:
msg += "\t{} + {} -> {}\n".format(u, a, v)
except KeyError:
continue
if len(self._objects) == 0:
msg += "objects: None\n"
else:
msg += "objects:\n"
for i, obj in enumerate(self._objects):
msg += "\t[{}] {}\n".format(i, str(obj))
msg += "actions: {}\n".format(self._actions)
msg += "agent: {}\n".format(str(self._agent))
msg += "has_reset_state: {}\n".format(self._has_reset_checkpoint)
msg += "".join(["=" for _ in range(10)])
return msg