Source code for syllabus.curricula.plr.task_sampler

# Code heavily based on the original Prioritized Level Replay implementation from https://github.com/facebookresearch/level-replay
# If you use this code, please cite the above codebase and original PLR paper: https://arxiv.org/abs/2010.03934
import gymnasium as gym
import numpy as np
import torch


[docs]class TaskSampler: """ Task sampler for Prioritized Level Replay (PLR) Args: tasks (list): List of tasks to sample from action_space (gym.spaces.Space): Action space of the environment num_actors (int): Number of actors/processes strategy (str): Strategy for sampling tasks. One of "value_l1", "gae", "policy_entropy", "least_confidence", "min_margin", "one_step_td_error". replay_schedule (str): Schedule for sampling replay levels. One of "fixed" or "proportionate". score_transform (str): Transform to apply to task scores. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". temperature (float): Temperature for score transform. Increasing temperature makes the sampling distribution more uniform. eps (float): Epsilon for eps-greedy score transform. rho (float): Proportion of seen tasks before replay sampling is allowed. nu (float): Probability of sampling a replay level if using a fixed replay_schedule. alpha (float): Linear interpolation weight for score updates. 0.0 means only use old scores, 1.0 means only use new scores. staleness_coef (float): Linear interpolation weight for task staleness vs. task score. 0.0 means only use task score, 1.0 means only use staleness. staleness_transform (str): Transform to apply to task staleness. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". staleness_temperature (float): Temperature for staleness transform. Increasing temperature makes the sampling distribution more uniform. """ def __init__( self, tasks: list, num_steps: int, action_space: gym.spaces.Space = None, num_actors: int = 1, strategy: str = "value_l1", replay_schedule: str = "proportionate", score_transform: str = "rank", temperature: float = 0.1, eps: float = 0.05, rho: float = 1.0, nu: float = 0.5, alpha: float = 1.0, staleness_coef: float = 0.1, staleness_transform: str = "power", staleness_temperature: float = 1.0, ): self.action_space = action_space self.tasks = tasks self.num_tasks = len(self.tasks) self.num_steps = num_steps self.strategy = strategy self.replay_schedule = replay_schedule self.score_transform = score_transform self.temperature = temperature self.eps = eps self.rho = rho self.nu = nu self.alpha = float(alpha) self.staleness_coef = staleness_coef self.staleness_transform = staleness_transform self.staleness_temperature = staleness_temperature self.unseen_task_weights = np.array([1.0] * self.num_tasks) self.task_scores = np.array([0.0] * self.num_tasks, dtype=float) self.partial_task_scores = np.zeros((num_actors, self.num_tasks), dtype=float) self.partial_task_steps = np.zeros((num_actors, self.num_tasks), dtype=np.int64) self.task_staleness = np.array([0.0] * self.num_tasks, dtype=float) self.next_task_index = 0 # Only used for sequential strategy # Logging metrics self._last_score = 0.0 if not self.requires_value_buffers and self.action_space is None: raise ValueError( 'Must provide action space to PLR if using "policy_entropy", "least_confidence", or "min_margin" strategies' )
[docs] def update_with_rollouts(self, rollouts, actor_id=None): if self.strategy == "random": return # Update with a RolloutStorage object if self.strategy == "policy_entropy": score_function = self._average_entropy elif self.strategy == "least_confidence": score_function = self._average_least_confidence elif self.strategy == "min_margin": score_function = self._average_min_margin elif self.strategy == "gae": score_function = self._average_gae elif self.strategy == "value_l1": score_function = self._average_value_l1 elif self.strategy == "one_step_td_error": score_function = self._one_step_td_error else: raise ValueError(f"Unsupported strategy, {self.strategy}") self._update_with_rollouts(rollouts, score_function, actor_index=actor_id)
[docs] def update_task_score(self, actor_index, task_idx, score, num_steps): score = self._partial_update_task_score(actor_index, task_idx, score, num_steps, done=True) self.unseen_task_weights[task_idx] = 0.0 # No longer unseen old_score = self.task_scores[task_idx] self.task_scores[task_idx] = (1.0 - self.alpha) * old_score + self.alpha * score
def _partial_update_task_score(self, actor_index, task_idx, score, num_steps, done=False): partial_score = self.partial_task_scores[actor_index][task_idx] partial_num_steps = self.partial_task_steps[actor_index][task_idx] running_num_steps = partial_num_steps + num_steps merged_score = partial_score + (score - partial_score) * num_steps / float(running_num_steps) if done: self.partial_task_scores[actor_index][task_idx] = 0.0 # zero partial score, partial num_steps self.partial_task_steps[actor_index][task_idx] = 0 else: self.partial_task_scores[actor_index][task_idx] = merged_score self.partial_task_steps[actor_index][task_idx] = running_num_steps return merged_score def _average_entropy(self, **kwargs): episode_logits = kwargs["episode_logits"] num_actions = self.action_space.n max_entropy = -(1.0 / num_actions) * np.log(1.0 / num_actions) * num_actions return (-torch.exp(episode_logits) * episode_logits).sum(-1).mean().item() / max_entropy def _average_least_confidence(self, **kwargs): episode_logits = kwargs["episode_logits"] return (1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])).mean().item() def _average_min_margin(self, **kwargs): episode_logits = kwargs["episode_logits"] top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0]) return 1 - (top2_confidence[:, 0] - top2_confidence[:, 1]).mean().item() def _average_gae(self, **kwargs): returns = kwargs["returns"] value_preds = kwargs["value_preds"] advantages = returns - value_preds return advantages.mean().item() def _average_value_l1(self, **kwargs): returns = kwargs["returns"] value_preds = kwargs["value_preds"] advantages = returns - value_preds return advantages.abs().mean().item() def _one_step_td_error(self, **kwargs): rewards = kwargs["rewards"] value_preds = kwargs["value_preds"] max_t = len(rewards) td_errors = (rewards[:-1] + value_preds[: max_t - 1] - value_preds[1:max_t]).abs() assert not torch.isnan( td_errors.abs().mean() ), f"Got invalid values for 'rewards' or 'value_preds'. Check that reward length: {len(rewards)}" return td_errors.abs().mean().item() @property def requires_value_buffers(self): return self.strategy in ["gae", "value_l1", "one_step_td_error"] def _update_with_scores(self, rollouts): tasks = rollouts.tasks scores = rollouts.scores done = ~(rollouts.masks > 0) num_actors = rollouts.tasks.shape[1] for actor_index in range(num_actors): done_steps = done[:, actor_index].nonzero()[:self.num_steps, 0] start_t = 0 for t in done_steps: if not start_t < self.num_steps: break if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle continue task_idx_t = tasks[start_t, actor_index].item() score = scores[start_t:t, actor_index].mean().item() num_steps = t - start_t self.update_task_score(actor_index, task_idx_t, score, num_steps) start_t = t.item() if start_t < self.num_steps: task_idx_t = tasks[start_t, actor_index].item() score = scores[start_t:, actor_index].mean().item() self._last_score = score num_steps = self.num_steps - start_t self._partial_update_task_score(actor_index, task_idx_t, score, num_steps) def _update_with_rollouts(self, rollouts, score_function, actor_index=None): tasks = rollouts.tasks if not self.requires_value_buffers: policy_logits = rollouts.action_log_dist done = ~(rollouts.masks > 0) num_actors = rollouts.tasks.shape[1] actors = [actor_index] if actor_index is not None else range(num_actors) for actor_index in actors: done_steps = done[:, actor_index].nonzero()[:self.num_steps, 0] start_t = 0 for t in done_steps: if not start_t < self.num_steps: break if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle continue # If there is only 1 step, we can't calculate the one-step td error if self.strategy == "one_step_td_error" and t - start_t <= 1: continue task_idx_t = tasks[start_t, actor_index].item() # Store kwargs for score function score_function_kwargs = {} if self.requires_value_buffers: score_function_kwargs["returns"] = rollouts.returns[start_t:t, actor_index] score_function_kwargs["rewards"] = rollouts.rewards[start_t:t, actor_index] score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:t, actor_index] else: episode_logits = policy_logits[start_t:t, actor_index] score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) score = score_function(**score_function_kwargs) num_steps = len(rollouts.tasks[start_t:t, actor_index]) self.update_task_score(actor_index, task_idx_t, score, num_steps) start_t = t.item() if start_t < self.num_steps: # If there is only 1 step, we can't calculate the one-step td error if self.strategy == "one_step_td_error" and start_t == self.num_steps - 1: continue task_idx_t = tasks[start_t, actor_index].item() # Store kwargs for score function score_function_kwargs = {} if self.requires_value_buffers: score_function_kwargs["returns"] = rollouts.returns[start_t:, actor_index] score_function_kwargs["rewards"] = rollouts.rewards[start_t:, actor_index] score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:, actor_index] else: episode_logits = policy_logits[start_t:, actor_index] score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) score = score_function(**score_function_kwargs) self._last_score = score num_steps = len(rollouts.tasks[start_t:, actor_index]) self._partial_update_task_score(actor_index, task_idx_t, score, num_steps)
[docs] def after_update(self, actor_indices=None): # Reset partial updates, since weights have changed, and thus logits are now stale actor_indices = range(self.partial_task_scores.shape[0]) if actor_indices is None else actor_indices for actor_index in actor_indices: for task_idx in range(self.partial_task_scores.shape[1]): if self.partial_task_scores[actor_index][task_idx] != 0: self.update_task_score(actor_index, task_idx, 0, 0) self.partial_task_scores.fill(0) self.partial_task_steps.fill(0)
def _update_staleness(self, selected_idx): if self.staleness_coef > 0: self.task_staleness = self.task_staleness + 1 self.task_staleness[selected_idx] = 0 def _sample_replay_level(self): sample_weights = self.sample_weights() if np.isclose(np.sum(sample_weights), 0): sample_weights = np.ones_like(sample_weights, dtype=float) / len(sample_weights) task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] self._update_staleness(task_idx) return task_idx def _sample_unseen_level(self): sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] self._update_staleness(task_idx) return task_idx
[docs] def sample(self, strategy=None): if not strategy: strategy = self.strategy if strategy == "random": return np.random.choice(range((self.num_tasks))) if strategy == "sequential": task_idx = self.next_task_index self.next_task_index = (self.next_task_index + 1) % self.num_tasks return task_idx num_unseen = (self.unseen_task_weights > 0).sum() proportion_seen = (self.num_tasks - num_unseen) / self.num_tasks if self.replay_schedule == "fixed": if proportion_seen >= self.rho: # Sample replay level with fixed prob = 1 - nu OR if all levels seen if np.random.rand() > self.nu or not proportion_seen < 1.0: return self._sample_replay_level() # Otherwise, sample a new level return self._sample_unseen_level() elif self.replay_schedule == "proportionate": if proportion_seen >= self.rho and np.random.rand() < proportion_seen: return self._sample_replay_level() else: return self._sample_unseen_level() else: raise NotImplementedError( f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.")
[docs] def sample_weights(self): weights = self._score_transform(self.score_transform, self.temperature, self.task_scores) weights = weights * (1 - self.unseen_task_weights) # zero out unseen levels z = np.sum(weights) if z > 0: weights /= z staleness_weights = 0 if self.staleness_coef > 0: staleness_weights = self._score_transform( self.staleness_transform, self.staleness_temperature, self.task_staleness, ) staleness_weights = staleness_weights * (1 - self.unseen_task_weights) z = np.sum(staleness_weights) if z > 0: staleness_weights /= z weights = (1 - self.staleness_coef) * weights + self.staleness_coef * staleness_weights return weights
def _score_transform(self, transform, temperature, scores): if transform == "constant": weights = np.ones_like(scores) if transform == "max": weights = np.zeros_like(scores) scores = scores[:] scores[self.unseen_task_weights > 0] = -float("inf") # only argmax over seen levels argmax = np.random.choice(np.flatnonzero(np.isclose(scores, scores.max()))) weights[argmax] = 1.0 elif transform == "eps_greedy": weights = np.zeros_like(scores) weights[scores.argmax()] = 1.0 - self.eps weights += self.eps / self.num_tasks elif transform == "rank": temp = np.flip(scores.argsort()) ranks = np.empty_like(temp) ranks[temp] = np.arange(len(temp)) + 1 weights = 1 / ranks ** (1.0 / temperature) elif transform == "power": eps = 0 if self.staleness_coef > 0 else 1e-3 weights = (np.array(scores) + eps) ** (1.0 / temperature) elif transform == "softmax": weights = np.exp(np.array(scores) / temperature) return weights
[docs] def metrics(self): return { "task_scores": self.task_scores, "unseen_task_weights": self.unseen_task_weights, "task_staleness": self.task_staleness, "proportion_seen": (self.num_tasks - (self.unseen_task_weights > 0).sum()) / self.num_tasks, "score": self._last_score, }