Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for syllabus.tests.sync_test_env
import warnings
from copy import copy
import gymnasium as gym
from syllabus.core import PettingZooTaskEnv , TaskEnv
from syllabus.task_space import DiscreteTaskSpace
[docs]
class SyncTestEnv ( TaskEnv ):
def __init__ ( self , num_episodes , num_steps = 100 ):
super () . __init__ ()
self . num_steps = num_steps
self . action_space = gym . spaces . Discrete ( 2 )
self . observation_space = gym . spaces . Tuple (( gym . spaces . Discrete ( self . num_steps ), gym . spaces . Discrete ( 2 )))
self . task_space = DiscreteTaskSpace ( gym . spaces . Discrete ( num_episodes + 1 ),
[ "error task" ] + [ f "task { i + 1 } " for i in range ( num_episodes )])
self . task = "error_task"
[docs]
def reset ( self , new_task = None ):
if new_task == "error task" :
warnings . warn ( "Received error task. This likely means that too many tasks are being requested." , stacklevel = 2 )
if new_task is None :
warnings . warn ( "No task provided. Resetting to error task." , stacklevel = 2 )
self . task = new_task
self . _turn = 0
return ( self . _turn , None ), { "content" : "reset" , "task" : self . task }
[docs]
def step ( self , action ):
self . _turn += 1
obs = self . observation (( self . _turn , action ))
rew = 1
term = self . _turn >= self . num_steps
trunc = False
info = { "content" : "step" , "task_completion" : self . _task_completion ( obs , rew , term , trunc , {})}
return obs , rew , term , trunc , info
[docs]
class PettingZooSyncTestEnv ( PettingZooTaskEnv ):
def __init__ ( self , num_episodes , num_steps = 100 ):
super () . __init__ ()
self . num_steps = num_steps
self . possible_agents = [ "agent1" , "agent2" ]
self . _action_spaces = { agent : gym . spaces . Discrete ( 2 ) for agent in self . possible_agents }
self . observation_spaces = { agent : gym . spaces . Tuple (( gym . spaces . Discrete ( self . num_steps ), gym . spaces . Discrete ( 2 )))
for agent in self . possible_agents }
self . task_space = DiscreteTaskSpace ( gym . spaces . Discrete ( num_episodes + 1 ),
[ "error task" ] + [ f "task { i + 1 } " for i in range ( num_episodes )])
self . task = "error_task"
self . metadata = { "render.modes" : [ "human" ]}
[docs]
def action_space ( self , agent ):
return self . _action_spaces [ agent ]
[docs]
def reset ( self , new_task = None ):
self . agents = copy ( self . possible_agents )
if new_task == "error task" :
print ( ValueError ( "Received error task. This likely means that too many tasks are being requested." ))
self . task = new_task
self . _turn = 0
obs = { agent : 0.5 for agent in self . agents }
info = { agent : { "content" : "reset" , "task" : self . task } for agent in self . agents }
return obs , info
[docs]
def step ( self , action ):
self . _turn += 1
obs = { agent : self . observation (( self . _turn , action [ agent ])) for agent in self . agents }
rew = { agent : 1 for agent in self . agents }
term = { agent : self . _turn >= self . num_steps for agent in self . agents }
trunc = { agent : False for agent in self . agents }
info = { agent : { "content" : "step" , "task_completion" : self . _task_completion ( obs , rew , all ( term . values ()), all ( trunc . values ()), {})}
for agent in self . agents }
if all ( term . values ()) or all ( trunc . values ()):
self . agents = []
return obs , rew , term , trunc , info