mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 14:26:04 +08:00
99 lines
2.0 KiB
Markdown
99 lines
2.0 KiB
Markdown
# Theano 条件语句
|
||
|
||
`theano` 中提供了两种条件语句,`ifelse` 和 `switch`,两者都是用于在符号变量上使用条件语句:
|
||
|
||
* `ifelse(condition, var1, var2)`
|
||
* 如果 `condition` 为 `true`,返回 `var1`,否则返回 `var2`
|
||
* `switch(tensor, var1, var2)`
|
||
* Elementwise `ifelse` 操作,更一般化
|
||
* `switch` 会计算两个输出,而 `ifelse` 只会根据给定的条件,计算相应的输出。
|
||
|
||
`ifelse` 需要从 `theano.ifelse` 中导入,而 `switch` 在 `theano.tensor` 模块中。
|
||
|
||
In [1]:
|
||
|
||
```py
|
||
import theano, time
|
||
import theano.tensor as T
|
||
import numpy as np
|
||
from theano.ifelse import ifelse
|
||
|
||
```
|
||
|
||
```py
|
||
Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)
|
||
|
||
```
|
||
|
||
假设我们有两个标量参数:$a, b$,和两个矩阵 $\mathbf{x, y}$,定义函数为:
|
||
|
||
$$ \mathbf z = f(a, b,\mathbf{x, y}) = \left\{ \begin{aligned} \mathbf x & ,\ a <= b\\="" \mathbf="" y="" &="" ,\="" a=""> b \end{aligned} \right. $$
|
||
|
||
定义变量:
|
||
|
||
In [2]:
|
||
|
||
```py
|
||
a, b = T.scalars('a', 'b')
|
||
x, y = T.matrices('x', 'y')
|
||
|
||
```
|
||
|
||
用 `ifelse` 构造,小于等于用 `T.lt()`,大于等于用 `T.gt()`:
|
||
|
||
In [3]:
|
||
|
||
```py
|
||
z_ifelse = ifelse(T.lt(a, b), x, y)
|
||
|
||
f_ifelse = theano.function([a, b, x, y], z_ifelse)
|
||
|
||
```
|
||
|
||
用 `switch` 构造:
|
||
|
||
In [4]:
|
||
|
||
```py
|
||
z_switch = T.switch(T.lt(a, b), x, y)
|
||
|
||
f_switch = theano.function([a, b, x, y], z_switch)
|
||
|
||
```
|
||
|
||
测试数据:
|
||
|
||
In [5]:
|
||
|
||
```py
|
||
val1 = 0.
|
||
val2 = 1.
|
||
big_mat1 = np.ones((10000, 1000), dtype=theano.config.floatX)
|
||
big_mat2 = np.ones((10000, 1000), dtype=theano.config.floatX)
|
||
|
||
```
|
||
|
||
比较两者的运行速度:
|
||
|
||
In [6]:
|
||
|
||
```py
|
||
n_times = 10
|
||
|
||
tic = time.clock()
|
||
for i in xrange(n_times):
|
||
f_switch(val1, val2, big_mat1, big_mat2)
|
||
print 'time spent evaluating both values %f sec' % (time.clock() - tic)
|
||
|
||
tic = time.clock()
|
||
for i in xrange(n_times):
|
||
f_ifelse(val1, val2, big_mat1, big_mat2)
|
||
print 'time spent evaluating one value %f sec' % (time.clock() - tic)
|
||
|
||
```
|
||
|
||
```py
|
||
time spent evaluating both values 0.638598 sec
|
||
time spent evaluating one value 0.461249 sec
|
||
|
||
``` |