import time
from collections import deque
import numpy as np
[docs]class VecEnv:
"""
An abstract asynchronous, vectorized environment.
Used to batch data from multiple copies of an environment, so that
each observation becomes an batch of observations, and expected action is a batch of actions to
be applied per-environment.
"""
closed = False
viewer = None
metadata = {
'render.modes': ['human', 'rgb_array']
}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
[docs] def reset(self):
"""
Reset all the environments and return an array of
observations, or a dict of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
"""
pass
[docs] def step_async(self, actions):
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
pass
[docs] def step_wait(self):
"""
Wait for the step taken with step_async().
Returns (obs, rews, dones, infos):
- obs: an array of observations, or a dict of
arrays of observations.
- rews: an array of rewards
- dones: an array of "episode done" booleans
- infos: a sequence of info objects
"""
pass
[docs] def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
[docs] def step(self, actions):
"""
Step the environments synchronously.
This is available for backwards compatibility.
"""
self.step_async(actions)
return self.step_wait()
[docs] def step_env(self, actions, reset_random=False):
if reset_random:
self.step_env_reset_random_async(actions)
else:
self.step_env_async(actions)
return self.step_wait()
[docs] def render(self, mode='human'):
raise NotImplementedError
[docs] def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
[docs] def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
[docs]class VecEnvWrapper(VecEnv):
"""
An environment wrapper that applies to an entire batch
of environments at once.
"""
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
VecEnv.__init__(self, num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
[docs] def step_async(self, actions):
self.venv.step_async(actions)
[docs] def step_wait(self):
pass
[docs] def close(self):
return self.venv.close()
[docs] def render(self, mode='human'):
return self.venv.render(mode=mode)
[docs] def get_images(self):
return self.venv.get_images()
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.venv, name)
[docs]class VecEnvObservationWrapper(VecEnvWrapper):
[docs] def process(self, obs):
pass
[docs] def reset(self):
outputs = self.venv.reset()
if len(outputs) == 2:
obs, infos = outputs
else:
obs, infos = outputs, {}
return self.process(obs), infos
[docs] def step_wait(self):
env_outputs = self.venv.step_wait()
if len(env_outputs) == 4:
obs, rews, terms, infos = env_outputs
truncs = np.zeros_like(terms)
else:
obs, rews, terms, truncs, infos = env_outputs
return self.process(obs), rews, terms, truncs, infos
[docs]class VecNormalize(VecEnvWrapper):
"""
A vectorized wrapper that normalizes the observations
and returns from an environment.
"""
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False):
VecEnvWrapper.__init__(self, venv)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.ret = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
[docs] def step_wait(self):
obs, rews, terms, truncs, infos = self.venv.step_wait()
news = np.logical_or(terms, truncs)
self.ret = self.ret * self.gamma + rews
obs = self._obfilt(obs)
if self.ret_rms:
self.ret_rms.update(self.ret)
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
self.ret[news] = 0.
return obs, rews, terms, truncs, infos
def _obfilt(self, obs):
if self.ob_rms:
self.ob_rms.update(obs)
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs
[docs] def reset(self, seed=None, options=None):
self.ret = np.zeros(self.num_envs)
if seed is not None:
obs, infos = self.venv.reset(seed=seed, options=options)
else:
obs, infos = self.venv.reset()
return self._obfilt(obs), infos
[docs]class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()
self.results_writer = None
self.info_keywords = info_keywords
self.keep_buf = keep_buf
if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf)
self.eplen_buf = deque([], maxlen=keep_buf)
[docs] def reset(self, seed=None, options=None):
if seed is not None:
obs, infos = self.venv.reset(seed=seed, options=options)
else:
obs, infos = self.venv.reset()
self.eprets = np.zeros(self.num_envs, 'f')
self.eplens = np.zeros(self.num_envs, 'i')
return obs, infos
[docs] def step_wait(self):
obs, rews, terms, truncs, infos = self.venv.step_wait()
dones = np.logical_or(terms, truncs)
self.eprets += rews
self.eplens += 1
# Convert dict of lists to list of dicts
if isinstance(infos, dict):
infos = [dict(zip(infos, t)) for t in zip(*infos.values())]
newinfos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
ret = self.eprets[i]
eplen = self.eplens[i]
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
epinfo[k] = info[k]
info['episode'] = epinfo
if self.keep_buf:
self.epret_buf.append(ret)
self.eplen_buf.append(eplen)
self.epcount += 1
self.eprets[i] = 0
self.eplens[i] = 0
if self.results_writer:
self.results_writer.write_row(epinfo)
newinfos[i] = info
return obs, rews, terms, truncs, newinfos
[docs]class RunningMeanStd():
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
[docs] def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
[docs] def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
[docs]def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count