在 Python 中创建快速 RGB 查找表

2024-01-08

我有一个称为“rgb2something”的函数,它将 RGB 数据 [1x1x3] 转换为单个值(概率),循环输入 RGB 数据中的每个像素结果相当慢。

我尝试了以下方法来加快转换速度。生成 LUT(查找表):

import numpy as np

levels = 256
levels2 = levels**2
lut = [0] * (levels ** 3)

levels_range = range(0, levels)

for r in levels_range:
    for g in levels_range:
        for b in levels_range:
            lut[r + (g * levels) + (b * levels2)] = rgb2something(r, g, b)

并将 RGB 转换为变换后的概率图像:

result = np.take(lut, r_channel + (g_channel * 256) + (b_channel * 65536))

然而,生成 LUT 和计算结果仍然很慢。在 2 维中它相当快,但是在 3 维(r、g 和 b)中它很慢。我怎样才能提高这个性能?

EDIT

rgb2something(r, g, b)看起来像这样:

def rgb2something(r, g, b):
    y = np.array([[r, g, b]])
    y_mean = np.mean(y, axis=0)
    y_centered = y - y_mean
    y_cov = y_centered.T.dot(y_centered) / len(y_centered)
    m = len(Consts.x)
    n = len(y)
    q = m + n
    pool_cov = (m / q * x_cov) + (n / q * y_cov)
    inv_pool_cov = np.linalg.inv(pool_cov)
    g = Consts.x_mean - y_mean
    mah = g.T.dot(inv_pool_cov).dot(g) ** 0.5
    return mah

EDIT 2:

我正在尝试实现的完整工作代码示例,我正在使用 OpenCV,因此任何 OpenCV 方法,例如应用查找表 https://www.learnopencv.com/tag/lut/欢迎,C/C++ 方法也是如此:

import matplotlib.pyplot as plt
import numpy as np 
import cv2

class Model:
    x = np.array([
        [6, 5, 2],
        [2, 5, 7],
        [6, 3, 1]
    ])
    x_mean = np.mean(x, axis=0)
    x_centered = x - x_mean
    x_covariance = x_centered.T.dot(x_centered) / len(x_centered)
    m = len(x)
    n = 1  # Only ever comparing to a single pixel
    q = m + n
    pooled_covariance = (m / q * x_covariance)  # + (n / q * y_cov) -< Always 0 for a single point
    inverse_pooled_covariance = np.linalg.inv(pooled_covariance)

def rgb2something(r, g, b):
    #Calculates Mahalanobis Distance between pixel and model X
    y = np.array([[r, g, b]])
    y_mean = np.mean(y, axis=0)
    g = Model.x_mean - y_mean
    mah = g.T.dot(Model.inverse_pooled_covariance).dot(g) ** 0.5
    return mah

def generate_lut():
    levels = 256
    levels2 = levels**2
    lut = [0] * (levels ** 3)

    levels_range = range(0, levels)

    for r in levels_range:
        for g in levels_range:
            for b in levels_range:
                lut[r + (g * levels) + (b * levels2)] = rgb2something(r, g, b)

    return lut

def calculate_distance(lut, input_image):
    return np.take(lut, input_image[:, :, 0] + (input_image[:, :, 1] * 256) + (input_image[:, :, 2] * 65536))

lut = generate_lut()
rgb = np.random.randint(255, size=(1080, 1920, 3), dtype=np.uint8)
result = calculate_distance(lut, rgb)

cv2.imshow("Example", rgb)
cv2.imshow("Result", result)
cv2.waitKey(0)

更新:添加blas优化

有几个简单且非常有效的优化:

(1)矢量化,矢量化!对这段代码中的所有内容进行向量化并不是那么困难。见下文。

(2) 使用正确的查找,即花哨的索引,而不是np.take

(3)使用Cholesky分解。与布拉斯dtrmm我们可以利用它的三角形结构

这是代码。只需将其添加到OP代码的末尾(在编辑2下)。除非您非常有耐心,否则您可能还想注释掉lut = generate_lut() and result = calculate_distance(lut, rgb)行和所有对 cv2 的引用。我还添加了一个随机行x使其协方差矩阵非奇异。

class Full_Model(Model):
    ch = np.linalg.cholesky(Model.inverse_pooled_covariance)
    chx = Model.x_mean@ch

def rgb2something_vectorized(rgb):
    return np.sqrt(np.sum(((rgb - Full_Model.x_mean)@Full_Model.ch)**2,  axis=-1))

from scipy.linalg import blas

def rgb2something_blas(rgb):
    *shp, nchan = rgb.shape
    return np.sqrt(np.einsum('...i,...i', *2*(blas.dtrmm(1, Full_Model.ch.T, rgb.reshape(-1, nchan).T, 0, 0, 0, 0, 0).T - Full_Model.chx,))).reshape(shp)

def generate_lut_vectorized():
    return rgb2something_vectorized(np.transpose(np.indices((256, 256, 256))))

def generate_lut_blas():
    rng = np.arange(256)
    arr = np.empty((256, 256, 256, 3))
    arr[0, ..., 0]  = rng
    arr[0, ..., 1]  = rng[:, None]
    arr[1:, ...] = arr[0]
    arr[..., 2] = rng[:, None, None]
    return rgb2something_blas(arr)

def calculate_distance_vectorized(lut, input_image):
    return lut[input_image[..., 2], input_image[..., 1], input_image[..., 0]]

# test code

