keras 自定义loss层+接受输入实例

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):

 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):

  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)

 def call(self, inputs):

  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x

y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight

# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))

# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])

# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)

# Add the class weights to the training
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

时间: 2020-06-26

Keras之自定义损失(loss)函数用法说明

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式: def my_loss(y_true, y_pred): # y_true: True labels. TensorFlow/Theano tensor # y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true . . . return scalar #返回一个标量

使用Keras加载含有自定义层或函数的模型操作

当我们导入的模型含有自定义层或者自定义函数时,需要使用custom_objects来指定目标层或目标函数. 例如: 我的一个模型含有自定义层"SincConv1D",需要使用下面的代码导入: from keras.models import load_model model = load_model('model.h5', custom_objects={'SincConv1D': SincConv1D}) 如果不加custom_objects指定目标层Layer,则会出现以下报错:

keras 自定义loss model.add_loss的使用详解

一点见解,不断学习,欢迎指正 1.自定义loss层作为网络一层加进model,同时该loss的输出作为网络优化的目标函数 from keras.models import Model import keras.layers as KL import keras.backend as K import numpy as np from keras.utils.vis_utils import plot_model x_train=np.random.normal(1,1,(100,784)) x_

keras 自定义loss损失函数,sample在loss上的加权和metric详解

首先辨析一下概念: 1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的 2. metric只是作为评价网络表现的一种"指标", 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程 在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如: # 方式一 def vae_loss(x, x_decoded_mean): xent_loss = objectives.binary_

JavaWeb实现文件上传与下载实例详解

在Web应用程序开发中,文件上传与下载功能是非常常用的功能,下面通过本文给大家介绍JavaWeb实现文件上传与下载实例详解. 对于文件上传,浏览器在上传的过程中是将文件以流的形式提交到服务器端的,如果直接使用Servlet获取上传文件的输入流然后再解析里面的请求参数是比较麻烦,所以一般选择采用apache的开源工具common-fileupload这个文件上传组件.这个common-fileupload上传组件的jar包可以去apache官网上面下载,common-fileupload是依赖于c

Apache 文件上传与文件下载案例详解

写一个Apache文件上传与文件下载的案例:以供今后学习 web.xml配置如下: <span style="font-family:SimSun;font-size:14px;"><?xml version="1.0" encoding="UTF-8"?> <web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns=&

FasfDFS整合Java实现文件上传下载功能实例详解

在上篇文章给大家介绍了FastDFS安装和配置整合Nginx-1.13.3的方法,大家可以点击查看下. 今天使用Java代码实现文件的上传和下载.对此作者提供了Java API支持,下载fastdfs-client-java将源码添加到项目中.或者在Maven项目pom.xml文件中添加依赖 <dependency> <groupId>org.csource</groupId> <artifactId>fastdfs-client-java</arti

Socket+JDBC+IO实现Java文件上传下载器DEMO详解

该demo实现的功能有: 1.用户注册: 注册时输入两次密码,若两次输入不一致,则注册失败,需要重新输入.若用户名被注册过,则提示用户重新输入用户名: 2.用户登录: 需要验证数据库中是否有对应的用户名和密码,若密码输错三次,则终止用户的登录操作: 3.文件上传: 从本地上传文件到文件数据库中 4.文件下载: 从数据库中下载文件到本地 5.文件更新: 根据id可更新数据库中的文件名 6.文件删除: 根据id删除数据库中某一个文件 7.看数据库所有文件; 8.查看文件(根据用户名); 9.查看文件

MySql在Mac上的安装与配置详解

一.下载安装 官网下载社区版dmg安装文件: https://dev.mysql.com/downloads/mysql/ 1.执行安装文件,按步骤完成安装. 2.安装完成后终端输入: mysql --version; ----显示版本号说明正常,若显示command not found,在终端输入如下,"/usr/local/mysql/bin/mysql"为mysql默认安装路径: $ cd /usr/local/bin/ $ sudo ln -fs /usr/local/mysq

微信小程序 上传头像的实例详解

微信小程序 上传头像的实例详解 最近在做微信小程序上传头像和上传照片功能就随手写一下代码: 上传头像html: <view class="edit-list"> <text class="list-name list-first">头像</text> <view class="edit-righr-bar"> <image class="head-portrait" src

基于BootStrap Metronic开发框架经验小结【五】Bootstrap File Input文件上传插件的用法详解

Bootstrap文件上传插件File Input是一个不错的文件上传控件,但是搜索使用到的案例不多,使用的时候,也是一步一个脚印一样摸着石头过河,这个控件在界面呈现上,叫我之前使用过的Uploadify 好看一些,功能也强大些,本文主要基于我自己的框架代码案例,介绍其中文件上传插件File Input的使用. 1.文件上传插件File Input介绍 这个插件主页地址是:http://plugins.krajee.com/file-input,可以从这里看到很多Demo的代码展示:http:/

spring boot 图片上传与显示功能实例详解

首先描述一下问题,spring boot 使用的是内嵌的tomcat, 所以不清楚文件上传到哪里去了, 而且spring boot 把静态的文件全部在启动的时候都会加载到classpath的目录下的,所以上传的文件不知相对于应用目录在哪,也不知怎么写访问路径合适,对于新手的自己真的一头雾水. 后面想起了官方的例子,没想到一开始被自己找到的官方例子,后面太依赖百度谷歌了,结果发现只有官方的例子能帮上忙,而且帮上大忙,直接上密码的代码 package hello; import static org

matplotlib在python上绘制3D散点图实例详解

大家可以先参考官方演示文档: 效果图: ''' ============== 3D scatterplot ============== Demonstration of a basic scatterplot in 3D. ''' from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np def randrange(n, vmin, vmax): ''' Helper f

使用 Python 在京东上抢口罩的思路详解

全国抗"疫"这么久终于见到曙光,在家待了将近一个月,现在终于可以去上班了,可是却发现出门必备的口罩却一直买不到.最近看到京东上每天都会有口罩的秒杀活动,试了几次却怎么也抢不到,到了抢购的时间,浏览器的页面根本就刷新不出来,等刷出来秒杀也结束了.现在每天只放出一万个,却有几百万人在抢,很想知道别人是怎么抢到的,于是就在网上找了大神公开出来的抢购代码.看了下代码并不复杂,现在我们就报着学习的态度一起看看. 使用模块 首先打开项目中 requirements.txt 文件,看下它都需要哪些模