¿Cómo realizo entrenamiento incremental cuando uso MoXing?
Si no está satisfecho con los resultados del entrenamiento al usar MoXing para crear un modelo, puede realizar un entrenamiento incremental después de modificar algunos datos y la información de etiquetas.
Adición de parámetros de entrenamiento incrementales a mox.run
Después de modificar los datos de etiquetado o los conjuntos de datos, puede modificar el parámetro log_dir y agregar el parámetro checkpoint_path a mox.run. Establezca log_dir en un nuevo directorio y checkpoint_path en la ruta de salida de los resultados de entrenamiento anteriores. Si la ruta de salida es un directorio de OBS, establezca la ruta en un valor que comience por obs://.
Si se cambian las etiquetas para los datos de etiquetas, realice operaciones de Si se cambian las etiquetas antes de ejecutar mox.run.
mox.run(input_fn=input_fn, model_fn=model_fn, optimizer_fn=optimizer_fn, run_mode=flags.run_mode, inter_mode=mox.ModeKeys.EVAL if use_eval_data else None, log_dir=log_dir, batch_size=batch_size_per_device, auto_batch=False, max_number_of_steps=max_number_of_steps, log_every_n_steps=flags.log_every_n_steps, save_summary_steps=save_summary_steps, save_model_secs=save_model_secs, checkpoint_path=flags.checkpoint_url, export_model=mox.ExportKeys.TF_SERVING)
Si se cambian las etiquetas
Si las etiquetas de un conjunto de datos han cambiado, ejecute la siguiente instrucción. La instrucción debe ejecutarse antes de ejecutar mox.run.
En la instrucción, la variable logits indica los pesos de capa de clasificación en diferentes redes y se configuran diferentes parámetros. Establezca este parámetro en la palabra clave correspondiente.
mox.set_flag('checkpoint_exclude_patterns', 'logits')
Si se utiliza la red integrada de MoXing, la palabra clave correspondiente debe obtenerse llamando a la siguiente API. En este ejemplo, la palabra clave Resnet_v1_50 es el valor de logits.
import moxing.tensorflow as mox model_meta = mox.get_model_meta(mox.NetworkKeys.RESNET_V1_50) logits_pattern = model_meta.default_logits_pattern print(logits_pattern)
También puede obtener una lista de redes compatibles con MoXing llamando a la siguiente API:
import moxing.tensorflow as mox print(help(mox.NetworkKeys))
Se muestra la siguiente información:
Help on class NetworkKeys in module moxing.tensorflow.nets.nets_factory: class NetworkKeys(builtins.object) | Data descriptors defined here: | | __dict__ | dictionary for instance variables (if defined) | | __weakref__ | list of weak references to the object (if defined) | | ---------------------------------------------------------------------- | Data and other attributes defined here: | | ALEXNET_V2 = 'alexnet_v2' | | CIFARNET = 'cifarnet' | | INCEPTION_RESNET_V2 = 'inception_resnet_v2' | | INCEPTION_V1 = 'inception_v1' | | INCEPTION_V2 = 'inception_v2' | | INCEPTION_V3 = 'inception_v3' | | INCEPTION_V4 = 'inception_v4' | | LENET = 'lenet' | | MOBILENET_V1 = 'mobilenet_v1' | | MOBILENET_V1_025 = 'mobilenet_v1_025' | | MOBILENET_V1_050 = 'mobilenet_v1_050' | | MOBILENET_V1_075 = 'mobilenet_v1_075' | | MOBILENET_V2 = 'mobilenet_v2' | | MOBILENET_V2_035 = 'mobilenet_v2_035' | | MOBILENET_V2_140 = 'mobilenet_v2_140' | | NASNET_CIFAR = 'nasnet_cifar' | | NASNET_LARGE = 'nasnet_large' | | NASNET_MOBILE = 'nasnet_mobile' | | OVERFEAT = 'overfeat' | | PNASNET_LARGE = 'pnasnet_large' | | PNASNET_MOBILE = 'pnasnet_mobile' | | PVANET = 'pvanet' | | RESNET_V1_101 = 'resnet_v1_101' | | RESNET_V1_110 = 'resnet_v1_110' | | RESNET_V1_152 = 'resnet_v1_152' | | RESNET_V1_18 = 'resnet_v1_18' | | RESNET_V1_20 = 'resnet_v1_20' | | RESNET_V1_200 = 'resnet_v1_200' | | RESNET_V1_50 = 'resnet_v1_50' | | RESNET_V1_50_8K = 'resnet_v1_50_8k' | | RESNET_V1_50_MOX = 'resnet_v1_50_mox' | | RESNET_V1_50_OCT = 'resnet_v1_50_oct' | | RESNET_V2_101 = 'resnet_v2_101' | | RESNET_V2_152 = 'resnet_v2_152' | | RESNET_V2_200 = 'resnet_v2_200' | | RESNET_V2_50 = 'resnet_v2_50' | | RESNEXT_B_101 = 'resnext_b_101' | | RESNEXT_B_50 = 'resnext_b_50' | | RESNEXT_C_101 = 'resnext_c_101' | | RESNEXT_C_50 = 'resnext_c_50' | | VGG_16 = 'vgg_16' | | VGG_16_BN = 'vgg_16_bn' | | VGG_19 = 'vgg_19' | | VGG_19_BN = 'vgg_19_bn' | | VGG_A = 'vgg_a' | | VGG_A_BN = 'vgg_a_bn' | | XCEPTION_41 = 'xception_41' | | XCEPTION_65 = 'xception_65' | | XCEPTION_71 = 'xception_71'