import copy
import signal
import sys
import threading
import time
import warnings
from functools import wraps
from multiprocessing.shared_memory import ShareableList
from queue import Empty
from typing import Dict
import ray
from torch.multiprocessing import Lock, Queue
from syllabus.core import Curriculum
from syllabus.utils import UsageError, decorate_all_functions
[docs]class CurriculumWrapper:
"""Wrapper class for adding multiprocessing synchronization to a curriculum.
"""
def __init__(self, curriculum: Curriculum) -> None:
self.curriculum = curriculum
if hasattr(curriculum, "unwrapped") and curriculum.unwrapped is not None:
self.unwrapped = curriculum.unwrapped
else:
self.unwrapped = curriculum
self.task_space = self.unwrapped.task_space
@property
def num_tasks(self):
return self.task_space.num_tasks
[docs] def count_tasks(self, task_space=None):
return self.task_space.count_tasks(gym_space=task_space)
@property
def tasks(self):
return self.task_space.tasks
@property
def requires_step_updates(self):
return self.curriculum.requires_step_updates
[docs] def get_tasks(self, task_space=None):
return self.task_space.get_tasks(gym_space=task_space)
[docs] def sample(self, k=1):
return self.curriculum.sample(k=k)
[docs] def update_task_progress(self, task, progress):
self.curriculum.update_task_progress(task, progress)
[docs] def update_on_step(self, task, obs, reward, term, trunc, info, progress):
self.curriculum.update_on_step(task, obs, reward, term, trunc, info, progress)
[docs] def log_metrics(self, writer, logs, step=None, log_n_tasks=1):
return self.curriculum.log_metrics(writer, logs, step=step, log_n_tasks=log_n_tasks)
[docs] def update_on_step_batch(self, step_results, env_id=None):
self.curriculum.update_on_step_batch(step_results, env_id=env_id)
[docs] def update_on_episode(self, episode_return, length, task, progress, env_id=None):
self.curriculum.update_on_episode(episode_return, length, task, progress, env_id=env_id)
[docs] def normalize(self, rewards, task):
return self.curriculum.normalize(rewards, task)
def __getattr__(self, attr):
curriculum_atr = getattr(self.curriculum, attr, None)
if curriculum_atr is not None:
return curriculum_atr
[docs]class MultiProcessingComponents:
def __init__(self, requires_step_updates, max_queue_size=1000000, timeout=60, max_envs=None):
self.requires_step_updates = requires_step_updates
self.task_queue = Queue(maxsize=max_queue_size)
self.update_queue = Queue(maxsize=max_queue_size)
self._instance_lock = Lock()
self._env_count = ShareableList([0])
self._debug = True
self.timeout = timeout
self.max_envs = max_envs
self._maxsize = max_queue_size
self.started = False
[docs] def peek_id(self):
return self._env_count[0]
[docs] def get_id(self):
with self._instance_lock:
instance_id = self._env_count[0]
self._env_count[0] += 1
return instance_id
[docs] def should_sync(self, env_id):
# Only receive step updates from self.max_envs environments
if self.max_envs is not None and env_id >= self.max_envs:
return False
return True
[docs] def put_task(self, task):
self.task_queue.put(task, block=False)
[docs] def get_task(self):
try:
if self.started and self.task_queue.empty():
warnings.warn(
f"Task queue capacity is {self.task_queue.qsize()} / {self.task_queue._maxsize}. Program may deadlock if task_queue is empty. If the update queue capacity is increasing, consider optimizing your curriculum or reducing the number of environments. Otherwise, consider increasing the buffer_size for your environment sync wrapper.")
task = self.task_queue.get(block=True, timeout=self.timeout)
return task
except Empty as e:
raise UsageError(
f"Failed to get task from queue after {self.timeout}s. Queue capacity is {self.task_queue.qsize()} / {self.task_queue._maxsize} items.") from e
[docs] def put_update(self, update):
self.update_queue.put(copy.deepcopy(update), block=False)
[docs] def get_update(self):
update = self.update_queue.get(block=False)
return update
[docs] def close(self):
if self._env_count is not None:
self._env_count.shm.close()
try:
self._env_count.shm.unlink()
except FileNotFoundError:
pass # Already unlinked
self.task_queue.close()
self.update_queue.close()
self._env_count = None
[docs] def get_metrics(self, log_n_tasks=1):
logs = []
logs.append(("curriculum/updates_in_queue", self.update_queue.qsize()))
logs.append(("curriculum/tasks_in_queue", self.task_queue.qsize()))
return logs
[docs]class CurriculumSyncWrapper(CurriculumWrapper):
def __init__(
self,
curriculum: Curriculum,
**kwargs,
):
super().__init__(curriculum)
self.update_thread = None
self.should_update = False
self.added_tasks = []
self.num_assigned_tasks = 0
self.components = MultiProcessingComponents(self.curriculum.requires_step_updates, **kwargs)
[docs] def start(self):
"""
Start the thread that reads the complete_queue and reads the task_queue.
"""
if not self.should_update:
self.update_thread = threading.Thread(name='update', target=self._update_queues, daemon=True)
self.should_update = True
self.components.started = True
signal.signal(signal.SIGINT, self._sigint_handler)
self.update_thread.start()
[docs] def stop(self):
"""
Stop the thread that reads the complete_queue and reads the task_queue.
"""
self.should_update = False
self.components.started = False
self.update_thread.join()
self.components.close()
def _sigint_handler(self, sig, frame):
self.stop()
sys.exit(0)
def _update_queues(self):
"""
Continuously process completed tasks and sample new tasks.
"""
# Update curriculum with environment results:
while self.should_update:
if not self.components.update_queue.empty():
update = self.components.get_update() # Blocks until update is available
if isinstance(update, list):
update = update[0]
# Sample new tasks if requested
if "request_sample" in update and update["request_sample"]:
new_tasks = self.curriculum.sample(k=1)
for task in new_tasks:
message = {"next_task": task}
self.components.put_task(message)
self.num_assigned_tasks += 1
self.route_update(update)
time.sleep(0.0)
else:
time.sleep(0.01)
[docs] def route_update(self, update_data: Dict[str, tuple]):
"""Update the curriculum with the specified update type.
TODO: Change method header to not use dictionary, use enums?
:param update_data: Dictionary
:type update_data: Dictionary with "update_type" key which maps to one of ["step", "step_batch", "episode", "on_demand", "task_progress", "noop"] and "args" with a tuple of the appropriate arguments for the given "update_type".
:raises NotImplementedError:
"""
update_type = update_data["update_type"]
args = update_data["metrics"]
env_id = update_data["env_id"] if "env_id" in update_data else None
if update_type == "step":
self.update_on_step(*args, env_id=env_id)
elif update_type == "step_batch":
self.update_on_step_batch(*args, env_id=env_id)
elif update_type == "episode":
self.update_on_episode(*args, env_id=env_id)
elif update_type == "task_progress":
self.update_task_progress(*args, env_id=env_id)
elif update_type == "noop":
# Used to request tasks from the synchronization layer
pass
else:
raise NotImplementedError(f"Update type {update_type} not implemented.")
[docs] def log_metrics(self, writer, logs, step=None, log_n_tasks=1):
logs = [] if logs is None else logs
logs += self.components.get_metrics(log_n_tasks=log_n_tasks)
return super().log_metrics(writer, logs, step=step, log_n_tasks=log_n_tasks)
[docs]def remote_call(func):
"""
Decorator for automatically forwarding calls to the curriculum via ray remote calls.
Note that this causes functions to block, and should be only used for operations that do not require parallelization.
"""
@wraps(func)
def wrapper(self, *args, **kw):
f_name = func.__name__
parent_func = getattr(CurriculumWrapper, f_name)
child_func = getattr(self, f_name)
# Only forward call if subclass does not explicitly override the function.
if child_func == parent_func:
curriculum_func = getattr(self.curriculum, f_name)
return ray.get(curriculum_func.remote(*args, **kw))
return wrapper
[docs]def make_multiprocessing_curriculum(curriculum, start=True, **kwargs):
"""
Helper function for creating a MultiProcessingCurriculumWrapper.
"""
mp_curriculum = CurriculumSyncWrapper(curriculum, **kwargs)
if start:
mp_curriculum.start()
return mp_curriculum
@ray.remote
class RayCurriculumWrapper(CurriculumWrapper):
def __init__(self, curriculum: Curriculum) -> None:
super().__init__(curriculum)
def get_remote_attr(self, name: str):
next_obj = getattr(self.curriculum, name)
return next_obj
[docs]@decorate_all_functions(remote_call)
class RayCurriculumSyncWrapper(CurriculumWrapper):
"""
Subclass of LearningProgress Curriculum that uses Ray to share tasks and receive feedback
from the environment. The only change is the @ray.remote decorator on the class.
The @decorate_all_functions(remote_call) annotation automatically forwards all functions not explicitly
overridden here to the remote curriculum. This is intended to forward private functions of Curriculum subclasses
for convenience.
# TODO: Implement the Curriculum methods explicitly
"""
def __init__(self, curriculum, actor_name="curriculum") -> None:
super().__init__(curriculum)
self.curriculum = RayCurriculumWrapper.options(name=actor_name).remote(curriculum)
self.unwrapped = None
self.task_space = curriculum.task_space
self.added_tasks = []
# If you choose to override a function, you will need to forward the call to the remote curriculum.
# This method is shown here as an example. If you remove it, the same functionality will be provided automatically.
[docs] def sample(self, k: int = 1):
return ray.get(self.curriculum.sample.remote(k=k))
[docs] def update_on_step_batch(self, step_results, env_id=None) -> None:
ray.get(self.curriculum._on_step_batch.remote(step_results))
[docs]def make_ray_curriculum(curriculum, actor_name="curriculum", **kwargs):
"""
Helper function for creating a RayCurriculumWrapper.
"""
return RayCurriculumSyncWrapper(curriculum, actor_name=actor_name, **kwargs)