Source code for syllabus.core.multiagent_curriculum_wrappers

from syllabus.core import CurriculumWrapper


[docs]class MultiagentSharedCurriculumWrapper(CurriculumWrapper): def __init__(self, curriculum, possible_agents, *args, joint_policy=False, **kwargs): super().__init__(curriculum, *args, **kwargs) self.possible_agents = possible_agents self.joint_policy = joint_policy self.num_agents = len(possible_agents)
[docs] def update_task_progress(self, task, progress, env_id=None): for agent in self.possible_agents: agent_index = self.possible_agents.index(agent) env_index = env_id if self.joint_policy else (env_id * self.num_agents) + agent_index self.curriculum.update_task_progress(task, progress, env_id=env_index)
[docs] def update_on_step(self, task, obs, reward, term, trunc, info, progress, env_id: int = None) -> None: """ Update the curriculum with the current step results from the environment. """ for i, agent in enumerate(obs.keys()): agent_index = self.possible_agents.index(agent) maybe_joint_obs = obs if self.joint_policy else obs[agent] env_index = env_id if self.joint_policy else (env_id * self.num_agents) + agent_index agent_progress = progress[agent] if isinstance(progress, dict) else progress self.curriculum.update_on_step( task, maybe_joint_obs, reward[i], term[i], trunc[i], info[agent], agent_progress, env_id=env_index)
[docs] def update_on_step_batch(self, step_results, env_id: int = None) -> None: tasks, obs, rews, terms, truncs, infos, progresses = step_results for t, o, r, te, tr, i, p in zip(tasks, obs, rews, terms, truncs, infos, progresses): self.update_on_step(t, o, r, te, tr, i, p, env_id=env_id)
[docs] def update_on_episode(self, episode_return, length, task, progress, env_id=None) -> None: """ Update the curriculum with episode results from the environment. """ for agent in episode_return.keys(): agent_index = self.possible_agents.index(agent) env_index = env_id if self.joint_policy else (env_id * self.num_agents) + agent_index self.curriculum.update_on_episode( episode_return[agent], length, task, progress, env_id=env_index)
[docs]class MultiagentIndependentCurriculumWrapper(CurriculumWrapper): def __init__(self, curriculum, possible_agents, *args, **kwargs): super().__init__(curriculum, *args, **kwargs) self.possible_agents = possible_agents self.num_agents = len(possible_agents)