2020-04-10 00:56:08 -07:00
|
|
|
from ray.util.iter import LocalIterator
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
2022-06-14 20:59:15 +00:00
|
|
|
from ray.rllib.utils.typing import Dict, SampleBatchType
|
2020-12-24 06:30:33 -08:00
|
|
|
from ray.util.iter_metrics import MetricsContext
|
2020-04-10 00:56:08 -07:00
|
|
|
|
2021-12-21 08:39:05 +01:00
|
|
|
# Backward compatibility.
|
|
|
|
from ray.rllib.utils.metrics import ( # noqa: F401
|
|
|
|
LAST_TARGET_UPDATE_TS,
|
|
|
|
NUM_TARGET_UPDATES,
|
|
|
|
APPLY_GRADS_TIMER,
|
|
|
|
COMPUTE_GRADS_TIMER,
|
|
|
|
WORKER_UPDATE_TIMER,
|
|
|
|
GRAD_WAIT_TIMER,
|
|
|
|
SAMPLE_TIMER,
|
|
|
|
LEARN_ON_BATCH_TIMER,
|
|
|
|
LOAD_BATCH_TIMER,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-12-21 08:39:05 +01:00
|
|
|
|
2020-04-10 00:56:08 -07:00
|
|
|
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
|
2021-03-18 20:27:41 +01:00
|
|
|
AGENT_STEPS_SAMPLED_COUNTER = "num_agent_steps_sampled"
|
2020-04-10 00:56:08 -07:00
|
|
|
STEPS_TRAINED_COUNTER = "num_steps_trained"
|
2021-10-12 07:03:41 -07:00
|
|
|
STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter"
|
2021-03-18 20:27:41 +01:00
|
|
|
AGENT_STEPS_TRAINED_COUNTER = "num_agent_steps_trained"
|
2020-04-10 00:56:08 -07:00
|
|
|
|
2021-12-21 08:39:05 +01:00
|
|
|
# End: Backward compatibility.
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Asserts that an object is a type of SampleBatch.
|
2020-12-24 06:30:33 -08:00
|
|
|
def _check_sample_batch_type(batch: SampleBatchType) -> None:
|
2020-07-29 21:15:09 +02:00
|
|
|
if not isinstance(batch, (SampleBatch, MultiAgentBatch)):
|
2020-04-10 00:56:08 -07:00
|
|
|
raise ValueError(
|
|
|
|
"Expected either SampleBatch or MultiAgentBatch, "
|
|
|
|
"got {}: {}".format(type(batch), batch)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
|
2022-06-14 20:59:15 +00:00
|
|
|
# Returns pipeline global vars that should be periodically sent to each worker.
|
|
|
|
def _get_global_vars() -> Dict:
|
|
|
|
metrics = LocalIterator.get_metrics()
|
|
|
|
return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
|
|
|
|
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def _get_shared_metrics() -> MetricsContext:
|
2020-05-07 23:40:29 -07:00
|
|
|
"""Return shared metrics for the training workflow.
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
This only applies if this algorithm has an execution plan."""
|
2020-05-07 23:40:29 -07:00
|
|
|
return LocalIterator.get_metrics()
|