[Pytorch 源码阅读] —— TH中的 c 语言泛型编程

2021/6/20 22:56:07

本文主要是介绍[Pytorch 源码阅读] —— TH中的 c 语言泛型编程,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

文章目录

      • 前言
      • C 中宏的使用
        • 替换文本
        • 宏函数
        • 将宏转成字符串
        • 组合名字
        • 预定义宏
      • TH 中的 c 泛型编程
        • 泛型示例
      • 参考文章

前言

基于 pytorch 1.10.0 版本,master
commit 号:047925dac1c07a0ad2c86c281fac5610b084d1bd

万事开头难,还是咬着牙开始了 Pytorch 的源码阅读内容,虽然感觉难度很大,而且有点无从下手,希望坚持下去能有所进步!

这里光源码编译就花了一些时间,尝试了 macbook,windows,和 linux 下 git clone 源码然后按照官方说明操作,但是基本碍于系统或者网络问题,git 拉第三方以来的时候会有很多问题,最后只是在 linux 系统下完成了编译工作。

主要是 pytorch 会自动生成一些代码,想具体看一下自动生成的代码是啥样,所以想着源码编译一下。如果有需要并且编译也遇到各种问题的读者,可以留言评论,我把编译后的源码压缩发出来。

因为 aten/src 下面 TH 开头的文件夹主要都是用 c 写的,所以下面的内容会涉及到很多 c 语言。

C 中宏的使用

这里简单了解一下 c 语言中宏的几种用法,为下面对代码的理解做铺垫。

替换文本

#define MAX_SIZE 1024
在预处理器工作时就会把所有的 MAX_SIZE 替换成 1024
宏定义支持 ‘’,当宏命令过长时,可以用来整理代码,宏定义只是单纯的文本替换,不会有检查,所以使用需要小心。

宏函数

宏也可以有参数的
#define MIN(X, Y) ((X) < (Y)? (X): (Y))
都加小括号是为了防止二义性

将宏转成字符串

#define WARN(EXP) printf(#EXP)
在宏内部使用 ‘#’ 来产生字符串, WARN(test) 会被替换成 printf(“test”)

组合名字

当我们使用不同的宏产生名字时,最终要将它们组合起来:
#define CONCAT(A,B,C) A ## B ## C
用上面的宏产生 “Double_Matrix_add” 可以使用:
Double_Matrix CONCAT(Double, Matrix, add)(Double_Matrix *A, Double_Matrix *B); 来实现。

预定义宏

C 语言的预处理器也有一些预定义的宏,‘file’ 当前输入文件的名称,会展开为完整路径;‘LINE’ 当前输入行号。

TH 中的 c 泛型编程

对 tensor 最重要的两个结构就是THTensor 和 THStorage,根据 aten/src/README.md 文档,所有构造 THTensor 和 THStorage 的函数都是 new 开头,析构函数都是 free 开头的函数,PyTorch 通过这些命名规则来进行 tensor 的内存管理,这里以 THTensor 为例,可以在 aten/src/TH/generic/THTensor.h 看到下面的是 THTensor 的定义,其代表的是 at::TensorImpl 这种类型,然后 Pytorch 为了用来区分 tensor 类型,又定义了更细的宏定义,它们其实代表的含义都是一致的,at::TensorImpl 类型:

#define THTensor at::TensorImpl

// These used to be distinct types; for some measure of backwards compatibility and documentation
// alias these to the single THTensor type.
#define THFloatTensor THTensor
#define THDoubleTensor THTensor
#define THHalfTensor THTensor
#define THByteTensor THTensor
#define THCharTensor THTensor
#define THShortTensor THTensor
#define THIntTensor THTensor
#define THLongTensor THTensor
#define THBoolTensor THTensor
#define THBFloat16Tensor THTensor
#define THComplexFloatTensor THTensor
#define THComplexDoubleTensor THTensor

然后紧接着可以看到大量类似下面接口的声明:

TH_API THTensor* THTensor_(newWithSize1d)(int64_t size0_);
TH_API THTensor* THTensor_(newClone)(THTensor* self);
TH_API THTensor* THTensor_(newContiguous)(THTensor* tensor);
…

这里其实就用到了 c 语言的泛型编程,因为上面的 THxxTensor 的宏代表的都是 THTensor 的类型,所以其 api 基本也是一致的,为了避免相似的代码写很多遍,PyTorch 采用了宏来实现泛型。

aten/src/TH/THTensor.h 中可以看到下面的定义:

#define THTensor_(NAME)   TH_CONCAT_4(TH,Real,Tensor_,NAME)

进一步,在 torch/include/TH/THGeneral.h 中,有如下定义:

#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)

所以这个宏就是单纯的将输入字符串拼接成:THRealTensor_NAME 的形式。
这里通过真实的数据类型宏定义 Real ,就实现了组成上面不同的 THTensor 的名字。
在相关的头文件中会有相关的宏定义。例如 torch/include/TH/THGenerateCharType.h 中有下面相关代码:

