如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU
发表于:2025-12-03 作者:千家信息网编辑
千家信息网最后更新 2025年12月03日,这篇文章将为大家详细讲解有关如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识
千家信息网最后更新 2025年12月03日如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU
这篇文章将为大家详细讲解有关如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。
总结说明:GuidedBackpropReLU和ReLU的区别很明显,在反向传播时候,仅传播从上一层接受到的正数梯度,将负数梯度直接置零.而ReLU则全部接受上一层的梯度,不论该梯度值是正数还是负数.
实验代码展示(实验中在第58和59行,将coefficient设置成-1和+1会出现不同的效果):
import torchfrom torch.autograd import Functionclass GuidedBackpropReLU(Function):'''特殊的ReLU,区别在于反向传播时候只考虑大于零的输入和大于零的梯度''' # @staticmethod# def forward(ctx, input_img): # torch.Size([1, 64, 112, 112])# positive_mask = (input_img > 0).type_as(input_img) # torch.Size([1, 64, 112, 112])# # output = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask)# output = input_img * positive_mask # 这行代码和上一行的功能相同# ctx.save_for_backward(input_img, output)# return output # torch.Size([1, 64, 112, 112])# 上部分定义的函数功能和以下定义的函数一致@staticmethoddef forward(ctx, input_img): # torch.Size([1, 64, 112, 112])output = torch.clamp(input_img, min=0.0)# print('函数中的输入张量requires_grad',input_img.requires_grad)ctx.save_for_backward(input_img, output)return output # torch.Size([1, 64, 112, 112])@staticmethoddef backward(ctx, grad_output): # torch.Size([1, 2048, 7, 7])input_img, output = ctx.saved_tensors # torch.Size([1, 2048, 7, 7]) torch.Size([1, 2048, 7, 7])# grad_input = None # 这行代码没作用positive_mask_1 = (input_img > 0).type_as(grad_output) # torch.Size([1, 2048, 7, 7]) 输入的特征大于零positive_mask_2 = (grad_output > 0).type_as(grad_output) # torch.Size([1, 2048, 7, 7]) 梯度大于零# grad_input = torch.addcmul(# torch.zeros(input_img.size()).type_as(input_img),# torch.addcmul(# torch.zeros(input_img.size()).type_as(input_img), # grad_output,# positive_mask_1# ), # positive_mask_2# )grad_input = grad_output * positive_mask_1 * positive_mask_2 # 这行代码的作用和上一行代码相同return grad_inputtorch.manual_seed(seed=20200910)size = (3,5)input_data_1 = input = torch.randn(*size, requires_grad=True)torch.manual_seed(seed=20200910)input_data_2 = input = torch.randn(*size, requires_grad=True)torch.manual_seed(seed=20200910)input_data_3 = input = torch.randn(*size, requires_grad=True)print('这三个输入数据的维度分别是:', input_data_1.shape, input_data_2.shape, input_data_3.shape)# print(input_data_1)# print(input_data_2)# print(input_data_3)coefficient = -1.0# coefficient = 1.0loss_1 = coefficient * torch.sum(torch.nn.ReLU()(input_data_1))loss_2 = coefficient * torch.sum(torch.nn.functional.relu(input_data_2))loss_3 = coefficient * torch.sum(GuidedBackpropReLU.apply(input_data_3))loss_1.backward()loss_2.backward()loss_3.backward()print(loss_1, loss_2, loss_3)print(loss_1.item(), loss_2.item(), loss_3.item())print('三个损失值是否相等', loss_1.item() == loss_2.item() == loss_3.item())print('简略打印三个梯度信息...')print(input_data_1.grad)print(input_data_2.grad)print(input_data_3.grad)print('这三个梯度的维度分别是:', input_data_1.grad.shape, input_data_2.grad.shape, input_data_3.grad.shape)print('检查这三个梯度是否两两相等...')print(torch.equal(input_data_1.grad, input_data_2.grad))print(torch.equal(input_data_1.grad, input_data_3.grad))print(torch.equal(input_data_2.grad, input_data_3.grad))控制台输出(#58 coefficient = -1.0):
Windows PowerShell版权所有 (C) Microsoft Corporation。保留所有权利。尝试新的跨平台 PowerShell https://aka.ms/pscore6加载个人及系统配置文件用了 842 毫秒。(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch2_2_0(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch2_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2021.1.502429796\pythonFiles\lib\python\debugpy\launcher' '62123' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\testReLU.py'这三个输入数据的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])tensor(-7.1553, grad_fn=) tensor(-7.1553, grad_fn= ) tensor(-7.1553, grad_fn= ) -7.155285835266113 -7.155285835266113 -7.155285835266113三个损失值是否相等 True简略打印三个梯度信息...tensor([[-1., 0., -1., 0., 0.], [-1., -1., 0., -1., -1.], [-1., -1., 0., 0., 0.]])tensor([[-1., 0., -1., 0., 0.], [-1., -1., 0., -1., -1.], [-1., -1., 0., 0., 0.]])tensor([[-0., -0., -0., -0., -0.], [-0., -0., -0., -0., -0.], [-0., -0., -0., -0., -0.]])这三个梯度的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])检查这三个梯度是否两两相等...TrueFalseFalse(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>
控制台输出(#59 coefficient = 1.0):
Windows PowerShell版权所有 (C) Microsoft Corporation。保留所有权利。尝试新的跨平台 PowerShell https://aka.ms/pscore6加载个人及系统配置文件用了 846 毫秒。(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch2_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2021.1.502429796\pythonFiles\lib\python\debugpy\launcher' '62135' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\testReLU.py'这三个输入数据的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])tensor(7.1553, grad_fn=) tensor(7.1553, grad_fn= ) tensor(7.1553, grad_fn= )7.155285835266113 7.155285835266113 7.155285835266113三个损失值是否相等 True简略打印三个梯度信息...tensor([[1., 0., 1., 0., 0.], [1., 1., 0., 1., 1.], [1., 1., 0., 0., 0.]])tensor([[1., 0., 1., 0., 0.], [1., 1., 0., 1., 1.], [1., 1., 0., 0., 0.]])tensor([[1., 0., 1., 0., 0.], [1., 1., 0., 1., 1.], [1., 1., 0., 0., 0.]])这三个梯度的维度分别是: torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5])检查这三个梯度是否两两相等...TrueTrueTrue(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch2_2_0(ssd4pytorch2_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>
关于如何对比pytorch的ReLU和自定义的class GuidedBackpropReLU就分享到这里了,希望以上内容可以对大家有一定的帮助,可以学到更多知识。如果觉得文章不错,可以把它分享出去让更多的人看到。
三个
梯度
维度
输入
代码
简略
信息
函数
损失
数据
传播
检查
相同
一行
个人
作用
内容
功能
控制台
文件
数据库的安全要保护哪些东西
数据库安全各自的含义是什么
生产安全数据库录入
数据库的安全性及管理
数据库安全策略包含哪些
海淀数据库安全审计系统
建立农村房屋安全信息数据库
易用的数据库客户端支持安全管理
连接数据库失败ssl安全错误
数据库的锁怎样保障安全
怎样设置手机服务器网速快
mc手机版侧搭服务器推荐
通用服务器如何挑选
软件开发项目监理需求
网络安全和信息化中心单位好吗
数据库安全网关功能
建材北京网络技术有限公司
松江区工商软件开发服务价格
服务器计算机管理
上海凤循网络技术有限公司
网络安全工作的目标是
网约车平台数据库接入
阿里云服务器配置
数据库的安全事件
全场景网络安全防护策略
奇峰网络技术
梦幻限时服务器
大芒果服务器
ibm软件开发会问哪些方面
突出安全保障提升网络安全防护力
软件开发方法所对应的模型
乐高与软件开发
软件开发sop模板
软件开发申请专利
数据库技术选择
企业邮箱海外服务器
论文数据库在哪找
db数据库查看存储过程
丰城网络安全专业
ibm软件开发会问哪些方面