结合代码理解ST-GCN的图卷积

Intro

本文主要对时空图卷积论文中gcn的实现,就个人的理解做一些介绍.原论文如下:

1
2
3
4
5
6
@inproceedings{stgcn2018aaai,
title = {Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition},
author = {Sijie Yan and Yuanjun Xiong and Dahua Lin},
booktitle = {AAAI},
year = {2018},
}

原文中3.6节 "Implementing ST-GCN"对于gcn的实现说的并不是很清楚,数学符号不太容易理解.

GCN实现原理

图卷积(GCN)是由卷积网络(CNN)推广而来,是CNN的超集,CNN可以看做是GCN的一个特例.
考虑一个3×33 \times 3的卷积核,该卷积核作用于一个3×33 \times 3的feature map时,可以看做九个元素和权重相乘在加和,得到一个标量,并作为中间位置的值.
cnn neighbors
如图,对于中间红色节点,周围8个蓝色节点就是其8个距离为1的邻居组(红色节点自身可认为是距离自己为0的邻居),每个邻居组里有1个邻居.卷积对于每个邻居组有一个独立的权重,考虑距离为0和1的邻居,共9组,所以该卷积有9个权重参数.
用图结构来理解上述卷积,对于当前红色节点,其距离为1的邻居共有8组,每组中有1个节点,距离为0的节点有1个,对所有节点加权求和,即为图卷积.

节点分区

节点分区指的是将当前节点的邻居节点分为若干邻居组,又可称为labeling, coloring, partitioning等.
分区策略在论文3.4节 Partition Strategies有比较详细的介绍,也很好理解.下文中选取空间分区法(Spatial configuration partitioning),按照节点与重心距离的不同分为3类.第一类是当前节点本身,第二类是与重心距离小于当前节点与重心距离的,第三类是与重心距离大于当前节点与重心距离的.类比到上面提到的 3×33 \times 3卷积的8个邻居,这种分区方法只能得到2个距离为1的邻居组和1个距离为0的当前节点,实现图卷积所需要的参数也只有3个.
与上述卷积邻居组不同的是,卷积的每个邻居组确定地有且只有1个元素(节点),而图卷积的邻居组中可能含有多个节点,也可能没有节点.这时,就需要对图卷积的邻居组中的所有节点进行一些处理,使1个邻居组体现为1个节点,最简单的方法就是在各特征维度上加权求和(或加权平均).例如,输入特征是3维,分别是xyz坐标,一个邻居组中有多组坐标,则加权平均,最终得到一个合成的3维坐标作为该邻居组的特征向量.

构造邻接阵A

将当前节点和其1-近邻分为3组后,需要对每组节点进行aggregation,论文中采取了一种非常简单的方法,即取平均.如下图,是我论文Kinectics的模型中导出的邻接阵A的具体值.其关键点共18个,用openpose提取的.关键点编号与实际关节对应关系参考openpose项目.
adj mat
每一列对应某个节点及其两个邻居组中的节点,观察每一列可以发现,如果一个邻居组中有2个节点,每个节点的feature就乘0.5,有3个节点的话,就乘0.333.实现起来也很方便,直接将feature矩阵resize为 (1,18)(-1, 18) 再分别乘这3个矩阵,再resize回原维度即可.
但是邻居组中每个节点对于邻居的特征贡献可能是不平均的,有的节点影响会更大,文中便提出了用一个 learnable weight matrix M与A按元素相乘,在训练过程中学习每个节点的贡献度.M的值我也从Kinectics模型中导出如下:
mask M

利用卷积实现图卷积

至此,邻居组有了,邻居组的特征有了,就要进行卷积(加权求和)了.考虑到特征通常不止1维,可以用 1×11 \times 1的卷积作用在邻居组合成特征上,得到一个邻居组的各个特征维度的加和输出.
分别对3个邻居组用1×11 \times 1卷积作用,最终加起来,即完成了图卷积.其中1×11 \times 1卷积作用相当于CNN的权重与特征相乘,加的作用相当于3×33 \times 3 CNN的9个位置的加.
具体的代码为:

1
2
3
4
5
6
for i, a in enumerate(A):
xa = x.view(-1, V).mm(a).view(N, C, T, V)
if i == 0:
y = self.conv_list[i](xa)
else:
y = y + self.conv_list[i](xa)

其中,conv_list包含3个1×11 \times 1的卷积层.

Code Tells

代码来自GitHub: yysijie/st-gcn

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
(gcn0): unit_gcn(
(conv_list): ModuleList(
(0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
(2): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU()
)
(tcn0): Unit2D(
(conv): Conv2d(64, 64, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
(relu): ReLU()
(dropout): Dropout(p=0)
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class unit_gcn(nn.Module):
def __init__(self,
in_channels,
out_channels,
A,
use_local_bn=False,
kernel_size=1,
stride=1,
mask_learning=False):
super(unit_gcn, self).__init__()

# ==========================================
# number of nodes
self.V = A.size()[-1]

# the adjacency matrixes of the graph
self.A = Variable(
A.clone(), requires_grad=False).view(-1, self.V, self.V)

# number of input channels
self.in_channels = in_channels

# number of output channels
self.out_channels = out_channels

# if true, use mask matrix to reweight the adjacency matrix
self.mask_learning = mask_learning

# number of adjacency matrix (number of partitions)
self.num_A = self.A.size()[0]

# if true, each node have specific parameters of batch normalizaion layer.
# if false, all nodes share parameters.
self.use_local_bn = use_local_bn
# ==========================================

self.conv_list = nn.ModuleList([
nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=(kernel_size, 1),
padding=(int((kernel_size - 1) / 2), 0),
stride=(stride, 1)) for i in range(self.num_A)
])

if mask_learning:
self.mask = nn.Parameter(torch.ones(self.A.size()))
if use_local_bn:
self.bn = nn.BatchNorm1d(self.out_channels * self.V)
else:
self.bn = nn.BatchNorm2d(self.out_channels)

self.relu = nn.ReLU()

# initialize
for conv in self.conv_list:
conv_init(conv)

def forward(self, x):
N, C, T, V = x.size()
self.A = self.A.cuda(x.get_device())
A = self.A

# reweight adjacency matrix
if self.mask_learning:
A = A * self.mask

# graph convolution
for i, a in enumerate(A):
xa = x.view(-1, V).mm(a).view(N, C, T, V)

if i == 0:
y = self.conv_list[i](xa)
else:
y = y + self.conv_list[i](xa)

# batch normalization
if self.use_local_bn:
y = y.permute(0, 1, 3, 2).contiguous().view(
N, self.out_channels * V, T)
y = self.bn(y)
y = y.view(N, self.out_channels, V, T).permute(0, 1, 3, 2)
else:
y = self.bn(y)

# nonliner
y = self.relu(y)

return y