Source code for syllabus.core.task_interface.subclass_task_wrapper
""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """importcopyfromtypingimportListimportgymnasiumasgymimportnumpyasnpfromgymnasiumimportspacesfromsyllabus.task_spaceimportDiscreteTaskSpacefrom.task_wrapperimportTaskWrapper
[docs]classSubclassTaskWrapper(TaskWrapper):# TODO: Automated tests""" This is a general wrapper for tasks defined as subclasses of a base environment. This wrapper reinitializes the environment with the provided env function at the start of each episode. This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. It's likely that you can achieve better performance by using a more specialized wrapper. """def__init__(self,env:gym.Env,task_subclasses:List[gym.Env]=None,**env_init_kwargs):super().__init__(env)self.task_list=task_subclassesself.task_space=DiscreteTaskSpace(len(self.task_list),self.task_list)self._env_init_kwargs=env_init_kwargs# kwargs for reinitializing the base environment# Add goal space to observationself.observation_space=copy.deepcopy(self.env.observation_space)self.observation_space["goal"]=spaces.MultiBinary(len(self.task_list))# Tracking episode endself.done=True# Initialize all tasksoriginal_class=self.env.__class__fortaskinself.task_list:self.env.__class__=taskself.env.__init__(**self._env_init_kwargs)self.env.__class__=original_classself.env.__init__(**self._env_init_kwargs)@propertydefcurrent_task(self):returnself.env.__class__def_task_name(self,task):returnself.task.__name__
[docs]defreset(self,new_task:int=None,**kwargs):""" Resets the environment along with all available tasks, and change the current task. """# Change task if new one is providedifnew_taskisnotNone:self.change_task(new_task)self.done=Falseobs,info=self.env.reset(**kwargs)returnself.observation(obs),info
[docs]defchange_task(self,new_task:int):""" Change task by directly editing environment class. This ensures that all instance variables are reset, not just the ones for the current task. We do this efficiently by keeping track of which reset functions have already been called, since very few tasks override reset. If new_task is provided, we change the task before calling the final reset. """# Ignore new task if mid episodeifself.current_task.__init__!=self._task_class(new_task).__init__andnotself.done:raiseRuntimeError("Cannot change task mid-episode.")# Ignore if task is unknownifnew_task>=len(self.task_list):raiseRuntimeError(f"Unknown task {new_task}.")# Update current taskprev_task=self.taskself.task=new_taskself.env.__class__=self._task_class(new_task)# If task requires reinitializationiftype(self.env).__init__!=prev_task.__init__:self.env.__init__(**self._env_init_kwargs)
[docs]defstep(self,action):""" Step through environment and update task completion. """obs,rew,term,trunc,info=self.env.step(action)self.done=termortruncinfo["task_completion"]=self._task_completion(obs,rew,term,trunc,info)returnself.observation(obs),rew,term,trunc,info