import warnings
from copy import copy
import gymnasium as gym
from syllabus.core import PettingZooTaskEnv, TaskEnv
from syllabus.task_space import DiscreteTaskSpace
[docs]class SyncTestEnv(TaskEnv):
def __init__(self, num_episodes, num_steps=100):
super().__init__()
self.num_steps = num_steps
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(self.num_steps), gym.spaces.Discrete(2)))
self.task_space = DiscreteTaskSpace(gym.spaces.Discrete(num_episodes + 1),
["error task"] + [f"task {i+1}" for i in range(num_episodes)])
self.task = "error_task"
[docs] def reset(self, new_task=None):
if new_task == "error task":
warnings.warn("Received error task. This likely means that too many tasks are being requested.", stacklevel=2)
if new_task is None:
warnings.warn("No task provided. Resetting to error task.", stacklevel=2)
self.task = new_task
self._turn = 0
return (self._turn, None), {"content": "reset", "task": self.task}
[docs] def step(self, action):
self._turn += 1
obs = self.observation((self._turn, action))
rew = 1
term = self._turn >= self.num_steps
trunc = False
info = {"content": "step", "task_completion": self._task_completion(obs, rew, term, trunc, {})}
return obs, rew, term, trunc, info
[docs]class PettingZooSyncTestEnv(PettingZooTaskEnv):
def __init__(self, num_episodes, num_steps=100):
super().__init__()
self.num_steps = num_steps
self.possible_agents = ["agent1", "agent2"]
self._action_spaces = {agent: gym.spaces.Discrete(2) for agent in self.possible_agents}
self.observation_spaces = {agent: gym.spaces.Tuple((gym.spaces.Discrete(self.num_steps), gym.spaces.Discrete(2)))
for agent in self.possible_agents}
self.task_space = DiscreteTaskSpace(gym.spaces.Discrete(num_episodes + 1),
["error task"] + [f"task {i+1}" for i in range(num_episodes)])
self.task = "error_task"
self.metadata = {"render.modes": ["human"]}
[docs] def action_space(self, agent):
return self._action_spaces[agent]
[docs] def reset(self, new_task=None):
self.agents = copy(self.possible_agents)
if new_task == "error task":
print(ValueError("Received error task. This likely means that too many tasks are being requested."))
self.task = new_task
self._turn = 0
obs = {agent: 0.5 for agent in self.agents}
info = {agent: {"content": "reset", "task": self.task} for agent in self.agents}
return obs, info
[docs] def step(self, action):
self._turn += 1
obs = {agent: self.observation((self._turn, action[agent])) for agent in self.agents}
rew = {agent: 1 for agent in self.agents}
term = {agent: self._turn >= self.num_steps for agent in self.agents}
trunc = {agent: False for agent in self.agents}
info = {agent: {"content": "step", "task_completion": self._task_completion(obs, rew, all(term.values()), all(trunc.values()), {})}
for agent in self.agents}
if all(term.values()) or all(trunc.values()):
self.agents = []
return obs, rew, term, trunc, info