Example Models¶
PyTorch models used in example scripts.
Submodules¶
syllabus.examples.models.minigrid_model module¶
- class syllabus.examples.models.minigrid_model.Categorical(num_inputs, num_outputs)[source]¶
Bases:
ModuleCategorical distribution (NN module)
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.minigrid_model.FixedCategorical(probs=None, logits=None, validate_args=None)[source]¶
Bases:
CategoricalCategorical distribution object
- class syllabus.examples.models.minigrid_model.MinigridAgent(obs_shape, num_actions, arch='small', base_kwargs=None)[source]¶
Bases:
MinigridPolicy
- class syllabus.examples.models.minigrid_model.MinigridPolicy(obs_shape, num_actions, arch='small', base_kwargs=None)[source]¶
Bases:
ModuleActor-Critic module
- forward()[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property is_recurrent¶
Size of rnn_hx.
syllabus.examples.models.procgen_model module¶
- class syllabus.examples.models.procgen_model.BasicBlock(n_channels, stride=1)[source]¶
Bases:
ModuleResidual Network Block
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.Categorical(num_inputs, num_outputs)[source]¶
Bases:
ModuleCategorical distribution (NN module)
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.Conv2d_tf(*args, **kwargs)[source]¶
Bases:
Conv2dConv2d with the padding behavior from TF
- forward(input)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.FixedCategorical(probs=None, logits=None, validate_args=None)[source]¶
Bases:
CategoricalCategorical distribution object
- class syllabus.examples.models.procgen_model.Flatten(*args, **kwargs)[source]¶
Bases:
ModuleFlatten a tensor
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.NNBase(hidden_size)[source]¶
Bases:
ModuleActor-Critic network (base class)
- property output_size¶
- class syllabus.examples.models.procgen_model.ProcgenAgent(obs_shape, num_actions, base_kwargs=None, detach_critic=False)[source]¶
Bases:
Module
- class syllabus.examples.models.procgen_model.ProcgenLSTMAgent(obs_shape, num_actions, base_kwargs=None, detach_critic=False)[source]¶
Bases:
Module
- class syllabus.examples.models.procgen_model.ResNetBase(num_inputs, hidden_size=256, channels=[16, 32, 32])[source]¶
Bases:
NNBaseResidual Network
- forward(inputs)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.