mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Better exceptions with traceback in TorchPolicy (#17690)
This commit is contained in:
parent
811d71b368
commit
6891dee6ea
1 changed files with 7 additions and 6 deletions
|
@ -994,9 +994,10 @@ class TorchPolicy(Policy):
|
|||
results[shard_idx] = (all_grads, grad_info)
|
||||
except Exception as e:
|
||||
with lock:
|
||||
results[shard_idx] = ValueError(
|
||||
results[shard_idx] = (ValueError(
|
||||
e.args[0] + "\n" +
|
||||
"In tower {} on device {}".format(shard_idx, device))
|
||||
"In tower {} on device {}".format(shard_idx, device)),
|
||||
e)
|
||||
|
||||
# Single device (GPU) or fake-GPU case (serialize for better
|
||||
# debugging).
|
||||
|
@ -1006,8 +1007,8 @@ class TorchPolicy(Policy):
|
|||
_worker(shard_idx, model, sample_batch, device)
|
||||
# Raise errors right away for better debugging.
|
||||
last_result = results[len(results) - 1]
|
||||
if isinstance(last_result, ValueError):
|
||||
raise last_result
|
||||
if isinstance(last_result[0], ValueError):
|
||||
raise last_result[0] from last_result[1]
|
||||
# Multi device (GPU) case: Parallelize via threads.
|
||||
else:
|
||||
threads = [
|
||||
|
@ -1027,8 +1028,8 @@ class TorchPolicy(Policy):
|
|||
outputs = []
|
||||
for shard_idx in range(len(sample_batches)):
|
||||
output = results[shard_idx]
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
if isinstance(output[0], Exception):
|
||||
raise output[0] from output[1]
|
||||
outputs.append(results[shard_idx])
|
||||
return outputs
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue