解决Keras TensorFlow 混编中 trainable=False设置无效问题

这是最近碰到一个问题,先描述下问题:

首先我有一个训练好的模型(例如vgg16),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()进行权值更新,问题就出在这,明明将vgg16的模型设置为trainable=False,但是tf的优化器仍然对vgg16做权值更新

以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。

trainable=False 无效

首先,我们导入训练好的模型vgg16,对其设置成trainable=False

from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:
  layer.trainable = False

设置好trainable=False后,再次查看可训练的变量,发现并没有变化,也就是说设置无效

# 再次查看可训练的变量
tf.trainable_variables()

[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,
 <tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量

from keras import models
with tf.variable_scope('base_model'):
  base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):
  model = models.Sequential()
  model.add(base_model)
  model.add(layers.Flatten())
  model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var

[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,
<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]

# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

总结

在keras与TensorFlow混编中,keras中设置trainable=False对于TensorFlow而言并不起作用

解决的办法就是通过variable_scope对变量进行区分,在通过tf.get_collection来获取需要训练的变量,最后通过tf优化器中var_list指定训练

以上这篇解决Keras TensorFlow 混编中 trainable=False设置无效问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

时间: 2020-06-28

ASP.NET在MVC中MaxLength特性设置无效的解决方法

本文实例讲述了ASP.NET在MVC中MaxLength特性设置无效的解决方法.分享给大家供大家参考.具体分析如下: 一.问题: 在ASP.NET MVC项目中,给某个Model打上了MaxLength特性如下: 复制代码 代码如下: public class SomeClass {     [MaxLength(16, ErrorMessage = "最大长度16")]     public string SomeProperty{get;set;} } 但在其对应的表单元素中并没有

Swift和Objective-C 混编注意事项

Swift和Objective-C 混编注意事项整理: 前言 Swift已推出数年,与Objective-C相比Swift的语言机制及使用简易程度上更接地气,大大降低了iOS入门门槛.当然这对新入行的童鞋没来讲,的确算是福音,但对于整个iOS编程从业者来讲,真真是,曾几何时"高大上",转瞬之间"矮矬穷".再加上培训班横行,批量批发之下,iOS再也看不到当年的辉煌.iOS10推出后,紧跟着Xcode8也推送了更新,细心者会发现,Xcode8下iOS版本最低适配已变为i

解决Keras 与 Tensorflow 版本之间的兼容性问题

在利用Keras进行实验的时候,后端为Tensorflow,出现了以下问题: 1. 服务器端激活Anaconda环境跑程序时,实验结果很差. 环境:tensorflow 1.4.0,keras 2.1.5 2. 服务器端未激活Anaconda环境跑程序时,实验结果回到正常值. 环境:tensorflow 1.7.0,keras 2.0.8 3. 自己PC端跑相同程序时,实验结果回到正常值. 环境:tensorflow 1.6.0,keras 2.1.5 怀疑实验结果的异常性是由于Keras和Te

浅谈python和C语言混编的几种方式(推荐)

Python这些年风头一直很盛,占据了很多领域的位置,Web.大数据.人工智能.运维均有它的身影,甚至图形界面做的也很顺,乃至full-stack这个词语刚出来的时候,似乎就是为了描述它. Python虽有GIL的问题导致多线程无法充分利用多核,但后来的multiprocess可以从多进程的角度来利用多核,甚至affinity可以绑定具体的CPU核,这个问题也算得到解决.虽基本为全栈语言,但有的时候为了效率,可能还是会去考虑和C语言混编.混编是计算机里一个不可回避的话题,涉及的东西很多,技术.架

解决vue单页面应用中动态修改title问题

详细信息查看:vue-weachat-title 解决问题: 1.Vuejs 单页应用在iOS系统下部分APP的webview中 标题不能通过 document.title = xxx 的方式修改 该插件只为解决该问题而生(兼容安卓) 2.在vue单页面中,通过浏览器分享到QQ.微信等应用中的链接,只有一个首页标题和默认icon图片 已测试:APP 微信 QQ 支付宝 淘宝 安装 npm install vue-wechat-title --save 用法 1.在main.js中引入 impor

keras tensorflow 实现在python下多进程运行

如下所示: from multiprocessing import Process import os def training_function(...): import keras # 此处需要在子进程中 ... if __name__ == '__main__': p = Process(target=training_function, args=(...,)) p.start() 原文地址:https://stackoverflow.com/questions/42504669/ker

解决webview内的iframe中的事件不可用的问题

最近做Android的Webview开发,使用iframe中嵌入了很多页面,嵌入的页面却不可用,最后发现是 webView.setWebViewClient(new WebViewClient() { @Override public boolean shouldOverrideUrlLoading(WebView view, String url) { return super.shouldOverrideUrlLoading(view, url); } 不要覆写 shouldOverride

完美解决phpdoc导出文档中@package的warning及Error的错误

今天在编写PHPDoc的导出文档的时候发现一个很郁闷的错误,虽然这个warning不是什么重要错误,但是看着总是很不爽的.于是就去网上找了很多相关的资料,可是郁闷的是不知道是我用的PHPDoc版本太新(1.4的版本),还是说很多人都没遇到这个问题,反正就是没有相关的这个资料找到,只是找到了一些从PHPDocumentor官方网倒出来的关于@package的使用注意事项,然后就只能一条一条检查,看了一个版本又一个版本,总算是被我解决了. 而且发现该方案可以解决@package之类相关的错误提示:

IE6浏览器中window.location.href无效的解决方法

本文实例讲述了IE6浏览器中window.location.href无效的解决方法.分享给大家供大家参考.具体方法如下: window.location.href是js中跳转功能,很多人在ie6中都会发现window.location.href不能跳转了,下面我给大家来介绍一下其原因与解决方法. 问题代码如下: 复制代码 代码如下: <a href="javascript:void(0);" onclick="javascript:test();">点击

Android 开发 使用WebUploader解决安卓微信浏览器上传图片中遇到的bug

先给大家分析下微信浏览器上传图片bug的原因 微信在新版本中采用的是自己的X5内核浏览器,而在较老的版本中还有可能是安卓的原生浏览器.具体的环境我也不太了解,但是经过实际多台安卓机型的测试,我采取的方案可以基本确保在安卓机中微信浏览器的成功上传.苹果机型没问题,因为微信的ios客户端使用的是Safari的内核,没有各种坑,且效果最好. 这里给出一个 WebUploader 官方关于移动端适配的 issues 链接.里面提供的方法确实有效,但就是解决的方案并没有很清楚的展示出来,从该issues中