mirror of
https://github.com/apachecn/ailearning.git
synced 2026-02-09 13:25:39 +08:00
149 lines
2.4 KiB
Markdown
149 lines
2.4 KiB
Markdown
# 对象关系映射
|
||
|
||
数据库中的记录可以与一个 `Python` 对象对应。
|
||
|
||
例如对于上一节中的数据库:
|
||
|
||
| Order | Date | Stock | Quantity | Price |
|
||
| --- | --- | --- | --- | --- |
|
||
| A0001 | 2013-12-01 | AAPL | 1000 | 203.4 |
|
||
| A0002 | 2013-12-01 | MSFT | 1500 | 167.5 |
|
||
| A0003 | 2013-12-02 | GOOG | 1500 | 167.5 |
|
||
|
||
可以用一个类来描述:
|
||
|
||
| Attr. | Method |
|
||
| --- | --- |
|
||
| Order id | Cost |
|
||
| Date | |
|
||
| Stock | |
|
||
| Quant. | |
|
||
| Price | |
|
||
|
||
可以使用 `sqlalchemy` 来实现这种对应:
|
||
|
||
In [1]:
|
||
|
||
```py
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy import Column, Date, Float, Integer, String
|
||
|
||
Base = declarative_base()
|
||
|
||
class Order(Base):
|
||
__tablename__ = 'orders'
|
||
|
||
order_id = Column(String, primary_key=True)
|
||
date = Column(Date)
|
||
symbol = Column(String)
|
||
quantity = Column(Integer)
|
||
price = Column(Float)
|
||
|
||
def get_cost(self):
|
||
return self.quantity*self.price
|
||
|
||
```
|
||
|
||
生成一个 `Order` 对象:
|
||
|
||
In [2]:
|
||
|
||
```py
|
||
import datetime
|
||
order = Order(order_id='A0004', date=datetime.date.today(), symbol='MSFT', quantity=-1000, price=187.54)
|
||
|
||
```
|
||
|
||
调用方法:
|
||
|
||
In [3]:
|
||
|
||
```py
|
||
order.get_cost()
|
||
|
||
```
|
||
|
||
Out[3]:
|
||
|
||
```py
|
||
-187540.0
|
||
```
|
||
|
||
使用上一节生成的数据库产生一个 `session`:
|
||
|
||
In [4]:
|
||
|
||
```py
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
engine = create_engine("sqlite:///my_database.sqlite") # 相当于 connection
|
||
Session = sessionmaker(bind=engine) # 相当于 cursor
|
||
session = Session()
|
||
|
||
```
|
||
|
||
使用这个 `session` 向数据库中添加刚才生成的对象:
|
||
|
||
In [5]:
|
||
|
||
```py
|
||
session.add(order)
|
||
session.commit()
|
||
|
||
```
|
||
|
||
显示是否添加成功:
|
||
|
||
In [6]:
|
||
|
||
```py
|
||
for row in engine.execute("SELECT * FROM orders"):
|
||
print row
|
||
|
||
```
|
||
|
||
```py
|
||
(u'A0001', u'2013-12-01', u'AAPL', 1000, 203.4)
|
||
(u'A0002', u'2013-12-01', u'MSFT', 1500, 167.5)
|
||
(u'A0003', u'2013-12-02', u'GOOG', 1500, 167.5)
|
||
(u'A0004', u'2015-09-10', u'MSFT', -1000, 187.54)
|
||
|
||
```
|
||
|
||
使用 `filter` 进行查询,返回的是 `Order` 对象的列表:
|
||
|
||
In [7]:
|
||
|
||
```py
|
||
for order in session.query(Order).filter(Order.symbol=="AAPL"):
|
||
print order.order_id, order.date, order.get_cost()
|
||
|
||
```
|
||
|
||
```py
|
||
A0001 2013-12-01 203400.0
|
||
|
||
```
|
||
|
||
返回列表的第一个:
|
||
|
||
In [8]:
|
||
|
||
```py
|
||
order_2 = session.query(Order).filter(Order.order_id=='A0002').first()
|
||
|
||
```
|
||
|
||
In [9]:
|
||
|
||
```py
|
||
order_2.symbol
|
||
|
||
```
|
||
|
||
Out[9]:
|
||
|
||
```py
|
||
u'MSFT'
|
||
``` |