Source code for syllabus.examples.task_wrappers.nethack_wrappers

""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """
import re
from typing import Any, Dict, List, Tuple

import gymnasium as gym
import numpy as np
from nle import nethack
from nle.env import base
from nle.env.base import ASCII_SPACE, ASCII_ESC, SKIP_EXCEPTIONS
from nle.env.tasks import TASK_ACTIONS, NetHackChallenge, NetHackGold
from shimmy.openai_gym_compatibility import GymV21CompatibilityV0

from syllabus.core import TaskWrapper
from syllabus.task_space import DiscreteTaskSpace


EXTENDED_TASK_ACTIONS = (
    *TASK_ACTIONS,
    nethack.Command.CAST,
    nethack.Command.QUAFF,
    nethack.Command.PRAY,
    nethack.Command.PICKUP,
    nethack.Command.PAY,
)

SKIP_EXCEPTIONS = (*SKIP_EXCEPTIONS, b"drink")


[docs] class NetHackScore(base.NLE): """Environment for "score" task. The task is an augmentation of the standard NLE task. The return function is defined as: :math:`\text{score}_t - \text{score}_{t-1} + \text{TP}`, where the :math:`\text{TP}` is a time penalty that grows with the amount of environment steps that do not change the state (such as navigating menus). Args: penalty_mode (str): name of the mode for calculating the time step penalty. Can be ``constant``, ``exp``, ``square``, ``linear``, or ``always``. Defaults to ``constant``. penalty_step (float): constant applied to amount of frozen steps. Defaults to -0.01. penalty_time (float): constant applied to amount of frozen steps. Defaults to -0.0. """ def __init__( self, *args, penalty_mode="constant", penalty_step: float = -0.01, penalty_time: float = -0.0, **kwargs, ): self.penalty_mode = penalty_mode self.penalty_step = penalty_step self.penalty_time = penalty_time self._frozen_steps = 0 self.dungeon_explored = {} actions = kwargs.pop("actions", TASK_ACTIONS) super().__init__(*args, actions=actions, **kwargs) def _get_time_penalty(self, last_observation, observation): blstats_old = last_observation[self._blstats_index] blstats_new = observation[self._blstats_index] old_time = blstats_old[nethack.NLE_BL_TIME] new_time = blstats_new[nethack.NLE_BL_TIME] if old_time == new_time: self._frozen_steps += 1 else: self._frozen_steps = 0 penalty = 0 if self.penalty_mode == "constant": if self._frozen_steps > 0: penalty += self.penalty_step elif self.penalty_mode == "exp": penalty += 2**self._frozen_steps * self.penalty_step elif self.penalty_mode == "square": penalty += self._frozen_steps**2 * self.penalty_step elif self.penalty_mode == "linear": penalty += self._frozen_steps * self.penalty_step elif self.penalty_mode == "always": penalty += self.penalty_step else: # default raise ValueError("Unknown penalty_mode '%s'" % self.penalty_mode) penalty += (new_time - old_time) * self.penalty_time return penalty def _reward_fn(self, last_observation, action, observation, end_status): """Score delta, but with added a state loop penalty.""" score_diff = super()._reward_fn( last_observation, action, observation, end_status ) time_penalty = self._get_time_penalty(last_observation, observation) return score_diff + time_penalty def _perform_known_steps(self, observation, done, exceptions=True): while not done: if observation[self._internal_index][3]: # xwaitforspace # Make sure to include information about going down the stairs. previous_msg = observation[self._message_index].copy() msg_str = bytes(previous_msg) observation, done = self.nethack.step(ASCII_SPACE) if b"You descend the stairs." in msg_str: observation = ( *observation[: self._message_index], previous_msg, *observation[self._message_index + 1:], ) continue internal = observation[self._internal_index] in_yn_function = internal[1] in_getline = internal[2] if in_getline: # Game asking for a line of text. We don't do that. observation, done = self.nethack.step(ASCII_ESC) continue if in_yn_function: # Game asking for a single character. # Note: No auto-yes to final questions thanks to the disclose option. if exceptions: # This causes an annoying unnecessary copy... msg = bytes(observation[self._message_index]) # Do not skip some questions to allow agent to select # stuff to eat, attack, and to select directions. # Also do not skip if all allowed or the allowed message appears. if self._allow_all_yn_questions or any( el in msg for el in SKIP_EXCEPTIONS ): break # Otherwise, auto-decline. observation, done = self.nethack.step(ASCII_ESC) break return observation, done
[docs] def get_scout_score(self, last_observation): glyphs = last_observation[self._glyph_index] blstats = last_observation[self._blstats_index] dungeon_num = blstats[nethack.NLE_BL_DNUM] dungeon_level = blstats[nethack.NLE_BL_DLEVEL] key = (dungeon_num, dungeon_level) explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF) self.dungeon_explored[key] = explored total_explored = 0 for key, value in self.dungeon_explored.items(): total_explored += value return total_explored
[docs] def step(self, action: int): """Steps the environment. Args: action (int): action integer as defined by ``self.action_space``. Returns: (dict, float, bool, dict): a tuple containing - (*dict*): an observation of the state; this will contain the keys specified by ``self.observation_space``. - (*float*): a reward; see ``self._reward_fn`` to see how it is specified. - (*bool*): True if the state is terminal, False otherwise. - (*dict*): a dictionary of extra information (such as `end_status`, i.e. a status info -- death, task win, etc. -- for the terminal state). """ # Careful: By default we re-use Numpy arrays, so copy before! last_observation = tuple(a.copy() for a in self.last_observation) # Fix the eating action such that it is possible to eat all items last_msg = bytes(last_observation[self._message_index]).decode("utf-8") if "What do you want to eat" in last_msg: pattern = r"\[([a-zA-Z]+)" match = re.search(pattern, last_msg) if match and self.actions[action] == ord("y"): # Action 'y' for 'yes' will lead to eating any random item in the inventory action = ord(match.group(1)[0]) else: # Otherwise escape action = ASCII_SPACE else: action = self.actions[action] observation, done = self.nethack.step(action) is_game_over = observation[self._program_state_index][0] == 1 if is_game_over or not self._allow_all_modes: observation, done = self._perform_known_steps( observation, done, exceptions=True ) self._steps += 1 self.last_observation = observation if self._check_abort(observation): end_status = self.StepStatus.ABORTED else: end_status = self._is_episode_end(observation) end_status = self.StepStatus(done or end_status) reward = float( self._reward_fn(last_observation, action, observation, end_status) ) if end_status and not done: # Try to end the game nicely. self._quit_game(observation, done) done = True info = {} info["end_status"] = end_status info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED info["dlvl"] = last_observation[self._blstats_index][12] info["gold"] = last_observation[self._blstats_index][13] info["xlvl"] = last_observation[self._blstats_index][18] info["scout"] = self.get_scout_score(last_observation) return self._get_observation(observation), reward, done, info
[docs] class NetHackExtendedActionEnv(base.NLE): def __init__(self, *args, no_progress_timeout: int = 10_000, **kwargs): super().__init__(*args, **kwargs) self.no_progress_timeout = no_progress_timeout def _perform_known_steps(self, observation, done, action=None, exceptions=True): while not done: message_str = bytes(observation[self._message_index]).decode("utf-8") # Macro-action for casting the spell if action == nethack.Command.CAST: observation, done = self.nethack.step(ord("a")) # Enhance skills whenever possible elif "You feel more confident in" in message_str: self.enhance = True if observation[self._internal_index][3]: # xwaitforspace if "You feel more confident in" in message_str: self.enhance = True # Make sure to include information about going down the stairs. previous_msg = observation[self._message_index].copy() msg_str = bytes(previous_msg) observation, done = self.nethack.step(ASCII_SPACE) action = ASCII_SPACE if b"You descend the stairs." in msg_str: observation = ( *observation[: self._message_index], previous_msg, *observation[self._message_index + 1:], ) continue internal = observation[self._internal_index] in_yn_function = internal[1] in_getline = internal[2] if in_getline: # Game asking for a line of text. We don't do that. observation, done = self.nethack.step(ASCII_ESC) action = ASCII_ESC continue if in_yn_function: # Game asking for a single character. # Note: No auto-yes to final questions thanks to the disclose option. if exceptions: # This causes an annoying unnecessary copy... msg = bytes(observation[self._message_index]) # Do not skip some questions to allow agent to select # stuff to eat, attack, and to select directions. # Also do not skip if all allowed or the allowed message appears. if self._allow_all_yn_questions or any( el in msg for el in SKIP_EXCEPTIONS ): break # Otherwise, auto-decline. observation, done = self.nethack.step(ASCII_ESC) action = ASCII_ESC if self.enhance: observation, done = self.nethack.step(nethack.Command.ENHANCE) observation, done = self.nethack.step(ord("a")) self.enhance = False break return observation, done
[docs] def get_scout_score(self, last_observation): glyphs = last_observation[self._glyph_index] blstats = last_observation[self._blstats_index] dungeon_num = blstats[nethack.NLE_BL_DNUM] dungeon_level = blstats[nethack.NLE_BL_DLEVEL] key = (dungeon_num, dungeon_level) explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF) self.dungeon_explored[key] = explored total_explored = 0 for key, value in self.dungeon_explored.items(): total_explored += value return total_explored
[docs] def reset(self, *args, **kwargs): self.dungeon_explored = {} self.enhance = False self.num_kills = 0 return super().reset(*args, **kwargs)
[docs] def step(self, action: int): """Steps the environment. Args: action (int): action integer as defined by ``self.action_space``. Returns: (dict, float, bool, dict): a tuple containing - (*dict*): an observation of the state; this will contain the keys specified by ``self.observation_space``. - (*float*): a reward; see ``self._reward_fn`` to see how it is specified. - (*bool*): True if the state is terminal, False otherwise. - (*dict*): a dictionary of extra information (such as `end_status`, i.e. a status info -- death, task win, etc. -- for the terminal state). """ # Careful: By default we re-use Numpy arrays, so copy before! last_observation = tuple(a.copy() for a in self.last_observation) # Fix the eating action such that it is possible to eat all items last_msg = bytes(last_observation[self._message_index]).decode("utf-8") if "you kill " in last_msg.lower(): self.num_kills += 1 if "What do you want to eat" in last_msg: pattern = r"\[([a-zA-Z]+)" match = re.search(pattern, last_msg) if match and self.actions[action] == ord("y"): # Action 'y' for 'yes' will lead to eating any random item in the inventory action = ord(match.group(1)[0]) else: # Otherwise escape action = ASCII_SPACE elif "What do you want to drink" in last_msg: pattern = r"\[([a-zA-Z]+)" match = re.search(pattern, last_msg) if match and self.actions[action] == ord("y"): action = ord(match.group(1)[0]) else: action = ASCII_SPACE else: action = self.actions[action] observation, done = self.nethack.step(action) is_game_over = observation[self._program_state_index][0] == 1 # perform known steps if is_game_over or not self._allow_all_modes: observation, done = self._perform_known_steps( observation, done, action=action, exceptions=True ) self._steps += 1 self.last_observation = observation if self._check_abort(observation): end_status = self.StepStatus.ABORTED else: end_status = self._is_episode_end(observation) end_status = self.StepStatus(done or end_status) reward = float( self._reward_fn(last_observation, action, observation, end_status) ) if end_status and not done: # Try to end the game nicely. self._quit_game(observation, done) done = True info = {} info["end_status"] = end_status info["is_ascended"] = self.nethack.how_done() == nethack.ASCENDED info["dlvl"] = last_observation[self._blstats_index][12] info["gold"] = last_observation[self._blstats_index][13] info["xlvl"] = last_observation[self._blstats_index][18] info["scout"] = self.get_scout_score(last_observation) info["kill_counts"] = self.num_kills observation = self._get_observation(observation) return observation, reward, done, info
def _check_abort(self, observation): """Check if time has stopped and no observations has changed long enough to trigger an abort.""" turns = observation[self._blstats_index][nethack.NLE_BL_TIME] if self._turns == turns: self._no_progress_count += 1 else: self._turns = turns self._no_progress_count = 0 return ( self._steps >= self._max_episode_steps or self._no_progress_count >= self.no_progress_timeout )
[docs] class NetHackScoreExtendedActions(NetHackExtendedActionEnv, NetHackScore): def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self.actions = EXTENDED_TASK_ACTIONS self.action_space = gym.spaces.Discrete(len(self.actions))
[docs] class NetHackSeed(NetHackScore): """Environment for the NetHack Challenge. The task is an augmentation of the standard NLE task. This is the NLE Score Task but with some subtle differences: * the action space is fixed to include the full keyboard * menus and "<More>" tokens are not skipped * starting character is randomly assigned """ def __init__( self, *args, character="@", allow_all_yn_questions=True, allow_all_modes=True, penalty_mode="constant", penalty_step: float = -0.00, penalty_time: float = -0.0, max_episode_steps: int = 1e6, observation_keys=( "glyphs", "chars", "colors", "specials", "blstats", "message", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", "tty_chars", "tty_colors", "tty_cursor", "misc", ), no_progress_timeout: int = 10_000, **kwargs, ): actions = nethack.ACTIONS kwargs["wizard"] = False super().__init__( *args, actions=actions, character=character, allow_all_yn_questions=allow_all_yn_questions, allow_all_modes=allow_all_modes, penalty_mode=penalty_mode, penalty_step=penalty_step, penalty_time=penalty_time, max_episode_steps=max_episode_steps, observation_keys=observation_keys, **kwargs, ) # If the in-game turn count doesn't change for 10_000 steps, we abort self.no_progress_timeout = no_progress_timeout
[docs] def reset(self, *args, **kwargs): self._turns = None self._no_progress_count = 0 return super().reset(*args, **kwargs)
def _check_abort(self, observation): """Check if time has stopped and no observations has changed long enough to trigger an abort.""" turns = observation[self._blstats_index][nethack.NLE_BL_TIME] if self._turns == turns: self._no_progress_count += 1 else: self._turns = turns self._no_progress_count = 0 return ( self._steps >= self._max_episode_steps or self._no_progress_count >= self.no_progress_timeout )
[docs] class NetHackDescend(NetHackScore): """Environment for "staircase" task. This task requires the agent to get on top of a staircase down (>). The reward function is :math:`I + \text{TP}`, where :math:`I` is 1 if the task is successful, and 0 otherwise, and :math:`\text{TP}` is the time step function as defined by `NetHackScore`. """
[docs] def reset(self, wizkit_items=None): self.max_dungeon_level = 1 return super().reset(wizkit_items=wizkit_items)
def _is_episode_end(self, observation): return self.StepStatus.RUNNING def _reward_fn(self, last_observation, action, observation, end_status): del action, end_status time_penalty = self._get_time_penalty(last_observation, observation) dungeon_level = observation[self._blstats_index][12] if dungeon_level > self.max_dungeon_level: reward = 100 self.max_dungeon_level = dungeon_level else: reward = 0 return reward + time_penalty
[docs] class NetHackCollect(NetHackGold): """Environment for "staircase" task. This task requires the agent to get on top of a staircase down (>). The reward function is :math:`I + \text{TP}`, where :math:`I` is 1 if the task is successful, and 0 otherwise, and :math:`\text{TP}` is the time step function as defined by `NetHackScore`. """ def __init__(self, *args, **kwargs): actions = kwargs.pop("actions", TASK_ACTIONS + (nethack.Command.PICKUP, nethack.Command.DROP)) super().__init__(*args, actions=actions, **kwargs)
[docs] def reset(self, wizkit_items=None): observation = super().reset(wizkit_items=wizkit_items) inventory = observation["inv_glyphs"] self.collected_items = set(inventory) self._inv_glphys_index = self._observation_keys.index("inv_glyphs") return observation
def _is_episode_end(self, observation): return self.StepStatus.RUNNING def _reward_fn(self, last_observation, action, observation, end_status): gold_reward = min(10, super()._reward_fn(last_observation, action, observation, end_status)) inventory = observation[self._inv_glphys_index] item_reward = 0 for item in inventory: if item not in self.collected_items: self.collected_items.add(item) item_reward += 10 return item_reward + gold_reward
[docs] class NetHackSatiate(NetHackScore): """Environment for the "eat" task. The task is similar to the one defined by `NetHackScore`, but the reward uses positive changes in the character's hunger level (e.g. by consuming comestibles or monster corpses), rather than the score. """ def _reward_fn(self, last_observation, action, observation, end_status): """Difference between previous hunger and new hunger.""" del end_status # Unused del action # Unused if not self.nethack.in_normal_game(): # Before game started and after it ended blstats are zero. return 0.0 old_internal = last_observation[self._internal_index] internal = observation[self._internal_index] old_blstats = last_observation[self._blstats_index] old_uhunger = old_internal[7] uhunger = internal[7] is_satiated = old_blstats[21] == 0 if is_satiated: # If the agent is satiated, we don't want to reward it for eating reward = 0 else: # Give a reward for eating, but cap it at 10 reward = min(10, uhunger - old_uhunger) time_penalty = self._get_time_penalty(last_observation, observation) return reward + time_penalty
[docs] class NetHackScoutClipped(NetHackScore): """Environment for the "scout" task. The task is similar to the one defined by `NetHackScore`, but the score is defined by the changes in glyphs discovered by the agent. """
[docs] def reset(self, *args, **kwargs): self.dungeon_explored = {} return super().reset(*args, **kwargs)
def _reward_fn(self, last_observation, action, observation, end_status): del end_status # Unused del action # Unused if not self.nethack.in_normal_game(): # Before game started and after it ended blstats are zero. return 0.0 reward = 0 glyphs = observation[self._glyph_index] blstats = observation[self._blstats_index] dungeon_num = blstats[nethack.NLE_BL_DNUM] dungeon_level = blstats[nethack.NLE_BL_DLEVEL] key = (dungeon_num, dungeon_level) explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF) explored_old = 0 if key in self.dungeon_explored: explored_old = self.dungeon_explored[key] reward = min(5, explored - explored_old) self.dungeon_explored[key] = explored time_penalty = self._get_time_penalty(last_observation, observation) return reward + time_penalty
[docs] class NethackTaskWrapper(TaskWrapper): """ This wrapper allows you to change the task of an NLE environment. This wrapper was designed to meet two goals. 1. Allow us to change the task of the NLE environment at the start of an episode 2. Allow us to use the predefined NLE task definitions without copying/modifying their code. This makes it easier to integrate with other work on nethack tasks or curricula. Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the environment to change its task. This wrapper manipulates the __class__ property to achieve this, but does so in a safe way. Specifically, we ensure that the instance variables needed for each task are available and reset at the start of the episode regardless of which task is active. """ def __init__( self, env: gym.Env, additional_tasks: List[base.NLE] = None, use_default_tasks: bool = True, env_kwargs: Dict[str, Any] = {}, wrappers: List[Tuple[gym.Wrapper, List[Any], Dict[str, Any]]] = None, seed: int = None, ): super().__init__(env) self.env = env self.task = NetHackScore self._init_kwargs = env_kwargs if self.env.__class__ == NetHackChallenge: self._no_progress_timeout = self._init_kwargs.pop("no_progress_timeout", 150) # This is set to False during reset self.done = True # Add nethack tasks provided by the base NLE task_list: List[base.NLE] = [] if use_default_tasks: task_list = [ NetHackScore, NetHackDescend, NetHackCollect, NetHackSatiate, NetHackScoutClipped, ] # Add in custom nethack tasks if additional_tasks: for task in additional_tasks: assert isinstance(task, base.NLE), "Env must subclass the base NLE" task_list.append(task) self.task_list = task_list self.task_space = DiscreteTaskSpace(len(task_list), task_list) # Add goal space to observation # self.observation_space = copy.deepcopy(self.env.observation_space) # self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) # Task completion metrics self.episode_return = 0 # TODO: Deal with wrappers self._nethack_env = self.env while self._nethack_env.__class__ not in self.task_list and self._nethack_env.__class__ != NetHackChallenge: if self._nethack_env.__class__ == GymV21CompatibilityV0: self._nethack_env = self._nethack_env.gym_env else: self._nethack_env = self._nethack_env.env # Initialize missing instance variables self._nethack_env.oracle_glyph = None if seed is not None: self.seed(seed)
[docs] def seed(self, seed): self.env.env.seed(core=seed, disp=seed)
def _task_name(self, task): return task.__name__
[docs] def reset(self, new_task=None, **kwargs): """ Resets the environment along with all available tasks, and change the current task. This ensures that all instance variables are reset, not just the ones for the current task. We do this efficiently by keeping track of which reset functions have already been called, since very few tasks override reset. If new_task is provided, we change the task before calling the final reset. """ # Change task if new one is provided if new_task is None: new_task = kwargs.get("options", None) kwargs.pop("options", None) if new_task is not None: self.change_task(new_task) self.done = False self.episode_return = 0 obs, info = self.env.reset(**kwargs) obs["prev_action"] = 0 obs["tty_cursor"] = self.task_space.encode(self.task) return self.observation(obs), info
[docs] def change_task(self, new_task: int): """ Change task by directly editing environment class. Ignores requests for unknown tasks or task changes outside of a reset. """ # Ignore new task if mid episode if self.task.__init__ != new_task.__init__ and not self.done: print(f"Given task {self._task_name(new_task)} needs to be reinitialized.\ Ignoring request to change task and keeping {self.task.__name__}") return # Ignore if task is unknown if new_task not in self.task_list: print(f"Given task {new_task} not in task list.\ Ignoring request to change task and keeping {self.env.__class__.__name__}") return # Update current task self.task = new_task self._nethack_env.__class__ = new_task
# If task requires reinitialization # if type(self._nethack_env).__init__ != NetHackScore.__init__: # self._nethack_env.__init__(actions=nethack.ACTIONS, **self._init_kwargs) def _encode_goal(self): goal_encoding = np.zeros(len(self.task_list)) index = self.task_list.index(self.task) goal_encoding[index] = 1 return goal_encoding
[docs] def observation(self, observation): """ Parses current inventory and new items gained this timestep from the observation. Returns a modified observation. """ # Add goal to observation # observation['goal'] = self._encode_goal() return observation
def _task_completion(self, obs, rew, term, trunc, info): # TODO: Add real task completion metrics completion = 0.0 if self.task == 0: completion = self.episode_return / 1000 elif self.task == 1: completion = self.episode_return elif self.task == 2: completion = self.episode_return elif self.task == 3: completion = self.episode_return elif self.task == 4: completion = self.episode_return / 1000 elif self.task == 5: completion = self.episode_return / 10 elif self.task == 6: completion = self.episode_return / 100 return min(max(completion, 0.0), 1.0)
[docs] def step(self, action): """ Step through environment and update task completion. """ obs, rew, term, trunc, info = self.env.step(action) # self.episode_return += rew self.done = term or trunc info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) obs["prev_action"] = action obs["tty_cursor"] = self.task_space.encode(self.task) return self.observation(obs), rew, term, trunc, info
[docs] class NethackSeedWrapper(TaskWrapper): """ This wrapper allows you to change the task of an NLE environment. This wrapper was designed to meet two goals. 1. Allow us to change the task of the NLE environment at the start of an episode 2. Allow us to use the predefined NLE task definitions without copying/modifying their code. This makes it easier to integrate with other work on nethack tasks or curricula. Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the environment to change its task. This wrapper manipulates the __class__ property to achieve this, but does so in a safe way. Specifically, we ensure that the instance variables needed for each task are available and reset at the start of the episode regardless of which task is active. """ def __init__( self, env: gym.Env, seed: int = 0, num_seeds: int = 200, ): super().__init__(env) self.env = env self.task_space = DiscreteTaskSpace(num_seeds) # Task completion metrics self.episode_return = 0 self.task = seed if seed is not None: self.seed(seed) observation_dict = {key: space for key, space in self.env.observation_space.spaces.items()} observation_dict["prev_action"] = gym.spaces.Discrete(1) self.observation_space = gym.spaces.Dict(spaces=observation_dict)
[docs] def seed(self, seed): env = self.env while hasattr(env, "env"): env = env.env env.seed(core=seed, disp=seed)
def _task_name(self, task): return task.__name__
[docs] def reset(self, new_task=None, **kwargs): """ Resets the environment along with all available tasks, and change the current task. This ensures that all instance variables are reset, not just the ones for the current task. We do this efficiently by keeping track of which reset functions have already been called, since very few tasks override reset. If new_task is provided, we change the task before calling the final reset. """ self.episode_return = 0 obs, info = super().reset(new_task=new_task, **kwargs) if isinstance(obs, dict): obs["prev_action"] = 0 return self.observation(obs), info
[docs] def change_task(self, new_task: int): """ Change task by setting the seed. """ # Ignore new task if mid episode self.task = new_task self.seed(new_task)
[docs] def observation(self, observation): """ Returns a modified observation. """ if isinstance(observation, dict): encoded_task = self.task_space.encode(self.task) cursor = (encoded_task, 0) if encoded_task is not None else (-1, 0) observation["tty_cursor"] = np.asarray(cursor, dtype=np.uint8) return observation
def _task_completion(self, obs, rew, term, trunc, info): return min(max(self.episode_return / 1000, 0.0), 1.0)
[docs] def step(self, action): """ Step through environment and update task completion. """ obs, rew, term, trunc, info = super().step(action) self.episode_return += rew if isinstance(obs, dict): obs["prev_action"] = action return obs, rew, term, trunc, info
[docs] class NethackDummyWrapper(TaskWrapper): def __init__( self, env: gym.Env, num_seeds: int = 200, ): super().__init__(env) self.env = env self.task_space = DiscreteTaskSpace(num_seeds)