Run Serve Tests on Windows (#10682)

This commit is contained in:
architkulkarni 2020-09-10 15:54:37 -07:00 committed by GitHub
parent e3aee6b434
commit 333f324b88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions

View file

@ -137,6 +137,7 @@ test_python() {
if [ "${OSTYPE}" = msys ]; then
pathsep=";"
args+=(
python/ray/serve/...
python/ray/tests/...
-python/ray/tests:test_advanced_2
-python/ray/tests:test_advanced_3 # test_invalid_unicode_in_worker_log() fails on Windows

View file

@ -6,6 +6,8 @@ import pickle
import json
import numpy as np
import requests
import os
import tempfile
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
@ -32,9 +34,13 @@ model.fit(train_x, train_y)
print("MSE:", mean_squared_error(model.predict(val_x), val_y))
# Save the model and label to file
with open("/tmp/iris_model_logistic_regression.pkl", "wb") as f:
MODEL_PATH = os.path.join(tempfile.gettempdir(),
"iris_model_logistic_regression.pkl")
LABEL_PATH = os.path.join(tempfile.gettempdir(), "iris_labels.json")
with open(MODEL_PATH, "wb") as f:
pickle.dump(model, f)
with open("/tmp/iris_labels.json", "w") as f:
with open(LABEL_PATH, "w") as f:
json.dump(target_names.tolist(), f)
# __doc_train_model_end__
@ -42,9 +48,9 @@ with open("/tmp/iris_labels.json", "w") as f:
# __doc_define_servable_begin__
class BoostingModel:
def __init__(self):
with open("/tmp/iris_model_logistic_regression.pkl", "rb") as f:
with open(MODEL_PATH, "rb") as f:
self.model = pickle.load(f)
with open("/tmp/iris_labels.json") as f:
with open(LABEL_PATH) as f:
self.label_list = json.load(f)
def __call__(self, flask_request):

View file

@ -3,13 +3,14 @@
from ray import serve
import os
import tempfile
import numpy as np
import requests
# __doc_import_end__
# yapf: enable
# __doc_train_model_begin__
TRAINED_MODEL_PATH = "/tmp/mnist_model.h5"
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")
def train_and_save_model():
@ -69,7 +70,7 @@ class TFMnistModel:
# __doc_deploy_begin__
client = serve.start()
client.create_backend("tf:v1", TFMnistModel, "/tmp/mnist_model.h5")
client.create_backend("tf:v1", TFMnistModel, TRAINED_MODEL_PATH)
client.create_endpoint("tf_classifier", backend="tf:v1", route="/mnist")
# __doc_deploy_end__