mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Run Serve Tests on Windows (#10682)
This commit is contained in:
parent
e3aee6b434
commit
333f324b88
3 changed files with 14 additions and 6 deletions
|
@ -137,6 +137,7 @@ test_python() {
|
|||
if [ "${OSTYPE}" = msys ]; then
|
||||
pathsep=";"
|
||||
args+=(
|
||||
python/ray/serve/...
|
||||
python/ray/tests/...
|
||||
-python/ray/tests:test_advanced_2
|
||||
-python/ray/tests:test_advanced_3 # test_invalid_unicode_in_worker_log() fails on Windows
|
||||
|
|
|
@ -6,6 +6,8 @@ import pickle
|
|||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
@ -32,9 +34,13 @@ model.fit(train_x, train_y)
|
|||
print("MSE:", mean_squared_error(model.predict(val_x), val_y))
|
||||
|
||||
# Save the model and label to file
|
||||
with open("/tmp/iris_model_logistic_regression.pkl", "wb") as f:
|
||||
MODEL_PATH = os.path.join(tempfile.gettempdir(),
|
||||
"iris_model_logistic_regression.pkl")
|
||||
LABEL_PATH = os.path.join(tempfile.gettempdir(), "iris_labels.json")
|
||||
|
||||
with open(MODEL_PATH, "wb") as f:
|
||||
pickle.dump(model, f)
|
||||
with open("/tmp/iris_labels.json", "w") as f:
|
||||
with open(LABEL_PATH, "w") as f:
|
||||
json.dump(target_names.tolist(), f)
|
||||
# __doc_train_model_end__
|
||||
|
||||
|
@ -42,9 +48,9 @@ with open("/tmp/iris_labels.json", "w") as f:
|
|||
# __doc_define_servable_begin__
|
||||
class BoostingModel:
|
||||
def __init__(self):
|
||||
with open("/tmp/iris_model_logistic_regression.pkl", "rb") as f:
|
||||
with open(MODEL_PATH, "rb") as f:
|
||||
self.model = pickle.load(f)
|
||||
with open("/tmp/iris_labels.json") as f:
|
||||
with open(LABEL_PATH) as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
from ray import serve
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import requests
|
||||
# __doc_import_end__
|
||||
# yapf: enable
|
||||
|
||||
# __doc_train_model_begin__
|
||||
TRAINED_MODEL_PATH = "/tmp/mnist_model.h5"
|
||||
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")
|
||||
|
||||
|
||||
def train_and_save_model():
|
||||
|
@ -69,7 +70,7 @@ class TFMnistModel:
|
|||
|
||||
# __doc_deploy_begin__
|
||||
client = serve.start()
|
||||
client.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
|
||||
client.create_backend("tf:v1", TFMnistModel, TRAINED_MODEL_PATH)
|
||||
client.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
|
||||
# __doc_deploy_end__
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue