mirror of
https://github.com/openmlsys/openmlsys-zh.git
synced 2026-06-14 22:16:11 +08:00
picture to code (#268)
Co-authored-by: liangzhibo <liangzhibo@huawei.com>
This commit is contained in:
@@ -50,7 +50,7 @@ Graph,DAG)、控制流图(Control-Flow Graph,CFG)等。
|
|||||||
AST抽象语法树采用树型中间表示的形式,是一种接近源代码层次的表示。对于表达式$a*5+a*5*b$,其AST表示如 :numref:`AST_DAG`所示。可以看到,AST形式包含$a*5$的两个不同副本,存在冗余。在AST的基础上,DAG提供了简化的表达形式,一个节点可以有多个父节点,相同子树可以重用。如果编译器能够证明$a$的值没有改变,则DAG可以重用子树,降低求值过程的代价。
|
AST抽象语法树采用树型中间表示的形式,是一种接近源代码层次的表示。对于表达式$a*5+a*5*b$,其AST表示如 :numref:`AST_DAG`所示。可以看到,AST形式包含$a*5$的两个不同副本,存在冗余。在AST的基础上,DAG提供了简化的表达形式,一个节点可以有多个父节点,相同子树可以重用。如果编译器能够证明$a$的值没有改变,则DAG可以重用子树,降低求值过程的代价。
|
||||||
|
|
||||||

|

