TinyMS CycleGAN 教程

在本教程中,我们会演示使用TinyMS API进行训练/推理一个CycleGAN模型过程。

环境要求

  • Ubuntu: 18.04

  • Python: 3.7.x

  • Flask: 1.1.2

  • MindSpore: CPU-1.1.1

  • TinyMS: 0.1.0

  • numpy: 1.17.5

  • Pillow: 8.1.0

  • pip: 21.0.1

  • requests: 2.18.4

  • matplotlib: 3.3.4

介绍

TinyMS是一个高级API,目的是让新手用户能够更加轻松地上手深度学习。TinyMS可以有效地减少用户在构建、训练、验证和推理一个模型过程中的操作次数。TinyMS也提供了教程和文档帮助开发者更好的上手和开发。

本教程中,包含5个步骤:下载数据集训练模型定义servable.json启动服务器推理,其中服务器在子进程中启动。

[1]:
import os
import argparse
import json
import tinyms as ts
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tinyms import context, Tensor
from tinyms.serving import start_server, predict, list_servables, shutdown, server_started
from tinyms.data import GeneratorDataset, UnalignedDataset, GanImageFolderDataset, DistributedSampler
from tinyms.vision import cyclegan_transform
from tinyms.model.cycle_gan.cycle_gan import get_generator_discriminator, cycle_gan, TrainOneStepG, TrainOneStepD
from tinyms.losses import CycleGANDiscriminatorLoss, CycleGANGeneratorLoss
from tinyms.optimizers import Adam
from tinyms.data.utils import save_image, generate_image_list
from tinyms.utils.common_utils import GanReporter, gan_load_ckpt, GanImagePool
from tinyms.utils.train import cyclegan_lr
from tinyms.utils.eval import CityScapes, fast_hist, get_scores
[WARNING] ME(29932:140184175310656,MainProcess):2021-03-21-15:16:53.948.651 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.

1. 下载数据集

在本教程中,会使用到cityscapes数据集,点击链接前往官方网站申请下载和处理数据集

2. 训练模型

设置参数,启动训练和验证过程

提示:训练过程非常漫长,建议跳过训练步骤并直接下载、使用本教程提供的ckpt文件进行后续的推理
[ ]:
def parse_args():
    parser = argparse.ArgumentParser(description='MindSpore Cycle GAN Example')
    parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
                        help='device where the code will be implemented (default: CPU)')
    parser.add_argument('--dataset_path', type=str, default="/root/dataset/cityscapes", help='cityscape dataset path.')
    parser.add_argument('--phase', type=str, default="train", help='train, eval or predict.')
    parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"),
                        help='generator model, should be in [resnet, unet].')
    parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
    parser.add_argument('--n_epoch', type=int, default=100,
                        help='number of epochs with the initial learning rate, default is 100')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
    parser.add_argument("--save_checkpoint_epochs", type=int, default=10,
                        help="Save checkpoint epochs, default is 10.")
    parser.add_argument("--G_A_ckpt", type=str, default="/etc/tinyms/serving/cyclegan_cityscape/G_A.ckpt", help="pretrained checkpoint file path of G_A.")
    parser.add_argument("--G_B_ckpt", type=str, default="/etc/tinyms/serving/cyclegan_cityscape/G_B.ckpt", help="pretrained checkpoint file path of G_B.")
    parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.")
    parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.")
    parser.add_argument('--outputs_dir', type=str, default='/root/',
                        help='models are saved here, default is ./outputs.')
    parser.add_argument('--save_imgs', type=bool, default=True,
                        help='whether save imgs when epoch end, default is True.')
    parser.add_argument("--cityscapes_dir", type=str, default="/root/dataset/cityscapes/testA", help="Path to the original cityscapes dataset")
    parser.add_argument("--result_dir", type=str, default="/root/dataset/cityscapes/testB", help="Path to the generated images to be evaluated")
    args_opt = parser.parse_args(args=[])
    return args_opt


args_opt = parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

dataset_path = args_opt.dataset_path
phase = args_opt.phase
G_A_ckpt = args_opt.G_A_ckpt
G_B_ckpt = args_opt.G_B_ckpt
repeat_size = 1


