mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Fix num_gpus cast and raise error on large batch (#4652)
This commit is contained in:
parent
be2cbdf130
commit
5a562bbf12
3 changed files with 13 additions and 4 deletions
|
@ -61,8 +61,8 @@ class Preprocessor(object):
|
|||
self._obs_space, observation)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Observation for a Box space should be an np.array, "
|
||||
"not a Python list.", observation)
|
||||
"Observation for a Box/MultiBinary/MultiDiscrete space "
|
||||
"should be an np.array, not a Python list.", observation)
|
||||
self._i += 1
|
||||
|
||||
@property
|
||||
|
|
|
@ -47,9 +47,13 @@ class TFMultiGPULearner(LearnerThread):
|
|||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
elif _fake_gpus:
|
||||
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
|
||||
self.devices = [
|
||||
"/cpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
else:
|
||||
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
|
||||
self.devices = [
|
||||
"/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
|
|
@ -68,6 +68,11 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
|||
batches = tmp
|
||||
|
||||
for batch in batches:
|
||||
if batch.count > self.max_buffer_size:
|
||||
raise ValueError(
|
||||
"The size of a single sample batch exceeds the replay "
|
||||
"buffer size ({} > {})".format(batch.count,
|
||||
self.max_buffer_size))
|
||||
self.replay_buffer.append(batch)
|
||||
self.num_steps_sampled += batch.count
|
||||
self.buffer_size += batch.count
|
||||
|
|
Loading…
Add table
Reference in a new issue