博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow_CNN_MNIST问题
阅读量:6601 次
发布时间:2019-06-24

本文共 4083 字,大约阅读时间需要 13 分钟。

先把问题贴出来:

问题主要来自神经网络各层对输入数据维度理解的问题,还是在理论上欠缺很多。

这是修改后的code:

1 # Author: Lee  2 import tensorflow as tf  3 import numpy as np  4   5 # 下载并载入mnist手写数据库  6 from tensorflow.examples.tutorials.mnist import input_data  7   8 mnist = input_data.read_data_sets('mnist', one_hot=True)  9  10 #None表示张量Tensor的第一个维度可以是任何维度, /255.是对灰色图像做归一化>>>input_x输入的数据经过神经网络得到预测的output_y 11 input_x = tf.placeholder(tf.float32, [None, 28 * 28]) / 255. 12 output_y = tf.placeholder(tf.int32, [None, 1 * 10]) 13 #对输入数据进行改变形状28 * 28 * 1, -1是维度设置为auto 14 input_x_images = tf.reshape(input_x, [-1, 28, 28, 1]) 15  16 #从测试Test数据集中选取3000个手写数字的图片和对应的标签 17 test_x = mnist.test.images[:3000] # 图片 18 test_y = mnist.test.labels[:3000] # 标签 19  20 #构建神经网络 21 #第一层卷积 filters,kernals size,strides 22 conv1 = tf.layers.conv2d( 23     inputs = input_x_images, # shape = [28, 28, 1] 24     filters = 32,            # 32个过滤器(输出深度为32),相当于扫32遍 25     kernel_size = [5, 5],    # 过滤器在二维的大小为5 * 5(2D卷积窗口的高度和宽度) 26     strides = 1,             # 步长为1 27     padding = 'SAME',        # padding补零方案,same表示输出大小不变(same和valid的算法需要参考官方文档),需要在外围补零两圈 28     activation = tf.nn.relu 29     ) 30 #经过第一层卷积之后的输出数据shape为28 * 28 * 32 31  32 #第一层池化(亚采样)pooling 33 pool1 = tf.layers.max_pooling2d( 34     inputs = conv1, 35     pool_size = [2, 2],     # 过滤器在二维的大小,类比kernel_size 36     strides = 2,            # 步长2 37     ) 38 #经过第一层池化之后的输出数据shape为14 * 14 * 32 39  40 #第二层卷积 filters,kernals size,strides 41 conv2 = tf.layers.conv2d( 42     inputs = pool1,         # shape = [14, 14, 32] 43     filters = 64,           # 64个过滤器,输出深度为64 44     kernel_size = [5, 5],   # 过滤器在二维的大小为5 * 5,相当于过滤器大小 45     strides = 1,            # 步长为1 46     padding = 'SAME',       # padding补零方案,same表示输出大小不变,需要在外围补零两圈 47     activation = tf.nn.relu 48     ) 49 #经过第二层卷积之后的输出数据shape为14 * 14 * 64 50  51 #第二层池化(亚采样)pooling 52 pool2 = tf.layers.max_pooling2d( 53     inputs = conv2, 54     pool_size = [2, 2],     # 过滤器在二维的大小,类比kernel_size 55     strides = 2,            # 步长2 56     ) 57 #经过第一层池化之后的输出数据shape为7 * 7 * 64 58  59 #平坦化(flat),进行扁平化[7 * 7 * 64,] 60 flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) 61  62 #1024个神经元的全连接层 63 dense = tf.layers.dense( 64     inputs = flat, 65     units = 1024, 66     activation = tf.nn.relu 67     ) 68  69 # Dropout 丢弃率为50%, Dropout的rate在[0,1] 70 dropout = tf.layers.dropout( 71     inputs = dense, 72     rate = 0.5 73     ) 74  75 # 10个神经元的全连接层, 这里不用激活函数做非线性化 76 logits = tf.layers.dense(inputs = dropout, units = 10) 77 #输出形状1 * 1 * 10 78  79 #计算误差
<计算cross_entropy(交叉熵),再用softmax进行计算百分比>
80 loss = tf.losses.softmax_cross_entropy(onehot_labels = output_y, logits = logits) 81 82 #使用Adam优化器,Adam为默认优化器,learning_rate = 0.001 83 train_op = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss) 84 85 #预测值和实际标签的匹配度 86 #返回(accuracy, update_op),能够创建两个局部变量 87 accuracy = tf.metrics.accuracy( 88 labels = tf.argmax(output_y, axis = 1), 89 predictions = tf.argmax(logits, axis = 1))[1] 90 91 92 #创建会话Session 93 sess = tf.Session() 94 95 #初始化变量:全局变量和局部变量 96 init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 97 sess.run(init) 98 99 for i in range(30000):100 batch = mnist.train.next_batch(50)101 train_loss, train_op_ = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})102 if i % 100 == 0:103 test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})104 print("step:", i, "accuracy:", test_accuracy,'loss:', train_loss)105 106 107 #测试, 预测值与真实值对比108 test_output = sess.run(logits, {input_x: test_x[:20]})109 inferenced_y = np.argmax(test_output, 1)110 #推测的数字111 print('inferenced_data:',inferenced_y)112 #实际数字113 print('test_output_data:', np.argmax(test_y[:20],1))

由于自己的台式机太水,在训练的时候太耗费时间了!

我是按照这张图搭建的神经网络,在第二次池化到扁平化处理的时候,我一直没有理解好为什么维度不兼容,后来查阅了一些别人写的MNIST代码,虽然没看到别人用tensorflow的layer写,大部分人还是用的tf.nn写的,感觉差不多,我主要还是对神经网络的理解出现了偏差,导致在第二次池化到扁平化的处理的数据出现问题。因为在开始的output_y是[None,10],当时在扁平化的时候使用的是[-1, 1, 1, 1024],其中-1是自动根据上下数据进行调整维度,根据上图来看为1 * 1 * 1024,所以就理解成需要扁平化成[1, 1 , 1024],所以还是需要多看看别人写的代码,从中吸取一点经验!!!

 

转载于:https://www.cnblogs.com/AlexHaiY/p/9343508.html

你可能感兴趣的文章
Linux wget 详解
查看>>
XamarinSQLite教程添加测试数据
查看>>
我的友情链接
查看>>
Mysql登录时提示1045的解决办法
查看>>
MySql 远程连接中phpmyadmin的设置
查看>>
类型判断时instanceof和equals的不同用法
查看>>
设计师与客户:迁就难出好设计
查看>>
discuz 门户diy实现翻页功能的修改记录
查看>>
授之以渔-运维平台应用模块一(应用树篇)
查看>>
pcDuino裸板程序-led
查看>>
3d打印机要火了还需时日
查看>>
关于Nature的.net版框架
查看>>
Hp DL380服务器硬盘故障数据恢复过程
查看>>
RAID磁盘阵列技术及数据恢复原理
查看>>
android:scaleType
查看>>
基于Quick-cocos2dx 2.2.3 的动态更新实现(不需修改任何框架上的)
查看>>
xenserver如何完美的创建本地ISO库
查看>>
JAVA 动态配置 (配置源={properties,redis})
查看>>
「C语言回顾之旅」第二篇:指针详解进阶
查看>>
Vmware虚拟机快速使用桥接模式上网
查看>>