python - 为 Keras 多输入模型发布 tf.data.Dataset

标签 python python-3.x tensorflow keras tensorflow2.0

我使用 Keras 函数式 API 构建了一个多输入模型。这个想法是对文本及其元数据进行分类。该模型适用于 NumPy 格式的输入,但使用 tf.data.Dataset 时失败。

UnimplementedError:  Cast string to int32 is not supported
     [[node functional_5/Cast (defined at <ipython-input-3-8e2b230c1da3>:17) ]] [Op:__inference_train_function_24120]

Function call stack:
train_function
我不确定如何解释它,因为两个输入应该是等效的。提前感谢您的任何指导。我在下面附上了我的项目的虚拟等效项。
型号:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import Input, Model, layers
from transformers import DistilBertTokenizer, TFDistilBertModel


MAX_LEN = 20

STRING_CATEGORICAL_COLUMNS = [
    "Organization",
    "Sector",
    "Content_type",
    "Geography",
    "Themes",
]

VOCAB = {
    "Organization": ["BNS", "FED", "ECB"],
    "Sector": ["BANK", "ASS", "MARKET"],
    "Content_type": ["LAW", "NOTES", "PAPER"],
    "Geography": ["UK", "FR", "DE", "CH", "US", "ES", "NA"],
    "Themes": ["A", "B", "C", "D", "E", "F", "G"],
}

DIM = {
    "Organization": 7,
    "Sector": 2,
    "Content_type": 3,
    "Geography": 4,
    "Themes": 5,
}


# BERT branch
tf_model = TFDistilBertModel.from_pretrained("distilbert-base-uncased", name="tfbert")

input_ids = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
attention_mask = Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")


embedding = tf_model(input_ids, attention_mask=attention_mask)[0][:, 0]

bert_input = {"input_ids": input_ids, "attention_mask": attention_mask}
model_bert = Model(inputs=[bert_input], outputs=[embedding])


# meta branch
meta_inputs = {}
meta_prepocs = []

for key in VOCAB:
    inputs = Input(shape=(None,), dtype=tf.string, name=key)
    meta_inputs[key] = inputs

    vocab_list = VOCAB[key]
    vocab_size = len(vocab_list)
    embed_dim = DIM[key]

    x = layers.experimental.preprocessing.StringLookup(
        vocabulary=vocab_list, num_oov_indices=1, mask_token="PAD", name="lookup_" + key
    )(inputs)

    x = layers.Embedding(
        input_dim=vocab_size + 2,  # 2 = PAD + NA
        output_dim=embed_dim,
        mask_zero=True,
        name="embedding_" + key,
    )(x)

    x = layers.GlobalAveragePooling1D(
        data_format="channels_last", name="poolembedding_" + key
    )(x)

    meta_prepocs.append(x)

meta_output = layers.concatenate(meta_prepocs, name="concatenate_meta")
model_meta = Model(meta_inputs, meta_output)


# combining branches
combined = layers.concatenate(
    [model_bert.output, model_meta.output], name="concatenate_all"
)
ouput = layers.Dense(128, activation="relu", name="dense")(combined)
ouput = layers.Dense(4, name="class_output")(ouput)
model = Model(inputs=[model_bert.input, model_meta.input], outputs=ouput)

model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
数据集
包含 5 个文本和相应元数据的虚拟数据集
# input meta
dict_meta = {
    "Organization": [
        ["BNS", "NA"],
        ["ECB", "PAD"],
        ["NA", "PAD"],
        ["NA", "PAD"],
        ["NA", "PAD"],
    ],
    "Sector": [
        ["BANK", "PAD", "PAD"],
        ["ASS", "PAD", "NA"],
        ["MARKET", "NA", "NA"],
        ["NA", "PAD", "NA"],
        ["NA", "PAD", "NA"],
    ],
    "Content_type": [
        ["NOTES", "PAD"],
        ["PAPER", "UNK"],
        ["LAW", "PAD"],
        ["LAW", "PAD"],
        ["LAW", "NOTES"],
    ],
    "Geography": [
        ["UK", "FR"],
        ["DE", "CH"],
        ["US", "ES"],
        ["ES", "PAD"],
        ["NA", "PAD"],
    ],
    "Themes": [["A", "B"], ["B", "C"], ["C", "PAD"], ["C", "PAD"], ["G", "PAD"]],
}

