diff --git a/chapter_frontend_and_ir/intermediate_representation.md b/chapter_frontend_and_ir/intermediate_representation.md index 94321ca..2a04c46 100644 --- a/chapter_frontend_and_ir/intermediate_representation.md +++ b/chapter_frontend_and_ir/intermediate_representation.md @@ -50,7 +50,7 @@ Graph,DAG)、控制流图(Control-Flow Graph,CFG)等。 AST抽象语法树采用树型中间表示的形式,是一种接近源代码层次的表示。对于表达式$a*5+a*5*b$,其AST表示如 :numref:`AST_DAG`所示。可以看到,AST形式包含$a*5$的两个不同副本,存在冗余。在AST的基础上,DAG提供了简化的表达形式,一个节点可以有多个父节点,相同子树可以重用。如果编译器能够证明$a$的值没有改变,则DAG可以重用子树,降低求值过程的代价。 ![AST图和DAG图](../img/ch04/中间表示-ASTDAG.svg) -:width:`600px` +:width:`400px` :label:`AST_DAG` 3、混合中间表示 @@ -99,12 +99,36 @@ IR作为PyTorch模型的中间表示,通过JIT即时编译的形式,将Pytho PyTorch框架采用命令式编程方式,其TorchScript IR以基于SSA的线性IR为基本组成形式,并通过JIT即时编译的Tracing和Scripting两种方法将Python代码转换成TorchScript -IR。如 :numref:`TorchScript_IR`给出了Python示例代码及其TorchScript -IR。 +IR。如下Python代码使用了Scripting方法并打印其对应的中间表示图: +```python +import torch -![Python代码及输出的TorchScript IR](../img/ch04/中间表示-torchscript.png) -:width:`800px` -:label:`TorchScript_IR` +@torch.jit.script +def test_func(input): + rv = 10.0 + for i in range(5): + rv = rv + input + rv = rv/2 + return rv + +print(test_func.graph) +``` +该中间表示图的结构为: +``` +graph(%input.1 : Tensor): + %9 : int = prim::Constant[value=1]() + %5 : bool = prim::Constant[value=1]() # test.py:6:1 + %rv.1 : float = prim::Constant[value=10.]() # test.py:5:6 + %2 : int = prim::Constant[value=5]() # test.py:6:16 + %14 : int = prim::Constant[value=2]() # test.py:8:10 + %rv : float = prim::Loop(%2, %5, %rv.1) # test.py:6:1 + block0(%i : int, %rv.9 : float): + %rv.3 : Tensor = aten::add(%input.1, %rv.9, %9) # :5:9 + %12 : float = aten::FloatImplicit(%rv.3) # test.py:7:2 + %rv.6 : float = aten::div(%12, %14) # test.py:8:7 + -> (%5, %rv.6) + return (%rv) +``` TorchScript是PyTorch的JIT实现,支持使用Python训练模型,然后通过JIT转换为语言无关的模块,从而提升模型部署能力,提高编译性能。同时,TorchScript @@ -116,12 +140,38 @@ Jax机器学习框架同时支持静态图和动态图,其中间表示采用Ja Representation) IR。Jaxpr IR是一种强类型、纯函数的中间表示,其输入、输出都带有类型信息,函数输出只依赖输入,不依赖全局变量。 -![ANF文法与Jaxpr IR](../img/ch04/中间表示-Jaxpr.png) -:width:`800px` -:label:`Jaxpr` Jaxpr IR的表达采用ANF(A-norm -Form)函数式表达形式,如 :numref:`Jaxpr`所示。ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。 +Form)函数式表达形式,ANF文法如下所示: + +``` + ::= NUMBER | STRING | VAR | BOOLEAN | PRIMOP + | (lambda (VAR ...) ) + ::= ( ...) + | (if ) + ::= (let ([VAR ]) ) | | +``` + +ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。如下代码打印了一个函数对应的JaxPr: +```python +from jax import make_jaxpr +import jax.numpy as jnp + +def test_func(x, y): + ret = x + jnp.sin(y) * 3 + return jnp.sum(ret) + +print(make_jaxpr(test_func)(jnp.zeros(8), jnp.ones(8))) +``` +其对应的JaxPr为: +``` +{ lambda ; a:f32[8] b:f32[8]. let + c:f32[8] = sin b + d:f32[8] = mul c 3.0 + e:f32[8] = add a d + f:f32[] = reduce_sum[axes=(0,)] e + in (f,) } +``` Jax框架结合了Autograd 和 JIT,基于Jaxpr IR,支持循环、分支、递归、闭包函数求导以及三阶求导,并且支持自动微分的反向传播和前向传播。 @@ -182,13 +232,34 @@ typed)。每个节点需要有一个具体的类型,这个对于性能最大 :width:`800px` :label:`MindIR` -接下来我们通过 :numref:`MindIR_example`中的一段程序作为示例,来进一步分析MindIR。 +接下来我们通过如下的一段程序作为示例,来进一步分析MindIR。 -![MindIR的ANF表达](../img/ch04/中间表示-MindIR示例.png) -:width:`600px` -:label:`MindIR_example` +```python +def func(x, y): + return x / y -在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。其函数图表示如 :numref:`MindIR_graph`所示。 +@ms_function +def test_f(x, y): + a = x - 1 + b = a + y + c = b * func(a, b) + return c +``` + +该函数对应的ANF表达式为: +``` +lambda (x, y) + let a = x - 1 in + let b = a + y in + let func = lambda (x, y) + let ret = x / y in + ret end in + let %1 = func(a, b) in + let c = b * %1 in + c end +``` + +在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。该函数对应的MindIR的可视化表示如 :numref:`MindIR_graph`所示。 ![MindIR的函数图表示](../img/ch04/中间表示-MindIR图.png) :width:`800px`