module moplayground.moppo.acting

Brax training acting functions.


function actor_step

actor_step(
    env: brax.envs.base.Env,
    env_state: brax.envs.base.State,
    policy: brax.training.types.Policy,
    directive,
    key: jax.Array,
    extra_fields: Sequence[str] = ()
)  Tuple[brax.envs.base.State, moplayground.moppo.acting.MultiObjectiveTransition]

Collect data.


function generate_unroll

generate_unroll(
    env: brax.envs.base.Env,
    env_state: brax.envs.base.State,
    policy: brax.training.types.Policy,
    directive: jax.Array,
    key: jax.Array,
    unroll_length: int,
    extra_fields: Sequence[str] = ()
)  Tuple[brax.envs.base.State, moplayground.moppo.acting.MultiObjectiveTransition]

Collect trajectories of given unroll_length.


class MultiObjectiveTransition

Container for a transition.


class Evaluator

Class to run evaluations.

method Evaluator.__init__

__init__(
    eval_env: brax.envs.base.Env,
    num_objs: int,
    eval_policy_fn,
    num_eval_envs: int,
    episode_length: int,
    action_repeat: int,
    key: jax.Array
)

method Evaluator.run_evaluation

run_evaluation(
    policy_params,
    training_metrics: Mapping[str, jax.Array],
    aggregate_episodes: bool = True
)  Mapping[str, jax.Array]

Run one epoch of evaluation.


This site uses Just the Docs, a documentation theme for Jekyll.