mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[air] Explicitly list out the args for BatchPredictor.predict_pipelined (#26551)
Signed-off-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
15dbc0362a
commit
f2401a14d9
2 changed files with 38 additions and 3 deletions
|
@ -3212,6 +3212,7 @@ class Dataset(Generic[T]):
|
|||
self._splits = blocks.split(split_size=blocks_per_window)
|
||||
try:
|
||||
sizes = [s.size_bytes() for s in self._splits]
|
||||
num_blocks = [s.initial_num_blocks() for s in self._splits]
|
||||
assert [s > 0 for s in sizes], sizes
|
||||
|
||||
def fmt(size_bytes):
|
||||
|
@ -3229,6 +3230,16 @@ class Dataset(Generic[T]):
|
|||
fmt(int(np.mean(sizes))),
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Blocks per window: "
|
||||
"{} min, {} max, {} mean".format(
|
||||
min(num_blocks),
|
||||
max(num_blocks),
|
||||
int(np.mean(num_blocks)),
|
||||
)
|
||||
)
|
||||
# TODO(ekl): log a warning if the blocks per window are much less
|
||||
# than the available parallelism.
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Created DatasetPipeline with {} windows; "
|
||||
|
|
|
@ -139,7 +139,14 @@ class BatchPredictor:
|
|||
*,
|
||||
blocks_per_window: Optional[int] = None,
|
||||
bytes_per_window: Optional[int] = None,
|
||||
**kwargs,
|
||||
# The remaining args are from predict().
|
||||
batch_size: int = 4096,
|
||||
min_scoring_workers: int = 1,
|
||||
max_scoring_workers: Optional[int] = None,
|
||||
num_cpus_per_worker: int = 1,
|
||||
num_gpus_per_worker: int = 0,
|
||||
ray_remote_args: Optional[Dict[str, Any]] = None,
|
||||
**predict_kwargs,
|
||||
) -> ray.data.DatasetPipeline:
|
||||
"""Setup a prediction pipeline for batch scoring.
|
||||
|
||||
|
@ -183,7 +190,15 @@ class BatchPredictor:
|
|||
This will be treated as an upper bound for the window size, but each
|
||||
window will still include at least one block. This is mutually
|
||||
exclusive with ``blocks_per_window``.
|
||||
kwargs: Keyword arguments passed to BatchPredictor.predict().
|
||||
batch_size: Split dataset into batches of this size for prediction.
|
||||
min_scoring_workers: Minimum number of scoring actors.
|
||||
max_scoring_workers: If set, specify the maximum number of scoring actors.
|
||||
num_cpus_per_worker: Number of CPUs to allocate per scoring worker.
|
||||
num_gpus_per_worker: Number of GPUs to allocate per scoring worker.
|
||||
ray_remote_args: Additional resource requirements to request from
|
||||
ray.
|
||||
predict_kwargs: Keyword arguments passed to the predictor's
|
||||
``predict()`` method.
|
||||
|
||||
Returns:
|
||||
DatasetPipeline that generates scoring results.
|
||||
|
@ -199,4 +214,13 @@ class BatchPredictor:
|
|||
blocks_per_window=blocks_per_window, bytes_per_window=bytes_per_window
|
||||
)
|
||||
|
||||
return self.predict(pipe)
|
||||
return self.predict(
|
||||
pipe,
|
||||
batch_size=batch_size,
|
||||
min_scoring_workers=min_scoring_workers,
|
||||
max_scoring_workers=max_scoring_workers,
|
||||
num_cpus_per_worker=num_cpus_per_worker,
|
||||
num_gpus_per_worker=num_gpus_per_worker,
|
||||
ray_remote_args=ray_remote_args,
|
||||
**predict_kwargs,
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue