Source code for syllabus.curricula.sequential

import re
import warnings
from collections import deque
from typing import Any, Callable, List, Union

from syllabus.core import Curriculum
from syllabus.curricula import DomainRandomization, Constant
from syllabus.task_space import DiscreteTaskSpace, TaskSpace


[docs]class SequentialCurriculum(Curriculum): """ Curriculum that iterates through a list of curricula based on stopping conditions. """ def __init__(self, curriculum_list: List[Curriculum], stopping_conditions: List[Any], *curriculum_args, return_buffer_size: int = 1000, **curriculum_kwargs): super().__init__(*curriculum_args, **curriculum_kwargs) assert len(curriculum_list) > 0, "Must provide at least one curriculum" assert len(stopping_conditions) == len(curriculum_list) - \ 1, f"Stopping conditions must be one less than the number of curricula. Final curriculum is used for the remainder of training. Expected {len(curriculum_list) - 1}, got {len(stopping_conditions)}." if len(curriculum_list) == 1: warnings.warn( "Your sequential curriculum only containes one element. Consider using that element directly instead.", stacklevel=2) self.curriculum_list = self._parse_curriculum_list(curriculum_list) self.stopping_conditions = self._parse_stopping_conditions(stopping_conditions) self._curriculum_index = 0 # Stopping metrics self.n_steps = 0 self.total_steps = 0 self.n_episodes = 0 self.total_episodes = 0 self.n_tasks = 0 self.total_tasks = 0 self.episode_returns = deque(maxlen=return_buffer_size) def _parse_curriculum_list(self, curriculum_list: List[Curriculum]) -> List[Curriculum]: """ Parse the curriculum list to ensure that all items are curricula. Adds Curriculum objects directly. Wraps task space items in NoopCurriculum objects. """ parsed_list = [] for item in curriculum_list: if isinstance(item, Curriculum): parsed_list.append(item) elif isinstance(item, TaskSpace): parsed_list.append(DomainRandomization(item, task_names=self.task_names)) elif isinstance(item, list): task_space = DiscreteTaskSpace(len(item), item) parsed_list.append(DomainRandomization(task_space, task_names=self.task_names)) elif self.task_space.contains(item): parsed_list.append(Constant(item, self.task_space, task_names=self.task_names)) else: raise ValueError(f"Invalid curriculum item: {item}") return parsed_list def _parse_stopping_conditions(self, stopping_conditions: List[Any]) -> List[Any]: """ Parse the stopping conditions to ensure that all items are integers. """ parsed_list = [] for item in stopping_conditions: if isinstance(item, Callable): parsed_list.append(item) elif isinstance(item, str): parsed_list.append(self._parse_condition_string(item)) else: raise ValueError(f"Invalid stopping condition: {item}") return parsed_list def _parse_condition_string(self, condition: str) -> Callable: """ Parse a string condition to a callable function. """ # Parse composite conditions if '|' in condition: conditions = re.split(re.escape('|'), condition) return lambda: any(self._parse_condition_string(cond)() for cond in conditions) elif '&' in condition: conditions = re.split(re.escape('&'), condition) return lambda: all(self._parse_condition_string(cond)() for cond in conditions) clauses = re.split('(<=|>=|=|<|>)', condition) try: metric, comparator, value = clauses if metric == "steps": metric_fn = self._get_steps elif metric == "total_steps": metric_fn = self._get_total_steps elif metric == "episodes": metric_fn = self._get_episodes elif metric == "total_episodes": metric_fn = self._get_total_episodes elif metric == "tasks": metric_fn = self._get_tasks elif metric == "total_tasks": metric_fn = self._get_total_tasks elif metric == "episode_return": metric_fn = self._get_episode_return else: raise ValueError(f"Invalid metric name: {metric}") if comparator == '<': return lambda: metric_fn() < float(value) elif comparator == '>': return lambda: metric_fn() > float(value) elif comparator == '<=': return lambda: metric_fn() <= float(value) elif comparator == '>=': return lambda: metric_fn() >= float(value) elif comparator == '=': return lambda: metric_fn() == float(value) else: raise ValueError(f"Invalid comparator: {comparator}") except ValueError as e: raise ValueError(f"Invalid condition string: {condition}") from e def _get_steps(self): return self.n_steps def _get_total_steps(self): return self.total_steps def _get_episodes(self): return self.n_episodes def _get_total_episodes(self): return self.total_episodes def _get_tasks(self): return self.n_tasks def _get_total_tasks(self): return self.total_tasks def _get_episode_return(self): return sum(self.episode_returns) / len(self.episode_returns) if len(self.episode_returns) > 0 else 0 @property def current_curriculum(self): return self.curriculum_list[self._curriculum_index] @property def requires_step_updates(self): return any(map(lambda c: c.requires_step_updates, self.curriculum_list))
[docs] def sample(self, k: int = 1) -> Union[List, Any]: """ Choose the next k tasks from the list. """ curriculum = self.current_curriculum tasks = curriculum.sample(k) # Recode tasks into environment task space decoded_tasks = [curriculum.task_space.decode(task) for task in tasks] recoded_tasks = [self.task_space.encode(task) for task in decoded_tasks] self.n_tasks += k self.total_tasks += k # Check if we should move on to the next phase of the curriculum self.check_stopping_conditions() return recoded_tasks
[docs] def update_on_episode(self, episode_return, length, task, progress, env_id=None): self.n_episodes += 1 self.total_episodes += 1 self.n_steps += length self.total_steps += length self.episode_returns.append(episode_return) if self.stat_recorder is not None: self.stat_recorder.record(episode_return, length, task) # Update current curriculum self.current_curriculum.update_on_episode(episode_return, length, task, progress, env_id) self.check_stopping_conditions()
[docs] def update_on_step(self, task, obs, rew, term, trunc, info, progress, env_id=None): if self.current_curriculum.requires_step_updates: self.current_curriculum.update_on_step(task, obs, rew, term, trunc, info, progress, env_id)
[docs] def update_on_step_batch(self, step_results, env_id=None): if self.current_curriculum.requires_step_updates: self.current_curriculum.update_on_step_batch(step_results, env_id)
[docs] def update_task_progress(self, task, progress, env_id=None): self.current_curriculum.update_task_progress(task, progress, env_id)
[docs] def check_stopping_conditions(self): if self._curriculum_index < len(self.stopping_conditions) and self.stopping_conditions[self._curriculum_index](): self._curriculum_index += 1 self.n_episodes = 0 self.n_steps = 0 self.episode_returns = deque(maxlen=100) self.n_tasks = 0
def _sample_distribution(self) -> List[float]: return self.current_curriculum._sample_distribution()
[docs] def log_metrics(self, writer, logs, step=None, log_n_tasks=1): logs = [] if logs is None else logs logs.append(("curriculum/current_stage", self._curriculum_index)) logs.append(("curriculum/steps", self.n_steps)) logs.append(("curriculum/episodes", self.n_episodes)) logs.append(("curriculum/episode_returns", self._get_episode_return())) # Set probability for tasks from other stages to 0 current_tasks = set(self.current_curriculum.task_space.tasks) all_tasks = set(self.task_space.tasks) noncurrent_tasks = all_tasks - current_tasks for task in noncurrent_tasks: name = self.task_names(task, self.task_space.encode(task)) logs.append((f"curriculum/{name}_prob", 0)) # Current curriculum will pass data to the writer for us return self.current_curriculum.log_metrics(writer, logs, step=step, log_n_tasks=log_n_tasks)