MNIST数据集处理是机器学习新手绕不开的问题,但Tensorflow官方给出了一个更为简单易懂的例子——Iris品种分类问题。
以下代码已详细注释,作为学习记录,若有纰漏,还请指正!
main.py:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import tensorflow as tf
import iris_data
# 设置命令行可接受参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=100, type=int, help='batch size')
parser.add_argument('--train_steps', default=1000, type=int,
help='number of training steps')
def main(argv):
# 读取命令行参数
args = parser.parse_args(argv[1:])
# 获取数据
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
# 描述如何使用输入数据的特征列
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
# 建立3个各含100个节点的隐藏层的DNN
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# 隐藏层以及各层节点
hidden_units=[100, 100, 100],
# 标签种类
n_classes=3)
# 训练模型
classifier.train(
input_fn=lambda: iris_data.train_input_fn(train_x, train_y, args.batch_size),
steps=args.train_steps)
print('训练完毕...')
# 评估模型
eval_result = classifier.evaluate(
input_fn=lambda: iris_data.eval_input_fn(test_x, test_y, args.batch_size))
print('\n评估完毕...')
print('测试集准确率: {accuracy:0.4f}\n'.format(**eval_result))
# 预测数据以及期望值
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
# 预测
predictions = classifier.predict(
input_fn=lambda: iris_data.eval_input_fn(predict_x, labels=None,
batch_size=args.batch_size))
template = '预测为:"{}" ({:.4f}%), 期望为:"{}"'
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec))
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main)
iris_data.py:
import pandas as pd
import tensorflow as tf
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
def maybe_download():
"""加载文件,如果本地不存在则下载"""
train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
return train_path, test_path
def load_data(y_name='Species'):
"""以 (train_x, train_y), (test_x, test_y) 的格式返回 iris 数据集"""
train_path, test_path = maybe_download()
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
train_x, train_y = train, train.pop(y_name)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
test_x, test_y = test, test.pop(y_name)
return (train_x, train_y), (test_x, test_y)
def train_input_fn(features, labels, batch_size):
"""训练数据的输入函数"""
# 将输入数据转为数据集(dataset)
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# 打乱数据集顺序、重复执行,再批处理数据
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# 返回数据集
return dataset
def eval_input_fn(features, labels, batch_size):
"""评估和预测的输入函数"""
features = dict(features)
if labels is None:
# 如果没有设置标签,则用特征代替
inputs = features
else:
inputs = (features, labels)
# 将输入数据转为数据集(dataset)
dataset = tf.data.Dataset.from_tensor_slices(inputs)
# 批处理
assert batch_size is not None, "需指定一个值"
dataset = dataset.batch(batch_size)
# 返回数据集
return dataset
更多内容请访问:IT源点
注意:本文归作者所有,未经作者允许,不得转载