module moplayground.moppo.moppo

Proximal policy optimization training.

See: https://arxiv.org/pdf/1707.06347.pdf


function sample_preferences

sample_preferences(
    key,
    it,
    sampling,
    K,
    warmup_frac,
    alpha,
    num_evals,
    num_envs,
    num_objs
)

function train

train(
    environment: brax.envs.base.Env,
    num_timesteps: int,
    max_devices_per_host: Optional[int] = None,
    wrap_env: bool = True,
    madrona_backend: bool = False,
    augment_pixels: bool = False,
    num_envs: int = 1,
    episode_length: Optional[int] = None,
    action_repeat: int = 1,
    wrap_env_fn: Optional[Callable[[Any], Any]] = None,
    randomization_fn: Optional[Callable[[brax.base.System, jax.Array], Tuple[brax.base.System, brax.base.System]]] = None,
    learning_rate: float = 0.0001,
    entropy_cost: float = 0.0001,
    discounting: float = 0.9,
    unroll_length: int = 10,
    batch_size: int = 32,
    num_minibatches: int = 16,
    num_updates_per_batch: int = 2,
    num_resets_per_eval: int = 0,
    normalize_observations: bool = False,
    reward_scaling: float = 1.0,
    clipping_epsilon: float = 0.3,
    gae_lambda: float = 0.95,
    max_grad_norm: Optional[float] = None,
    normalize_advantage: bool = True,
    alpha: float = 1.0,
    warmup_frac: float = 0.0,
    sampling: str = 'dense',
    k: int = 4,
    network_factory: brax.training.types.NetworkFactory[moplayground.moppo.factory.MOPPONetworks] = <function make_moppo_networks at 0x797cd2801b20>,
    init_policy_params: dict = None,
    init_normalizer_params: dict = None,
    init_value_params: dict = None,
    seed: int = 0,
    use_pmap_on_reset: bool = True,
    num_evals: int = 1,
    eval_env: Optional[brax.envs.base.Env] = None,
    num_eval_envs: int = 128,
    deterministic_eval: bool = False,
    log_training_metrics: bool = False,
    training_metrics_steps: Optional[int] = None,
    progress_fn: Callable[[int, Mapping[str, jax.Array]], NoneType] = <function <lambda> at 0x797cd2803420>,
    policy_params_fn: Callable[..., NoneType] = <function <lambda> at 0x797cd28034c0>,
    save_checkpoint_path: Optional[str] = None,
    restore_checkpoint_path: Optional[str] = None,
    restore_params: Optional[Any] = None,
    restore_value_fn: bool = True,
    run_evals: bool = True
)

class MOTrainingState

Contains training state for the learner.

method MOTrainingState.__init__

__init__(
    optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]],
    params: moplayground.moppo.losses.MOPPONetworkParams,
    normalizer_params: brax.training.acme.running_statistics.RunningStatisticsState,
    env_steps: brax.training.types.UInt64
)  None

This site uses Just the Docs, a documentation theme for Jekyll.