图像增强 cnn

2023-10-30

目录

实时图像增强,基于“间距自适应查找表”的方法(CVPR 2022)

Image-Adaptive-3DLUT

水下图像增强UWCNN-wtf

直方图均衡化:

CycleGan增强 2个项目


实时图像增强,基于“间距自适应查找表”的方法(CVPR 2022)

https://blog.csdn.net/jacke121/article/details/124777484

Image-Adaptive-3DLUT

https://github.com/HuiZeng/Image-Adaptive-3DLUT

该文是香港理工大学张磊老师及其学生在图像增强领域的又一颠覆性成果。它将深度学习技术与传统3DLUT图像增强技术结合,得到了一种更灵活、更高效的图像增强技术。所提方法能够以1.66ms的速度对4K分辨率图像进行增强(硬件平台:Titan RTX GPU)。

paper: https://www4.comp.polyu.edu.hk/~cslzhang/paper/PAMI_LUT.pdf

code: https://github.com/HuiZeng/Image-Adaptive-3DLUT

需要编译:

trilinear_cpp

from setuptools import setup
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

if torch.cuda.is_available():
    print('Including CUDA code.')
    setup(
        name='trilinear',
        ext_modules=[
            CUDAExtension('trilinear', [
                'src/trilinear_cuda.cpp',
                'src/trilinear_kernel.cu',
            ])
        ],
        cmdclass={
            'build_ext': BuildExtension
        })
else:
    print('NO CUDA is found. Fall back to CPU.')
    setup(name='trilinear',
        ext_modules=[CppExtension('trilinear', ['src/trilinear.cpp'])],
        cmdclass={'build_ext': BuildExtension})

set DISTUTILS_USE_SDK=1
set MSSdk=1

编译成功,调用 dll找不到,解决方法:

把目录:Lib\site-packages\torch\lib

下面的dll拷贝到pyd目录下面,可以调用了。

水下图像增强UWCNN-wtf

https://github.com/MACILLAS/UWCNN

测试代码:

import glob

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, SpatialDropout2D, ReLU, Input, Concatenate, Add
from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError
from tensorflow.keras.optimizers import Adam
import os
import pandas as pd
import cv2

class UWCNN(tf.keras.Model):

    def __init__(self):
        super(UWCNN, self).__init__()
        self.conv1 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze1")
        self.relu1 = ReLU()
        self.conv2 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze2")
        self.relu2 = ReLU()
        self.conv3 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze3")
        self.relu3 = ReLU()
        self.concat1 = Concatenate(axis=3)

        self.conv4 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze4")
        self.relu4 = ReLU()
        self.conv5 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze5")
        self.relu5 = ReLU()
        self.conv6 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze6")
        self.relu6 = ReLU()
        self.concat2 = Concatenate(axis=3)

        self.conv7 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze7")
        self.relu7 = ReLU()
        self.conv8 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze8")
        self.relu8 = ReLU()
        self.conv9 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze9")
        self.relu9 = ReLU()
        self.concat3 = Concatenate(axis=3)

        self.conv10 = Conv2D(3, 3, (1, 1), 'same', name="conv2d_dehaze10")
        self.add1 = Add()

    def call(self, inputs):
        image_conv1 = self.relu1(self.conv1(inputs))
        image_conv2 = self.relu2(self.conv2(image_conv1))
        image_conv3 = self.relu3(self.conv3(image_conv2))
        dehaze_concat1 = self.concat1([image_conv1, image_conv2, image_conv3, inputs])

        image_conv4 = self.relu4(self.conv4(dehaze_concat1))
        image_conv5 = self.relu5(self.conv5(image_conv4))
        image_conv6 = self.relu6(self.conv6(image_conv5))
        dehaze_concat2 = self.concat2([dehaze_concat1, image_conv4, image_conv5, image_conv6])

        image_conv7 = self.relu7(self.conv7(dehaze_concat2))
        image_conv8 = self.relu8(self.conv8(image_conv7))
        image_conv9 = self.relu9(self.conv9(image_conv8))
        dehaze_concat3 = self.concat3([dehaze_concat2, image_conv7, image_conv8, image_conv9])

        image_conv10 = self.conv10(dehaze_concat3)
        out = self.add1([inputs, image_conv10])
        return out