#define Real Char
...
#undef Real

torch/include/TH/THGenerateFloatType.h 中有:

#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateFloatType.h"
#endif

#define scalar_t float
#define accreal double
#define Real Float
#define TH_REAL_IS_FLOAT
#line 1 TH_GENERIC_FILE
#include TH_GENERIC_FILE
#undef accreal
#undef scalar_t
#undef Real
#undef TH_REAL_IS_FLOAT

#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif

泛型示例

现在我们利用上述宏的规则需要构建一个泛型的 add 函数,形如:

struct NumVector {
  num *data;
  int n;
}
void NumVector_add(NumVector * A, NumVector * B, NumVector * C) {
  int i, j,n;
  n = C->n;
  for (i=0; i<n; i++) {
    C->data[i] = A->data[i] + B->data[i];
  }
}

现在考虑如何把 NumVector_add 特化成 FloatVector_add 等函数名称, 通过下面的宏来实现:

#define Vector_(NAME) Num##Vector_##add
#define Vector Num##Vector

#define num float
#define Num Float

struct NumVector
{
  num *data;
  int n;
};
void NumVector_add(NumVector *A, NumVector *B, NumVector *C)
{
  // codes
}

但是上述代码实际只会产生 “NumVector”的名字,因为 C 中的 # 和 ## 不会展开宏名,需要使用一个中间宏来展开宏名再组合它们:

#define CONCAT_3_HELPER(A, B, C) A ## B ## C
#define CONCAT_3(A, B, C) CONCAT_3_HELPER(A, B, C)
#define CONCAT_2_HELPER(A, B) A ## B
#define CONCAT_2(A, B) CONCAT_2_HELPER(A, B)  // 中间传递的时候会展开宏,达到想要的效果

#define Vector_(NAME) CONCAT_3(Num, Vector_, NAME)
#define Vector CONCAT_2(Num, Vector)

#define num float
#define Num Float

struct NumVector
{
  num *data;
  int n;
};
void Vector_(add)(Vector *A, Vector *B, Vector *C)
{
 // codes
}

#undef num
#undef Num

#define num double
#define Num Double

struct NumVector
{
  num *data;
  int n;
};
void Vector_(add)(Vector *A, Vector *B, Vector *C)
{
  // codes
}

#undef num
#undef Num

由上面不停的复制粘贴之前的命令并不实际,如果泛型的代码在另外一个文件里的话, 每次需要直接读取相关头文件就好了,于是:
将原先的 add.c 拓展成下面的结构:

add.h —— 用来展开 generic/add.h
add.c —— 用来展开 generic/add.c
generic/
    add.h ——  泛型 Vector 类型的定义
    add.c —— 泛型 add 函数的定义

add.h

#define Vector_(NAME) CONCAT_3(Num, Vector_, NAME)
#define Vector CONCAT_2(Num, Vector)

#define num float
#define Num Float
#include “generic/add.h”
#undef num
#undef Num

#define num double
#define Num Double
#include “generic/add.h”
#undef num
#undef Num

add.c

#include "add.h"
#define num float
#define Num Float
#include “generic/add.c”
#undef num
#undef Num

#define num double
#define Num Double
#include “generic/add.c”
#undef num
#undef Num

但是我们不能每次都写一遍 “num”这个泛型的定义,可以将它打包到一个头文件“generateFloat.h”里去,然后定义一个宏 “GENERIC_FILE” 来存储要特例化的文件名,首先判断是否有这个宏了:
generateFloat.h

#ifndef GENERIC_FILE
#error "You must define GENERIC_FILE before including GenerateFloat.h"
#endif

#define num float
#define Num Float
#line 1 GENERIC_FILE
#include GENERIC_FILE
#undef num
#undef Num

其中 #line是修改预编译的行号(__LINE__)和文件名(__FILE__),使每次加载 GENERIC_FILE 时都是从 line 1 开始的,都好像是重新读取,再修改 generic/add.h 和 generic/add.c :

generic/add.h

#ifndef GENERIC_FILE
#define GENERIC_FILE "generic/add.h"
#else

struct NumVector
{
  num *data;
  int n;
};
void Vector_(add)(Vector *A, Vector *B, Vector *C);
#endif

/generic/add.c

#ifndef GENERIC_FILE
#define GENERIC_FILE "generic/add.c"
#else
void Vector_(add)(Vector *A, Vector *B, Vector *C){
  //codes
}
#endif

通过上述修改就可以达到 C 中泛型的效果了,现在回看 PyTorch 中 TH 目录下的整体结构就是如此。

参考文章

https://zhuanlan.zhihu.com/p/34496542



这篇关于[Pytorch 源码阅读] —— TH中的 c 语言泛型编程的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程