module moplayground.moppo.acting
Brax training acting functions.
function actor_step
actor_step(
env: brax.envs.base.Env,
env_state: brax.envs.base.State,
policy: brax.training.types.Policy,
directive,
key: jax.Array,
extra_fields: Sequence[str] = ()
) → Tuple[brax.envs.base.State, moplayground.moppo.acting.MultiObjectiveTransition]
Collect data.
function generate_unroll
generate_unroll(
env: brax.envs.base.Env,
env_state: brax.envs.base.State,
policy: brax.training.types.Policy,
directive: jax.Array,
key: jax.Array,
unroll_length: int,
extra_fields: Sequence[str] = ()
) → Tuple[brax.envs.base.State, moplayground.moppo.acting.MultiObjectiveTransition]
Collect trajectories of given unroll_length.
class MultiObjectiveTransition
Container for a transition.
class Evaluator
Class to run evaluations.
method Evaluator.__init__
__init__(
eval_env: brax.envs.base.Env,
num_objs: int,
eval_policy_fn,
num_eval_envs: int,
episode_length: int,
action_repeat: int,
key: jax.Array
)
method Evaluator.run_evaluation
run_evaluation(
policy_params,
training_metrics: Mapping[str, jax.Array],
aggregate_episodes: bool = True
) → Mapping[str, jax.Array]
Run one epoch of evaluation.