def parse_function(filename, label):
    filename_image_string = tf.io.read_file(filename)
    label_image_string = tf.io.read_file(label)
    # Decode the filename_image_string
    filename_image = tf.image.decode_bmp(filename_image_string, channels=3)
    filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
    # Decode the label_image_string
    label_image = tf.image.decode_bmp(label_image_string, channels=3)
    label_image = tf.image.convert_image_dtype(label_image, tf.float32)
    return filename_image, label_image

def combloss (y_actual, y_predicted):
    '''
    This is the custom loss function for keras model
    :param y_actual:
    :param y_predicted:
    :return:
    '''
    # this is just l2 + lssim
    lssim = tf.constant(1, dtype=tf.float32) - tf.reduce_mean(tf.image.ssim(y_actual, y_predicted, max_val=1, filter_size=13)) #remove max_val=1.0
    lmse = MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)(y_actual, y_predicted)
    lmse = tf.math.multiply(lmse, 4)
    return tf.math.add(lmse, lssim)

def train(datafile="data.csv", ckptpath="./train_type1/cp.ckpt", type='type1'):
    df = pd.read_csv(datafile)
    augfiles = list(df["AUGFILE"])
    gtfiles = list(df["GTFILE"])

    augImages = tf.constant(augfiles)
    gtImages = tf.constant(gtfiles)

    dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))
    dataset = dataset.shuffle(len(augImages))
    #dataset = dataset.repeat()
    dataset = dataset.map(parse_function).batch(10)

    # Call backs
    #checkpoint_path = "./train_type1/cp.ckpt"
    checkpoint_path = ckptpath
    checkpoint_dir = os.path.dirname(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

    model = UWCNN()
    model.compile(optimizer=Adam(), loss=combloss)
    model.fit(dataset, epochs=40, callbacks=[cp_callback])

    os.listdir(checkpoint_dir)
    #model.save('saved_model/my_model')
    model.save('save_model/'+type)



def model_test(imgfile="12433.png", ckdir="./train_type1/cp.ckpt", outdir="./results/", type='type1'):
    # model = tf.keras.models.load_model('save_model/'+type, custom_objects={'loss': combloss}, compile=False)

    model = UWCNN()
    # model.summary()
    model.compile(optimizer=Adam(), loss=combloss)
    model.load_weights(ckdir)
    filename_image_string = tf.io.read_file(imgfile)
    filename_image = tf.image.decode_png(filename_image_string, channels=3)
    filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
    filename_image = tf.image.resize(filename_image, (460, 620))
    l, w, c = filename_image.shape
    filename_image = tf.reshape(filename_image, [1, l, w, c])
    output = model.predict(filename_image)
    output = output.reshape((l, w, c)) * 255
    cv2.imwrite(outdir+type+"_"+os.path.basename(imgfile), output)


def eval_dir():

    dir_a=r'F:\project\zengqiang_shuixia\raw-890-s\raw_img'

    out_dir=r'F:\project\zengqiang_shuixia\raw-890-s\result/'

    files=glob.glob(dir_a+'/*.png')

    model = UWCNN()
    # model.summary()
    model.compile(optimizer=Adam(), loss=combloss)
    model.load_weights(ckdir)

    for imgfile in files:

        filename_image_string = tf.io.read_file(imgfile)
        filename_image = tf.image.decode_png(filename_image_string, channels=3)
        filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
        filename_image = tf.image.resize(filename_image, (256, 256))
        l, w, c = filename_image.shape
        filename_image = tf.reshape(filename_image, [1, l, w, c])
        output = model.predict(filename_image)
        output = output.reshape((l, w, c)) * 255
        cv2.imwrite(out_dir+ os.path.basename(imgfile), output)

if __name__ == "__main__":
    # train(datafile="data_type1.csv", ckptpath="./train_type1/cp.ckpt", type='type1')

    type = "type1"
    ckdir = "./train_type1/cp.ckpt"
    # model_test(imgfile="./test_images/532_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    eval_dir()
    # model_test(imgdir="./test_images/", imgfile="602_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="617_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="12422.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="12433.png", ckdir=ckdir, outdir="./results/", type=type)



直方图均衡化:

# This code normalizes the output images in HSI space
# Code Inspired By: Li C, Anwar S, Porikli F. Underwater scene prior inspired deep learning image and video enhancement[J]. Pattern Recognition, 2020, 98: 107038
# Implementation By: Max Midwinter
#HSI and RGB conversion code by: DaiPuWei

import os
import cv2
import numpy as np
import math

def RGB2HSI(rgb_img):
    """
         This is the function to convert RGB color image to HSI image
         :param rgm_img: RGB color image
         :return: HSI image
    """
    #Save the number of rows and columns of the original image
    row = np.shape(rgb_img)[0]
    col = np.shape(rgb_img)[1]
    #Copy the original image
    hsi_img = rgb_img.copy()
    #Channel splitting the image
    B,G,R = cv2.split(rgb_img)
    #R, G, B = cv2.split(rgb_img)
    # Normalize the channel to [0,1]
    [B,G,R] = [ i/ 255.0 for i in ([B,G,R])]
    H = np.zeros((row, col))    #Define H channel
    I = (R + G + B) / 3.0       #Calculate I channel
    S = np.zeros((row,col))      #Define S channel
    for i in range(row):
        den = np.sqrt((R[i]-G[i])**2+(R[i]-B[i])*(G[i]-B[i]))
        thetha = np.arccos(0.5*(R[i]-B[i]+R[i]-G[i])/den)   #Calculate the included angle
        h = np.zeros(col)               #Define temporary array
        #den>0 and G>=B element h is assigned to thetha
        h[B[i]<=G[i]] = thetha[B[i]<=G[i]]
        #den>0 and G<=B element h is assigned to thetha
        h[G[i]<B[i]] = 2*np.pi-thetha[G[i]<B[i]]
        #den<0 element h is assigned a value of 0
        h[den == 0] = 0
        H[i] = h/(2*np.pi)      #Assign to the H channel after radiating
    #Calculate S channel
    for i in range(row):
        min = []
        #Find the minimum value of each group of RGB values
        for j in range(col):
            arr = [B[i][j],G[i][j],R[i][j]]
            min.append(np.min(arr))
        min = np.array(min)
        #Calculate S channel
        S[i] = 1 - min*3/(R[i]+B[i]+G[i])
        #I is 0 directly assigned to 0
        S[i][R[i]+B[i]+G[i] == 0] = 0
    #Extend to 255 for easy display, generally H component is between [0,2pi], S and I are between [0,1]
    hsi_img[:,:,0] = H*255
    hsi_img[:,:,1] = S*255
    hsi_img[:,:,2] = I*255
    return hsi_img

def HSI2RGB(hsi_img):
    """
         This is the function to convert HSI image to RGB image
         :param hsi_img: HSI color image
         :return: RGB image
    """
    # Save the number of rows and columns of the original image
    row = np.shape(hsi_img)[0]
    col = np.shape(hsi_img)[1]
    #Copy the original image
    rgb_img = hsi_img.copy()
    #Channel splitting the image
    H,S,I = cv2.split(hsi_img)
    # Normalize the channel to [0,1]
    [H,S,I] = [ i/ 255.0 for i in ([H,S,I])]
    R,G,B = H,S,I
    for i in range(row):
        h = H[i]*2*np.pi
        #H is greater than or equal to 0 and less than 120 degrees
        a1 = h >=0
        a2 = h < 2*np.pi/3
        a = a1 & a2         #Fancy index of the first case
        tmp = np.cos(np.pi / 3 - h)
        b = I[i] * (1 - S[i])
        r = I[i]*(1+S[i]*np.cos(h)/tmp)
        g = 3*I[i]-r-b
        B[i][a] = b[a]
        R[i][a] = r[a]
        G[i][a] = g[a]
        #H is greater than or equal to 120 degrees and less than 240 degrees
        a1 = h >= 2*np.pi/3
        a2 = h < 4*np.pi/3
        a = a1 & a2         #Fancy index of the second case
        tmp = np.cos(np.pi - h)
        r = I[i] * (1 - S[i])
        g = I[i]*(1+S[i]*np.cos(h-2*np.pi/3)/tmp)
        b = 3 * I[i] - r - g
        R[i][a] = r[a]
        G[i][a] = g[a]
        B[i][a] = b[a]
        #H is greater than or equal to 240 degrees and less than 360 degrees
        a1 = h >= 4 * np.pi / 3
        a2 = h < 2 * np.pi
        a = a1 & a2             #Fancy index of the third case
        tmp = np.cos(5 * np.pi / 3 - h)
        g = I[i] * (1-S[i])
        b = I[i]*(1+S[i]*np.cos(h-4*np.pi/3)/tmp)
        r = 3 * I[i] - g - b
        B[i][a] = b[a]
        G[i][a] = g[a]
        R[i][a] = r[a]
    rgb_img[:,:,0] = B*255
    rgb_img[:,:,1] = G*255
    rgb_img[:,:,2] = R*255
    return rgb_img

def transform (dir = None):
    img = cv2.imread(dir)
    hsi_img = RGB2HSI(img)/255
    h, s, i = cv2.split(hsi_img)
    s = cv2.normalize(s, dst=None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
    i = cv2.normalize(i, dst=None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
    norm_hsi_img = cv2.merge((h, s, i))
    norm_hsi_img = norm_hsi_img*255
    rgb_img = HSI2RGB(norm_hsi_img)
    cv2.imwrite(dir, rgb_img)

def allImgInDir (path = './results'):
    fname = []
    for root, d_names, f_names in os.walk(path):
        for f in f_names:
            fname.append(os.path.join(root, f))
            print("File: "+str(f))
            transform(os.path.join(root, f))
    print("fname  %s" % fname)

if __name__ ==  "__main__":
    allImgInDir('./results')

CycleGan增强 2个项目

GitHub - ioannispol/UnderWaterGAN: CycleGAN model to generate images with underwater features

 

GitHub - darkmatter18/Underwater-image-enhancement: A Deep Learning CycleGAN Based application, that can enhance the underwater images.

5年前: 

GitHub - aitorzip/PyTorch-CycleGAN: A clean and readable Pytorch implementation of CycleGAN

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图像增强 cnn 的相关文章

随机推荐

  • Linux下在非root权限下修改gcc版本(亲测可用)

    一 前言 最近在安装 香港中文大学 商汤科技联合实验室开源的基于 PyTorch 的检测库 mmdetection时候发现gcc版本需要在4 9以上 但是考虑到实验室服务器集群上gcc的版本还是比较旧的 作为一个非root用户又没有操作权限
  • STM32移植到GD32(以32的工程为模板简单三步完成移植)

    STM32移植到GD32 一 移植说明 最近有个项目想用GD替代原有的STM32 因为GD的成本更低 然后我就找了一些GD的资料 发现目前网上已有的一些资料都比较老 比如ST移植到GD的攻略 很多都停留在GD刚推广不久的过渡时期 目前已经不
  • 数据结构之排序:快速排序

    快速排序 Quick Sort 由 C A Hoare 在1962年提出 是冒泡排序的一种改进 采用了分治策略 将原问题划分成若干个规模更小但与原问题相似的子问题 然后递归方法解决 合并问题的解 基本思想 通过一趟排序将序列分割成独立的两个
  • C++ & QT 琐碎知识点

    此文仅记录C 和QT 学习过程中一些琐碎知识点 shadow build 是将源码路径和构建路径分开 主要将makefile和其它生成的文件分开 保证源码文件的清洁 qmke和cmake都有采用 pro user 用于记录打开工程的路径 所
  • SpringCloud Stream消息驱动

    目录 一 SpringCloud Stream概述 二 Binder 三 Consumer Groups 针对消费者 四 Publish Subscribe 介绍一下yml配置的含义 五 消息分组 六 消息分区 1 生产者方配置 2 消费者
  • osgEarth的Rex引擎原理分析(三十五)osgEarth地球椭球体ellipsoid 大地基准面datum 地图投影Projection详解

    目标 二十九 中的问题83 地球椭球体的中心为地心 形状为椭球体 大地基准面是适应某一区域的椭球体 球体中心不一定在地心 地图投影是球面和平面映射关系的方法 Horizontal Datum A datum is a reference p
  • 学Transformer前,你需要了解的Attention机制(基于注意力机制的Seq2seq可视化神经机器翻译模型)

    在我们开始学习transformer之前 应该了先解下什么是attention注意力机制 相关内容获取 欢迎关注公众号 AI技术星球 发送 222 序列到序列 Sequence to sequence 模型已经在机器翻译 文本摘要和图像字幕
  • ubuntu18.04 配置nfs服务

    1 安装nfs服务器软件 sudo apt install nfs kernel server 2 修改配置文件 添加nfs server上用于共享的目录 并设置允许访问该目录的客户机IP 及其读写权限 sudo vim etc expor
  • YOLOV5 和 Yolov5s各个版本的 发展史、论文、各个版本代码资源分享合集 !!!

    点击上方 码农的后花园 选择 星标 公众号 精选文章 第一时间送达 2020年2月份YOLO之父Joseph Redmon宣布退出计算机视觉的研究的时候 很多人都以为目标检测神器YOLO系列就此终结 没想到的是 2020年4月份曾经参与YO
  • 后台获取数据库时间出现的格式问题记录

    问题描述 要从数据库获取时间类型然后传给前台页面 数据库中的时间格式是yyyy MM dd HH mm ss 如图所示 但是获取出来时格式就变成了Mon Dec 13 10 04 16 CST 2021这种 解决 可以在前端或者后端解决 后
  • MYSQL数据库和表

    一 安装MYSQL数据库时生成系统使用的数据库 1 显示数据库 2 创建数据库 3 选择数据库 mysql gt use stusys Database changed 4 修改数据库 mysql gt alter database stu
  • win32应用程序_不是有效的win32应用程序怎么解决

    在日常办公中经常用到电脑 有许多使用技巧 本次给大家介绍不是有效的win32应用程序怎么办 快来看看吧 方法一 不是有效的win32应用程序表示这个应用程序和系统不兼容 用户可以在计算机属性页面查看系统是32位还是64位 之后下载相对应的应
  • 【源码】贝叶斯变化点检测与时间序列分解

    BEAST 突变 季节性和趋势的贝叶斯估计器 是一种快速 通用的贝叶斯模型平均算法 用于将时间序列或1D序列数据分解为单个组件 例如突变 趋势和周期 季节性变化 如Zhao等人 2019 所述 BEAST可用于变化点检测 即断点或结构中断
  • mysql进阶1——proxysql中间件

    文章目录 一 基本了解 二 安装部署 三 proxysql管理配置 3 1 内置库 3 1 1 main库表 3 1 2 stats库表 3 1 3 monitor库 3 2 常用管理变量 3 2 1 添加管理用户 3 2 2 添加普通用户
  • WSL无法访问网络的解决办法

    今天在用WSL的时候突然网络抽风 域名解析出了问题 apt update都用不了 网上查了很多方法 什么vEthernet的IP啊 ifconfigip啊 ip route add default啥的 都不管用 最后还是看了一下 etc r
  • 多益网络提前批前端面试(凉)

    题外话 面试时间是晚上7点多 多益还是加班严重啊 这点哈哈哈哈 下面正文 自我介绍 问项目 问看过的书籍 这里就是挖坑了 尽量找自己会的说 XHR HTTP1和HTTP2的区别 隐藏一个HTML标签 v for 为什么不能用index做ke
  • QT基础部件学习笔记

    目录 一 QT程序开发流程 二 QT基础部件分类 1 按钮类 普通 工具 单选 复选 命令连接 编辑 编辑 2 布局类 水平 垂直 网格 两列 该类的实例具体与其他类同时使用 编辑 3 输出类 标签 文本浏览器 日历 七段数码管 进度条 4
  • 反编译解析数组为什么可以使用foreach

    反编译解析数组为什么可以使用foreach 一 说明 二 集合使用foreach 三 数组使用foreach 四 数组使用for 五 javap反编译程序 5 1 TestCollection结果 5 2 TestArray结果 5 3 T
  • 阿里云mysql gtid_阿里云RDS mysql报错:Statement violates GTID consistency

    近日有用户反馈使用RDS mysql8 0时 在执行语句 create table select时报错了 主要错误是 Statement violates GTID consistency 字面理解是语句违反GTID一致性 报错截图 Sta
  • 图像增强 cnn

    目录 实时图像增强 基于 间距自适应查找表 的方法 CVPR 2022 Image Adaptive 3DLUT 水下图像增强UWCNN wtf 直方图均衡化 CycleGan增强 2个项目 实时图像增强 基于 间距自适应查找表 的方法 C