Source code for syllabus.curricula.learning_progress

import math
import random
import warnings
from typing import Any, List, Union

import numpy as np
from scipy.stats import norm

from syllabus.core import Curriculum
from syllabus.task_space import DiscreteTaskSpace, MultiDiscreteTaskSpace


[docs]class LearningProgress(Curriculum): """ Provides an interface for tracking success rates of discrete tasks and sampling tasks based on their success rate using the method from https://arxiv.org/abs/2106.14876. TODO: Support task spaces aside from Discrete """ def __init__(self, eval_envs, evaluator, *args, ema_alpha=0.1, eval_interval=None, eval_interval_steps=None, **kwargs): super().__init__(*args, **kwargs) self.eval_envs = eval_envs self.evaluator = evaluator self.ema_alpha = ema_alpha self.eval_interval = eval_interval self.eval_interval_steps = eval_interval_steps assert self.eval_interval is None or self.eval_interval_steps is None, "Only one of eval_interval or eval_interval_steps can be set." self.completed_episodes = 0 self.completed_steps = 0 assert isinstance( self.task_space, (DiscreteTaskSpace, MultiDiscreteTaskSpace) ), f"LearningProgressCurriculum only supports Discrete and MultiDiscrete task spaces. Got {self.task_space.__class__.__name__}." self._p_fast = np.zeros(self.num_tasks) self._p_slow = np.zeros(self.num_tasks) self._evaluate_all_tasks() def _evaluate_all_tasks(self, eval_eps=1): task_progresses = np.zeros(self.task_space.num_tasks) for task_idx, task in enumerate(self.task_space.tasks): obss, _ = self.eval_envs.reset(options=task) ep_counter = 0 progress = 0.0 while ep_counter < eval_eps: actions, _, _ = self.evaluator.get_action(obss) obss, rewards, terminateds, truncateds, infos = self.eval_envs.step(actions) dones = tuple(a | b for a, b in zip(terminateds, truncateds)) for i, done in enumerate(dones): if done: if isinstance(infos, list): task_progress = infos[i]["final_info"]['task_completion'] elif isinstance(infos, dict): task_progress = infos["final_info"][i]['task_completion'] progress += task_progress ep_counter += 1 task_progresses[task_idx] = progress task_success_rates = np.divide(task_progresses, float(eval_eps)) # Update task scores self._p_fast = (task_progresses * self.ema_alpha) + (self._p_fast * (1.0 - self.ema_alpha)) self._p_slow = (self._p_fast * self.ema_alpha) + (self._p_slow * (1.0 - self.ema_alpha)) return task_success_rates
[docs] def update_task_progress(self, task: int, progress: Union[float, bool], env_id: int = None): """ Update the success rate for the given task using a fast and slow exponential moving average. """ if task is None or progress == 0.0: return super().update_task_progress(task, progress) self._p_fast[task] = (progress * self.ema_alpha) + (self._p_fast[task] * (1.0 - self.ema_alpha)) self._p_slow[task] = (self._p_fast[task] * self.ema_alpha) + (self._p_slow[task] * (1.0 - self.ema_alpha))
[docs] def update_on_episode(self, episode_return: float, length: int, task: Any, progress: Union[float, bool], env_id: int = None) -> None: self.completed_episodes += 1 self.completed_steps += length if self.eval_interval is not None and self.completed_episodes % self.eval_interval == 0: self._evaluate_all_tasks() if self.eval_interval_steps is not None and self.completed_steps > self.eval_interval_steps: self._evaluate_all_tasks() self.completed_steps = 0
def _learning_progress(self, reweight: bool = True) -> float: """ Compute the learning progress metric for the given task. """ slow = self._reweight(self._p_slow) if reweight else self._p_slow fast = self._reweight(self._p_fast) if reweight else self._p_fast return abs(fast - slow) def _reweight(self, p: np.ndarray, p_theta: float = 0.1) -> float: """ Reweight the given success rate using the reweighting function from the paper. """ numerator = p * (1.0 - p_theta) denominator = p + p_theta * (1.0 - 2.0 * p) return numerator / denominator def _sigmoid(self, x: np.ndarray): """ Sigmoid function for reweighting the learning progress.""" return 1 / (1 + np.exp(-x)) def _sample_distribution(self) -> List[float]: """ Return sampling distribution over the task space based on the learning progress.""" if self.num_tasks == 0: return [] task_dist = np.ones(self.num_tasks) / self.num_tasks task_lps = self._learning_progress() posidxs = [i for i, lp in enumerate(task_lps) if lp > 0] zeroout = len(posidxs) > 0 subprobs = task_lps[posidxs] if zeroout else task_lps std = np.std(subprobs) subprobs = (subprobs - np.mean(subprobs)) / (std if std else 1) # z-score subprobs = self._sigmoid(subprobs) # sigmoid subprobs = subprobs / np.sum(subprobs) # normalize if zeroout: # If some tasks have nonzero progress, zero out the rest task_dist = np.zeros(len(task_lps)) task_dist[posidxs] = subprobs else: # If all tasks have 0 progress, return uniform distribution task_dist = subprobs return task_dist
if __name__ == "__main__": def sample_binomial(p=0.5, n=200): success = 0.0 for _ in range(n): rand = random.random() if rand < p: success += 1.0 return success / n def generate_history(center=0, curve=1.0, n=100): center = center if center else n / 2.0 def sig(x, x_0=center, curve=curve): return 1.0 / (1.0 + math.e**(curve * (x_0 - x))) history = [] probs = [] success_prob = 0.0 for i in range(n): probs.append(success_prob) history.append(sample_binomial(p=success_prob)) success_prob = sig(i) return history, probs tasks = range(20) histories = {task: generate_history(center=random.randint(0, 100), curve=random.random()) for task in tasks} curriculum = LearningProgress(DiscreteTaskSpace(len(tasks))) for i in range(len(histories[0][0])): for task in tasks: curriculum.update_task_progress(task, histories[task][0][i]) if i > 10: distribution = curriculum._sample_distribution() print("[", end="") for j, prob in enumerate(distribution): print(f"{prob:.3f}", end="") if j < len(distribution) - 1: print(", ", end="") print("]") tasks = [0] histories = {task: generate_history(n=200, center=75, curve=0.1) for task in tasks} curriculum = LearningProgress(DiscreteTaskSpace(len(tasks))) lp_raw = [] lp_reweight = [] p_fast = [] p_slow = [] true_probs = [] estimates = [] for estimate, true_prob in zip(histories[0][0], histories[0][1]): curriculum.update_task_progress(tasks[0], estimate) lp_raw.append(curriculum._learning_progress(tasks[0], reweight=False)) lp_reweight.append(curriculum._learning_progress(tasks[0])) p_fast.append(curriculum._p_fast[0]) p_slow.append(curriculum._p_slow[0]) true_probs.append(true_prob) estimates.append(estimate) try: import matplotlib.pyplot as plt # TODO: Plot probabilities def plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw): x_axis = range(0, len(true_probs)) plt.plot(x_axis, true_probs, color="#222222", label="True Success Probability") plt.plot(x_axis, estimates, color="#888888", label="Estimated Success Probability") plt.plot(x_axis, p_slow, color="#ee3333", label="p_slow") plt.plot(x_axis, p_fast, color="#33ee33", label="p_fast") plt.plot(x_axis, lp_raw, color="#c4c25b", label="Learning Progress") plt.plot(x_axis, lp_reweight, color="#1544ee", label="Learning Progress Reweighted") plt.xlabel('Time step') plt.ylabel('Learning Progress') plt.legend() plt.show() plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw) # Reweight Plot x_axis = np.linspace(0, 1, num=100) y_axis = [] for x in x_axis: y_axis.append(curriculum._reweight(x)) plt.plot(x_axis, y_axis, color="blue", label="p_theta = 0.1") plt.xlabel('p') plt.ylabel('reweight') plt.legend() plt.show() # Z-score plot tasks = [i for i in range(50)] curriculum = LearningProgress(DiscreteTaskSpace(len(tasks))) histories = {task: generate_history(n=200, center=60, curve=0.09) for task in tasks} for i in range(len(histories[0][0])): for task in tasks: curriculum.update_task_progress(task, histories[task][0][i]) distribution = curriculum._sample_distribution() x_axis = np.linspace(-3, 3, num=len(distribution)) sigmoid_axis = curriculum._sigmoid(x_axis) plt.plot(x_axis, norm.pdf(x_axis, 0, 1), color="blue", label="Normal distribution") plt.plot(x_axis, sigmoid_axis, color="orange", label="Sampling weight") plt.xlabel('Z-scored distributed learning progress') plt.legend() plt.show() except ImportError: warnings.warn("Matplotlib not installed. Plotting will not work.")