Tensorflow: 何时使用tf.expand_dims?

5 浏览
0 Comments

Tensorflow: 何时使用tf.expand_dims?

Tensorflow教程中包括使用tf.expand_dims来给张量添加一个“批处理维度”。我已经阅读了这个函数的文档,但对我来说仍然有些神秘。有人知道在什么情况下必须使用它吗?

我的代码如下。我的意图是根据预测和实际的bin之间的距离计算损失。(例如,如果predictedBin = 10truthBin = 7,那么binDistanceLoss = 3)。

batch_size = tf.size(truthValues_placeholder)
labels = tf.expand_dims(truthValues_placeholder, 1)
predictedBin = tf.argmax(logits)
binDistanceLoss = tf.abs(tf.sub(labels, logits))

在这种情况下,我需要对predictedBinbinDistanceLoss应用tf.expand_dims吗?谢谢。

0