def random_check_lut(lut):
    """Because the original lut generator is excruciatingly slow,
    we only compare a random sample, using the original code
    """
    levels = 256
    levels2 = levels**2
    lut = lut.ravel()

    levels_range = range(0, levels)

    for r, g, b in np.random.randint(0, 256, (1000, 3)):
        assert np.isclose(lut[r + (g * levels) + (b * levels2)], rgb2something(r, g, b))

import time
td = []
td.append((time.time(), 'create lut vectorized'))
lutv = generate_lut_vectorized()
td.append((time.time(), 'create lut using blas'))
lutb = generate_lut_blas()
td.append((time.time(), 'lookup using np.take'))
res = calculate_distance(lutv, rgb)
td.append((time.time(), 'process on the fly (no lookup)'))
resotf = rgb2something_vectorized(rgb)
td.append((time.time(), 'process on the fly (blas)'))
resbla = rgb2something_blas(rgb)
td.append((time.time(), 'lookup using fancy indexing'))
resv = calculate_distance_vectorized(lutv, rgb)
td.append((time.time(), None))

print("sanity checks ... ", end='')
assert np.allclose(res, resotf) and np.allclose(res, resv) \
    and np.allclose(res, resbla) and np.allclose(lutv, lutb)
random_check_lut(lutv)
print('all ok\n')

t, d = zip(*td)
for ti, di in zip(np.diff(t), d):
    print(f'{di:32s} {ti:10.3f} seconds')

示例运行:

sanity checks ... all ok

create lut vectorized                 1.116 seconds
create lut using blas                 0.917 seconds
lookup using np.take                  0.398 seconds
process on the fly (no lookup)        0.127 seconds
process on the fly (blas)             0.069 seconds
lookup using fancy indexing           0.064 seconds

我们可以看到,最佳查找以微弱优势击败了最佳即时计算。也就是说,该示例可能高估了查找成本,因为随机像素可能比自然图像不太适合缓存。

原始答案(也许对某些人仍然有用)

如果 rgb2something 无法矢量化,并且您想处理一张典型图像,那么您可以使用以下命令获得不错的加速np.unique.

如果 rgb2 的东西很昂贵并且必须处理多个图像,那么unique可以与缓存结合使用,可以方便地使用functools.lru_cache---唯一的(次要)绊脚石:参数必须是可散列的。事实证明,强制执行的代码修改(将 RGB 数组转换为 3 字节字符串)恰好可以提高性能。

仅当您有大量像素覆盖大多数色调时,使用完整的查找表才值得。在这种情况下,最快的方法是使用 numpy 花式索引来进行实际查找。

import numpy as np
import time
import functools

def rgb2something(rgb):
    # waste some time:
    np.exp(0.1*rgb)
    return rgb.mean()

@functools.lru_cache(None)
def rgb2something_lru(rgb):
    rgb = np.frombuffer(rgb, np.uint8)
    # waste some time:
    np.exp(0.1*rgb)
    return rgb.mean()

def apply_to_img(img):
    shp = img.shape
    return np.reshape([rgb2something(x) for x in img.reshape(-1, shp[-1])], shp[:2])

def apply_to_img_lru(img):
    shp = img.shape
    return np.reshape([rgb2something_lru(x) for x in img.ravel().view('S3')], shp[:2])

def apply_to_img_smart(img, print_stats=True):
    shp = img.shape
    unq, bck = np.unique(img.reshape(-1, shp[-1]), return_inverse=True, axis=0)
    if print_stats:
        print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
    return np.array([rgb2something(x) for x in unq])[bck].reshape(shp[:2])

def apply_to_img_smarter(img, print_stats=True):
    shp = img.shape
    unq, bck = np.unique(img.ravel().view('S3'), return_inverse=True)
    if print_stats:
        print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
    return np.array([rgb2something_lru(x) for x in unq])[bck].reshape(shp[:2])

def make_full_lut():
    x = np.empty((3,), np.uint8)
    return np.reshape([rgb2something(x) for x[0] in range(256)
                       for x[1] in range(256) for x[2] in range(256)],
                      (256, 256, 256))

def make_full_lut_cheat(): # for quicker testing lookup
    i, j, k = np.ogrid[:256, :256, :256]
    return (i + j + k) / 3

def apply_to_img_full_lut(img, lut):
    return lut[(*np.moveaxis(img, 2, 0),)]

from scipy.misc import face

t0 = time.perf_counter()
bw = apply_to_img(face())
t1 = time.perf_counter()
print('naive                 ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru first time        ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru second time       ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_smart(face(), False)
t1 = time.perf_counter()
print('using unique:         ', t1-t0, 'seconds')

rgb2something_lru.cache_clear()

t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru first: ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru second:', t1-t0, 'seconds')

t0 = time.perf_counter()
lut = make_full_lut_cheat()
t1 = time.perf_counter()
print('creating full lut:    ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_full_lut(face(), lut)
t1 = time.perf_counter()
print('using full lut:       ', t1-t0, 'seconds')

print()
apply_to_img_smart(face())

import Image
Image.fromarray(bw.astype(np.uint8)).save('bw.png')

示例运行:

naive                  6.8886632949870545 seconds
lru first time         1.7458112589956727 seconds
lru second time        0.4085628940083552 seconds
using unique:          2.0951434450107627 seconds
unique and lru first:  2.0168916099937633 seconds
unique and lru second: 0.3118703299842309 seconds
creating full lut:     151.17599205300212 seconds
using full lut:        0.12164952099556103 seconds

total no pixels 786432 
no unique pixels 134105
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 Python 中创建快速 RGB 查找表 的相关文章

随机推荐