mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
309 lines
10 KiB
Python
309 lines
10 KiB
Python
import os
|
|
import pickle
|
|
import threading
|
|
from collections import deque
|
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Type
|
|
|
|
import gym
|
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
|
from ray.rllib.utils.annotations import PublicAPI, override
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.policy import create_policy_for_framework
|
|
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
|
|
from ray.rllib.utils.threading import with_lock
|
|
from ray.rllib.utils.typing import (
|
|
AlgorithmConfigDict,
|
|
PartialAlgorithmConfigDict,
|
|
PolicyID,
|
|
)
|
|
from ray.tune.utils.util import merge_dicts
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
@PublicAPI
|
|
class PolicyMap(dict):
|
|
"""Maps policy IDs to Policy objects.
|
|
|
|
Thereby, keeps n policies in memory and - when capacity is reached -
|
|
writes the least recently used to disk. This allows adding 100s of
|
|
policies to a Algorithm for league-based setups w/o running out of memory.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
worker_index: int,
|
|
num_workers: int,
|
|
capacity: Optional[int] = None,
|
|
path: Optional[str] = None,
|
|
policy_config: Optional[AlgorithmConfigDict] = None,
|
|
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
|
seed: Optional[int] = None,
|
|
):
|
|
"""Initializes a PolicyMap instance.
|
|
|
|
Args:
|
|
worker_index: The worker index of the RolloutWorker this map
|
|
resides in.
|
|
num_workers: The total number of remote workers in the
|
|
WorkerSet to which this map's RolloutWorker belongs to.
|
|
capacity: The maximum number of policies to hold in memory.
|
|
The least used ones are written to disk/S3 and retrieved
|
|
when needed.
|
|
path: The path to store the policy pickle files to. Files
|
|
will have the name: [policy_id].[worker idx].policy.pkl.
|
|
policy_config: The Algorithm's base config dict.
|
|
session_creator: An optional
|
|
tf1.Session creation callable.
|
|
seed: An optional seed (used to seed tf policies).
|
|
"""
|
|
super().__init__()
|
|
|
|
self.worker_index = worker_index
|
|
self.num_workers = num_workers
|
|
self.session_creator = session_creator
|
|
self.seed = seed
|
|
|
|
# The file extension for stashed policies (that are no longer available
|
|
# in-memory but can be reinstated any time from storage).
|
|
self.extension = f".{self.worker_index}.policy.pkl"
|
|
|
|
# Dictionary of keys that may be looked up (cached or not).
|
|
self.valid_keys: Set[str] = set()
|
|
# The actual cache with the in-memory policy objects.
|
|
self.cache: Dict[str, Policy] = {}
|
|
# The doubly-linked list holding the currently in-memory objects.
|
|
self.deque = deque(maxlen=capacity or 10)
|
|
# The file path where to store overflowing policies.
|
|
self.path = path or "."
|
|
# The core config to use. Each single policy's config override is
|
|
# added on top of this.
|
|
self.policy_config: AlgorithmConfigDict = policy_config or {}
|
|
# The orig classes/obs+act spaces, and config overrides of the
|
|
# Policies.
|
|
self.policy_specs: Dict[PolicyID, PolicySpec] = {}
|
|
|
|
# Lock used for locking some methods on the object-level.
|
|
# This prevents possible race conditions when accessing the map
|
|
# and the underlying structures, like self.deque and others.
|
|
self._lock = threading.RLock()
|
|
|
|
def create_policy(
|
|
self,
|
|
policy_id: PolicyID,
|
|
policy_cls: Type["Policy"],
|
|
observation_space: gym.Space,
|
|
action_space: gym.Space,
|
|
config_override: PartialAlgorithmConfigDict,
|
|
merged_config: AlgorithmConfigDict,
|
|
) -> None:
|
|
"""Creates a new policy and stores it to the cache.
|
|
|
|
Args:
|
|
policy_id: The policy ID. This is the key under which
|
|
the created policy will be stored in this map.
|
|
policy_cls: The (original) policy class to use.
|
|
This may still be altered in case tf-eager (and tracing)
|
|
is used.
|
|
observation_space: The observation space of the
|
|
policy.
|
|
action_space: The action space of the policy.
|
|
config_override: The config override
|
|
dict for this policy. This is the partial dict provided by
|
|
the user.
|
|
merged_config: The entire config (merged
|
|
default config + `config_override`).
|
|
"""
|
|
_class = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
|
|
|
|
self[policy_id] = create_policy_for_framework(
|
|
policy_id,
|
|
_class,
|
|
merged_config,
|
|
observation_space,
|
|
action_space,
|
|
self.worker_index,
|
|
self.session_creator,
|
|
self.seed,
|
|
)
|
|
|
|
# Store spec (class, obs-space, act-space, and config overrides) such
|
|
# that the map will be able to reproduce on-the-fly added policies
|
|
# from disk.
|
|
self.policy_specs[policy_id] = PolicySpec(
|
|
policy_class=policy_cls,
|
|
observation_space=observation_space,
|
|
action_space=action_space,
|
|
config=config_override,
|
|
)
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def __getitem__(self, item):
|
|
# Never seen this key -> Error.
|
|
if item not in self.valid_keys:
|
|
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
|
|
|
|
# Item already in cache -> Rearrange deque (least recently used) and
|
|
# return.
|
|
if item in self.cache:
|
|
self.deque.remove(item)
|
|
self.deque.append(item)
|
|
# Item not currently in cache -> Get from disk and - if at capacity -
|
|
# remove leftmost one.
|
|
else:
|
|
self._read_from_disk(policy_id=item)
|
|
|
|
return self.cache[item]
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def __setitem__(self, key, value):
|
|
# Item already in cache -> Rearrange deque (least recently used).
|
|
if key in self.cache:
|
|
self.deque.remove(key)
|
|
self.deque.append(key)
|
|
self.cache[key] = value
|
|
# Item not currently in cache -> store new value and - if at capacity -
|
|
# remove leftmost one.
|
|
else:
|
|
# Cache at capacity -> Drop leftmost item.
|
|
if len(self.deque) == self.deque.maxlen:
|
|
self._stash_to_disk()
|
|
self.deque.append(key)
|
|
self.cache[key] = value
|
|
self.valid_keys.add(key)
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def __delitem__(self, key):
|
|
# Make key invalid.
|
|
self.valid_keys.remove(key)
|
|
# Remove policy from memory if currently cached.
|
|
if key in self.cache:
|
|
policy = self.cache[key]
|
|
self._close_session(policy)
|
|
del self.cache[key]
|
|
# Remove file associated with the policy, if it exists.
|
|
filename = self.path + "/" + key + self.extension
|
|
if os.path.isfile(filename):
|
|
os.remove(filename)
|
|
|
|
@override(dict)
|
|
def __iter__(self):
|
|
return iter(self.keys())
|
|
|
|
@override(dict)
|
|
def items(self):
|
|
"""Iterates over all policies, even the stashed-to-disk ones."""
|
|
|
|
def gen():
|
|
for key in self.valid_keys:
|
|
yield (key, self[key])
|
|
|
|
return gen()
|
|
|
|
@override(dict)
|
|
def keys(self):
|
|
self._lock.acquire()
|
|
ks = list(self.valid_keys)
|
|
self._lock.release()
|
|
|
|
def gen():
|
|
for key in ks:
|
|
yield key
|
|
|
|
return gen()
|
|
|
|
@override(dict)
|
|
def values(self):
|
|
self._lock.acquire()
|
|
vs = [self[k] for k in self.valid_keys]
|
|
self._lock.release()
|
|
|
|
def gen():
|
|
for value in vs:
|
|
yield value
|
|
|
|
return gen()
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def update(self, __m, **kwargs):
|
|
for k, v in __m.items():
|
|
self[k] = v
|
|
for k, v in kwargs.items():
|
|
self[k] = v
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def get(self, key):
|
|
if key not in self.valid_keys:
|
|
return None
|
|
return self[key]
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def __len__(self):
|
|
"""Returns number of all policies, including the stashed-to-disk ones."""
|
|
return len(self.valid_keys)
|
|
|
|
@with_lock
|
|
@override(dict)
|
|
def __contains__(self, item):
|
|
return item in self.valid_keys
|
|
|
|
def _stash_to_disk(self):
|
|
"""Writes the least-recently used policy to disk and rearranges cache.
|
|
|
|
Also closes the session - if applicable - of the stashed policy.
|
|
"""
|
|
# Get least recently used policy (all the way on the left in deque).
|
|
delkey = self.deque.popleft()
|
|
policy = self.cache[delkey]
|
|
# Get its state for writing to disk.
|
|
policy_state = policy.get_state()
|
|
# Closes policy's tf session, if any.
|
|
self._close_session(policy)
|
|
# Remove from memory. This will clear the tf Graph as well.
|
|
del self.cache[delkey]
|
|
# Write state to disk.
|
|
with open(self.path + "/" + delkey + self.extension, "wb") as f:
|
|
pickle.dump(policy_state, file=f)
|
|
|
|
def _read_from_disk(self, policy_id):
|
|
"""Reads a policy ID from disk and re-adds it to the cache."""
|
|
# Make sure this policy ID is not in the cache right now.
|
|
assert policy_id not in self.cache
|
|
# Read policy state from disk.
|
|
with open(self.path + "/" + policy_id + self.extension, "rb") as f:
|
|
policy_state = pickle.load(f)
|
|
|
|
# Get class and config override.
|
|
merged_conf = merge_dicts(
|
|
self.policy_config, self.policy_specs[policy_id].config
|
|
)
|
|
|
|
# Create policy object (from its spec: cls, obs-space, act-space,
|
|
# config).
|
|
self.create_policy(
|
|
policy_id,
|
|
self.policy_specs[policy_id].policy_class,
|
|
self.policy_specs[policy_id].observation_space,
|
|
self.policy_specs[policy_id].action_space,
|
|
self.policy_specs[policy_id].config,
|
|
merged_conf,
|
|
)
|
|
# Restore policy's state.
|
|
policy = self[policy_id]
|
|
policy.set_state(policy_state)
|
|
|
|
def _close_session(self, policy):
|
|
sess = policy.get_session()
|
|
# Closes the tf session, if any.
|
|
if sess is not None:
|
|
sess.close()
|