new computational graph chapter (#419)

* new computational graph chapter

* Update main.yml

* Update main.yml

---------

Co-authored-by: Luo Mai <luo.mai.cs@gmail.com>
This commit is contained in:
Jiarong Han
2023-02-16 19:03:17 +08:00
committed by GitHub
parent 15ccf2fecf
commit 0c0bff1b83
20 changed files with 215 additions and 125 deletions

View File

@@ -1,20 +1,21 @@
## 计算图的设计背景和作用
![基于计算图的架构](../img/ch03/dag.svg)
![基于计算图的架构](../img/ch03/graph.png)
:width:`800px`
:label:`dag`
早期机器学习框架主要为了支持基于卷积神经网络的图像分类问题。这些神经网络的拓扑结构简单神经网络层往往通过串行构建),他们的拓扑结构可以用简的配置文件表达例如Caffe基于Protocol
Buffer格式的模型定义。随着机器学习的进一步发展模型的拓扑日益复杂包括混合专家生成对抗网络多注意力模型。这些复杂的模型拓扑结构例如分支结构带有条件的if-else循环会影响模型算子的执行、自动化梯度计算一般称为自动微分以及训练参数的自动化判断。为此我们需要一个更加通用的技术来执行任意机器学习模型计算图应运而生。综合来看计算图对于一个机器学习框架提供了以下几个关键作用
早期机器学习框架主要针对全连接和卷积神经网络设计,这些神经网络的拓扑结构简单神经网络层之间通过串行连接。因此,它们的拓扑结构可以用简的配置文件表达例如Caffe基于Protocol Buffer格式的模型定义
- **对于输入数据、算子和算子执行顺序的统一表达。**
机器学习框架用户可以用多种高层次编程语言PythonJulia和C++来编写训练程序。这些高层次程序需要统一的表达成框架底层C和C++算子的执行。因此,计算图的第一个核心作用是可以作为一个统一的数据结构来表达用户用不同语言编写的训练程序。这个数据结构可以准确表述用户的输入数据、模型所带有的多个算子,以及算子之间的执行顺序。
现代机器学习模型的拓扑结构日益复杂,显著的例子包括混合专家模型、生成对抗网络、注意力模型等。复杂的模型结构(例如带有分支的循环结构等)需要机器学习框架能够对模型算子的执行依赖关系、梯度计算以及训练参数进行快速高效的分析,便于优化模型结构、制定调度执行策略以及实现自动化梯度计算,从而提高机器学习框架训练复杂模型的效率。因此,机器学习系统设计者需要一个通用的数据结构来理解、表达和执行机器学习模型。为了应对这个需求,如:numref:`dag`所示基于计算图的机器学习框架应运而生,框架延续前端语言与后端语言分离的设计。从高层次来看,计算图实现了以下关键功能:
- **定义中间状态和模型状态**
一个用户训练程序中,用户会生成中间变量(神经网络层之间传递的激活值和梯度)来完成复杂的训练过程。而这其中,只有模型参数需要最后持久化,从而为后续的模型推理做准备。通过计算图,机器学习框架可以准确分析出中间状态的生命周期(一个中间变量何时生成,以及何时销毁),从而帮助框架更好的管理内存
- **统一的计算过程表达**
编写机器学习模型程序的过程中用户希望使用高层次编程语言如Python、Julia和C++。然而硬件加速器等设备往往只提供了C和C++编程接口因此机器学习系统的实现通常需要基于C和C++。用不同的高层次语言编写的程序因此需要被表达为一个统一的数据结构从而被底层共享的C和C++系统模块执行。这个数据结构(即计算图)可以表述用户的输入数据、模型中的计算逻辑(通常称为算子)以及算子之间的执行顺序
- **自动化计算梯度。**
用户给定的训练程序仅仅包含了一个机器学习模型如何将用户输入(一般为训练数据)转化为输出(一般为损失函数)的过程。而为了训练这个模型,机器学习框架需要分析任意机器学习模型和其中的算子,找出自动化计算梯度的方法。计算图的出现让自动化分析模型定义和自动化计算梯度成为可能
用户的模型训练程序接收训练数据集的数据样本,通过神经网络前向计算,最终计算出损失值。根据损失值,机器学习系统为每个模型参数计算出梯度来更新模型参数。考虑到用户可以写出任意的模型拓扑和损失值计算方法,计算梯度的方法必须通用并且能实现自动运行。计算图可以辅助机器学习系统快速分析参数之间的梯度传递关系,实现自动化计算梯度的目标
- **分析模型变量生命周期。**
在用户训练模型的过程中,系统会通过计算产生临时的中间变量,如前向计算中的激活值和反向计算中的梯度。前向计算的中间变量可能与梯度共同参与到模型的参数更新过程中。通过计算图,系统可以准确分析出中间变量的生命周期(一个中间变量生成以及销毁时机),从而帮助框架优化内存管理。
- **优化程序执行。**
用户给定的模型程序往往是"串行化"地连接起来多个神经网络层。通过利用计算图来分析模型算子执行关系,机器学习框架可以更好地发现将算子进行异步执行的机会,从而以更快的速度完成模型程序的执行。
用户给定的模型程序具备不同的网络拓扑结构。机器学习框架利用计算图来分析模型结构和算子执行依赖关系,并自动寻找算子并行计算的策略,从而提高模型的执行效率