Source code for syllabus.core.dual_curriculum_wrapper

from typing import Tuple, TypeVar

from syllabus.core import Agent, Curriculum, CurriculumWrapper
from syllabus.task_space import TupleTaskSpace

EnvTask = TypeVar("EnvTask")
AgentTask = TypeVar("AgentTask")


[docs]class DualCurriculumWrapper(CurriculumWrapper): """Curriculum wrapper containing both an agent and environment-based curriculum.""" def __init__( self, env_curriculum: Curriculum, agent_curriculum: Curriculum, batch_agent_tasks: bool = False, batch_size: int = 32, *args, **kwargs, ) -> None: self.agent_curriculum = agent_curriculum self.env_curriculum = env_curriculum self.task_space = TupleTaskSpace( env_curriculum.task_space.gym_space, agent_curriculum.task_space.gym_space, ) self.batch_agent_tasks = batch_agent_tasks self.batch_size = batch_size self.batched_tasks = [] self.agent_task = None super().__init__(self.task_space, *args, **kwargs)
[docs] def sample(self, k=1) -> Tuple[EnvTask, AgentTask]: """Sets new tasks for the environment and agent curricula.""" env_task = self.env_curriculum.sample(k=k) if len(self.batched_tasks) < k: self.batched_tasks = self.agent_curriculum.sample(k=1) * self.batch_size agent_task = [self.batched_tasks.pop() for _ in range(k)] return list(zip(env_task, agent_task))
[docs] def get_agent(self, agent: AgentTask) -> Agent: return self.agent_curriculum.get_opponent(agent)
[docs] def update_agent(self, agent: Agent) -> int: return self.agent_curriculum.update_agent(agent)
[docs] def update_on_episode(self, episode_return, length, task, progress, env_id=None): self.env_curriculum.update_on_episode(episode_return, length, task[0], progress, env_id) self.agent_curriculum.update_on_episode(episode_return, length, task[1], progress, env_id)
[docs] def update_on_step(self, task, obs, reward, term, trunc, info, env_id=None): if self.env_curriculum.requires_step_updates: self.env_curriculum.update_on_step( task[0], obs, reward, term, trunc, info, env_id=env_id ) if self.agent_curriculum.requires_step_updates: self.agent_curriculum.update_on_step( task[1], obs, reward, term, trunc, info, env_id=env_id )
[docs] def update_on_step_batch(self, step_results, env_id=None): tasks, o, r, t, tr, i, p = step_results env_step_results = ([task[0] for task in tasks], o, r, t, tr, i, p) agent_step_results = ([task[1] for task in tasks], o, r, t, tr, i, p) if self.env_curriculum.requires_step_updates: self.env_curriculum.update_on_step_batch(env_step_results, env_id=env_id) if self.agent_curriculum.requires_step_updates: self.agent_curriculum.update_on_step_batch(agent_step_results, env_id=env_id)
[docs] def update_task_progress(self, task, progress): self.env_curriculum.update_task_progress(task[0], progress) self.agent_curriculum.update_task_progress(task[1], progress)
def __getattr__(self, name): """Delegate attribute lookup to the curricula if not found.""" if hasattr(self.env_curriculum, name): return getattr(self.env_curriculum, name) elif hasattr(self.agent_curriculum, name): return getattr(self.agent_curriculum, name) else: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" )