Source code for syllabus.core.task_interface.subclass_task_wrapper

""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """
import copy
from typing import List

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from syllabus.task_space import DiscreteTaskSpace

from .task_wrapper import TaskWrapper


[docs]class SubclassTaskWrapper(TaskWrapper): # TODO: Automated tests """ This is a general wrapper for tasks defined as subclasses of a base environment. This wrapper reinitializes the environment with the provided env function at the start of each episode. This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. It's likely that you can achieve better performance by using a more specialized wrapper. """ def __init__(self, env: gym.Env, task_subclasses: List[gym.Env] = None, **env_init_kwargs): super().__init__(env) self.task_list = task_subclasses self.task_space = DiscreteTaskSpace(len(self.task_list), self.task_list) self._env_init_kwargs = env_init_kwargs # kwargs for reinitializing the base environment # Add goal space to observation self.observation_space = copy.deepcopy(self.env.observation_space) self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) # Tracking episode end self.done = True # Initialize all tasks original_class = self.env.__class__ for task in self.task_list: self.env.__class__ = task self.env.__init__(**self._env_init_kwargs) self.env.__class__ = original_class self.env.__init__(**self._env_init_kwargs) @property def current_task(self): return self.env.__class__ def _task_name(self, task): return self.task.__name__
[docs] def reset(self, new_task: int = None, **kwargs): """ Resets the environment along with all available tasks, and change the current task. """ # Change task if new one is provided if new_task is not None: self.change_task(new_task) self.done = False obs, info = self.env.reset(**kwargs) return self.observation(obs), info
[docs] def change_task(self, new_task: int): """ Change task by directly editing environment class. This ensures that all instance variables are reset, not just the ones for the current task. We do this efficiently by keeping track of which reset functions have already been called, since very few tasks override reset. If new_task is provided, we change the task before calling the final reset. """ # Ignore new task if mid episode if self.current_task.__init__ != self._task_class(new_task).__init__ and not self.done: raise RuntimeError("Cannot change task mid-episode.") # Ignore if task is unknown if new_task >= len(self.task_list): raise RuntimeError(f"Unknown task {new_task}.") # Update current task prev_task = self.task self.task = new_task self.env.__class__ = self._task_class(new_task) # If task requires reinitialization if type(self.env).__init__ != prev_task.__init__: self.env.__init__(**self._env_init_kwargs)
def _encode_goal(self): goal_encoding = np.zeros(len(self.task_list)) goal_encoding[self.task] = 1 return goal_encoding
[docs] def step(self, action): """ Step through environment and update task completion. """ obs, rew, term, trunc, info = self.env.step(action) self.done = term or trunc info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) return self.observation(obs), rew, term, trunc, info