TCN

本文主要记录个人在学习TCN算法时间卷积网络(Temporal Convolutional Network)的个人理解。

摘要

TCN(Temporal Convolutional Networks)·1可以解释为时间卷积网络,此网络主要是把CNN中卷积的思想带入到了类似RNN的时序网络中,得到了极好的预测精度。

算法模型细节

膨胀因果卷积

TCN使用了机器学习中很常用的卷积运算,此部分算法细节我写在了FCN全卷积神经网络的文章中,在此就不再详细介绍。

在TCN网络中使用的卷积方法是膨胀因果卷积,此卷积主要特点是由膨胀卷积以及因果卷积组成。

膨胀卷积

膨胀卷积的做法是在普通卷积中加入空洞,以此来增加感受野,具体的计算方法如图所示2

具体到代码上,Pytorch一维卷积函数torch.nn.Conv1d的初始化参数直接包含了dilations参数也即膨胀系数,相关代码的修改编写非常方便。

因果卷积

因果卷积也即输入输出具有因果性,t时刻的输出只依赖于t及t以前的时刻的信息,不依赖t时刻以后的信息。图示为典型的因果卷积示意图3

如果按照上面示意图的结构,那么如果结果与前n个时间节点有关,那么卷积层数就要有n-1层,这样一般会导致网络层数太多深度太深,所以TCN使用了膨胀因果卷积,膨胀因子由dilations参数给出,依次是2的倍数,这有一个简单易懂的动图(膨胀因子为[1, 2, 4, 8]):

残差块结构

整个模型深度仍然很深,为减少过深网络带来的梯度消失等问题,TCN引入了和ResNet网络类似的残差块设计,将层与层之间的连接变成了残差结构,考虑到输入输出通道数可能不同,所以但通道数不同时引入了一个卷积层在做相加。

Pytorch 代码

以下代码主要来原论文作者提供的代码4,这里是TCN架构的核心部分。

TCN架构的Pytorch代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
from torch.nn.utils import weight_norm


# 裁剪模块,裁剪掉多余的padding
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size

def forward(self, x):
return x[:, :, :-self.chomp_size].contiguous()


# 相当于一个残差网络模块
class TemporalBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
super(TemporalBlock, self).__init__()
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)

self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)

self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()

def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)

def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)


# TCN网络模块
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size-1) * dilation_size, dropout=dropout)]

self.network = nn.Sequential(*layers)

def forward(self, x):
return self.network(x)

  1. TCN 网络原论文arxiv地址

  2. GitHub - vdumoulin/conv_arithmetic: A technical report on convolution arithmetic in the context of deep learning

  3. 因果卷积示意图来源,Wavenet模型论文

  4. TCN网络原作者公开的Pytorch代码

正在加载今日诗词....