import copy
import time
import torch
from typing import Any, Callable, Dict
import gymnasium as gym
import numpy as np
import ray
from gymnasium.utils.step_api_compatibility import step_api_compatibility
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper
from syllabus.core import Curriculum, MultiProcessingComponents
from syllabus.core.task_interface import PettingZooTaskWrapper, TaskEnv, TaskWrapper
from syllabus.task_space import TaskSpace
[docs]class GymnasiumSyncWrapper(gym.Wrapper):
"""
This wrapper is used to set the task on reset for a Gym environments running
on parallel processes created using multiprocessing.Process. Meant to be used
with a QueueLearningProgressCurriculum running on the main process.
"""
def __init__(self,
env,
task_space: TaskSpace,
components: MultiProcessingComponents,
batch_size: int = 100,
buffer_size: int = 2, # Having an extra task in the buffer minimizes wait time at reset
remove_keys: list = None,
change_task_on_completion: bool = False,
global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None):
# TODO: reimplement global task progress metrics
assert isinstance(
task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead."
super().__init__(env)
self.env = env
self.task_space = task_space
self.components = components
self._latest_task = None
self.batch_size = batch_size
self.remove_keys = remove_keys if remove_keys is not None else []
self.change_task_on_completion = change_task_on_completion
self.global_task_completion = global_task_completion
self.task_progress = 0.0
self._batch_step = 0
self.instance_id = components.get_id()
self.update_on_step = components.requires_step_updates and components.should_sync(self.instance_id)
self.episode_length = 0
self.episode_return = 0
# Create batch buffers for step updates
if self.update_on_step:
self._obs = [None] * self.batch_size
self._rews = np.zeros(self.batch_size, dtype=np.float32)
self._terms = np.zeros(self.batch_size, dtype=bool)
self._truncs = np.zeros(self.batch_size, dtype=bool)
self._infos = [None] * self.batch_size
self._tasks = [None] * self.batch_size
self._task_progresses = np.zeros(self.batch_size, dtype=np.float32)
# Request initial task
assert buffer_size > 0, "Buffer size must be greater than 0 to sample initial task for envs."
for _ in range(buffer_size):
update = {
"update_type": "noop",
"metrics": None,
"request_sample": True,
}
self.components.put_update(update)
[docs] def reset(self, *args, **kwargs):
self.task_progress = 0.0
self.episode_length = 0
self.episode_return = 0
message = self.components.get_task() # Blocks until a task is available
next_task = self.task_space.decode(message["next_task"])
self._latest_task = next_task
obs, info = self.env.reset(*args, new_task=next_task, **kwargs)
info["task"] = self.task_space.encode(self.get_task())
if self.update_on_step:
self._update_step(obs, 0.0, False, False, info, send=False)
return obs, info
[docs] def step(self, action):
obs, rew, term, trunc, info = step_api_compatibility(self.env.step(action), output_truncation_bool=True)
info["task"] = self.task_space.encode(self.get_task())
self.episode_length += 1
self.episode_return += rew
self.task_progress = info.get("task_completion", 0.0)
# Update curriculum with step info
if self.update_on_step:
self._update_step(obs, rew, term, trunc, info)
# Episode update
if term or trunc:
episode_update = {
"update_type": "episode",
"metrics": (self.episode_return, self.episode_length, self.task_space.encode(self.get_task()), self.task_progress),
"env_id": self.instance_id,
"request_sample": True
}
self.components.put_update([episode_update])
if self.change_task_on_completion and self.task_progress >= 1.0:
update = {
"update_type": "task_progress",
"metrics": (self.task_space.encode(self.get_task()), self.task_progress),
"env_id": self.instance_id,
"request_sample": True
}
self.components.put_update(update)
message = self.components.get_task() # Blocks until a task is available
next_task = self.task_space.decode(message["next_task"])
self.env.change_task(next_task)
self._latest_task = next_task
return obs, rew, term, trunc, info
def _update_step(self, obs, rew, term, trunc, info, send=True):
trimmed_obs = {key: obs[key]
for key in obs.keys() if key not in self.remove_keys} if isinstance(obs, dict) else obs
self._obs[self._batch_step] = trimmed_obs
self._rews[self._batch_step] = rew
self._terms[self._batch_step] = term
self._truncs[self._batch_step] = trunc
self._infos[self._batch_step] = info
self._tasks[self._batch_step] = self.task_space.encode(self.get_task())
self._task_progresses[self._batch_step] = self.task_progress
self._batch_step += 1
# Send batched updates
if send and (self._batch_step >= self.batch_size or term or trunc):
updates = self._package_step_updates()
self.components.put_update(updates)
self._batch_step = 0
def _package_step_updates(self):
return [{
"update_type": "step_batch",
"metrics": ([
self._tasks[:self._batch_step],
self._obs[:self._batch_step],
self._rews[:self._batch_step],
self._terms[:self._batch_step],
self._truncs[:self._batch_step],
self._infos[:self._batch_step],
self._task_progresses[:self._batch_step],
],),
"env_id": self.instance_id,
"request_sample": False
}]
[docs] def get_task(self):
# Allow user to reject task
if hasattr(self.env, "task"):
return self.env.task
return self._latest_task
def __getattr__(self, attr):
env_attr = getattr(self.env, attr, None)
if env_attr is not None:
return env_attr
[docs]class PettingZooSyncWrapper(BaseParallelWrapper):
"""
This wrapper is used to set the task on reset for a Gym environments running
on parallel processes created using multiprocessing.Process. Meant to be used
with a QueueLearningProgressCurriculum running on the main process.
"""
def __init__(self,
env,
task_space: TaskSpace,
components: MultiProcessingComponents,
batch_size: int = 100,
buffer_size: int = 2, # Having an extra task in the buffer minimizes wait time at reset
remove_keys: list = None,
change_task_on_completion: bool = False,
global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None):
# TODO: reimplement global task progress metrics
assert isinstance(
task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead."
super().__init__(env)
self.env = env
self.task_space = task_space
self.components = components
self._latest_task = None
self.batch_size = batch_size
self.remove_keys = remove_keys if remove_keys is not None else []
self.change_task_on_completion = change_task_on_completion
self.global_task_completion = global_task_completion
self._batch_step = 0
self.instance_id = components.get_id()
self.update_on_step = components.requires_step_updates and components.should_sync(self.instance_id)
self.task_progress = 0.0
self.episode_length = 0
self.episode_returns = {agent: 0 for agent in self.env.possible_agents}
# Create template values for reset step update
_template_rews = {agent: 0 for agent in self.env.possible_agents}
_template_terms = {agent: False for agent in self.env.possible_agents}
_template_truncs = {agent: False for agent in self.env.possible_agents}
self._template_args = (_template_rews, _template_terms, _template_truncs)
# Create batch buffers for step updates
if self.update_on_step:
num_agents = len(self.env.possible_agents)
self.agent_map = {agent: i for i, agent in enumerate(self.env.possible_agents)}
self._obs = [[None for _ in range(num_agents)]] * self.batch_size
self._rews = np.zeros((self.batch_size, num_agents), dtype=np.float32)
self._terms = np.zeros((self.batch_size, num_agents), dtype=bool)
self._truncs = np.zeros((self.batch_size, num_agents), dtype=bool)
self._infos = [[None for _ in range(num_agents)]] * self.batch_size
self._tasks = [[None for _ in range(num_agents)]] * self.batch_size
self._task_progresses = np.zeros((self.batch_size, num_agents), dtype=np.float32)
# Request initial task
assert buffer_size > 0, "Buffer size must be greater than 0 to sample initial task for envs."
for _ in range(buffer_size):
update = {
"update_type": "noop",
"metrics": None,
"request_sample": True,
}
self.components.put_update(update)
[docs] def reset(self, *args, **kwargs):
self.task_progress = 0.0
self.episode_length = 0
self.episode_returns = {agent: 0 for agent in self.env.possible_agents}
message = self.components.get_task() # Blocks until a task is available
next_task = self.task_space.decode(message["next_task"])
self._latest_task = next_task
obs, info = self.env.reset(*args, new_task=next_task, **kwargs)
info["task"] = self.task_space.encode(self.get_task())
if self.update_on_step:
self._update_step(obs, *self._template_args, info, False, send=False)
return self.env.reset(*args, new_task=next_task, **kwargs)
[docs] def step(self, actions):
obs, rews, terms, truncs, infos = self.env.step(actions)
self.episode_length += 1
for agent in rews.keys():
self.episode_returns[agent] += rews[agent]
if "task_completion" in list(infos.values())[0]:
self.task_progress = max([info["task_completion"] for info in infos.values()])
is_finished = (len(self.env.agents) == 0) or all(terms.values())
# Update curriculum with step info
if self.update_on_step:
self._update_step(obs, rews, terms, truncs, infos, is_finished)
if is_finished:
episode_update = {
"update_type": "episode",
"metrics": (self.episode_returns, self.episode_length, self.task_space.encode(self.env.task), self.task_progress),
"env_id": self.instance_id,
"request_sample": True
}
self.components.put_update([episode_update])
if self.change_task_on_completion and self.task_progress >= 1.0:
update = {
"update_type": "task_progress",
"metrics": (self.task_space.encode(self.get_task()), self.task_progress),
"env_id": self.instance_id,
"request_sample": True
}
self.components.put_update(update)
message = self.components.get_task() # Blocks until a task is available
next_task = self.task_space.decode(message["next_task"])
self.env.change_task(next_task)
return obs, rews, terms, truncs, infos
def _update_step(self, obs, rews, terms, truncs, infos, is_finished, send=True):
agent_indices = [self.agent_map[agent] for agent in rews.keys()]
# Environment outputs
trimmed_obs = self._trim_obs(obs)
self._obs[self._batch_step] = trimmed_obs
self._rews[self._batch_step][agent_indices] = list(rews.values())
self._terms[self._batch_step][agent_indices] = list(terms.values())
self._truncs[self._batch_step][agent_indices] = list(truncs.values())
self._infos[self._batch_step] = infos
self._tasks[self._batch_step] = self.task_space.encode(self.get_task())
self._task_progresses[self._batch_step] = self.task_progress
self._batch_step += 1
# Send batched updates
if self._batch_step >= self.batch_size or is_finished:
updates = self._package_step_updates()
self.components.put_update(updates)
self._batch_step = 0
def _package_step_updates(self):
return [{
"update_type": "step_batch",
"metrics": ([
self._tasks[:self._batch_step],
self._obs[:self._batch_step],
self._rews[:self._batch_step],
self._terms[:self._batch_step],
self._truncs[:self._batch_step],
self._infos[:self._batch_step],
self._task_progresses[:self._batch_step],
],),
"env_id": self.instance_id,
"request_sample": False
}]
def _trim_obs(self, obs):
if len(self.agents) > 0 and isinstance(obs[self.agents[0]], dict):
return {agent: {key: obs[agent][key] for key in obs[agent].keys() if key not in self.remove_keys} for agent in self.agents}
else:
return obs
[docs] def get_task(self):
# Allow user to reject task
if hasattr(self.env, "task"):
return self.env.task
return self._latest_task
def __getattr__(self, attr):
env_attr = getattr(self.env, attr, None)
if env_attr is not None:
return env_attr
[docs]class RayGymnasiumSyncWrapper(gym.Wrapper):
"""
This wrapper is used to set the task on reset for a Gym environments running
on parallel processes created using ray. Meant to be used with a
RayLearningProgressCurriculum running on the main process.
"""
def __init__(self,
env,
update_on_step: bool = True,
task_space: gym.Space = None,
global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None):
assert isinstance(env, TaskWrapper) or isinstance(env, TaskEnv) or isinstance(
env, PettingZooTaskWrapper), "Env must implement the task API"
super().__init__(env)
self.env = env
self.update_on_step = update_on_step # Disable to improve performance
self.task_space = task_space
self.curriculum = ray.get_actor("curriculum")
self.task_completion = 0.0
self.global_task_completion = global_task_completion
self.step_results = []
[docs] def reset(self, *args, **kwargs):
self.step_results = []
# Update curriculum
update = {
"update_type": "task_progress",
"metrics": (self.env.task, self.task_completion),
"request_sample": True
}
self.curriculum.update.remote(update)
self.task_completion = 0.0
# Sample new task
sample = ray.get(self.curriculum.sample.remote())
next_task = sample[0]
return self.env.reset(*args, new_task=next_task, **kwargs)
[docs] def step(self, action):
obs, rew, term, trunc, info = self.env.step(action)
if "task_completion" in info:
if self.global_task_completion is not None:
# TODO: Hide rllib interface?
self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info)
else:
self.task_completion = info["task_completion"]
# TODO: Optimize
if self.update_on_step:
self.step_results.append((obs, rew, term, trunc, info))
if len(self.step_results) >= 1000 or term or trunc:
update = {
"update_type": "step_batch",
"metrics": (self.step_results,),
"request_sample": False
}
self.curriculum.update.remote(update)
self.step_results = []
return obs, rew, term, trunc, info
[docs] def change_task(self, new_task):
"""
Changes the task of the existing environment to the new_task.
Each environment will implement tasks differently. The easiest system would be to call a
function or set an instance variable to change the task.
Some environments may need to be reset or even reinitialized to change the task.
If you need to reset or re-init the environment here, make sure to check
that it is not in the middle of an episode to avoid unexpected behavior.
"""
self.env.change_task(new_task)
def __getattr__(self, attr):
env_attr = getattr(self.env, attr, None)
if env_attr:
return env_attr
[docs]class RayPettingZooSyncWrapper(BaseParallelWrapper):
"""
This wrapper is used to set the task on reset for a Gym environments running
on parallel processes created using ray. Meant to be used with a
RayLearningProgressCurriculum running on the main process.
"""
def __init__(self,
env,
task_space: TaskSpace,
update_on_step: bool = True,
global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None):
assert isinstance(env, TaskWrapper) or isinstance(env, TaskEnv) or isinstance(
env, PettingZooTaskWrapper), "Env must implement the task API"
super().__init__(env)
self.env = env
self.update_on_step = update_on_step # Disable to improve performance
self.task_space = task_space
self.curriculum = ray.get_actor("curriculum")
self.task_completion = 0.0
self.global_task_completion = global_task_completion
self.step_results = []
[docs] def reset(self, *args, **kwargs):
self.step_results = []
# Update curriculum
update = {
"update_type": "task_progress",
"metrics": (self.env.task, self.task_completion),
"request_sample": True
}
self.curriculum.update.remote(update)
self.task_completion = 0.0
# Sample new task
sample = ray.get(self.curriculum.sample.remote())
next_task = sample[0]
return self.env.reset(*args, new_task=next_task, **kwargs)
[docs] def change_task(self, new_task):
"""
Changes the task of the existing environment to the new_task.
Each environment will implement tasks differently. The easiest system would be to call a
function or set an instance variable to change the task.
Some environments may need to be reset or even reinitialized to change the task.
If you need to reset or re-init the environment here, make sure to check
that it is not in the middle of an episode to avoid unexpected behavior.
"""
self.env.change_task(new_task)
[docs] def step(self, action):
obs, rew, term, trunc, info = self.env.step(action)
if "task_completion" in info:
if self.global_task_completion is not None:
# TODO: Hide rllib interface?
self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info)
else:
self.task_completion = info["task_completion"]
# TODO: Optimize
if self.update_on_step:
self.step_results.append((obs, rew, term, trunc, info))
if len(self.step_results) >= 1000 or term or trunc:
update = {
"update_type": "step_batch",
"metrics": (self.step_results,),
"request_sample": False
}
self.curriculum.update.remote(update)
self.step_results = []
return obs, rew, term, trunc, info
def __getattr__(self, attr):
env_attr = getattr(self.env, attr, None)
if env_attr:
return env_attr