mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-11 14:26:04 +08:00
139 lines
1.5 KiB
Markdown
139 lines
1.5 KiB
Markdown
# Theano tensor 模块:索引
|
||
|
||
In [1]:
|
||
|
||
```py
|
||
import theano
|
||
import theano.tensor as T
|
||
import numpy as np
|
||
|
||
```
|
||
|
||
```py
|
||
Using gpu device 1: Tesla C2075 (CNMeM is disabled)
|
||
|
||
```
|
||
|
||
## 简单索引
|
||
|
||
`tensor` 模块完全支持 `numpy` 中的简单索引:
|
||
|
||
In [2]:
|
||
|
||
```py
|
||
t = T.arange(9)
|
||
|
||
print t[1::2].eval()
|
||
|
||
```
|
||
|
||
```py
|
||
[1 3 5 7]
|
||
|
||
```
|
||
|
||
`numpy` 结果:
|
||
|
||
In [3]:
|
||
|
||
```py
|
||
n = np.arange(9)
|
||
|
||
print n[1::2]
|
||
|
||
```
|
||
|
||
```py
|
||
[1 3 5 7]
|
||
|
||
```
|
||
|
||
## mask 索引
|
||
|
||
`tensor` 模块虽然支持简单索引,但并不支持 `mask` 索引,例如这样的做法是<font color="red">错误</font>的:
|
||
|
||
In [4]:
|
||
|
||
```py
|
||
t = T.arange(9).reshape((3,3))
|
||
|
||
print t[t > 4].eval()
|
||
|
||
```
|
||
|
||
```py
|
||
[[[0 1 2]
|
||
[0 1 2]
|
||
[0 1 2]]
|
||
|
||
[[0 1 2]
|
||
[0 1 2]
|
||
[3 4 5]]
|
||
|
||
[[3 4 5]
|
||
[3 4 5]
|
||
[3 4 5]]]
|
||
|
||
```
|
||
|
||
`numpy` 中的结果:
|
||
|
||
In [5]:
|
||
|
||
```py
|
||
n = np.arange(9).reshape((3,3))
|
||
|
||
print n[n > 4]
|
||
|
||
```
|
||
|
||
```py
|
||
[5 6 7 8]
|
||
|
||
```
|
||
|
||
要想像 `numpy` 一样得到正确结果,我们需要使用这样的方法:
|
||
|
||
In [6]:
|
||
|
||
```py
|
||
print t[(t > 4).nonzero()].eval()
|
||
|
||
```
|
||
|
||
```py
|
||
[5 6 7 8]
|
||
|
||
```
|
||
|
||
## 使用索引进行赋值
|
||
|
||
`tensor` 模块不支持直接使用索引赋值,例如 `a[5] = b, a[5]+=b` 等是不允许的。
|
||
|
||
不过可以考虑用 `set_subtensor` 和 `inc_subtensor` 来实现类似的功能:
|
||
|
||
### T.set_subtensor(x, y)
|
||
|
||
实现类似 r[10:] = 5 的功能:
|
||
|
||
In [7]:
|
||
|
||
```py
|
||
r = T.vector()
|
||
|
||
new_r = T.set_subtensor(r[10:], 5)
|
||
|
||
```
|
||
|
||
### T.inc_subtensor(x, y)
|
||
|
||
实现类似 r[10:] += 5 的功能:
|
||
|
||
In [8]:
|
||
|
||
```py
|
||
r = T.vector()
|
||
|
||
new_r = T.inc_subtensor(r[10:], 5)
|
||
|
||
``` |