module moplayground.moppo.losses

Proximal policy optimization training.

See: https://arxiv.org/pdf/1707.06347.pdf


function compute_mo_gae

compute_mo_gae(
    truncation: jax.Array,
    termination: jax.Array,
    rewards: jax.Array,
    values: jax.Array,
    bootstrap_value: jax.Array,
    lambda_: float = 1.0,
    discount: float = 0.99
)

Calculates the Generalized Advantage Estimation (GAE) in the multi-objective setting. T represents time, B is the batch size, and M is the number of objectives

Args:

  • truncation [T, B]: truncation signal.
  • termination [T, B]: termination signal.
  • rewards [T, B, M]: multi-objective rewards generated by following the policy.
  • values [T, B, M]: multi-objective value function estimates wrt. the target policy.
  • bootstrap_value [B, M]: multi-objective value function estimate at time T.
  • lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to lambda_=1.
  • discount: TD discount.

Returns:

  • Values [T, B, M]: can be used as target to train a baseline (V(x_t) - vs_t)^2.
  • Advantages [T, B, M]: for policy loss

function compute_mo_ppo_loss

compute_mo_ppo_loss(
    params: moplayground.moppo.losses.MOPPONetworkParams,
    normalizer_params: Any,
    data: moplayground.moppo.acting.MultiObjectiveTransition,
    rng: jax.Array,
    moppo_network: moplayground.moppo.factory.MOPPONetworks,
    entropy_cost: float = 0.0001,
    discounting: float = 0.9,
    reward_scaling: float = 1.0,
    gae_lambda: float = 0.95,
    clipping_epsilon: float = 0.3,
    normalize_advantage: bool = True
)  Tuple[jax.Array, Mapping[str, jax.Array]]

Computes PPO loss.

Args:

  • params: Network parameters,
  • normalizer_params: Parameters of the normalizer.
  • data: Transition that with leading dimension [B, T]. extra fields required are [‘state_extras’][‘truncation’] [‘policy_extras’][‘raw_action’] [‘policy_extras’][‘log_prob’]
  • rng: Random key
  • ppo_network: PPO networks.
  • entropy_cost: entropy cost.
  • discounting: discounting,
  • reward_scaling: reward multiplier.
  • gae_lambda: General advantage estimation lambda.
  • clipping_epsilon: Policy loss clipping epsilon
  • normalize_advantage: whether to normalize advantage estimate

Returns: A tuple (loss, metrics)


class MOPPONetworkParams

Contains training state for the learner.

method MOPPONetworkParams.__init__

__init__(hypernetwork: Any)  None

This site uses Just the Docs, a documentation theme for Jekyll.