module moplayground.moppo.factory
PPO networks.
function make_hypernetwork_inference_fn
make_hypernetwork_inference_fn(
ppo_networks: moplayground.moppo.factory.MOPPONetworks
)
Creates params and inference function for the PPO agent.
function make_moppo_networks
make_moppo_networks(
observation_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]],
action_size: int,
num_objectives: int,
hypersize: tuple,
key: jax.Array,
target_policy_params: dict = None,
target_value_params: dict = None,
preprocess_observations_fn: brax.training.types.PreprocessObservationFn = <function identity_observation_preprocessor at 0x797dacbfbd80>,
policy_hidden_layer_sizes: Sequence[int] = (32, 32, 32, 32),
value_hidden_layer_sizes: Sequence[int] = (256, 256, 256, 256, 256),
activation: Callable[[jax.Array], jax.Array] = <PjitFunction of <function silu at 0x797df4637380>>,
policy_obs_key: str = 'state',
value_obs_key: str = 'state',
distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal',
noise_std_type: Literal['scalar', 'log'] = 'scalar',
init_noise_std: float = 1.0,
state_dependent_std: bool = False,
hypertype: str = 'MLP',
num_features: int = 8
) → MOPPONetworks
Make PPO networks with preprocessor.
function make_hypernetwork
make_hypernetwork(
observation_size: int,
num_objectives: int,
target_policy_dict: dict,
hypersize: tuple,
hypertype: str = 'MLP',
policy_obs_key: str = 'state',
num_features: int = 8,
target_value_dict: dict = None
)
function make_mo_value_network
make_mo_value_network(
obs_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]],
num_objectives: int,
preprocess_observations_fn: brax.training.types.PreprocessObservationFn = <function identity_observation_preprocessor at 0x797dacbfbd80>,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: Callable[[jax.Array], jax.Array] = <jax._src.custom_derivatives.custom_jvp object at 0x797df49d1fd0>,
obs_key: str = 'state'
) → FeedForwardNetwork
Creates a value network.
class FeedForwardHypernetwork
FeedForwardHypernetwork(init: Callable[…, Any], apply: Callable[…, Any], get_features: Callable[…, Any], get_flat_mlps: Callable[…, Any])
method FeedForwardHypernetwork.__init__
__init__(
init: Callable[..., Any],
apply: Callable[..., Any],
get_features: Callable[..., Any],
get_flat_mlps: Callable[..., Any]
) → None
class MOPPONetworks
MOPPONetworks(hypernetwork: moplayground.moppo.factory.FeedForwardHypernetwork, policy_network: brax.training.networks.FeedForwardNetwork, value_network: brax.training.networks.FeedForwardNetwork, parametric_action_distribution: brax.training.distribution.ParametricDistribution)
method MOPPONetworks.__init__
__init__(
hypernetwork: moplayground.moppo.factory.FeedForwardHypernetwork,
policy_network: brax.training.networks.FeedForwardNetwork,
value_network: brax.training.networks.FeedForwardNetwork,
parametric_action_distribution: brax.training.distribution.ParametricDistribution
) → None