# input text
list_text = [
    "Trump in denial over election defeat as Biden gears up to fight Covid",
    "Feds seize $1 billion in bitcoins they say were stolen from Silk Road",
    "Kevin de Bruyne misses penalty as Manchester City and Liverpool draw",
    "United States nears 10 million coronavirus cases",
    "Fiji resort offers the ultimate in social distancing",
]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
params = {
    "max_length": MAX_LEN,
    "padding": "max_length",
    "truncation": True,
}
tokenized = tokenizer(list_text, **params)
dict_text = tokenized.data

#input label
label = [[1], [0], [1], [0], [1]]
使用 NumPy 格式进行训练
ds_meta = tf.data.Dataset.from_tensor_slices((dict_meta))
ds_meta = ds_meta.batch(5)
example_meta = next(iter(ds_meta))

ds_text = tf.data.Dataset.from_tensor_slices((dict_text))
ds_text = ds_text.batch(5)
example_text = next(iter(ds_text))

ds_label = tf.data.Dataset.from_tensor_slices((label))
ds_label = ds_label.batch(5)
example_label = next(iter(ds_label))

model.fit([example_text, example_meta], example_label)
1/1 [==============================] - 0s 1ms/step - loss: 2.4866
使用 tf.data.Dataset 进行训练
ds = tf.data.Dataset.from_tensor_slices(
    (
        {
            "attention_mask": dict_text["attention_mask"],
            "input_ids": dict_text["input_ids"],
            "Content_type": dict_meta["Organization"],
            "Geography": dict_meta["Geography"],
            "Organization": dict_meta["Organization"],
            "Sector": dict_meta["Sector"],
            "Themes": dict_meta["Themes"],
        },
        {"class_output": label},
    )
)


ds = ds.batch(5)
model.fit(ds, epochs=1)
2020-11-10 14:52:47.502445: W tensorflow/core/framework/op_kernel.cc:1744] OP_REQUIRES failed at cast_op.cc:124 : Unimplemented: Cast string to int32 is not supported
Traceback (most recent call last):

  File "<ipython-input-10-a894466398cd>", line 1, in <module>
    model.fit(ds, epochs=1)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 807, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)

UnimplementedError:  Cast string to int32 is not supported
     [[node functional_5/Cast (defined at <ipython-input-3-8e2b230c1da3>:17) ]] [Op:__inference_train_function_24120]

Function call stack:
train_function

最佳答案

您可以使用 zip function 组合数据集. zip函数可以将嵌套数据集作为参数,因此我们只需要使用 numpy 数组重现您在 fit 函数中提供数据的方式:

ds_meta = tf.data.Dataset.from_tensor_slices((dict_meta))
ds_text = tf.data.Dataset.from_tensor_slices((dict_text))
ds_label = tf.data.Dataset.from_tensor_slices((label))
combined_dataset = tf.data.Dataset.zip(((ds_text,ds_meta),ds_label))
combined_dataset = combined_dataset.batch(5)
运行它:
>>> model.fit(combined_dataset)
1/1 [==============================] - 0s 212us/step - loss: 2.2895

关于python - 为 Keras 多输入模型发布 tf.data.Dataset,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64770484/

相关文章:

python - Web3 - 通过代码购买 BSC 代币(加密货币)

El Capitan 下的 Python "app transport security"错误

python:subprocess.check_output 如何创建调用?

machine-learning - 无法让简单的二元分类器工作

php - 将多个站点合并为一个

python - Pygame 游戏循环语法错误

python - 不和谐.py |发出静音命令,无法获取成员名称

Django - 特定用户类型需要登录

tensorflow - conv2d 的任意过滤器(与矩形相反)

python - 减少CNN的训练时间