激活函数之tanh介绍及C++/PyTorch实现

2021/7/29 9:05:58

本文主要是介绍激活函数之tanh介绍及C++/PyTorch实现,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

      深度神经网络中使用的激活函数有很多种,这里介绍下tanh。它的公式如下,截图来自于维基百科(https://en.wikipedia.org/wiki/Activation_function):

      tanh又称双曲正切,它解决了sigmoid非零中心问题。tanh取值范围在(-1, 1)内,它也是非线性的。它也不能完全解决梯度消失问题。

      C++实现如下:

template<typename _Tp>
int activation_function_tanh(const _Tp* src, _Tp* dst, int length)
{
	for (int i = 0; i < length; ++i) {
		_Tp ep = std::exp(src[i]);
		_Tp em = std::exp(-src[i]);

		dst[i] = (ep - em) / (ep + em);
	}

	return 0;
}

template<typename _Tp>
int activation_function_tanh_derivative(const _Tp* src, _Tp* dst, int length)
{
	for (int i = 0; i < length; ++i) {
		dst[i] = (_Tp)1. - src[i] * src[i];
	}

	return 0;
}

int test_activation_function()
{
	std::vector<float> src{ 1.1f, -2.2f, 3.3f, 0.4f, -0.5f, -1.6f };
	int length = src.size();
	std::vector<float> dst(length);

	fprintf(stderr, "source vector: \n");
	fbc::print_matrix(src);
	fprintf(stderr, "calculate activation function:\n");

	fprintf(stderr, "type: tanh result: \n");
	fbc::activation_function_tanh(src.data(), dst.data(), length);
	fbc::print_matrix(dst);
	fprintf(stderr, "type: tanh derivative result: \n");
	fbc::activation_function_tanh_derivative(dst.data(), dst.data(), length);
	fbc::print_matrix(dst);
}

      执行结果如下:

      Python和PyTorch实现如下:

import numpy as np
import torch

data = [1.1, -2.2, 3.3, 0.4, -0.5, -1.6]

# numpy impl
def tanh(x):
	lists = list()
	for i in range(len(x)):
		lists.append((np.exp(x[i]) - np.exp(-x[i])) / (np.exp(x[i]) + np.exp(-x[i])))
	return lists

def tanh_derivative(x):
	return 1 - np.power(tanh(x), 2)

output = [round(value, 4) for value in tanh(data)] # 通过round保留小数点后4位
print("numpy tanh:", output)
print("numpt tanh derivative:", [round(value, 4) for value in tanh_derivative(data)])
print("numpt tanh derivative2:", [round(1. - value*value, 4) for value in tanh(data)])

# call pytorch interface
input = torch.FloatTensor(data)
m = torch.nn.Tanh()
output2 = m(input)
print("pytorch tanh:", output2)
print("pytorch tanh derivative:", 1. - output2*output2)

      执行结果如下:

      由以上执行结果可知:C++、Python、PyTorch三种实现方式结果完全一致。 

     GitHub:

            https://github.com/fengbingchun/NN_Test

            https://github.com/fengbingchun/PyTorch_Test



这篇关于激活函数之tanh介绍及C++/PyTorch实现的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程