model = args_opt.model
batch_size = args_opt.batch_size
max_dataset_size = float("inf")
outputs_dir = args_opt.outputs_dir

max_epoch = args_opt.max_epoch
n_epoch = args_opt.n_epoch
n_epoch = min(max_epoch, n_epoch)


def create_dataset(dataset_path, batch_size=1, repeat_size=1, max_dataset_size=None,
                   shuffle=True, num_parallel_workers=1, phase='train', data_dir='testA'):
    """ create Mnist dataset for train or eval.
    Args:
        data_path: Data path
        batch_size: The number of data records in each group
        repeat_size: The number of replicated data records
        num_parallel_workers: The number of parallel workers
    """
    # define dataset and apply the transform func
    if phase == 'train':
        ds = UnalignedDataset(dataset_path, phase, max_dataset_size=max_dataset_size, shuffle=True)

        device_num = 1
        distributed_sampler = DistributedSampler(len(ds), num_replicas=device_num, rank=0, shuffle=shuffle)
        gan_generator_ds = GeneratorDataset(ds, column_names=["image_A", "image_B"], sampler=distributed_sampler,
                                            num_parallel_workers=num_parallel_workers)
    else:
        datadir = os.path.join(dataset_path, data_dir)
        ds = GanImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
        gan_generator_ds = GeneratorDataset(ds, column_names=["image", "image_name"],
                                            num_parallel_workers=num_parallel_workers)

    gan_generator_ds = cyclegan_transform.apply_ds(gan_generator_ds,
                                                   repeat_size=repeat_size,
                                                   batch_size=batch_size,
                                                   num_parallel_workers=num_parallel_workers,
                                                   shuffle=shuffle,
                                                   phase=phase)
    dataset_size = len(ds)
    return gan_generator_ds, dataset_size


# create dataset
dataset, args_opt.dataset_size = create_dataset(dataset_path, batch_size=batch_size, repeat_size=1,
                                                max_dataset_size=max_dataset_size, shuffle=True,
                                                num_parallel_workers=1,
                                                phase="train",
                                                data_dir=None)


G_A, G_B, D_A, D_B = get_generator_discriminator(model)
gan_load_ckpt(args_opt.G_A_ckpt, args_opt.G_B_ckpt, args_opt.D_A_ckpt, args_opt.D_B_ckpt,
              G_A, G_B, D_A, D_B)
generator_net = cycle_gan(G_A, G_B)

# define loss function and optimizer
loss_D = CycleGANDiscriminatorLoss(D_A, D_B)
loss_G = CycleGANGeneratorLoss(generator_net, D_A, D_B)
lr = cyclegan_lr(max_epoch, n_epoch, args_opt.dataset_size)

optimizer_G = Adam(generator_net.trainable_params(),
                   cyclegan_lr(max_epoch, n_epoch, args_opt.dataset_size), beta1=0.5)
optimizer_D = Adam(loss_D.trainable_params(),
                   cyclegan_lr(max_epoch, n_epoch, args_opt.dataset_size), beta1=0.5)

# build two net: generator net and descriminator net
net_G = TrainOneStepG(loss_G, generator_net, optimizer_G)
net_D = TrainOneStepD(loss_D, optimizer_D)

# train process
imgae_pool_A = GanImagePool(pool_size=50)
imgae_pool_B = GanImagePool(pool_size=50)


def train_process(args_opt, data_loader, net_G, net_D, imgae_pool_A, imgae_pool_B):
    reporter = GanReporter(args_opt)
    reporter.info('==========start training===============')
    for _ in range(max_epoch):
        reporter.epoch_start()
        for data in data_loader:
            img_A = data["image_A"]
            img_B = data["image_B"]
            res_G = net_G(img_A, img_B)
            fake_A = res_G[0]
            fake_B = res_G[1]
            res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B))
            reporter.step_end(res_G, res_D)
            reporter.visualizer(img_A, img_B, fake_A, fake_B)
        reporter.epoch_end(net_G)

    reporter.info('==========end training===============')


data_loader = dataset.create_dict_iterator()
train_process(args_opt, data_loader, net_G, net_D, imgae_pool_A, imgae_pool_B)

