module moplayground.moppo.losses
Proximal policy optimization training.
See: https://arxiv.org/pdf/1707.06347.pdf
function compute_mo_gae
compute_mo_gae(
truncation: jax.Array,
termination: jax.Array,
rewards: jax.Array,
values: jax.Array,
bootstrap_value: jax.Array,
lambda_: float = 1.0,
discount: float = 0.99
)
Calculates the Generalized Advantage Estimation (GAE) in the multi-objective setting. T represents time, B is the batch size, and M is the number of objectives
Args:
truncation [T, B]: truncation signal.termination [T, B]: termination signal.rewards [T, B, M]: multi-objective rewards generated by following the policy.values [T, B, M]: multi-objective value function estimates wrt. the target policy.bootstrap_value [B, M]: multi-objective value function estimate at time T.lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to lambda_=1.discount: TD discount.
Returns:
Values [T, B, M]: can be used as target to train a baseline (V(x_t) - vs_t)^2.Advantages [T, B, M]: for policy loss
function compute_mo_ppo_loss
compute_mo_ppo_loss(
params: moplayground.moppo.losses.MOPPONetworkParams,
normalizer_params: Any,
data: moplayground.moppo.acting.MultiObjectiveTransition,
rng: jax.Array,
moppo_network: moplayground.moppo.factory.MOPPONetworks,
entropy_cost: float = 0.0001,
discounting: float = 0.9,
reward_scaling: float = 1.0,
gae_lambda: float = 0.95,
clipping_epsilon: float = 0.3,
normalize_advantage: bool = True
) → Tuple[jax.Array, Mapping[str, jax.Array]]
Computes PPO loss.
Args:
params: Network parameters,normalizer_params: Parameters of the normalizer.data: Transition that with leading dimension [B, T]. extra fields required are [‘state_extras’][‘truncation’] [‘policy_extras’][‘raw_action’] [‘policy_extras’][‘log_prob’]rng: Random keyppo_network: PPO networks.entropy_cost: entropy cost.discounting: discounting,reward_scaling: reward multiplier.gae_lambda: General advantage estimation lambda.clipping_epsilon: Policy loss clipping epsilonnormalize_advantage: whether to normalize advantage estimate
Returns: A tuple (loss, metrics)
class MOPPONetworkParams
Contains training state for the learner.
method MOPPONetworkParams.__init__
__init__(hypernetwork: Any) → None