|
||||||
:width:`600px`
|
:width:`400px`
|
||||||
:label:`AST_DAG`
|
:label:`AST_DAG`
|
||||||
|
|
||||||
3、混合中间表示
|
3、混合中间表示
|
||||||
@@ -99,12 +99,36 @@ IR作为PyTorch模型的中间表示,通过JIT即时编译的形式,将Pytho
|
|||||||
|
|
||||||
PyTorch框架采用命令式编程方式,其TorchScript
|
PyTorch框架采用命令式编程方式,其TorchScript
|
||||||
IR以基于SSA的线性IR为基本组成形式,并通过JIT即时编译的Tracing和Scripting两种方法将Python代码转换成TorchScript
|
IR以基于SSA的线性IR为基本组成形式,并通过JIT即时编译的Tracing和Scripting两种方法将Python代码转换成TorchScript
|
||||||
IR。如 :numref:`TorchScript_IR`给出了Python示例代码及其TorchScript
|
IR。如下Python代码使用了Scripting方法并打印其对应的中间表示图:
|
||||||
IR。
|
```python
|
||||||
|
import torch
|
||||||
|
|
||||||

|
@torch.jit.script
|
||||||
:width:`800px`
|
def test_func(input):
|
||||||
:label:`TorchScript_IR`
|
rv = 10.0
|
||||||
|
for i in range(5):
|
||||||
|
rv = rv + input
|
||||||
|
rv = rv/2
|
||||||
|
return rv
|
||||||
|
|
||||||
|
print(test_func.graph)
|
||||||
|
```
|
||||||
|
该中间表示图的结构为:
|
||||||
|
```
|
||||||
|
graph(%input.1 : Tensor):
|
||||||
|
%9 : int = prim::Constant[value=1]()
|
||||||
|
%5 : bool = prim::Constant[value=1]() # test.py:6:1
|
||||||
|
%rv.1 : float = prim::Constant[value=10.]() # test.py:5:6
|
||||||
|
%2 : int = prim::Constant[value=5]() # test.py:6:16
|
||||||
|
%14 : int = prim::Constant[value=2]() # test.py:8:10
|
||||||
|
%rv : float = prim::Loop(%2, %5, %rv.1) # test.py:6:1
|
||||||
|
block0(%i : int, %rv.9 : float):
|
||||||
|
%rv.3 : Tensor = aten::add(%input.1, %rv.9, %9) # <string>:5:9
|
||||||
|
%12 : float = aten::FloatImplicit(%rv.3) # test.py:7:2
|
||||||
|
%rv.6 : float = aten::div(%12, %14) # test.py:8:7
|
||||||
|
-> (%5, %rv.6)
|
||||||
|
return (%rv)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
TorchScript是PyTorch的JIT实现,支持使用Python训练模型,然后通过JIT转换为语言无关的模块,从而提升模型部署能力,提高编译性能。同时,TorchScript
|
TorchScript是PyTorch的JIT实现,支持使用Python训练模型,然后通过JIT转换为语言无关的模块,从而提升模型部署能力,提高编译性能。同时,TorchScript
|
||||||
@@ -116,12 +140,38 @@ Jax机器学习框架同时支持静态图和动态图,其中间表示采用Ja
|
|||||||
Representation) IR。Jaxpr
|
Representation) IR。Jaxpr
|
||||||
IR是一种强类型、纯函数的中间表示,其输入、输出都带有类型信息,函数输出只依赖输入,不依赖全局变量。
|
IR是一种强类型、纯函数的中间表示,其输入、输出都带有类型信息,函数输出只依赖输入,不依赖全局变量。
|
||||||
|
|
||||||

|
|
||||||
:width:`800px`
|
|
||||||
:label:`Jaxpr`
|
|
||||||
|
|
||||||
Jaxpr IR的表达采用ANF(A-norm
|
Jaxpr IR的表达采用ANF(A-norm
|
||||||
Form)函数式表达形式,如 :numref:`Jaxpr`所示。ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。
|
Form)函数式表达形式,ANF文法如下所示:
|
||||||
|
|
||||||
|
```
|
||||||
|
<aexp> ::= NUMBER | STRING | VAR | BOOLEAN | PRIMOP
|
||||||
|
| (lambda (VAR ...) <exp>)
|
||||||
|
<cexp> ::= (<aexp> <aexp> ...)
|
||||||
|
| (if <aexp> <exp> <exp>)
|
||||||
|
<exp> ::= (let ([VAR <cexp>]) <exp>) | <cexp> | <aexp>
|
||||||
|
```
|
||||||
|
|
||||||
|
ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。如下代码打印了一个函数对应的JaxPr:
|
||||||
|
```python
|
||||||
|
from jax import make_jaxpr
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
def test_func(x, y):
|
||||||
|
ret = x + jnp.sin(y) * 3
|
||||||
|
return jnp.sum(ret)
|
||||||
|
|
||||||
|
print(make_jaxpr(test_func)(jnp.zeros(8), jnp.ones(8)))
|
||||||
|
```
|
||||||
|
其对应的JaxPr为:
|
||||||
|
```
|
||||||
|
{ lambda ; a:f32[8] b:f32[8]. let
|
||||||
|
c:f32[8] = sin b
|
||||||
|
d:f32[8] = mul c 3.0
|
||||||
|
e:f32[8] = add a d
|
||||||
|
f:f32[] = reduce_sum[axes=(0,)] e
|
||||||
|
in (f,) }
|
||||||
|
```
|
||||||
|
|
||||||
Jax框架结合了Autograd 和 JIT,基于Jaxpr
|
Jax框架结合了Autograd 和 JIT,基于Jaxpr
|
||||||
IR,支持循环、分支、递归、闭包函数求导以及三阶求导,并且支持自动微分的反向传播和前向传播。
|
IR,支持循环、分支、递归、闭包函数求导以及三阶求导,并且支持自动微分的反向传播和前向传播。
|
||||||
@@ -182,13 +232,34 @@ typed)。每个节点需要有一个具体的类型,这个对于性能最大
|
|||||||
:width:`800px`
|
:width:`800px`
|
||||||
:label:`MindIR`
|
:label:`MindIR`
|
||||||
|
|
||||||
接下来我们通过 :numref:`MindIR_example`中的一段程序作为示例,来进一步分析MindIR。
|
接下来我们通过如下的一段程序作为示例,来进一步分析MindIR。
|
||||||
|
|
||||||

|
```python
|
||||||
:width:`600px`
|
def func(x, y):
|
||||||
:label:`MindIR_example`
|
return x / y
|
||||||
|
|
||||||
在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。其函数图表示如 :numref:`MindIR_graph`所示。
|
@ms_function
|
||||||
|
def test_f(x, y):
|
||||||
|
a = x - 1
|
||||||
|
b = a + y
|
||||||
|
c = b * func(a, b)
|
||||||
|
return c
|
||||||
|
```
|
||||||
|
|
||||||
|
该函数对应的ANF表达式为:
|
||||||
|
```
|
||||||
|
lambda (x, y)
|
||||||
|
let a = x - 1 in
|
||||||
|
let b = a + y in
|
||||||
|
let func = lambda (x, y)
|
||||||
|
let ret = x / y in
|
||||||
|
ret end in
|
||||||
|
let %1 = func(a, b) in
|
||||||
|
let c = b * %1 in
|
||||||
|
c end
|
||||||
|
```
|
||||||
|
|
||||||
|
在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。该函数对应的MindIR的可视化表示如 :numref:`MindIR_graph`所示。
|
||||||
|
|
||||||

|

|
||||||
:width:`800px`
|
:width:`800px`
|
||||||
|
|||||||
Reference in New Issue
Block a user