Source code for syllabus.task_space.task_space

import itertools
import typing
import warnings
from typing import Any, List, Tuple, Union

import numpy as np
from gymnasium.spaces import Discrete, MultiDiscrete, Space

from syllabus.utils import UsageError


[docs]class TaskSpace: """ TaskSpace is an extension of gym spaces that allows for efficient encoding and decoding of tasks. This is useful for environments that have a large number of tasks or require complex task representations. Encoding tasks provides several advantages: 1. Minimizing the bandwidth required to transfer tasks between processes 2. Simplifying the task formats that curricula need to support 3. Allowing the environment to use a convenient and interpretable task format, with no impact on performance """ def __init__(self, space_or_value: Union[Space, int, List, Tuple], tasks: List[Any] = None): """ Generic TaskSpace initialization. Provides syntactic sugar for creating gym spaces. :param space_or_value: gym space or value that can be parsed into a gym space :type space_or_value: Union[Space, int, List, Tuple] :param tasks: The corresponding task representations :type tasks: List[Any], optional """ # Syntactic sugar for creating gym spaces if isinstance(space_or_value, Space): self.gym_space = space_or_value else: self.gym_space = self._create_gym_space(space_or_value) # Autogenerate task names if tasks is None: tasks = self._generate_task_names(self.gym_space) self._task_set = set(tasks) self._task_list = tasks def _create_gym_space(self, gym_space: Space): """ Create a gym space from a simple value. :param gym_space: A simple value to create a gym space from :type gym_space: Space :return: Created gym space :rtype: Space """ if isinstance(gym_space, int): gym_space = Discrete(gym_space) elif isinstance(gym_space, (tuple, list)): gym_space = MultiDiscrete(gym_space) return gym_space def _generate_task_names(self, gym_space: Space): """ Generate basic task names for a gym space. :param gym_space: A gymnasium space :type gym_space: Space :return: List of task names :rtype: List """ if isinstance(gym_space, Discrete): tasks = list(range(gym_space.n)) elif isinstance(gym_space, MultiDiscrete): tasks = [tuple(range(dim)) for dim in gym_space.nvec] else: tasks = [] return tasks
[docs] def decode(self, encoding: Any) -> Any: """ Convert the task encoding to the original task representation. This method provides generic decoding safety checks for all task spaces, and calls the specific _decode method for each task space. It will throw a UsageError if the encoding cannot be decoded into the task space. :param encoding: Encoding of the task :type encoding: Any :return: Decoded task that can be used by the environment :rtype: Any """ try: return self._decode(encoding) except KeyError as e: # Check if task is already in decoded form try: self._encode(encoding) warnings.warn(f"Task encoding already decoded: {encoding}", stacklevel=2) return encoding except (KeyError, TypeError): raise UsageError(f"Failed to decode task encoding: {encoding}") from e except ValueError as e: raise UsageError(f"Failed to decode task encoding: {encoding}") from e
[docs] def encode(self, task: Any) -> Any: """ Convert the task to an efficient encoding to speed up multiprocessing. This method provides generic encoding safety checks for all task spaces, and calls the specific _encode method for each task space. It will throw a UsageError if the task is not in the task space or cannot be encoded. :param task: Task to encode :type task: Any :return: Encoded task :rtype: Any """ try: return self._encode(task) except KeyError as e: try: self._decode(task) warnings.warn(f"Task already encoded: {task}", stacklevel=2) return task except (KeyError, TypeError): raise UsageError(f"Failed to encode task: {task}") from e except ValueError as e: raise UsageError(f"Failed to encode task: {task}") from e
def _decode(self, encoding: Any) -> Any: """ Convert the task encoding to the original task representation. Subclasses should implement this method for their decoding logic. :param encoding: Encoding of the task :type encoding: Any :return: Decoded task representation :rtype: Any """ raise NotImplementedError def _encode(self, task: Any) -> Any: """ Convert the task to an efficient encoding to speed up multiprocessing. Subclasses should implement this method for their encoding logic. :param task: Task to encode :type task: Any :return: Encoded task :rtype: Any """ raise NotImplementedError
[docs] def contains(self, encoding: Any) -> bool: """ Check if the encoding is a valid task in the task space. :param encoding: Encoding of the task :type encoding: Any :return: Boolean specifying if the encoding is a valid task :rtype: bool """ return encoding in self._task_set or self._decode(encoding) in self._task_set
@property def tasks(self) -> List[Any]: """ Return the list of all tasks in the task space. :return: List of all tasks :rtype: List[Any] """ return self._task_list @property def num_tasks(self) -> int: """ Return the number of tasks in the task space. :return: Number of tasks :rtype: int """ return len(self._task_list)
[docs] def task_name(self, task: int) -> str: """ Return the name of the task. :param task: Task to get the name of :type task: int :return: Name of the task :rtype: str """ return repr(self._decode(task))
[docs]class DiscreteTaskSpace(TaskSpace): """Task space for discrete tasks.""" def __init__(self, space_or_value: Union[Space, int], tasks=None): """ Initialize a discrete task space. :param space_or_value: gym space or value that can be parsed into a gym space :type space_or_value: Union[Space, int] :param tasks: The corresponding tasks representations :type tasks: List[Any], optional """ super().__init__(space_or_value, tasks) # Use space efficient implementation for sequential task spaces self._sequential = self._is_sequential(self.tasks) if self._sequential: self._first_task = self.tasks[0] # First and smallest task self._last_task = self.tasks[-1] # Last and largest task else: self._encode_map = {task: i for i, task in enumerate(self.tasks)} self._decode_map = {i: task for i, task in enumerate(self.tasks)} def _is_sequential(self, tasks: List[int]): """ Check if the tasks are sequential integers. :param tasks: List of tasks :type tasks: List[int] :return: Boolean specifying if the tasks are sequential integers :rtype: bool """ return isinstance(tasks[0], (int, np.integer)) and tuple(tasks) == tuple(range(tasks[0], tasks[-1] + 1)) def _decode(self, encoding: int) -> int: """ Convert the task encoding to the original task representation. :param encoding: Encoding of the task :type encoding: int :return: Decoded task representation :rtype: int """ assert isinstance(encoding, (int, np.integer)), f"Encoding must be an integer. Got {type(encoding)} instead." if self._sequential: task = encoding + self._first_task if task < self._first_task or task > self._last_task: raise UsageError(f"Encoding {encoding} does not map to a task in the task space") return task else: return self._decode_map[encoding] def _encode(self, task: int) -> int: """ Convert the task to an efficient encoding. :param task: Task to encode :type task: int :return: Encoded task :rtype: int """ if self._sequential: assert isinstance(task, (int, np.integer)), f"Task must be an integer. Got {type(task)} instead." if task < self._first_task or task > self._last_task: raise UsageError(f"Task {task} is not in the task space") return task - self._first_task else: return self._encode_map[task]
[docs] def sample(self) -> int: """ Sample a task from the task space. :return: Sampled task :rtype: int """ sample = self.gym_space.sample() return self._decode(sample)
[docs] def seed(self, seed: int): """ Seed the task space. :param seed: Seed value :type seed: int """ self.gym_space.seed(seed)
[docs]class BoxTaskSpace(TaskSpace): """Task space for continuous tasks.""" def _decode(self, encoding: np.ndarray) -> np.ndarray: """ Convert the task encoding to the original task representation. :param encoding: Encoding of the task :type encoding: np.ndarray :return: Decoded task representation :rtype: np.ndarray :raises UsageError: If encoding does not map to a task in the task space """ assert isinstance(encoding, np.ndarray), f"Encoding must be a numpy array. Got {type(encoding)} instead." if not self.contains(encoding): raise UsageError(f"Encoding {encoding} does not map to a task in the task space") return encoding def _encode(self, task: np.ndarray) -> np.ndarray: """ Convert the task to an efficient encoding. :param task: Task to encode :type task: np.ndarray :return: Encoded task :rtype: np.ndarray """ return task
[docs] def sample(self) -> np.ndarray: """ Sample a task from the task space. :return: Sampled task :rtype: np.ndarray """ sample = self.gym_space.sample() return self._decode(sample)
[docs] def seed(self, seed: int): """ Seed the task space. :param seed: Seed value :type seed: int """ self.gym_space.seed(seed)
@property def tasks(self) -> List[Any]: """ Return the list of all tasks in the task space. :return: List of all tasks :rtype: List[Any] """ return None @property def num_tasks(self) -> int: """ Return the number of tasks in the task space. :return: Number of tasks :rtype: int """ return -1
[docs] def task_name(self, task: np.ndarray) -> str: """ Return the name of the task. :param task: Task to get the name of :type task: np.ndarray :return: Name of the task :rtype: str """ return repr(self._decode(task))
[docs] def contains(self, encoding: np.ndarray) -> bool: """ Return boolean specifying if encoding is a valid member of this space. :param encoding: Encoding of the task :type encoding: np.ndarray :return: Boolean specifying if encoding is a valid task :rtype: bool """ if not isinstance(encoding, np.ndarray): try: encoding = np.asarray(encoding, dtype=self.gym_space.dtype) except (ValueError, TypeError): return False shape_check = encoding.shape == self.gym_space.shape bounds_check = np.all((encoding >= self.gym_space.low) & (encoding <= self.gym_space.high)) return shape_check and bounds_check
# def to_multidiscrete(self, grid_points: Union[int, List[int]]): # # Convert to Box Task Space to MultiDiscrete Task Space # if isinstance(self.gym_space, Box): # elements = self.gym_space.shape[0] # print(self.gym_space.shape)
[docs]class MultiDiscreteTaskSpace(TaskSpace): """Task space for multi-discrete tasks.""" def __init__(self, space_or_value: Union[MultiDiscrete, int], tasks: Union[List[Any], Tuple[Any]] = None, flatten: bool = False): """ Initialize a multi-discrete task space. :param space_or_value: gym space or value that can be parsed into a gym space :type space_or_value: Union[MultiDiscrete, int] :param tasks: The corresponding tasks representations :type tasks: Union[List[Any], Tuple[Any]], optional :param flatten: Whether to flatten the encoding into a discrete list :type flatten: bool """ super().__init__(space_or_value, tasks) self.flatten = flatten self._all_tasks = list(itertools.product(*self._task_list)) self._encode_maps = [{task: i for i, task in enumerate(tasks)} for tasks in self._task_list] self._decode_maps = [{i: task for i, task in enumerate(tasks)} for tasks in self._task_list] def _is_sequential(self, tasks: List[int]) -> bool: """ Check if the tasks are sequential integers. :param tasks: List of tasks :type tasks: List[int] :return: Boolean specifying if the tasks are sequential integers :rtype: bool """ return isinstance(tasks[0], (int, np.integer)) and tuple(tasks) == tuple(range(tasks[0], tasks[-1] + 1)) def _decode(self, encoding: Union[int, Tuple[int]]) -> Tuple[int]: """ Convert the task encoding to the original task representation. :param encoding: Encoding of the task :type encoding: Union[int, Tuple[int]] :return: Decoded task representation :rtype: Tuple[int] """ assert isinstance(encoding, (int, np.integer, tuple) ), f"Encoding must be an integer or tuple. Got {type(encoding)} instead." if self.flatten: assert isinstance(encoding, (int, np.integer) ), f"Encoding must be an integer. Got {type(encoding)} instead." encoding = np.unravel_index(encoding, self.gym_space.nvec) if len(encoding) != len(self._decode_maps): raise UsageError( f"Encoding length ({len(encoding)}) must match number of discrete spaces ({len(self._decode_maps)})") return tuple(decode_map[t] for decode_map, t in zip(self._decode_maps, encoding)) def _encode(self, task: Tuple[Any]) -> int: """ Convert the task to an efficient encoding. :param task: Task to encode :type task: Tuple[Any] :return: Encoded task :rtype: int """ if len(task) != len(self._encode_maps): raise UsageError( f"Task length ({len(task)}) must match number of discrete spaces ({len(self._encode_maps)})") encoding = tuple(encode_map[t] for encode_map, t in zip(self._encode_maps, task)) if self.flatten: encoding = np.ravel_multi_index(encoding, self.gym_space.nvec) return encoding
[docs] def sample(self): """ Sample a task from the task space. :return: Sampled task :rtype: int """ sample = self.gym_space.sample() if self.flatten: sample = np.ravel_multi_index(sample, self.gym_space.nvec) return self._decode(sample)
[docs] def seed(self, seed: int): """ Seed the task space. :param seed: Seed value :type seed: int """ self.gym_space.seed(seed)
@property def tasks(self) -> List[Any]: """ Return the list of all tasks in the task space. :return: List of all tasks :rtype: List[Any] """ return self._all_tasks @property def num_tasks(self) -> int: """ Return the number of tasks in the task space. :return: Number of tasks :rtype: int """ return int(np.prod(self.gym_space.nvec))
[docs]class TupleTaskSpace(TaskSpace): """Task space for tuple tasks. Can be used to combine multiple task spaces into a single task space.""" def __init__(self, task_spaces: Tuple[TaskSpace], space_names: Tuple = None, flatten: bool = False): """ Initialize a tuple task space. :param task_spaces: Tuple of task spaces :type task_spaces: Tuple[TaskSpace] :param space_names: Names of the spaces :type space_names: Tuple, optional :param flatten: Whether to flatten the encoding into a discrete list :type flatten: bool """ super().__init__(None, None) self.task_spaces = task_spaces self.space_names = space_names self.flatten = flatten if self.flatten: for space in self.task_spaces: if hasattr(space, "flatten"): space.flatten = self.flatten self._all_tasks = None self._task_nums = tuple(space.num_tasks for space in self.task_spaces) def _is_sequential(self, tasks: Tuple[int]) -> bool: """ Check if the tasks are sequential integers. :param tasks: List of tasks :type tasks: Tuple[int] :return: Boolean specifying if the tasks are sequential integers :rtype: bool """ return isinstance(tasks[0], (int, np.integer)) and tuple(tasks) == tuple(range(tasks[0], tasks[-1] + 1)) def _decode(self, encoding: Union[int, Tuple[Any]]) -> Tuple[int]: """ Convert the task encoding to the original task representation. :param encoding: Encoding of the task :type encoding: int :return: Decoded task representation :rtype: Tuple[Any] """ assert isinstance(encoding, (int, np.integer, tuple) ), f"Encoding must be an integer or tuple. Got {type(encoding)} instead." if self.flatten: assert isinstance(encoding, (int, np.integer) ), f"Encoding must be an integer. Got {type(encoding)} instead." encoding = np.unravel_index(encoding, self._task_nums) if len(encoding) != len(self.task_spaces): raise UsageError( f"Encoding length ({len(encoding)}) must match number of task spaces ({len(self.task_spaces)})") return tuple(space.decode(t) for space, t in zip(self.task_spaces, encoding)) def _encode(self, task: Tuple[Any]): """ Convert the task to an efficient encoding. :param task: Task to encode :type task: Tuple[Any] :return: Encoded task :rtype: int """ if len(task) != len(self.task_spaces): raise UsageError( f"Task length ({len(task)}) must match number of task spaces ({len(self.task_spaces)})") encoding = tuple(space.encode(t) for space, t in zip(self.task_spaces, task)) if self.flatten: encoding = np.ravel_multi_index(encoding, self._task_nums) return encoding
[docs] def contains(self, encoding: int) -> bool: """ Check if the encoding is a valid task in the task space. :param encoding: Encoding of the task :type encoding: int :return: Boolean specifying if the encoding is a valid task :rtype: bool """ for element, space in zip(encoding, self.task_spaces): if not space.contains(element): return False return True
[docs] def sample(self) -> Tuple[Any]: """ Sample a task from the task space. :return: Sampled task :rtype: Tuple[Any] """ return [space.sample() for space in self.task_spaces]
[docs] def seed(self, seed: int): """ Seed all subspaces. :param seed: Seed value :type seed: int """ for space in self.task_spaces: space.seed(seed)
@property def tasks(self) -> List[Any]: """ Return the list of all tasks in the task space. :return: List of all tasks :rtype: List[Any] """ if self._all_tasks is None: task_lists = [space.tasks for space in self.task_spaces] self._all_tasks = list(itertools.product(*task_lists)) return self._all_tasks @property def num_tasks(self) -> int: """ Return the number of tasks in the task space. :return: Number of tasks :rtype: int """ return int(np.prod(self._task_nums))
[docs] def task_name(self, task: Tuple[int]) -> str: """ Return the name of the task. :param task: Task to get the name of :type task: Tuple[int] :return: Name of the task :rtype: str """ return repr(self._decode(task))