(注:您现在可以获得这段代码的更精致的版本作为 github 上的要点 https://gist.github.com/dave-andersen/265e68a5e879b5540ebc.)
你绝对可以做到,但你需要定义自己的优化标准(对于 k 均值,通常是最大迭代计数以及分配何时稳定)。这是一个如何做到这一点的示例(可能有更优化的方法来实现它,并且肯定有更好的方法来选择初始点)。基本上就像你在 numpy 中做的那样,如果你真的努力避免在 python 中迭代地做事情:
import tensorflow as tf
import numpy as np
import time
N=10000
K=4
MAX_ITERS = 1000
start = time.time()
points = tf.Variable(tf.random_uniform([N,2]))
cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))
# Silly initialization: Use the first two points as the starting
# centroids. In the real world, do this better.
centroids = tf.Variable(tf.slice(points.initialized_value(), [0,0], [K,2]))
# Replicate to N copies of each centroid and K copies of each
# point, then subtract and compute the sum of squared distances.
rep_centroids = tf.reshape(tf.tile(centroids, [N, 1]), [N, K, 2])
rep_points = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
sum_squares = tf.reduce_sum(tf.square(rep_points - rep_centroids),
reduction_indices=2)
# Use argmin to select the lowest-distance point
best_centroids = tf.argmin(sum_squares, 1)
did_assignments_change = tf.reduce_any(tf.not_equal(best_centroids,
cluster_assignments))
def bucket_mean(data, bucket_ids, num_buckets):
total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets)
count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets)
return total / count
means = bucket_mean(points, best_centroids, K)
# Do not write to the assigned clusters variable until after
# computing whether the assignments have changed - hence with_dependencies
with tf.control_dependencies([did_assignments_change]):
do_updates = tf.group(
centroids.assign(means),
cluster_assignments.assign(best_centroids))
sess = tf.Session()
sess.run(tf.initialize_all_variables())
changed = True
iters = 0
while changed and iters < MAX_ITERS:
iters += 1
[changed, _] = sess.run([did_assignments_change, do_updates])
[centers, assignments] = sess.run([centroids, cluster_assignments])
end = time.time()
print ("Found in %.2f seconds" % (end-start)), iters, "iterations"
print "Centroids:"
print centers
print "Cluster assignments:", assignments
(请注意,真正的实现需要更加小心地选择初始集群,避免所有点都进入一个集群的问题情况等。这只是一个快速演示。我已经更新了之前的答案,使其变得有点更清晰且“值得举例”。)