Source code for syllabus.examples.task_wrappers.minigrid_task_wrapper
""" Task wrapper that can select a new MiniGrid task on reset. """importwarningsimportgymnasiumasgymimportnumpyasnpfromsyllabus.coreimportTaskWrapperfromsyllabus.task_spaceimportDiscreteTaskSpace
[docs]classMinigridTaskWrapper(TaskWrapper):""" This wrapper allows you to change the task of an NLE environment. """def__init__(self,env:gym.Env):super().__init__(env)try:fromgym_minigrid.minigridimportCOLOR_TO_IDX,OBJECT_TO_IDXexceptImportError:warnings.warn("Unable to import gym_minigrid.",stacklevel=2)self.observation_space=gym.spaces.Box(low=0,high=255,shape=(self.env.width,self.env.height,3),# number of cellsdtype='uint8')m,n,c=self.observation_space.shapeself.observation_space=gym.spaces.Box(self.observation_space.low[0,0,0],self.observation_space.high[0,0,0],[c,m,n],dtype=self.observation_space.dtype)# Set up task spaceself.task_space=DiscreteTaskSpace(4000)self.task=None
[docs]defreset(self,new_task=None,**kwargs):""" Resets the environment along with all available tasks, and change the current task. This ensures that all instance variables are reset, not just the ones for the current task. We do this efficiently by keeping track of which reset functions have already been called, since very few tasks override reset. If new_task is provided, we change the task before calling the final reset. """# Change task if new one is providedifnew_taskisnotNone:self.change_task(new_task)self.done=Falseself.episode_return=0returnself.observation(self.env.reset(**kwargs)["image"])
[docs]defchange_task(self,new_task:int):""" Change task by directly editing environment class. Ignores requests for unknown tasks or task changes outside of a reset. """seed=int(new_task)self.task=seedself.env.seed(seed)
[docs]defstep(self,action):""" Step through environment and update task completion. """# assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()"obs,rew,term,trunc,info=self.env.step(action)obs=self.observation(obs["image"])self.episode_return+=rewself.done=termortruncinfo["task_completion"]=self._task_completion(obs,rew,term,trunc,info)returnobs,rew,term,trunc,info