Source code for tinyms.model.cycle_gan.cycle_gan

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Cycle GAN network."""

import tinyms as ts
from tinyms import layers
from tinyms.primitives import OnesLike, GradOperation, Fill, DType, Shape, depend
from .resnet import ResNetGenerator
from .unet import UnetGenerator
from .common_net import ConvNormReLU, init_weights


def get_generator(model):
    """
        Get generator by model.

        Args:
            model (str): Currently it should be in [resnet, unet].

        Returns:
            Generator, generator net.

        Raises:
            NotImplementedError: If `model` is not in [resnet, unet].
    """
    if model == "resnet":
        net = ResNetGenerator(in_planes=3, ngf=64, n_layers=9, alpha=0.2,
                              norm_mode='instance', dropout=True, pad_mode='CONSTANT')
        init_weights(net, init_type='normal', init_gain=0.02)
    elif model == "unet":
        net = UnetGenerator(in_planes=3, out_planes=3, ngf=64, n_layers=9,
                            alpha=0.2, norm_mode='instance', dropout=True)
        init_weights(net, init_type='normal', init_gain=0.02)
    else:
        raise NotImplementedError(f'Model {model} not recognized.')
    return net


def get_discriminator():
    """
        Get discriminator net.

        Returns:
            Discriminator, discriminator net.
    """
    net = Discriminator(in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='instance')
    init_weights(net, init_type='normal', init_gain=0.02)
    return net


class Discriminator(layers.Layer):
    """
    Discriminator of GAN.

    Args:
        in_planes (int): Input channel.
        ndf (int): Output channel.
        n_layers (int): The number of ConvNormReLU blocks.
        alpha (float): LeakyRelu slope. Default: 0.2.
        norm_mode (str): Specifies norm method. The optional values are "batch", "instance".

    Returns:
        Tensor, output tensor.

    Examples:
        >>> Discriminator(3, 64, 3)
    """
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layer_list = [
            layers.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            layers.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layer_list.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layer_list.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layer_list.append(layers.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = layers.SequentialLayer(layer_list)

    def construct(self, x):
        output = self.features(x)
        return output


class Generator(layers.Layer):
    """
    Generator of CycleGAN, return fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.

    Args:
        G_A (layers.Layer): The generator network of domain A to domain B.
        G_B (layers.Layer): The generator network of domain B to domain A.
        use_identity (bool): Use identity loss or not. Default: True.

    Returns:
        Tensors, fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.

    Examples:
        >>> Generator(G_A, G_B)
    """

    def __init__(self, G_A, G_B, use_identity=True):
        super(Generator, self).__init__()
        self.G_A = G_A
        self.G_B = G_B
        self.ones = OnesLike()
        self.use_identity = use_identity

    def construct(self, img_A, img_B):
        """If use_identity, identity loss will be used."""
        fake_A = self.G_B(img_B)
        fake_B = self.G_A(img_A)
        rec_A = self.G_B(fake_B)
        rec_B = self.G_A(fake_A)
        if self.use_identity:
            identity_A = self.G_B(img_A)
            identity_B = self.G_A(img_B)
        else:
            identity_A = self.ones(img_A)
            identity_B = self.ones(img_B)
        return fake_A, fake_B, rec_A, rec_B, identity_A, identity_B


class WithLossCell(layers.Layer):
    """
    Wrap the network with loss function to return generator loss.

    Args:
        network (Layer): The target network to wrap.

    Returns:
       Generator Loss lg
    """
    def __init__(self, network):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self.network = network

    def construct(self, img_A, img_B):
        _, _, lg, _, _, _, _, _, _ = self.network(img_A, img_B)
        return lg


class TrainOneStepG(layers.Layer):
    """
    Encapsulation class of Cycle GAN generator network training.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

    Args:
        G (layers.Layer): Generator with loss Layer. Note that loss function should have been added.
        generator (layers.Layer): Generator of CycleGAN.
        optimizer (Optimizer): Optimizer for updating the weights.
        sens (Number): The adjust parameter. Default: 1.0.
    """
    def __init__(self, G, generator, optimizer, sens=1.0):
        super(TrainOneStepG, self).__init__(auto_prefix=False)
        self.optimizer = optimizer
        self.G = G
        self.G.set_grad()
        self.G.set_train()
        self.G.D_A.set_grad(False)
        self.G.D_A.set_train(False)
        self.G.D_B.set_grad(False)
        self.G.D_B.set_train(False)
        self.grad = GradOperation(get_by_list=True, sens_param=True)
        self.sens = sens
        self.weights = ts.ParameterTuple(generator.trainable_params())
        self.net = WithLossCell(G)

    def construct(self, img_A, img_B):
        weights = self.weights
        fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B)
        sens = Fill()(DType()(lg), Shape()(lg), self.sens)
        grads_g = self.grad(self.net, weights)(img_A, img_B, sens)
        return fake_A, fake_B, depend(lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib


class TrainOneStepD(layers.Layer):
    """
    Encapsulation class of Cycle GAN discriminator network training.

    Append an optimizer to the training network after that the construct
    function can be called to create the backward graph.

    Args:
        G (layers.Layer): Generator with loss Layer. Note that loss function should have been added.
        optimizer (Optimizer): Optimizer for updating the weights.
        sens (Number): The adjust parameter. Default: 1.0.
    """
    def __init__(self, D, optimizer, sens=1.0):
        super(TrainOneStepD, self).__init__(auto_prefix=False)
        self.optimizer = optimizer
        self.D = D
        self.D.set_grad()
        self.D.set_train()
        self.grad = GradOperation(get_by_list=True, sens_param=True)
        self.sens = sens
        self.weights = ts.ParameterTuple(D.trainable_params())

    def construct(self, img_A, img_B, fake_A, fake_B):
        weights = self.weights
        ld = self.D(img_A, img_B, fake_A, fake_B)
        sens_d = Fill()(DType()(ld), Shape()(ld), self.sens)
        grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d)
        return depend(ld, self.optimizer(grads_d))


def get_generator_discriminator(model='resnet'):
    """
    Get G_A, G_B generator network and  D_A, D_B discriminator network.

    Args:
        model: The generator model, currently it should be in [resnet, unet].

    Returns:
        G_A, G_B, D_A, D_B network.

    Examples:
        >>> G_A, G_B, D_A, D_B = cycle_gan('resnet')
    """
    if model not in ['resnet', 'unet']:
        raise NotImplementedError(f'Model {model} not recognized.')

    G_A = get_generator(model)
    G_B = get_generator(model)
    D_A = get_discriminator()
    D_B = get_discriminator()

    return G_A, G_B, D_A, D_B


[docs]def cycle_gan(G_A, G_B): """ Get Cycle GAN network. Args: G_A (layers.Layer): The generator net, currently it should be in [resnet, unet]. G_B (layers.Layer): The generator net, currently it should be in [resnet, unet]. Returns: Cycle GAN instance. Examples: >>> gan_net = cycle_gan(G_A, G_B) """ if not isinstance(G_A, layers.Layer) or not isinstance(G_B, layers.Layer): raise NotImplementedError(f'G_A and G_B are not the instance of layers.Layer') return Generator(G_A, G_B)
[docs]def cycle_gan_infer(g_model='resnet'): """ Get Cycle GAN network for predict. Args: G_A (layers.Layer): The generator net, currently it should be in [resnet, unet]. G_B (layers.Layer): The generator net, currently it should be in [resnet, unet]. Returns: Cycle GAN instance. Examples: >>> gan_net = cycle_gan(G_A, G_B) """ if g_model not in ['resnet', 'unet']: raise NotImplementedError(f'Model {g_model} not recognized.') G_A = get_generator(g_model) G_B = get_generator(g_model) return G_A, G_B