mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
124 lines
3.4 KiB
Python
124 lines
3.4 KiB
Python
import os
|
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
|
|
|
|
@DeveloperAPI
|
|
class EvaluatorInterface:
|
|
"""This is the interface between policy optimizers and policy evaluation.
|
|
|
|
See also: RolloutWorker
|
|
"""
|
|
|
|
@DeveloperAPI
|
|
def sample(self):
|
|
"""Returns a batch of experience sampled from this evaluator.
|
|
|
|
This method must be implemented by subclasses.
|
|
|
|
Returns:
|
|
SampleBatch|MultiAgentBatch: A columnar batch of experiences
|
|
(e.g., tensors), or a multi-agent batch.
|
|
|
|
Examples:
|
|
>>> print(ev.sample())
|
|
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def learn_on_batch(self, samples):
|
|
"""Update policies based on the given batch.
|
|
|
|
This is the equivalent to apply_gradients(compute_gradients(samples)),
|
|
but can be optimized to avoid pulling gradients into CPU memory.
|
|
|
|
Either this or the combination of compute/apply grads must be
|
|
implemented by subclasses.
|
|
|
|
Returns:
|
|
info: dictionary of extra metadata from compute_gradients().
|
|
|
|
Examples:
|
|
>>> batch = ev.sample()
|
|
>>> ev.learn_on_batch(samples)
|
|
"""
|
|
|
|
grads, info = self.compute_gradients(samples)
|
|
self.apply_gradients(grads)
|
|
return info
|
|
|
|
@DeveloperAPI
|
|
def compute_gradients(self, samples):
|
|
"""Returns a gradient computed w.r.t the specified samples.
|
|
|
|
Either this or learn_on_batch() must be implemented by subclasses.
|
|
|
|
Returns:
|
|
(grads, info): A list of gradients that can be applied on a
|
|
compatible evaluator. In the multi-agent case, returns a dict
|
|
of gradients keyed by policy ids. An info dictionary of
|
|
extra metadata is also returned.
|
|
|
|
Examples:
|
|
>>> batch = ev.sample()
|
|
>>> grads, info = ev2.compute_gradients(samples)
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def apply_gradients(self, grads):
|
|
"""Applies the given gradients to this evaluator's weights.
|
|
|
|
Either this or learn_on_batch() must be implemented by subclasses.
|
|
|
|
Examples:
|
|
>>> samples = ev1.sample()
|
|
>>> grads, info = ev2.compute_gradients(samples)
|
|
>>> ev1.apply_gradients(grads)
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def get_weights(self):
|
|
"""Returns the model weights of this Evaluator.
|
|
|
|
This method must be implemented by subclasses.
|
|
|
|
Returns:
|
|
object: weights that can be set on a compatible evaluator.
|
|
info: dictionary of extra metadata.
|
|
|
|
Examples:
|
|
>>> weights = ev1.get_weights()
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def set_weights(self, weights):
|
|
"""Sets the model weights of this Evaluator.
|
|
|
|
This method must be implemented by subclasses.
|
|
|
|
Examples:
|
|
>>> weights = ev1.get_weights()
|
|
>>> ev2.set_weights(weights)
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def get_host(self):
|
|
"""Returns the hostname of the process running this evaluator."""
|
|
|
|
return os.uname()[1]
|
|
|
|
@DeveloperAPI
|
|
def apply(self, func, *args):
|
|
"""Apply the given function to this evaluator instance."""
|
|
|
|
return func(self, *args)
|