Method 1:
data = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [100, 99, 98,
97, 96, 95, 94, 93, 92, 91]])
value, ids = tf.nn.top_k(data, 8)
Output:
[[ 10 9 8 7 6 5 4 3]
[100 99 98 97 96 95 94 93]]
Method 2:
data = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [100, 99, 98,
97, 96, 95, 94, 93, 92, 91]])
sorted_data = tf.sort(data, direction='DESCENDING')data_range = tf.range(0, sorted_data.shape[1], 1)
data_range = tf.expand_dims(data_range, 0)
data_range = tf.tile(data_range, [sorted_data.shape[0], 1])out = tf.where(tf.less(data_range, 6), sorted_data, tf.zeros_like(sorted_data))
Output:
[[ 10 9 8 7 6 5 0 0 0 0]
[100 99 98 97 96 95 0 0 0 0]]