使用inception模型进行迁移学习

2023-05-16

迁移学习相关知识可以参考:庄福振, 罗平, 何清,等. 迁移学习研究进展[J]. 软件学报, 2015, 26(1):26-39.
本文涉及内容主要有:
1.如何使用现有模型进行分类
2.如何得到样本的transfer-values
3.如何迁移学习
4.如何评价与验证
主要参考:https://github.com/Hvass-Labs/TensorFlow-Tutorials
其它参考:http://www.cnblogs.com/rgvb178/p/6052541.html
stackoverFlow-解决所有的问题
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
1.如何使用现有模型进行分类
模型:inception
可以直接参考上面给的github #8的例子。
注意,这里的模型restore是老方法,即现在tensorflow已经升级了模型的保存。模型使用链接已经给出。

2.如何得到样本的transfer-values
解决这个问题需要对tensorflow的神经网络具有一定的知识 ,我是个小白,所有中间绕了一些弯。仍然可以使用github #8例子,需要注意的是:
2.1 restore模型后,需要恢复使用的tensor,方法是:

self.graph.get_tensor_by_name(self.tensor_name_softmax)

其中self.tensor_name_softmax为”softmax:0”。注意这个名称的规范,softmax为tensorflow的名称,即name属性。必须使用这个形式才能正确获得tensor
2.2 需要知道神经网络模型是如何定义
我们的目标是得到transfer-value,这个值按照参考中的意思是全连接层的输入,对于inception模型是最后pool层的输出,所以,例子中的做法是:

tensor_name_transfer_layer = "pool_3:0" #对应第三个pool层
self.graph.get_tensor_by_name(self.tensor_name_softmax)

获得这个tensor后,就可以使用restore的模型得到transfer-values了,关键代码段为:

    def transfer_values(self, image_path=None, image=None):
        """
        Calculate the transfer-values for the given image.
        These are the values of the last layer of the Inception model before
        the softmax-layer, when inputting the image to the Inception model.

        The transfer-values allow us to use the Inception model in so-called
        Transfer Learning for other data-sets and different classifications.

        It may take several hours or more to calculate the transfer-values
        for all images in a data-set. It is therefore useful to cache the
        results using the function transfer_values_cache() below.

        :param image_path:
            The input image is a jpeg-file with this file-path.

        :param image:
            The input image is a 3-dim array which is already decoded.
            The pixels MUST be values between 0 and 255 (float or int).

        :return:
            The transfer-values for those images.
        """

        # Create a feed-dict for the TensorFlow graph with the input image.
        feed_dict = self._create_feed_dict(image_path=image_path, image=image)

        # Use TensorFlow to run the graph for the Inception model.
        # This calculates the values for the last layer of the Inception model
        # prior to the softmax-classification, which we call transfer-values.
        transfer_values = self.session.run(self.transfer_layer, feed_dict=feed_dict)

        # Reduce to a 1-dim array.
        transfer_values = np.squeeze(transfer_values)

        return transfer_values

2.3 使用cifar10数据集作为目标领域任务的数据集
通过Inception模型得到cifar10中的所有图片的transfer-values

3.如何迁移学习
到这一步,我们的任务已经很明了了。
目标领域任务:识别cifar10中的10个类别图形
迁移学习类型:源领域与目标领域数据集不同,任务相同
方法:通过源领域模型得到目标领域中数据
学习可以描述 为:使用2步得到的transfer-values放到新的模型中训练。直到达到一定的准确率。
新的神经网络为:

        with self.graph.as_default():####默认图与自定义图的关系
            #全连接层
            dense = tf.reshape(self.x, [-1, self.w_d.get_shape().as_list()[0]])
            dense = tf.nn.relu(tf.add(tf.matmul(dense, self.w_d), self.b_d))

            #dropout
#            dense = tf.nn.dropout(dense, self.keep_prob)

            #简单的softmax层
            y_pred = tf.matmul(dense,self.w_out) + self.b_out

            cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_pred, self.y_true))
            optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost,self.global_step)

            #output pred class
            y_pred_cls = tf.argmax(y_pred,1,name='output')
            #找到预测正确的标签
            correct_pred = tf.equal(y_pred_cls,tf.argmax(self.y_true,1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
            init = tf.global_variables_initializer()

            self.session.run(init)

可以看到,其实只有一个全连接层,输出使用简单的softmax方法。

4.如何评价与检验
相信到这一步已经可以成功的进行测试了。在我的笔记本上,先是获得transfer-values消耗了2个小时,在新的网络上训练很快就达到了目标准确率,并且与参考中结果一致。

吴恩达说过:未来属于迁移学习。通过这个 小小的实验,可以发现,通过知识迁移,不仅省去了大量的时间,而且可以达到一个很高的准确率。

下一篇,计划给大家详细展示下如何保存模型与加载模型,并使用模型分类。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用inception模型进行迁移学习 的相关文章

随机推荐