import warnings
from typing import Any, List, Union
import gymnasium as gym
import numpy as np
import torch
from syllabus.core import Curriculum
from syllabus.core.evaluator import Evaluator
from syllabus.task_space import DiscreteTaskSpace, MultiDiscreteTaskSpace
from syllabus.utils import UsageError
from .task_sampler import TaskSampler
[docs]class RolloutStorage(object):
def __init__(
self,
num_steps: int,
num_processes: int,
requires_value_buffers: bool,
observation_space: gym.Space, # TODO: Use np array when space is box or discrete
num_minibatches: int = 1,
buffer_size: int = 2,
action_space: gym.Space = None,
gamma: float = 0.999,
gae_lambda: float = 0.95,
lstm_size: int = None,
evaluator: Evaluator = None,
device: str = "cpu",
):
self.num_steps = num_steps
# Hack to prevent overflow from lagging updates.
self.buffer_steps = num_steps * buffer_size
self.num_processes = num_processes
self._requires_value_buffers = requires_value_buffers
self.num_minibatches = num_minibatches
self._gamma = gamma
self._gae_lambda = gae_lambda
self.evaluator = evaluator
self.device = device
if self.num_processes % self.num_minibatches != 0:
raise UsageError(
f"Number of processes {self.num_processes} must be divisible by the number of minibatches {self.num_minibatches}."
)
self.tasks = torch.zeros(self.buffer_steps, num_processes, 1, dtype=torch.int)
self.masks = torch.ones(self.buffer_steps + 1, num_processes, 1)
self.lstm_states = None
if lstm_size is not None:
self.lstm_states = (
torch.zeros(self.buffer_steps + 1, num_processes, lstm_size),
torch.zeros(self.buffer_steps + 1, num_processes, lstm_size),
)
self.obs = {env_idx: [None for _ in range(self.buffer_steps)] for env_idx in range(self.num_processes)}
self.env_steps = [0] * num_processes
self.value_steps = torch.zeros(num_processes, dtype=torch.int)
if requires_value_buffers:
self.returns = torch.zeros(self.buffer_steps + 1, num_processes, 1)
self.rewards = torch.zeros(self.buffer_steps, num_processes, 1)
self.value_preds = torch.zeros(self.buffer_steps + 1, num_processes, 1)
else:
if action_space is None:
raise ValueError(
"Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'"
)
self.action_log_dist = torch.zeros(self.buffer_steps, num_processes, action_space.n)
self.num_steps = num_steps
self.env_to_idx = {}
self.max_idx = 0
self.to(self.device)
@property
def using_lstm(self):
return self.lstm_states is not None
[docs] def to(self, device):
self.device = device
self.masks = self.masks.to(device)
self.tasks = self.tasks.to(device)
if self.using_lstm:
self.lstm_states = (
self.lstm_states[0].to(device),
self.lstm_states[1].to(device),
)
if self._requires_value_buffers:
self.rewards = self.rewards.to(device)
self.value_preds = self.value_preds.to(device)
self.returns = self.returns.to(device)
else:
self.action_log_dist = self.action_log_dist.to(device)
[docs] def get_index(self, env_index):
""" Map the environment ids to indices in the buffer. """
if env_index not in self.env_to_idx:
assert self.max_idx < self.num_processes, f"Number of environments {self.max_idx} exceeds num_processes {self.num_processes}."
self.env_to_idx[env_index] = self.max_idx
self.max_idx += 1
return self.env_to_idx[env_index]
[docs] def insert_at_index(self, env_index, mask, obs=None, reward=None, task=None, steps=1):
assert steps < self.buffer_steps, f"Number of steps {steps} exceeds buffer size {self.buffer_steps}. Increase PLR's num_steps or decrease environment wrapper's batch size."
env_index = self.get_index(env_index)
step = self.env_steps[env_index]
end_step = step + steps
assert end_step < self.buffer_steps, f"Number of insert of {steps} steps at {step} exceeds buffer size {self.buffer_steps}. Increase PLR's num_steps or decrease environment wrapper's batch size."
self.masks[step + 1:end_step + 1, env_index].copy_(torch.as_tensor(mask[:, None]))
if obs is not None:
self.obs[env_index][step: end_step] = obs
if reward is not None:
self.rewards[step:end_step, env_index].copy_(torch.as_tensor(reward[:, None]))
# if action_log_dist is not None:
# self.action_log_dist[step:end_step, env_index].copy_(torch.as_tensor(action_log_dist[:, None]))
if task is not None:
try:
int(task[0])
except TypeError:
assert isinstance(
task, int), f"Provided task must be an integer, got {task[0]} with type {type(task[0])} instead."
self.tasks[step:end_step, env_index].copy_(torch.as_tensor(np.array(task)[:, None]))
self.env_steps[env_index] += steps
# Get value predictions if batch is ready
value_steps = self.value_steps.numpy()
while all((self.env_steps - value_steps) > 0):
self.get_value_predictions()
# Check if the buffer is ready to be updated. Wait until we have enough value predictions.
if self.value_steps[env_index] >= self.num_steps + 1:
if self._requires_value_buffers:
self.compute_returns(self._gamma, self._gae_lambda, env_index)
return env_index
return None
[docs] def get_value_predictions(self):
value_steps = self.value_steps.numpy()
process_chunks = np.split(np.arange(self.num_processes), self.num_minibatches)
for processes in process_chunks:
obs = [self.obs[env_idx][value_steps[env_idx]] for env_idx in processes]
lstm_states = dones = None
if self.using_lstm:
lstm_states = (
torch.unsqueeze(self.lstm_states[0][value_steps[processes], processes], 0),
torch.unsqueeze(self.lstm_states[1][value_steps[processes], processes], 0),
)
dones = torch.squeeze(1 - self.masks[value_steps[processes], processes], -1).int()
# Get value predictions and check for common usage errors
try:
values, lstm_states, _ = self.evaluator.get_value(obs, lstm_states, dones)
except RuntimeError as e:
raise UsageError(
"Encountered an error getting values for PLR. Check that lstm_size is set correctly and that there are no errors in the evaluator's get_value implementation."
) from e
self.value_preds[value_steps[processes], processes] = values.to(self.device)
self.value_steps[processes] += 1 # Increase index to store lstm_states and next iteration
value_steps = self.value_steps.numpy()
if self.using_lstm:
assert lstm_states is not None, "Evaluator must return lstm_state in extras for PLR."
# Place new lstm_states in next step
self.lstm_states[0][value_steps[processes], processes] = lstm_states[0].to(self.lstm_states[0].device)
self.lstm_states[1][value_steps[processes], processes] = lstm_states[1].to(self.lstm_states[1].device)
[docs] def after_update(self, env_index):
# After consuming the first num_steps of data, remove them and shift the remaining data in the buffer
self.tasks[:, env_index] = self.tasks[:, env_index].roll(-self.num_steps, 0)
self.masks[:, env_index] = self.masks[:, env_index].roll(-self.num_steps, 0)
self.obs[env_index] = self.obs[env_index][self.num_steps:]
if self.using_lstm:
self.lstm_states[0][:, env_index] = self.lstm_states[0][:, env_index].roll(-self.num_steps, 0)
self.lstm_states[1][:, env_index] = self.lstm_states[1][:, env_index].roll(-self.num_steps, 0)
if self._requires_value_buffers:
self.returns[:, env_index] = self.returns[:, env_index].roll(-self.num_steps, 0)
self.rewards[:, env_index] = self.rewards[:, env_index].roll(-self.num_steps, 0)
self.value_preds[:, env_index] = self.value_preds[:, env_index].roll(-(self.num_steps), 0)
else:
self.action_log_dist[:, env_index] = self.action_log_dist[:, env_index].roll(-self.num_steps, 0)
self.env_steps[env_index] -= self.num_steps
self.value_steps[env_index] -= self.num_steps
[docs] def compute_returns(self, gamma, gae_lambda, env_index):
assert self._requires_value_buffers, "Selected strategy does not use compute_rewards."
gae = 0
for step in reversed(range(self.num_steps)):
delta = (
self.rewards[step, env_index]
+ gamma * self.value_preds[step + 1, env_index] * self.masks[step + 1, env_index]
- self.value_preds[step, env_index]
)
gae = delta + gamma * gae_lambda * self.masks[step + 1, env_index] * gae
self.returns[step, env_index] = gae + self.value_preds[step, env_index]
[docs]class PrioritizedLevelReplay(Curriculum):
""" Prioritized Level Replay (PLR) Curriculum.
Args:
task_space (TaskSpace): The task space to use for the curriculum.
*curriculum_args: Positional arguments to pass to the curriculum.
task_sampler_kwargs_dict (dict): Keyword arguments to pass to the task sampler. See TaskSampler for details.
action_space (gym.Space): The action space to use for the curriculum. Required for some strategies.
device (str): The device to use to store curriculum data, either "cpu" or "cuda".
num_steps (int): The number of steps to store in the rollouts.
num_processes (int): The number of parallel environments.
gamma (float): The discount factor used to compute returns
gae_lambda (float): The GAE lambda value.
suppress_usage_warnings (bool): Whether to suppress warnings about improper usage.
**curriculum_kwargs: Keyword arguments to pass to the curriculum.
"""
def __init__(
self,
task_space: Union[DiscreteTaskSpace, MultiDiscreteTaskSpace],
observation_space: gym.Space,
*curriculum_args,
task_sampler_kwargs_dict: dict = None,
action_space: gym.Space = None,
lstm_size: int = None,
device: str = "cpu",
num_steps: int = 256,
num_processes: int = 64,
num_minibatches: int = 1,
buffer_size: int = 4,
gamma: float = 0.999,
gae_lambda: float = 0.95,
suppress_usage_warnings=False,
evaluator: Evaluator = None,
**curriculum_kwargs,
):
# Preprocess curriculum intialization args
if task_sampler_kwargs_dict is None:
task_sampler_kwargs_dict = {}
self._strategy = task_sampler_kwargs_dict.get("strategy", None)
if not isinstance(task_space, (DiscreteTaskSpace, MultiDiscreteTaskSpace)):
raise ValueError(
f"Task space must be discrete or multi-discrete, got {task_space}."
)
if "num_actors" in task_sampler_kwargs_dict and task_sampler_kwargs_dict['num_actors'] != num_processes:
warnings.warn(
f"Overwriting 'num_actors' {task_sampler_kwargs_dict['num_actors']} in task sampler kwargs with PLR num_processes {num_processes}.", stacklevel=2)
task_sampler_kwargs_dict["num_actors"] = num_processes
super().__init__(task_space, *curriculum_args, **curriculum_kwargs)
# Number of steps stored in rollouts and used to update task sampler
self._num_steps = num_steps
self._num_processes = num_processes # Number of parallel environments
self._supress_usage_warnings = suppress_usage_warnings
self.evaluator = evaluator
self._task2index = {task: i for i, task in enumerate(self.tasks)}
self._task_sampler = TaskSampler(self.tasks, self._num_steps,
action_space=action_space, **task_sampler_kwargs_dict)
self._rollouts = RolloutStorage(
self._num_steps,
self._num_processes,
self._task_sampler.requires_value_buffers,
observation_space,
num_minibatches=num_minibatches,
buffer_size=buffer_size,
action_space=action_space,
gamma=gamma,
gae_lambda=gae_lambda,
lstm_size=lstm_size,
evaluator=evaluator,
device=device,
)
self._rollouts.to(device)
[docs] def requires_step_updates(self) -> bool:
return True
def _sample_distribution(self) -> List[float]:
"""
Returns a sample distribution over the task space.
"""
return self._task_sampler.sample_weights()
[docs] def sample(self, k: int = 1) -> Union[List, Any]:
if self._should_use_startup_sampling():
return self._startup_sample()
else:
return [self._task_sampler.sample() for _ in range(k)]
[docs] def update_on_step(self, task, obs, rew, term, trunc, info, progress, env_id: int = None) -> None:
"""
Update the curriculum with the current step results from the environment.
"""
assert env_id is not None, "env_id must be provided for PLR updates."
# Update rollouts
update_id = self._rollouts.insert_at_index(
env_id,
mask=np.array([not (term or trunc)]),
reward=np.array([rew]),
obs=np.array([obs]),
)
# Update task sampler
if update_id is not None:
self._update_sampler(update_id)
[docs] def update_on_step_batch(self, step_results, env_id=None) -> None:
"""
Update the curriculum with a batch of step results from the environment.
"""
assert env_id is not None, "env_id must be provided for PLR updates."
tasks, obs, rews, terms, truncs, _, _ = step_results
update_id = self._rollouts.insert_at_index(
env_id,
mask=np.logical_not(np.logical_or(terms, truncs)),
reward=rews,
obs=obs,
steps=len(rews),
task=tasks,
)
# Update task sampler
if update_id is not None:
self._update_sampler(update_id)
def _update_sampler(self, env_id):
""" Update the task sampler with the current rollouts. """
self._task_sampler.update_with_rollouts(self._rollouts, env_id)
self._rollouts.after_update(env_id)
self._task_sampler.after_update()
[docs] def log_metrics(self, writer, logs, step=None, log_n_tasks=1):
"""
Log the task distribution to the provided tensorboard writer.
"""
logs = [] if logs is None else logs
metrics = self._task_sampler.metrics()
logs.append(("curriculum/proportion_seen", metrics["proportion_seen"]))
logs.append(("curriculum/score", metrics["score"]))
return super().log_metrics(writer, logs, step=step, log_n_tasks=log_n_tasks)