""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """
from typing import Any, Dict, List, Tuple
import gymnasium as gym
import numpy as np
from nle import nethack
from nle.env import base
from nle.env.tasks import TASK_ACTIONS, NetHackChallenge, NetHackGold, NetHackScore
from shimmy.openai_gym_compatibility import GymV21CompatibilityV0
from syllabus.core import TaskWrapper
from syllabus.task_space import DiscreteTaskSpace
[docs]class NetHackSeed(NetHackScore):
"""Environment for the NetHack Challenge.
The task is an augmentation of the standard NLE task. This is the NLE Score Task
but with some subtle differences:
* the action space is fixed to include the full keyboard
* menus and "<More>" tokens are not skipped
* starting character is randomly assigned
"""
def __init__(
self,
*args,
character="@",
allow_all_yn_questions=True,
allow_all_modes=True,
penalty_mode="constant",
penalty_step: float = -0.00,
penalty_time: float = -0.0,
max_episode_steps: int = 1e6,
observation_keys=(
"glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"inv_glyphs",
"inv_strs",
"inv_letters",
"inv_oclasses",
"tty_chars",
"tty_colors",
"tty_cursor",
"misc",
),
no_progress_timeout: int = 10_000,
**kwargs,
):
actions = nethack.ACTIONS
kwargs["wizard"] = False
super().__init__(
*args,
actions=actions,
character=character,
allow_all_yn_questions=allow_all_yn_questions,
allow_all_modes=allow_all_modes,
penalty_mode=penalty_mode,
penalty_step=penalty_step,
penalty_time=penalty_time,
max_episode_steps=max_episode_steps,
observation_keys=observation_keys,
**kwargs,
)
# If the in-game turn count doesn't change for 10_000 steps, we abort
self.no_progress_timeout = no_progress_timeout
[docs] def reset(self, *args, **kwargs):
self._turns = None
self._no_progress_count = 0
return super().reset(*args, **kwargs)
def _check_abort(self, observation):
"""Check if time has stopped and no observations has changed long enough
to trigger an abort."""
turns = observation[self._blstats_index][nethack.NLE_BL_TIME]
if self._turns == turns:
self._no_progress_count += 1
else:
self._turns = turns
self._no_progress_count = 0
return (
self._steps >= self._max_episode_steps
or self._no_progress_count >= self.no_progress_timeout
)
[docs]class NetHackDescend(NetHackScore):
"""Environment for "staircase" task.
This task requires the agent to get on top of a staircase down (>).
The reward function is :math:`I + \text{TP}`, where :math:`I` is 1 if the
task is successful, and 0 otherwise, and :math:`\text{TP}` is the time step
function as defined by `NetHackScore`.
"""
[docs] def reset(self, wizkit_items=None):
self.max_dungeon_level = 1
return super().reset(wizkit_items=wizkit_items)
def _is_episode_end(self, observation):
return self.StepStatus.RUNNING
def _reward_fn(self, last_observation, action, observation, end_status):
del action, end_status
time_penalty = self._get_time_penalty(last_observation, observation)
dungeon_level = observation[self._blstats_index][12]
if dungeon_level > self.max_dungeon_level:
reward = 100
self.max_dungeon_level = dungeon_level
else:
reward = 0
return reward + time_penalty
[docs]class NetHackCollect(NetHackGold):
"""Environment for "staircase" task.
This task requires the agent to get on top of a staircase down (>).
The reward function is :math:`I + \text{TP}`, where :math:`I` is 1 if the
task is successful, and 0 otherwise, and :math:`\text{TP}` is the time step
function as defined by `NetHackScore`.
"""
def __init__(self, *args, **kwargs):
actions = kwargs.pop("actions", TASK_ACTIONS + (nethack.Command.PICKUP, nethack.Command.DROP))
super().__init__(*args, actions=actions, **kwargs)
[docs] def reset(self, wizkit_items=None):
observation = super().reset(wizkit_items=wizkit_items)
inventory = observation["inv_glyphs"]
self.collected_items = set(inventory)
self._inv_glphys_index = self._observation_keys.index("inv_glyphs")
return observation
def _is_episode_end(self, observation):
return self.StepStatus.RUNNING
def _reward_fn(self, last_observation, action, observation, end_status):
gold_reward = min(10, super()._reward_fn(last_observation, action, observation, end_status))
inventory = observation[self._inv_glphys_index]
item_reward = 0
for item in inventory:
if item not in self.collected_items:
self.collected_items.add(item)
item_reward += 10
return item_reward + gold_reward
[docs]class NetHackSatiate(NetHackScore):
"""Environment for the "eat" task.
The task is similar to the one defined by `NetHackScore`, but the reward
uses positive changes in the character's hunger level (e.g. by consuming
comestibles or monster corpses), rather than the score.
"""
def _reward_fn(self, last_observation, action, observation, end_status):
"""Difference between previous hunger and new hunger."""
del end_status # Unused
del action # Unused
if not self.nethack.in_normal_game():
# Before game started and after it ended blstats are zero.
return 0.0
old_internal = last_observation[self._internal_index]
internal = observation[self._internal_index]
old_blstats = last_observation[self._blstats_index]
old_uhunger = old_internal[7]
uhunger = internal[7]
is_satiated = old_blstats[21] == 0
if is_satiated:
# If the agent is satiated, we don't want to reward it for eating
reward = 0
else:
# Give a reward for eating, but cap it at 10
reward = min(10, uhunger - old_uhunger)
time_penalty = self._get_time_penalty(last_observation, observation)
return reward + time_penalty
[docs]class NetHackScoutClipped(NetHackScore):
"""Environment for the "scout" task.
The task is similar to the one defined by `NetHackScore`, but the score is
defined by the changes in glyphs discovered by the agent.
"""
[docs] def reset(self, *args, **kwargs):
self.dungeon_explored = {}
return super().reset(*args, **kwargs)
def _reward_fn(self, last_observation, action, observation, end_status):
del end_status # Unused
del action # Unused
if not self.nethack.in_normal_game():
# Before game started and after it ended blstats are zero.
return 0.0
reward = 0
glyphs = observation[self._glyph_index]
blstats = observation[self._blstats_index]
dungeon_num = blstats[nethack.NLE_BL_DNUM]
dungeon_level = blstats[nethack.NLE_BL_DLEVEL]
key = (dungeon_num, dungeon_level)
explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF)
explored_old = 0
if key in self.dungeon_explored:
explored_old = self.dungeon_explored[key]
reward = min(5, explored - explored_old)
self.dungeon_explored[key] = explored
time_penalty = self._get_time_penalty(last_observation, observation)
return reward + time_penalty
[docs]class NethackTaskWrapper(TaskWrapper):
"""
This wrapper allows you to change the task of an NLE environment.
This wrapper was designed to meet two goals.
1. Allow us to change the task of the NLE environment at the start of an episode
2. Allow us to use the predefined NLE task definitions without copying/modifying their code.
This makes it easier to integrate with other work on nethack tasks or curricula.
Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the
environment to change its task. This wrapper manipulates the __class__ property to achieve this,
but does so in a safe way. Specifically, we ensure that the instance variables needed for each
task are available and reset at the start of the episode regardless of which task is active.
"""
def __init__(
self,
env: gym.Env,
additional_tasks: List[base.NLE] = None,
use_default_tasks: bool = True,
env_kwargs: Dict[str, Any] = {},
wrappers: List[Tuple[gym.Wrapper, List[Any], Dict[str, Any]]] = None,
seed: int = None,
):
super().__init__(env)
self.env = env
self.task = NetHackScore
self._init_kwargs = env_kwargs
if self.env.__class__ == NetHackChallenge:
self._no_progress_timeout = self._init_kwargs.pop("no_progress_timeout", 150)
# This is set to False during reset
self.done = True
# Add nethack tasks provided by the base NLE
task_list: List[base.NLE] = []
if use_default_tasks:
task_list = [
NetHackScore,
NetHackDescend,
NetHackCollect,
NetHackSatiate,
NetHackScoutClipped,
]
# Add in custom nethack tasks
if additional_tasks:
for task in additional_tasks:
assert isinstance(task, base.NLE), "Env must subclass the base NLE"
task_list.append(task)
self.task_list = task_list
self.task_space = DiscreteTaskSpace(len(task_list), task_list)
# Add goal space to observation
# self.observation_space = copy.deepcopy(self.env.observation_space)
# self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list))
# Task completion metrics
self.episode_return = 0
# TODO: Deal with wrappers
self._nethack_env = self.env
while self._nethack_env.__class__ not in self.task_list and self._nethack_env.__class__ != NetHackChallenge:
if self._nethack_env.__class__ == GymV21CompatibilityV0:
self._nethack_env = self._nethack_env.gym_env
else:
self._nethack_env = self._nethack_env.env
# Initialize missing instance variables
self._nethack_env.oracle_glyph = None
if seed is not None:
self.seed(seed)
[docs] def seed(self, seed):
self.env.env.seed(core=seed, disp=seed)
def _task_name(self, task):
return task.__name__
[docs] def reset(self, new_task=None, **kwargs):
"""
Resets the environment along with all available tasks, and change the current task.
This ensures that all instance variables are reset, not just the ones for the current task.
We do this efficiently by keeping track of which reset functions have already been called,
since very few tasks override reset. If new_task is provided, we change the task before
calling the final reset.
"""
# Change task if new one is provided
if new_task is None:
new_task = kwargs.get("options", None)
kwargs.pop("options", None)
if new_task is not None:
self.change_task(new_task)
self.done = False
self.episode_return = 0
obs, info = self.env.reset(**kwargs)
obs["prev_action"] = 0
obs["tty_cursor"] = self.task_space.encode(self.task)
return self.observation(obs), info
[docs] def change_task(self, new_task: int):
"""
Change task by directly editing environment class.
Ignores requests for unknown tasks or task changes outside of a reset.
"""
# Ignore new task if mid episode
if self.task.__init__ != new_task.__init__ and not self.done:
print(f"Given task {self._task_name(new_task)} needs to be reinitialized.\
Ignoring request to change task and keeping {self.task.__name__}")
return
# Ignore if task is unknown
if new_task not in self.task_list:
print(f"Given task {new_task} not in task list.\
Ignoring request to change task and keeping {self.env.__class__.__name__}")
return
# Update current task
self.task = new_task
self._nethack_env.__class__ = new_task
# If task requires reinitialization
# if type(self._nethack_env).__init__ != NetHackScore.__init__:
# self._nethack_env.__init__(actions=nethack.ACTIONS, **self._init_kwargs)
def _encode_goal(self):
goal_encoding = np.zeros(len(self.task_list))
index = self.task_list.index(self.task)
goal_encoding[index] = 1
return goal_encoding
[docs] def observation(self, observation):
"""
Parses current inventory and new items gained this timestep from the observation.
Returns a modified observation.
"""
# Add goal to observation
# observation['goal'] = self._encode_goal()
return observation
def _task_completion(self, obs, rew, term, trunc, info):
# TODO: Add real task completion metrics
completion = 0.0
if self.task == 0:
completion = self.episode_return / 1000
elif self.task == 1:
completion = self.episode_return
elif self.task == 2:
completion = self.episode_return
elif self.task == 3:
completion = self.episode_return
elif self.task == 4:
completion = self.episode_return / 1000
elif self.task == 5:
completion = self.episode_return / 10
elif self.task == 6:
completion = self.episode_return / 100
return min(max(completion, 0.0), 1.0)
[docs] def step(self, action):
"""
Step through environment and update task completion.
"""
obs, rew, term, trunc, info = self.env.step(action)
# self.episode_return += rew
self.done = term or trunc
info["task_completion"] = self._task_completion(obs, rew, term, trunc, info)
obs["prev_action"] = action
obs["tty_cursor"] = self.task_space.encode(self.task)
return self.observation(obs), rew, term, trunc, info
[docs]class NethackSeedWrapper(TaskWrapper):
"""
This wrapper allows you to change the task of an NLE environment.
This wrapper was designed to meet two goals.
1. Allow us to change the task of the NLE environment at the start of an episode
2. Allow us to use the predefined NLE task definitions without copying/modifying their code.
This makes it easier to integrate with other work on nethack tasks or curricula.
Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the
environment to change its task. This wrapper manipulates the __class__ property to achieve this,
but does so in a safe way. Specifically, we ensure that the instance variables needed for each
task are available and reset at the start of the episode regardless of which task is active.
"""
def __init__(
self,
env: gym.Env,
seed: int = 0,
num_seeds: int = 200,
):
super().__init__(env)
self.env = env
self.task_space = DiscreteTaskSpace(num_seeds)
# Task completion metrics
self.episode_return = 0
self.task = seed
if seed is not None:
self.seed(seed)
[docs] def seed(self, seed):
env = self.env
while hasattr(env, "env"):
env = env.env
env.seed(core=seed, disp=seed)
def _task_name(self, task):
return task.__name__
[docs] def reset(self, new_task=None, **kwargs):
"""
Resets the environment along with all available tasks, and change the current task.
This ensures that all instance variables are reset, not just the ones for the current task.
We do this efficiently by keeping track of which reset functions have already been called,
since very few tasks override reset. If new_task is provided, we change the task before
calling the final reset.
"""
# Change task if new one is provided
if new_task is None:
new_task = kwargs.get("options", None)
if new_task is not None:
self.change_task(new_task)
self.episode_return = 0
obs, info = self.env.reset(**kwargs)
if isinstance(obs, dict):
obs["prev_action"] = 0
encoded_task = self.task_space.encode(self.task)
obs["tty_cursor"] = encoded_task if encoded_task is not None else -1
return self.observation(obs), info
[docs] def change_task(self, new_task: int):
"""
Change task by setting the seed.
"""
# Ignore new task if mid episode
self.task = new_task
self.seed(new_task)
[docs] def observation(self, observation):
"""
Parses current inventory and new items gained this timestep from the observation.
Returns a modified observation.
"""
return observation
[docs] def step(self, action):
"""
Step through environment and update task completion.
"""
obs, rew, term, trunc, info = self.env.step(action)
if isinstance(obs, dict):
obs["prev_action"] = action
encoded_task = self.task_space.encode(self.task)
obs["tty_cursor"] = encoded_task if encoded_task is not None else -1
return self.observation(obs), rew, term, trunc, info