tensorflow - 机器学习中的多标签/多任务/多类回归

标签 tensorflow keras

我的挑战是训练一个神经网络来识别不同类别任务的某些 Action 和事件,或者在给定输入的情况下如何调用它。 我发现训练神经网络时大多数输入/输出都是 0 或 1 或 [0,1]。但在我的场景中,我希望输入采用任意大的整数形式,并且输出预计具有相同的形式。

举个例子: 输入

X = [ 23, 4, 0, 1233423, 1, 0, 0] ->
Y = [ 2, 1, 1]

现在 X[i] 中的每个元素代表同一实体的不同属性。 假设它想要描述一个人:

23 -> maps to a place he/she was born
4 -> maps to a school they graduated 

等等

另一方面,Y[i] 中的每个条目表示人类在 3 个不同类别中更有可能执行的操作(在本例中 len(Y) 为 3):

Y[0] = 2 -> maps to eating icecream ( from a variety of other choices )
Y[1] = 1 -> maps to a time of day moment ( morning, noon, afternoon, evening, etc...)
Y[2] = 1 -> maps to a day of the week for example

现在,当然,如果任务只是一个多标签问题,我会在输出层应用 sigmoid 并执行 binary_crossentropy 作为损失函数,但这当然不起作用。 这里因为我的输出显然不在[0,1]之间。 另外,我不太确定要应用什么损失函数,因为我希望正确预测 Y 中的所有类/子类。我基本上想说的是,每个 Y[i] 本身就是一个自己的类。 如果我的输出采用 (3, labels_per_class) 的形式,则会更准确 损失函数将计算 3 个不同类别中每一个类别的损失 尝试以这样的方式优化结果,使 3 个类别中的每一个类别都有正确的标签。 我不确定这是否可能或至少如何实现。

我的神经网络知识确实还处于起步阶段,所以很明显我正在努力解决这个问题。

但实际上,更简单地说,我更好地知道如何描述它。它或多或少类似于自动编码器,但输入和输出都是整数。不同之处在于,在我的例子中,输出的大小与输入的大小不同,而在自动编码器中它们是相同的。

我的解决方案是在输出层应用relu(当然也在所有其他层上应用类似relu的激活)和binary_crossentropy作为损失函数但网络的准确率很低,在15%左右。

最佳答案

对于标准分类,您可能会创建一个密集层,其节点数等于类数,然后应用 softmax。损失将是 tf.losses.softmax_cross_entropy。如果您想允许多个类,而不仅仅是一个类,您可以使用 sigmoid。

现在您有多个分类任务。一种方法是采用最后一个隐藏层(执行 softmax 的层之前的一层)。对于每个任务,做一个密集层,其节点数等于该任务的类数,并应用 softmax。要计算损失,只需将损失加在一起即可。

如果任务差异太大,您可能希望每个预测都有多个层。

如果,比如说,吃冰淇淋比正确安排一天中的时间重要得多,您也可以对不同的损失进行一些权重。

仅当预测空间连续时才使用 relu。假设一天中的时间是连续的,但吃冰淇淋、上类、看电视之间的选择却不是连续的。如果您使用 relu,请使用 L1(tf.losses.absolut_difference) 或 L2 (tf.losses.mean_squared_error) 等损失。

关于tensorflow - 机器学习中的多标签/多任务/多类回归,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47452806/

相关文章:

python - 使用数据生成器时,Keras 自定义指标 self.validation_data 为 none

python - 张量不是该图的元素;部署 Keras 模型

python - 安装turicreate时出现tensorflow错误?

python - 神经网络将所有内容分类为一类,在不平衡数据集上召回率=1

machine-learning - 确保 Python 代码是在 GPU 还是 CPU 上运行

keras:如何按顺序预测类(class)?

python - Keras 模型无法预测是否在线程中调用

javascript - "Dependency was not found"for tfjs in Vue/Webpack project with yarn

python - Tensorflow CNN 实现的准确性较差

linux - 通过 Flask 内存泄漏的 Tensorflow Inception 模型