module moplayground.moppo.factory

PPO networks.


function make_hypernetwork_inference_fn

make_hypernetwork_inference_fn(
    ppo_networks: moplayground.moppo.factory.MOPPONetworks
)

Creates params and inference function for the PPO agent.


function make_moppo_networks

make_moppo_networks(
    observation_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]],
    action_size: int,
    num_objectives: int,
    hypersize: tuple,
    key: jax.Array,
    target_policy_params: dict = None,
    target_value_params: dict = None,
    preprocess_observations_fn: brax.training.types.PreprocessObservationFn = <function identity_observation_preprocessor at 0x797dacbfbd80>,
    policy_hidden_layer_sizes: Sequence[int] = (32, 32, 32, 32),
    value_hidden_layer_sizes: Sequence[int] = (256, 256, 256, 256, 256),
    activation: Callable[[jax.Array], jax.Array] = <PjitFunction of <function silu at 0x797df4637380>>,
    policy_obs_key: str = 'state',
    value_obs_key: str = 'state',
    distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal',
    noise_std_type: Literal['scalar', 'log'] = 'scalar',
    init_noise_std: float = 1.0,
    state_dependent_std: bool = False,
    hypertype: str = 'MLP',
    num_features: int = 8
)  MOPPONetworks

Make PPO networks with preprocessor.


function make_hypernetwork

make_hypernetwork(
    observation_size: int,
    num_objectives: int,
    target_policy_dict: dict,
    hypersize: tuple,
    hypertype: str = 'MLP',
    policy_obs_key: str = 'state',
    num_features: int = 8,
    target_value_dict: dict = None
)

function make_mo_value_network

make_mo_value_network(
    obs_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]],
    num_objectives: int,
    preprocess_observations_fn: brax.training.types.PreprocessObservationFn = <function identity_observation_preprocessor at 0x797dacbfbd80>,
    hidden_layer_sizes: Sequence[int] = (256, 256),
    activation: Callable[[jax.Array], jax.Array] = <jax._src.custom_derivatives.custom_jvp object at 0x797df49d1fd0>,
    obs_key: str = 'state'
)  FeedForwardNetwork

Creates a value network.


class FeedForwardHypernetwork

FeedForwardHypernetwork(init: Callable[…, Any], apply: Callable[…, Any], get_features: Callable[…, Any], get_flat_mlps: Callable[…, Any])

method FeedForwardHypernetwork.__init__

__init__(
    init: Callable[..., Any],
    apply: Callable[..., Any],
    get_features: Callable[..., Any],
    get_flat_mlps: Callable[..., Any]
)  None

class MOPPONetworks

MOPPONetworks(hypernetwork: moplayground.moppo.factory.FeedForwardHypernetwork, policy_network: brax.training.networks.FeedForwardNetwork, value_network: brax.training.networks.FeedForwardNetwork, parametric_action_distribution: brax.training.distribution.ParametricDistribution)

method MOPPONetworks.__init__

__init__(
    hypernetwork: moplayground.moppo.factory.FeedForwardHypernetwork,
    policy_network: brax.training.networks.FeedForwardNetwork,
    value_network: brax.training.networks.FeedForwardNetwork,
    parametric_action_distribution: brax.training.distribution.ParametricDistribution
)  None

This site uses Just the Docs, a documentation theme for Jekyll.