Estos contenidos se han traducido de forma automática para su comodidad, pero Huawei Cloud no garantiza la exactitud de estos. Para consultar los contenidos originales, acceda a la versión en inglés.
Centro de ayuda/
ModelArts/
Implementación de inferencia/
Especificaciones de inferencia/
Ejemplos de scripts personalizados/
TensorFlow 2.1
Actualización más reciente 2024-09-25 GMT+08:00
TensorFlow 2.1
Entrenamiento y guardado de un modelo
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.2),
# Name the output layer output, which is used to obtain the result during model inference.
tf.keras.layers.Dense(10, activation='softmax', name="output")
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10)
tf.keras.models.save_model(model, "./mnist")
Código de inferencia
En el archivo de código de inferencia de modelo customize_service.py, agregue una clase de modelo hijo. Esta clase de modelo hijo hereda las propiedades de su clase de modelo padre. Para obtener más información sobre las instrucciones de importación de diferentes tipos de clases de modelo padre, consulte Tabla 1.
import logging
import threading
import numpy as np
import tensorflow as tf
from PIL import Image
from model_service.tfserving_model_service import TfServingBaseService
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MnistService(TfServingBaseService):
def __init__(self, model_name, model_path):
self.model_name = model_name
self.model_path = model_path
self.model = None
self.predict = None
# The label file can be loaded here and used in the post-processing function.
# Directories for storing the label.txt file on OBS and in the model package
# with open(os.path.join(self.model_path, 'label.txt')) as f:
# self.label = json.load(f)
# Load the model in saved_model format in non-blocking mode to prevent blocking timeout.
thread = threading.Thread(target=self.load_model)
thread.start()
def load_model(self):
# Load the model in saved_model format.
self.model = tf.saved_model.load(self.model_path)
signature_defs = self.model.signatures.keys()
signature = []
# only one signature allowed
for signature_def in signature_defs:
signature.append(signature_def)
if len(signature) == 1:
model_signature = signature[0]
else:
logging.warning("signatures more than one, use serving_default signature from %s", signature)
model_signature = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
self.predict = self.model.signatures[model_signature]
def _preprocess(self, data):
images = []
for k, v in data.items():
for file_name, file_content in v.items():
image1 = Image.open(file_content)
image1 = np.array(image1, dtype=np.float32)
image1.resize((28, 28, 1))
images.append(image1)
images = tf.convert_to_tensor(images, dtype=tf.dtypes.float32)
preprocessed_data = images
return preprocessed_data
def _inference(self, data):
return self.predict(data)
def _postprocess(self, data):
return {
"result": int(data["output"].numpy()[0].argmax())
}
Tema principal: Ejemplos de scripts personalizados
Comentarios
¿Le pareció útil esta página?
Deje algún comentario
Muchas gracias por sus comentarios. Seguiremos trabajando para mejorar la documentación.
El sistema está ocupado. Vuelva a intentarlo más tarde.