Source code for tinyms.text.transforms

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import tinyms as ts

from . import _transform_ops
from ._transform_ops import *
from ..data import BertDataset


__all__ = [
    'bert_transform', 'BertDatasetTransform',
]
__all__.extend(_transform_ops.__all__)


[docs]class BertDatasetTransform(object): r''' Apply preprocess operation on GeneratorDataset instance. ''' def __init__(self): pass def apply_ds(self, data_set, batch_size): assert isinstance(data_set, BertDataset), "For BertDatasetTransform, BertDataset is needed" type_cast_op = TypeCast(ts.int32) data_set = data_set.map(operations=type_cast_op, input_columns="masked_lm_ids") data_set = data_set.map(operations=type_cast_op, input_columns="masked_lm_positions") data_set = data_set.map(operations=type_cast_op, input_columns="next_sentence_labels") data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids") data_set = data_set.map(operations=type_cast_op, input_columns="input_mask") data_set = data_set.map(operations=type_cast_op, input_columns="input_ids") # apply batch operations data_set = data_set.batch(batch_size, drop_remainder=True) return data_set
bert_transform = BertDatasetTransform()