# eval
# original image dir
cityscapes_dir = args_opt.cityscapes_dir

# fake image dir generated after predict
result_dir = args_opt.result_dir


def eval_process(args_opt, cityscapes_dir, result_dir):
    CS = CityScapes()
    hist_perframe = ts.zeros((CS.class_num, CS.class_num)).asnumpy()
    cityscapes = generate_image_list(cityscapes_dir)
    args_opt.dataset_size = len(cityscapes)
    reporter = GanReporter(args_opt)
    reporter.start_eval()
    for i, img_path in enumerate(cityscapes):
        if i % 100 == 0:
            reporter.info('Evaluating: %d/%d' % (i, len(cityscapes)))
        img_name = os.path.split(img_path)[1]
        ids1 = CS.get_id(os.path.join(cityscapes_dir, img_name))
        ids2 = CS.get_id(os.path.join(result_dir, img_name))
        hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num)

    mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
    reporter.info("mean_pixel_acc:{}, mean_class_acc:{}, mean_class_iou: {}".format(mean_pixel_acc,
                                                                                    mean_class_acc,
                                                                                    mean_class_iou))
    reporter.info("************ Per class numbers below ************")
    for i, cl in enumerate(CS.classes):
        while len(cl) < 15:
            cl = cl + ' '
        reporter.info("{}: acc = {}, iou = {}".format(cl, per_class_acc[i], per_class_iou[i]))
    reporter.end_eval()


# Compare the similarity between the original image and the fake image
eval_process(args_opt, cityscapes_dir, result_dir)
提示:如果跳过了训练步骤,下载预训练的ckpt文件并继续推理步骤

点击G_A下载G_A.ckpt, 点击G_B下载G_B.ckpt,并将这两个ckpt文件保存到/etc/tinyms/serving/cyclegan_cityscape

或者运行以下代码下载两个ckpt文件:

[2]:
ckpt_folder = '/etc/tinyms/serving/cyclegan_cityscape'
G_A_ckpt_path = '/etc/tinyms/serving/cyclegan_cityscape/G_A.ckpt'
G_B_ckpt_path = '/etc/tinyms/serving/cyclegan_cityscape/G_B.ckpt'

if not os.path.exists(ckpt_folder):
    !mkdir -p  /etc/tinyms/serving/cyclegan_cityscape
    !wget -P /etc/tinyms/serving/cyclegan_cityscape https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_A.ckpt
    !wget -P /etc/tinyms/serving/cyclegan_cityscape https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_B.ckpt
else:
    print('ckpt folder already exists')
    if not os.path.exists(G_A_ckpt_path):
        !wget -P /etc/tinyms/serving/cyclegan_cityscape https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_A.ckpt
    if not os.path.exists(G_B_ckpt_path):
        !wget -P /etc/tinyms/serving/cyclegan_cityscape https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_B.ckpt
ckpt folder already exists
--2021-03-21 15:16:58--  https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_A.ckpt
Resolving ascend-tutorials.obs.cn-north-4.myhuaweicloud.com (ascend-tutorials.obs.cn-north-4.myhuaweicloud.com)... 121.36.121.44, 49.4.112.5, 49.4.112.90, ...
Connecting to ascend-tutorials.obs.cn-north-4.myhuaweicloud.com (ascend-tutorials.obs.cn-north-4.myhuaweicloud.com)|121.36.121.44|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7785480 (7.4M) [binary/octet-stream]
Saving to: ‘/etc/tinyms/serving/cyclegan_cityscape/G_A.ckpt’

G_A.ckpt            100%[===================>]   7.42M  9.60MB/s    in 0.8s

2021-03-21 15:16:59 (9.60 MB/s) - ‘/etc/tinyms/serving/cyclegan_cityscape/G_A.ckpt’ saved [7785480/7785480]

