import warnings
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Union
import numpy as np
from syllabus.task_space import TaskSpace
from .stat_recorder import StatRecorder
Agent = TypeVar("Agent")
[docs]class Curriculum:
"""Base class and API for defining curricula to interface with Gym environments.
"""
def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_names: Callable = None, record_stats: bool = False) -> None:
"""Initialize the base Curriculum
:param task_space: the environment's task space from which new tasks are sampled
:param random_start_tasks: Number of uniform random tasks to sample before using the algorithm's sample method, defaults to 0
:param task_names: Names of the tasks in the task space, defaults to None
:param record_stats: Whether to record statistics for each task, defaults to False
"""
assert isinstance(
task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead."
self.task_space = task_space
self.random_start_tasks = random_start_tasks
self.completed_tasks = 0
self.task_names = task_names if task_names is not None else lambda task, idx: idx
self.stat_recorder = StatRecorder(self.task_space, task_names=task_names) if record_stats else None
if self.num_tasks == 0:
warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.", stacklevel=2)
@property
def requires_step_updates(self) -> bool:
"""Returns whether the curriculum requires step updates from the environment.
:return: True if the curriculum requires step updates, False otherwise
"""
return False
@property
def num_tasks(self) -> int:
"""Counts the number of tasks in the task space.
:return: Returns the number of tasks in the task space if it is countable, TODO: -1 otherwise
"""
return self.task_space.num_tasks
@property
def tasks(self) -> List[tuple]:
"""List all of the tasks in the task space.
:return: List of tasks if task space is enumerable, TODO: empty list otherwise?
"""
return self.task_space.tasks
[docs] def update_task_progress(self, task: Any, progress: Union[float, bool], env_id: int = None) -> None:
"""Update the curriculum with a task and its progress. This is used for binary tasks that can be completed mid-episode.
:param task: Task for which progress is being updated.
:param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task.
:param env_id: Environment identifier
"""
self.completed_tasks += 1
[docs] def update_on_step(self, task: Any, obs: Any, rew: float, term: bool, trunc: bool, info: dict, progress: Union[float, bool], env_id: int = None) -> None:
""" Update the curriculum with the current step results from the environment.
:param obs: Observation from the environment
:param rew: Reward from the environment
:param term: True if the episode ended on this step, False otherwise
:param trunc: True if the episode was truncated on this step, False otherwise
:param info: Extra information from the environment
:param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task.
:param env_id: Environment identifier
:raises NotImplementedError:
"""
raise NotImplementedError(
"This curriculum does not require step updates. Set update_on_step for the environment sync wrapper to False to improve performance and prevent this error.")
[docs] def update_on_step_batch(self, step_results: Tuple[List[Any], List[Any], List[int], List[bool], List[bool], List[Dict], List[int]], env_id: int = None) -> None:
"""Update the curriculum with a batch of step results from the environment.
This method can be overridden to provide a more efficient implementation. It is used
as a convenience function and to optimize the multiprocessing message passing throughput.
:param step_results: List of step results
:param env_id: Environment identifier
"""
tasks, obs, rews, terms, truncs, infos, progresses = tuple(step_results)
for t, o, r, te, tr, i, p in zip(tasks, obs, rews, terms, truncs, infos, progresses):
self.update_on_step(t, o, r, te, tr, i, p, env_id=env_id)
[docs] def update_on_episode(self, episode_return: float, length: int, task: Any, progress: Union[float, bool], env_id: int = None) -> None:
"""Update the curriculum with episode results from the environment.
:param episode_return: Episodic return
:param length: Length of the episode
:param task: Task for which the episode was completed
:param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task.
:param env_id: Environment identifier
"""
if self.stat_recorder is not None:
self.stat_recorder.record(episode_return, length, task, env_id)
[docs] def get_agent(self, agent_id: int) -> Agent:
""" Load an agent from the buffer of saved agents.
:param agent_id: Identifier of the agent to load
:return: Loaded agent
"""
raise NotImplementedError("This curriculum does not track agents.")
[docs] def add_agent(self, agent: Agent):
""" Add an agent to the curriculum.
:param agent: Agent to add to the curriculum
:return agent_id: Identifier of the added agent
"""
raise NotImplementedError("This curriculum does not track agents.")
def _sample_distribution(self) -> List[float]:
"""Returns a sample distribution over the task space.
Any curriculum that maintains a true probability distribution should implement this method to retrieve it.
"""
raise NotImplementedError
def _should_use_startup_sampling(self) -> bool:
return self.random_start_tasks > 0 and self.completed_tasks < self.random_start_tasks
def _startup_sample(self) -> List:
return self.task_space.sample()
[docs] def sample(self, k: int = 1) -> Union[List, Any]:
"""Sample k tasks from the curriculum.
:param k: Number of tasks to sample, defaults to 1
:return: Either returns a single task if k=1, or a list of k tasks
"""
if self._should_use_startup_sampling():
return self._startup_sample()
# Use list of indices because np.choice does not play nice with tuple tasks
task_dist = self._sample_distribution()
task_idx = np.random.choice(list(range(self.num_tasks)), size=k, p=task_dist)
return task_idx
[docs] def normalize(self, reward: float, task: Any) -> float:
"""
Normalize reward by task.
:param reward: Reward to normalize
:param task: Task for which the reward was received
:return: Normalized reward
"""
assert self.stat_recorder is not None, "Curriculum must be initialized with record_stats=True to use normalize()"
return self.stat_recorder.normalize(reward, task)
[docs] def log_metrics(self, writer, logs: List[Dict], step: int = None, log_n_tasks: int = 1):
"""Log the task distribution to the provided writer.
:param writer: Tensorboard summary writer or wandb object
:param logs: Cumulative list of logs to write
:param step: Global step number
:param log_n_tasks: Maximum number of tasks to log, defaults to 1. Use -1 to log all tasks.
:return: Updated logs list
"""
logs = [] if logs is None else logs
if self.stat_recorder is not None:
logs += self.stat_recorder.get_metrics(log_n_tasks=log_n_tasks)
try:
import wandb
use_wandb = writer == wandb
except ImportError:
use_wandb = False
try:
task_dist = self._sample_distribution()
if len(self.tasks) > log_n_tasks and log_n_tasks != -1:
warnings.warn(f"Too many tasks to log {len(self.tasks)}. Only logging stats for 1 task.", stacklevel=2)
task_dist = task_dist[:log_n_tasks]
# Add basic logs
for idx, prob in enumerate(task_dist):
name = self.task_names(self.tasks[idx], idx)
logs.append((f"curriculum/{name}_prob", prob))
# Write logs
for name, prob in logs:
if use_wandb:
writer.log({name: prob, "global_step": step})
else:
writer.add_scalar(name, prob, step)
except Exception as e:
# No need to crash over logging :)
warnings.warn(f"Failed to log curriculum stats to wandb. Ignoring error {e}", stacklevel=2)
return logs