mirror of
https://github.com/openmlsys/openmlsys-zh.git
synced 2026-04-02 10:20:20 +08:00
picture to code (#268)
Co-authored-by: liangzhibo <liangzhibo@huawei.com>
This commit is contained in:
@@ -50,7 +50,7 @@ Graph,DAG)、控制流图(Control-Flow Graph,CFG)等。
|
||||
AST抽象语法树采用树型中间表示的形式,是一种接近源代码层次的表示。对于表达式$a*5+a*5*b$,其AST表示如 :numref:`AST_DAG`所示。可以看到,AST形式包含$a*5$的两个不同副本,存在冗余。在AST的基础上,DAG提供了简化的表达形式,一个节点可以有多个父节点,相同子树可以重用。如果编译器能够证明$a$的值没有改变,则DAG可以重用子树,降低求值过程的代价。
|
||||
|
||||

|
||||
:width:`600px`
|
||||
:width:`400px`
|
||||
:label:`AST_DAG`
|
||||
|
||||
3、混合中间表示
|
||||
@@ -99,12 +99,36 @@ IR作为PyTorch模型的中间表示,通过JIT即时编译的形式,将Pytho
|
||||
|
||||
PyTorch框架采用命令式编程方式,其TorchScript
|
||||
IR以基于SSA的线性IR为基本组成形式,并通过JIT即时编译的Tracing和Scripting两种方法将Python代码转换成TorchScript
|
||||
IR。如 :numref:`TorchScript_IR`给出了Python示例代码及其TorchScript
|
||||
IR。
|
||||
IR。如下Python代码使用了Scripting方法并打印其对应的中间表示图:
|
||||
```python
|
||||
import torch
|
||||
|
||||

|
||||
:width:`800px`
|
||||
:label:`TorchScript_IR`
|
||||
@torch.jit.script
|
||||
def test_func(input):
|
||||
rv = 10.0
|
||||
for i in range(5):
|
||||
rv = rv + input
|
||||
rv = rv/2
|
||||
return rv
|
||||
|
||||
print(test_func.graph)
|
||||
```
|
||||
该中间表示图的结构为:
|
||||
```
|
||||
graph(%input.1 : Tensor):
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%5 : bool = prim::Constant[value=1]() # test.py:6:1
|
||||
%rv.1 : float = prim::Constant[value=10.]() # test.py:5:6
|
||||
%2 : int = prim::Constant[value=5]() # test.py:6:16
|
||||
%14 : int = prim::Constant[value=2]() # test.py:8:10
|
||||
%rv : float = prim::Loop(%2, %5, %rv.1) # test.py:6:1
|
||||
block0(%i : int, %rv.9 : float):
|
||||
%rv.3 : Tensor = aten::add(%input.1, %rv.9, %9) # <string>:5:9
|
||||
%12 : float = aten::FloatImplicit(%rv.3) # test.py:7:2
|
||||
%rv.6 : float = aten::div(%12, %14) # test.py:8:7
|
||||
-> (%5, %rv.6)
|
||||
return (%rv)
|
||||
```
|
||||
|
||||
|
||||
TorchScript是PyTorch的JIT实现,支持使用Python训练模型,然后通过JIT转换为语言无关的模块,从而提升模型部署能力,提高编译性能。同时,TorchScript
|
||||
@@ -116,12 +140,38 @@ Jax机器学习框架同时支持静态图和动态图,其中间表示采用Ja
|
||||
Representation) IR。Jaxpr
|
||||
IR是一种强类型、纯函数的中间表示,其输入、输出都带有类型信息,函数输出只依赖输入,不依赖全局变量。
|
||||
|
||||

|
||||
:width:`800px`
|
||||
:label:`Jaxpr`
|
||||
|
||||
Jaxpr IR的表达采用ANF(A-norm
|
||||
Form)函数式表达形式,如 :numref:`Jaxpr`所示。ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。
|
||||
Form)函数式表达形式,ANF文法如下所示:
|
||||
|
||||
```
|
||||
<aexp> ::= NUMBER | STRING | VAR | BOOLEAN | PRIMOP
|
||||
| (lambda (VAR ...) <exp>)
|
||||
<cexp> ::= (<aexp> <aexp> ...)
|
||||
| (if <aexp> <exp> <exp>)
|
||||
<exp> ::= (let ([VAR <cexp>]) <exp>) | <cexp> | <aexp>
|
||||
```
|
||||
|
||||
ANF形式将表达式划分为两类:原子表达式(aexp)和复合表达式(cexp)。原子表达式用于表示常数、变量、原语、匿名函数,复合表达式由多个原子表达式组成,可看作一个匿名函数或原语函数调用,组合的第一个输入是调用的函数,其余输入是调用的参数。如下代码打印了一个函数对应的JaxPr:
|
||||
```python
|
||||
from jax import make_jaxpr
|
||||
import jax.numpy as jnp
|
||||
|
||||
def test_func(x, y):
|
||||
ret = x + jnp.sin(y) * 3
|
||||
return jnp.sum(ret)
|
||||
|
||||
print(make_jaxpr(test_func)(jnp.zeros(8), jnp.ones(8)))
|
||||
```
|
||||
其对应的JaxPr为:
|
||||
```
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
```
|
||||
|
||||
Jax框架结合了Autograd 和 JIT,基于Jaxpr
|
||||
IR,支持循环、分支、递归、闭包函数求导以及三阶求导,并且支持自动微分的反向传播和前向传播。
|
||||
@@ -182,13 +232,34 @@ typed)。每个节点需要有一个具体的类型,这个对于性能最大
|
||||
:width:`800px`
|
||||
:label:`MindIR`
|
||||
|
||||
接下来我们通过 :numref:`MindIR_example`中的一段程序作为示例,来进一步分析MindIR。
|
||||
接下来我们通过如下的一段程序作为示例,来进一步分析MindIR。
|
||||
|
||||

|
||||
:width:`600px`
|
||||
:label:`MindIR_example`
|
||||
```python
|
||||
def func(x, y):
|
||||
return x / y
|
||||
|
||||
在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。其函数图表示如 :numref:`MindIR_graph`所示。
|
||||
@ms_function
|
||||
def test_f(x, y):
|
||||
a = x - 1
|
||||
b = a + y
|
||||
c = b * func(a, b)
|
||||
return c
|
||||
```
|
||||
|
||||
该函数对应的ANF表达式为:
|
||||
```
|
||||
lambda (x, y)
|
||||
let a = x - 1 in
|
||||
let b = a + y in
|
||||
let func = lambda (x, y)
|
||||
let ret = x / y in
|
||||
ret end in
|
||||
let %1 = func(a, b) in
|
||||
let c = b * %1 in
|
||||
c end
|
||||
```
|
||||
|
||||
在ANF中,每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中,每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。该函数对应的MindIR的可视化表示如 :numref:`MindIR_graph`所示。
|
||||
|
||||

|
||||
:width:`800px`
|
||||
|
||||
Reference in New Issue
Block a user