如何在图构建时获得张量(在TensorFlow中)的维度?

5 浏览
0 Comments

如何在图构建时获得张量(在TensorFlow中)的维度?

我正在尝试一个不按预期行为的Op。

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
  embed = tf.nn.embedding_lookup(embeddings, train_dataset)
  embed = tf.reduce_sum(embed, reduction_indices=0)

因此,我需要知道张量`embed`的维度。我知道这可以在运行时完成,但对于这样一个简单的操作来说,这太麻烦了。有什么更简单的方法吗?

0