import gymnasium as gym
from syllabus.core import TaskWrapper
from syllabus.task_space import DiscreteTaskSpace
PROCGEN_RETURN_BOUNDS = {
"coinrun": (5, 10),
"starpilot": (2.5, 64),
"caveflyer": (3.5, 12),
"dodgeball": (1.5, 19),
"fruitbot": (-1.5, 32.4),
"chaser": (0.5, 13),
"miner": (1.5, 13),
"jumper": (3, 10),
"leaper": (3, 10),
"maze": (5, 10),
"bigfish": (1, 40),
"heist": (3.5, 10),
"climber": (2, 12.6),
"plunder": (4.5, 30),
"ninja": (3.5, 10),
"bossfight": (0.5, 13),
}
[docs]
class ProcgenTaskWrapper(TaskWrapper):
"""
This wrapper allows you to change the task of an NLE environment.
"""
def __init__(self, env: gym.Env, env_id, seed=0):
super().__init__(env)
self.task_space = DiscreteTaskSpace(200)
self.env_id = env_id
self.task = seed
self.seed(seed)
self.episode_return = 0
self.observation_space = self.env.observation_space
[docs]
def seed(self, seed):
self.env.unwrapped.gym_env.unwrapped._venv.seed(int(seed), 0)
[docs]
def reset(self, new_task=None, **kwargs):
"""
Resets the environment along with all available tasks, and change the current task.
This ensures that all instance variables are reset, not just the ones for the current task.
We do this efficiently by keeping track of which reset functions have already been called,
since very few tasks override reset. If new_task is provided, we change the task before
calling the final reset.
"""
self.episode_return = 0.0
return super().reset(new_task=new_task, **kwargs)
[docs]
def change_task(self, new_task: int):
"""
Change task by directly editing environment class.
Ignores requests for unknown tasks or task changes outside of a reset.
"""
seed = int(new_task)
self.task = seed
self.seed(seed)
[docs]
def step(self, action):
"""
Step through environment and update task completion.
"""
obs, rew, term, trunc, info = super().step(action)
self.episode_return += rew
return self.observation(obs), rew, term, trunc, info
def _task_completion(self, obs, rew, term, trunc, info) -> float:
if not (term or trunc):
return 0.0
env_min, env_max = PROCGEN_RETURN_BOUNDS[self.env_id]
normalized_return = (self.episode_return - env_min) / float(env_max - env_min)
return normalized_return
# clipped_return = 1 if normalized_return > 0.5 else 0 # Binary progress
# return clipped_return
[docs]
def observation(self, observation):
return observation