picture to code (#268)

Co-authored-by: liangzhibo <liangzhibo@huawei.com>
This commit is contained in:
Liang ZhiBo
2022-04-11 16:55:31 +08:00
committed by GitHub
parent 34d2152e88
commit 6db4bb9711

View File

@@ -50,7 +50,7 @@ GraphDAG)、控制流图(Control-Flow GraphCFG)等。
AST抽象语法树采用树型中间表示的形式是一种接近源代码层次的表示。对于表达式$a*5+a*5*b$其AST表示如 :numref:`AST_DAG`所示。可以看到AST形式包含$a*5$的两个不同副本存在冗余。在AST的基础上DAG提供了简化的表达形式一个节点可以有多个父节点相同子树可以重用。如果编译器能够证明$a$的值没有改变则DAG可以重用子树降低求值过程的代价。
![AST图和DAG图](../img/ch04/中间表示-ASTDAG.svg)
:width:`600px`
:width:`400px`
:label:`AST_DAG`
3、混合中间表示
@@ -99,12 +99,36 @@ IR作为PyTorch模型的中间表示通过JIT即时编译的形式将Pytho
PyTorch框架采用命令式编程方式其TorchScript
IR以基于SSA的线性IR为基本组成形式并通过JIT即时编译的Tracing和Scripting两种方法将Python代码转换成TorchScript
IR。如 :numref:`TorchScript_IR`给出了Python示例代码及其TorchScript
IR。
IR。如下Python代码使用了Scripting方法并打印其对应的中间表示图
```python
import torch
![Python代码及输出的TorchScript IR](../img/ch04/中间表示-torchscript.png)
:width:`800px`
:label:`TorchScript_IR`
@torch.jit.script
def test_func(input):
rv = 10.0
for i in range(5):
rv = rv + input
rv = rv/2
return rv
print(test_func.graph)
```
该中间表示图的结构为:
```
graph(%input.1 : Tensor):
%9 : int = prim::Constant[value=1]()
%5 : bool = prim::Constant[value=1]() # test.py:6:1
%rv.1 : float = prim::Constant[value=10.]() # test.py:5:6
%2 : int = prim::Constant[value=5]() # test.py:6:16
%14 : int = prim::Constant[value=2]() # test.py:8:10
%rv : float = prim::Loop(%2, %5, %rv.1) # test.py:6:1
block0(%i : int, %rv.9 : float):
%rv.3 : Tensor = aten::add(%input.1, %rv.9, %9) # <string>:5:9
%12 : float = aten::FloatImplicit(%rv.3) # test.py:7:2
%rv.6 : float = aten::div(%12, %14) # test.py:8:7
-> (%5, %rv.6)
return (%rv)
```
TorchScript是PyTorch的JIT实现支持使用Python训练模型然后通过JIT转换为语言无关的模块从而提升模型部署能力提高编译性能。同时TorchScript
@@ -116,12 +140,38 @@ Jax机器学习框架同时支持静态图和动态图其中间表示采用Ja
Representation) IR。Jaxpr
IR是一种强类型、纯函数的中间表示其输入、输出都带有类型信息函数输出只依赖输入不依赖全局变量。
![ANF文法与Jaxpr IR](../img/ch04/中间表示-Jaxpr.png)
:width:`800px`
:label:`Jaxpr`
Jaxpr IR的表达采用ANF(A-norm
Form)函数式表达形式,如 :numref:`Jaxpr`所示。ANF形式将表达式划分为两类原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。
Form)函数式表达形式,ANF文法如下所示
```
<aexp> ::= NUMBER | STRING | VAR | BOOLEAN | PRIMOP
| (lambda (VAR ...) <exp>)
<cexp> ::= (<aexp> <aexp> ...)
(if <aexp> <exp> <exp>)
<exp> ::= (let ([VAR <cexp>]) <exp>) | <cexp> | <aexp>
```
ANF形式将表达式划分为两类原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数复合表达式由多个原子表达式组成可看作一个匿名函数或原语函数调用组合的第一个输入是调用的函数其余输入是调用的参数。如下代码打印了一个函数对应的JaxPr
```python
from jax import make_jaxpr
import jax.numpy as jnp
def test_func(x, y):
ret = x + jnp.sin(y) * 3
return jnp.sum(ret)
print(make_jaxpr(test_func)(jnp.zeros(8), jnp.ones(8)))
```
其对应的JaxPr为
```
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
```
Jax框架结合了Autograd 和 JIT基于Jaxpr
IR支持循环、分支、递归、闭包函数求导以及三阶求导并且支持自动微分的反向传播和前向传播。
@@ -182,13 +232,34 @@ typed。每个节点需要有一个具体的类型这个对于性能最大
:width:`800px`
:label:`MindIR`
接下来我们通过 :numref:`MindIR_example`的一段程序作为示例来进一步分析MindIR。
接下来我们通过如下的一段程序作为示例来进一步分析MindIR。
![MindIR的ANF表达](../img/ch04/中间表示-MindIR示例.png)
:width:`600px`
:label:`MindIR_example`
```python
def func(x, y):
return x / y
在ANF中每个表达式都用let表达式绑定为一个变量通过对变量的引用来表示对表达式输出的依赖而在MindIR中每个表达式都绑定为一个节点通过节点与节点之间的有向边表示依赖关系。其函数图表示如 :numref:`MindIR_graph`所示。
@ms_function
def test_f(x, y):
a = x - 1
b = a + y
c = b * func(a, b)
return c
```
该函数对应的ANF表达式为
```
lambda (x, y)
let a = x - 1 in
let b = a + y in
let func = lambda (x, y)
let ret = x / y in
ret end in
let %1 = func(a, b) in
let c = b * %1 in
c end
```
在ANF中每个表达式都用let表达式绑定为一个变量通过对变量的引用来表示对表达式输出的依赖而在MindIR中每个表达式都绑定为一个节点通过节点与节点之间的有向边表示依赖关系。该函数对应的MindIR的可视化表示如 :numref:`MindIR_graph`所示。
![MindIR的函数图表示](../img/ch04/中间表示-MindIR图.png)
:width:`800px`