实际上,您可以通过简单的矩阵乘法来解决这个问题。
result = M * (1:size(M, 2)).';
3
1
2
1
3
这是通过将 M x 3 矩阵与 3 x 1 数组相乘来实现的,其中 3x1 的元素只是[1; 2; 3]
。简而言之,对于每一行M
,对 3 x 1 数组执行逐元素乘法。行中只有 1M
会产生任何结果。然后将按元素相乘的结果相加。因为每行只有一个“1”,所以结果将是该 1 所在的列索引。
例如对于第一行M
.
element_wise_multiplication = [0 0 1] .* [1 2 3]
[0, 0, 3]
sum(element_wise_multiplication)
3
Update
基于提供的解决方案@reyryeng https://stackoverflow.com/a/35953070/670206 and @Luis https://stackoverflow.com/a/35966907/670206下面,我决定进行比较,看看各种方法的性能如何比较。
设置测试矩阵(M
)我创建了一个原始问题中指定形式的矩阵并改变了行数。哪一列有 1 是使用随机选择的randi([1 nCols], size(M, 1))
。执行时间分析使用timeit
.
运行时使用M
类型的double
(MATLAB 的默认值)您将得到以下执行时间。
If M
is a logical
,那么矩阵乘法会受到影响,因为它必须在矩阵乘法之前转换为数值类型,而其他两个则有一些性能改进。
这是我使用的测试代码。
sizes = round(linspace(100, 100000, 100));
times = zeros(numel(sizes), 3);
for k = 1:numel(sizes)
M = generateM(sizes(k));
times(k,1) = timeit(@()M * (1:size(M, 2)).');
M = generateM(sizes(k));
times(k,2) = timeit(@()max(M, [], 2), 2);
M = generateM(sizes(k));
times(k,3) = timeit(@()find(M.'), 2);
end
figure
plot(range, times / 1000);
legend({'Multiplication', 'Max', 'Find'})
xlabel('Number of rows in M')
ylabel('Execution Time (ms)')
function M = generateM(nRows)
M = zeros(nRows, 3);
col = randi([1 size(M, 2)], 1, size(M, 1));
M(sub2ind(size(M), 1:numel(col), col)) = 1;
end