Source code for neugym.utils.function

import pickle
import networkx as nx
import numpy as np


__all__ = [
    "save_env",
    "load_env",
    "show_area_connection",
    "show_area"
]


[docs] def save_env(env, filename, protocol=pickle.HIGHEST_PROTOCOL): """Save environment in Python pickle format. Parameters ---------- env : environment object NeuGym environment object. filename : str Filename to write. Filenames ending in .gz or .bz2 will be compressed. protocol : integer Pickling protocol to use. Default value: ``pickle.HIGHEST_PROTOCOL``. Examples -------- >>> W = GridWorld() >>> ng.save_env(W, "test.pkl") References ---------- .. [#] https://docs.python.org/3/library/pickle.html """ with open(filename, 'wb') as f: pickle.dump(env, f, protocol=protocol)
[docs] def load_env(filename): """Load environment in Python pickle format. Parameters ---------- filename : str Filename to read. Filenames ending in .gz or .bz2 will be uncompressed. Returns ------- W : environment object NeuGym environment object. Examples -------- >>> W = GridWorld() >>> ng.save_env(W, "test.pkl") >>> W = ng.load_env("test.pkl") References ---------- .. [*] https://docs.python.org/3/library/pickle.html """ with open(filename, 'rb') as f: return pickle.load(f)
[docs] def show_area_connection(env, layout='spring'): """Show environment area connections. Parameters ---------- env : environment object NeuGym environment object. layout : str {"circular", "spring", "shell", "spectral"} (default: "circular") Layout with which to show the area connections. Examples -------- >>> W = GridWorld() >>> W.add_area((1, 1)) >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.add_area((1, 1)) >>> W.add_path((1, 0, 0), (2, 0, 0), register_action=(0, -1)) >>> W.add_area((1, 1)) >>> W.add_path((2, 0, 0), (3, 0, 0), register_action=(-1, 0)) >>> W.add_path((3, 0, 0), (0, 0, 0)) >>> ng.show_area_connection(W) """ import matplotlib.pyplot as plt fig, ax = plt.subplots(1, 1) g = nx.Graph() labels = {} for area_idx in range(env.num_area + 1): try: alias = env.get_area_name(area_idx) except RuntimeError: alias = None g.add_node(area_idx) label = '{}\n({})'.format(area_idx, alias) if alias is not None else str(area_idx) labels[area_idx] = label for start, end in env.world.edges(): if start[0] != end[0]: g.add_edge(start[0], end[0]) if layout == 'circular': pos = nx.circular_layout(g) elif layout == 'spring': pos = nx.spring_layout(g) elif layout == 'shell': pos = nx.shell_layout(g) elif layout == 'spectral': pos = nx.spectral_layout(g) else: msg = "Invalid layout '{}', should be one of " \ "['circular', 'spring', 'shell', 'spectral']".format(layout) raise ValueError(msg) nx.draw_networkx(g, pos=pos, labels=labels, ax=ax) ax.axis('off') plt.tight_layout() plt.show()
[docs] def show_area(env, area, show_altitude=False, figsize=None): """Show details for one area. Visualize altitude, objects, and blocks within one area. Grid color indicates state altitude. Blocked states will be marked with black cross. Objects are shown in red dots. Parameters ---------- env : environment object NeuGym environment object. area : int or str Index or name of the area to show. show_altitude : bool (default: False) Whether to show state altitude value. figsize : tuple of ints (optional, default=None) Size of the figure. Examples ------- >>> W = GridWorld() >>> W.add_area((3, 5), name='slope') >>> W.add_path((0, 0, 0), (1, 0, 0)) >>> W.set_altitude(1, np.random.randn(3, 5)) >>> W.block((1, 2, 4)) >>> W.add_object((1, 2, 4), 0.5, 1) >>> W.add_object((1, 0, 3), 0.5, 1) >>> ng.show_area(W, 1, show_altitude=True) """ import matplotlib.pyplot as plt if type(area) == str: area_idx = env.get_area_index(area) elif type(area) == int: area_idx = area else: msg = "int for area index or str for area name " \ "expected, got '{}'".format(type(area)) raise TypeError(msg) if area_idx > env.num_area or area_idx < 0: msg = "Area {} not found".format(area_idx) raise ValueError(msg) vmin = min(nx.get_node_attributes(env.world, "altitude").values()) vmax = max(nx.get_node_attributes(env.world, "altitude").values()) shape = env.get_area_shape(area_idx) fig, ax = plt.subplots(1, 1, figsize=figsize) title = "Area[{}]".format(area_idx) try: alias = env.get_area_name(area_idx) except RuntimeError: pass else: title += " ({})".format(alias) mat = env.get_area_altitude(area_idx) ax.matshow(mat, cmap='Blues', vmin=vmin, vmax=vmax) ax.set_title(title) for x in range(shape[0]): for y in range(shape[1]): coord = (area_idx, x, y) blocked = nx.get_node_attributes(env.world, 'blocked')[coord] if blocked: ax.scatter(y, x, s=500, color='k', marker='X') else: altitude = mat[x, y] if show_altitude: ax.annotate("{}\n{}".format( coord, np.around(altitude, decimals=2)), (y, x), horizontalalignment='center', verticalalignment='bottom') else: ax.annotate("{}".format(coord), (y, x), horizontalalignment='center', verticalalignment='bottom') for action in env.actions: dx, dy = action alias = (area_idx, x + dx, y + dy) try: next_state = env._path_alias[alias] except KeyError: continue ax.plot((y, alias[2]), (x, alias[1]), color='r', linewidth=3, linestyle=':') ax.annotate("{}".format( next_state), (alias[2], alias[1]), horizontalalignment='center', verticalalignment='bottom') for obj in env._objects: a, x, y = obj.coord if a != area_idx: continue else: ax.scatter(y, x, s=500, color='r', alpha=0.5) ax.axis('off') plt.tight_layout() plt.show()