Example Models#
PyTorch models used in example scripts.
Submodules#
syllabus.examples.models.minigrid_model module#
- class syllabus.examples.models.minigrid_model.Categorical(num_inputs, num_outputs)[source]#
Bases:
Module
Categorical distribution (NN module)
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.minigrid_model.FixedCategorical(probs=None, logits=None, validate_args=None)[source]#
Bases:
Categorical
Categorical distribution object
- class syllabus.examples.models.minigrid_model.MinigridAgent(obs_shape, num_actions, arch='small', base_kwargs=None)[source]#
Bases:
MinigridPolicy
- class syllabus.examples.models.minigrid_model.MinigridPolicy(obs_shape, num_actions, arch='small', base_kwargs=None)[source]#
Bases:
Module
Actor-Critic module
- forward()[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property is_recurrent#
Size of rnn_hx.
syllabus.examples.models.procgen_model module#
- class syllabus.examples.models.procgen_model.BasicBlock(n_channels, stride=1)[source]#
Bases:
Module
Residual Network Block
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.Categorical(num_inputs, num_outputs)[source]#
Bases:
Module
Categorical distribution (NN module)
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.Conv2d_tf(*args, **kwargs)[source]#
Bases:
Conv2d
Conv2d with the padding behavior from TF
- forward(input)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.FixedCategorical(probs=None, logits=None, validate_args=None)[source]#
Bases:
Categorical
Categorical distribution object
- class syllabus.examples.models.procgen_model.Flatten(*args, **kwargs)[source]#
Bases:
Module
Flatten a tensor
- forward(x)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class syllabus.examples.models.procgen_model.NNBase(hidden_size)[source]#
Bases:
Module
Actor-Critic network (base class)
- property output_size#
- class syllabus.examples.models.procgen_model.ProcgenAgent(obs_shape, num_actions, arch='small', base_kwargs=None)[source]#
Bases:
Module
- class syllabus.examples.models.procgen_model.ProcgenLSTMAgent(obs_shape, num_actions, base_kwargs=None)[source]#
Bases:
Module
- class syllabus.examples.models.procgen_model.ResNetBase(num_inputs, hidden_size=256, channels=[16, 32, 32])[source]#
Bases:
NNBase
Residual Network
- forward(inputs)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.