mirror of
https://github.com/openmlsys/openmlsys-zh.git
synced 2026-04-05 03:37:53 +08:00
clean commit messages.
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
### 环境配置
|
||||
|
||||
在构建机器学习工作流程前,MindSpore需要通过context.set_context来配置运行需要的信息,如运行模式、后端信息、硬件等信息。
|
||||
导入context模块,配置运行需要的信息。
|
||||
导入context模块,配置运行需要的信息。以下代码运行环境为Ubuntu16.04,CUDA10.1,MindSpore1.5.2。
|
||||
|
||||
```python
|
||||
import os
|
||||
@@ -94,7 +94,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
# 导入需要用到的模块
|
||||
import mindspore.nn as nn
|
||||
# 定义线性模型
|
||||
class MLPNet(nn.Module):
|
||||
class MLPNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MLPNet, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
@@ -175,20 +175,32 @@ train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
|
||||
|
||||
测试是模型运行测试数据集得到的结果,通常在训练过程中,每训练一定的数据量后就会测试一次,以验证模型的泛化能力。MindSpore使用model.eval接口读入测试数据集。
|
||||
```python
|
||||
def test_net(network, model, data_path):
|
||||
def test_net(model, data_path):
|
||||
"""定义验证的方法"""
|
||||
ds_eval = create_dataset(os.path.join(data_path, "test"))
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("{}".format(acc))
|
||||
# 验证模型精度
|
||||
test_net(model, mnist_path)
|
||||
```
|
||||
|
||||
在训练完毕后,参数保存在checkpoint中,可以将训练好的参数加载到模型中进行验证。
|
||||
```python
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
# 定义测试数据集,batch_size设置为1,则取出一张图片
|
||||
ds_test = create_dataset(os.path.join(mnist_path, "test"), batch_size=1).create_dict_iterator()
|
||||
data = next(ds_test)
|
||||
# images为测试图片,labels为测试图片的实际分类
|
||||
images = data["image"].asnumpy()
|
||||
labels = data["label"].asnumpy()
|
||||
# 加载已经保存的用于测试的模型
|
||||
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
|
||||
# 加载参数到网络中
|
||||
load_param_into_net(net, param_dict)
|
||||
# 使用函数model.predict预测image对应分类
|
||||
output = model.predict(Tensor(data['image']))
|
||||
# 输出预测分类与实际分类
|
||||
print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
|
||||
```
|
||||
Reference in New Issue
Block a user