--2021-03-21 15:16:59--  https://ascend-tutorials.obs.cn-north-4.myhuaweicloud.com/ckpt_files/cityscapes/G_B.ckpt
Resolving ascend-tutorials.obs.cn-north-4.myhuaweicloud.com (ascend-tutorials.obs.cn-north-4.myhuaweicloud.com)... 121.36.121.44, 49.4.112.5, 49.4.112.90, ...
Connecting to ascend-tutorials.obs.cn-north-4.myhuaweicloud.com (ascend-tutorials.obs.cn-north-4.myhuaweicloud.com)|121.36.121.44|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7785480 (7.4M) [binary/octet-stream]
Saving to: ‘/etc/tinyms/serving/cyclegan_cityscape/G_B.ckpt’

G_B.ckpt            100%[===================>]   7.42M  10.6MB/s    in 0.7s

2021-03-21 15:17:01 (10.6 MB/s) - ‘/etc/tinyms/serving/cyclegan_cityscape/G_B.ckpt’ saved [7785480/7785480]

3. 定义servable.json

运行下列代码定义servable json文件:

[3]:
servable_json = [{'name': 'cyclegan_cityscape',
                  'description': 'This servable hosts a Cycle GAN model predicting for cityscape dataset',
                  'model': {
                      "name": "cycle_gan",
                      "format": "ckpt",
                      "g_model": "resnet"}}]
os.chdir("/etc/tinyms/serving")
json_data = json.dumps(servable_json, indent=4)

with open('servable.json', 'w') as json_file:
    json_file.write(json_data)

4. 启动服务器

4.1 介绍

TinyMS推理是C/S(Client/Server)架构。TinyMS使用Flask这个轻量化的网页服务器架构作为C/S通讯的基础架构。为了能够对模型进行推理,用户必须首先启动服务器。如果成功启动,服务器会在子进程中运行并且会监听从地址127.0.0.1,端口号5000发送来的POST请求并且使用MindSpore作为后端来处理这些请求。后端会构建模型,运行推理并且返回结果给客户端

4.2 启动服务器

运行下列代码以启动服务器:

[4]:
start_server()
Server starts at host 127.0.0.1, port 5000

5. 推理

5.1 List servables

使用list_servables函数检查当前后端的serving模型

[5]:
list_servables()
[5]:
[{'description': 'This servable hosts a Cycle GAN model predicting for cityscape dataset',
  'model': {'format': 'ckpt', 'g_model': 'resnet', 'name': 'cycle_gan'},
  'name': 'cyclegan_cityscape'}]

如果输出的description字段显示这是一个CycleGAN的模型,则可以继续到下一步发送推理请求

5.2 发送推理请求

运行predict函数发送推理请求,gray to colorcolor to gray两种策略都会被执行,建议直接使用cityscapes/test数据集中的图片,利于进行风格对比,如果自行上传图片,请resize至 256*256 的图片尺寸

[6]:
servable_name = 'cyclegan_cityscape'
dataset_name = 'cityscape'

if server_started() is True:
    # gray to color
    testA_path = '/root/dataset/cityscapes/testA/1.jpg'
    strategy = 'gray2color'
    fakeB_data = predict(testA_path, servable_name, dataset_name, strategy)

    # color to gray
    testB_path = '/root/dataset/cityscapes/testB/1.jpg'
    strategy = 'color2gray'
    fakeA_data = predict(testB_path, servable_name, dataset_name, strategy)

    # draw the plot
    plt.figure(dpi=160, figsize=(10, 10))

    plt.subplot(221)
    plt.imshow(Image.open(testA_path))
    plt.axis('off')
    plt.title(testA_path)

    plt.subplot(222)
    plt.imshow(Image.fromarray(np.uint8(fakeB_data)))
    plt.axis('off')
    plt.title("fakeB.jpg")

    plt.subplot(223)
    plt.imshow(Image.open(testB_path))
    plt.axis('off')
    plt.title(testB_path)

    plt.subplot(224)
    plt.imshow(Image.fromarray(np.uint8(fakeA_data)))
    plt.axis('off')
    plt.title("fakeA.jpg")

    plt.show()
else:
    print("Server not started")
../../_images/tutorials_ipynb_TinyMS_CycleGAN_tutorial_zh_15_0.png

检查输出

如果看到了4张图片,表明左列图片是原始数据集中的图片,而右列的图片是生成的图片

关闭服务器

[7]:
shutdown()
[7]:
'Server shutting down...'