Files
ailearning/docs/da/120.md
2020-10-19 21:08:55 +08:00

99 lines
2.0 KiB
Markdown
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Theano 条件语句
`theano` 中提供了两种条件语句,`ifelse``switch`,两者都是用于在符号变量上使用条件语句:
* `ifelse(condition, var1, var2)`
* 如果 `condition``true`,返回 `var1`,否则返回 `var2`
* `switch(tensor, var1, var2)`
* Elementwise `ifelse` 操作,更一般化
* `switch` 会计算两个输出,而 `ifelse` 只会根据给定的条件,计算相应的输出。
`ifelse` 需要从 `theano.ifelse` 中导入,而 `switch``theano.tensor` 模块中。
In [1]:
```py
import theano, time
import theano.tensor as T
import numpy as np
from theano.ifelse import ifelse
```
```py
Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)
```
假设我们有两个标量参数:$a, b$,和两个矩阵 $\mathbf{x, y}$,定义函数为:
$$ \mathbf z = f(a, b,\mathbf{x, y}) = \left\{ \begin{aligned} \mathbf x & ,\ a <= b\\="" \mathbf="" y="" &="" ,\="" a=""> b \end{aligned} \right. $$
定义变量:
In [2]:
```py
a, b = T.scalars('a', 'b')
x, y = T.matrices('x', 'y')
```
`ifelse` 构造,小于等于用 `T.lt()`,大于等于用 `T.gt()`
In [3]:
```py
z_ifelse = ifelse(T.lt(a, b), x, y)
f_ifelse = theano.function([a, b, x, y], z_ifelse)
```
`switch` 构造:
In [4]:
```py
z_switch = T.switch(T.lt(a, b), x, y)
f_switch = theano.function([a, b, x, y], z_switch)
```
测试数据:
In [5]:
```py
val1 = 0.
val2 = 1.
big_mat1 = np.ones((10000, 1000), dtype=theano.config.floatX)
big_mat2 = np.ones((10000, 1000), dtype=theano.config.floatX)
```
比较两者的运行速度:
In [6]:
```py
n_times = 10
tic = time.clock()
for i in xrange(n_times):
f_switch(val1, val2, big_mat1, big_mat2)
print 'time spent evaluating both values %f sec' % (time.clock() - tic)
tic = time.clock()
for i in xrange(n_times):
f_ifelse(val1, val2, big_mat1, big_mat2)
print 'time spent evaluating one value %f sec' % (time.clock() - tic)
```
```py
time spent evaluating both values 0.638598 sec
time spent evaluating one value 0.461249 sec
```