2017年4月1日 星期六

[深度學習練習] [Deep Learning Practice] 神經網路入門(九)- 遞歸神經網路 X 實作

上一集中,介紹了基本的 RNN 以及 LSTM model,

這一篇將會實作 RNN model,預測 mnist 的資料,

以更加了解 RNN 的使用方法和原理。

------------------------------------------------------------------------------------------------------------

最近 tensorflow 更新到 1.0.1 版本,做了滿多更新,

一些基本的方法,model 的位置,類別內的參數,都有更新。

可以用官方提供的工具去把之前的 code 更新,再稍微看一下這個

大家一定要記得去更新喔!

更新之後,之前用 tensorflow 練習的 code 也都不能用了,會再找時間更新。

這一篇文章將會以 tensorflow 1.0.1 實現。

------------------------------------------------------------------------------------------------------------

一、遞歸神經網路實作

這邊會如同之前寫基礎神經網路和卷積神經網路時一樣,

先定義一些基本參數。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.rnn import BasicLSTMCell, static_rnn

mnist = input_data.read_data_sets('/tmp/data', one_hot = True)

# 來回訓練三次
hm_epochs = 3
# 0~9 十種的多元分類
n_classes = 10
# 因為是要用 SGD 方法,所以定義,在 mnist 55000 training data 中,每 128 組最佳化一次。
batch_size = 128
# 每一次要輸入的 x 的大小
chunk_size = 28 # feature size
# 有幾個要輸入的 x
n_chunks = 28 # time step
# hidden state size
rnn_size = 128

x = tf.placeholder('float',[None, n_chunks, chunk_size])
y = tf.placeholder('float')

這邊要注意一下,feature size 以及 time step 的概念

feature size 也就是上一章中所指的,被 vectorize 後單字的向量長度

time step 則是總共有幾個 loop 跑完,以上一章範例所指就是 6,

早,啊,你,今天,要,做。

而在這個數字預測中,我們 input 是 28*28 的照片,

我們就把它當作是長度為 28 的向量,有 28 個 time step

第一行,第二行... ... 一行一行這樣運算下去,

最後在 full connect 到一個 1*10 的向量,

完成多元分類。

圖解大概如下,

























把 28 * 28 的圖像分成 28 份,然後逐一輸入每個 Cell,最後解碼 output。

------------------------------------------------------------------------------------------------------------

二、定義遞歸神經網路

定義好基本的參數,也瞭解我們要如何對數字預測的圖片下手後,

可以開始用 tf 寫我們的遞歸神經網路了。

def recurrent_neural_network(x):

    layer = {'weights':tf.Variable(tf.random_normal([rnn_size, n_classes])),
            'biases':tf.Variable(tf.random_normal([n_classes]))}


    # transpose 這個 funtion 是對矩陣做不同維度座標軸的轉換

    x = tf.transpose(x, [1,0,2])
    x = tf.reshape(x, [-1, chunk_size])


    # 這邊把一張圖片轉成以每列為單位的輸入,即是上圖那個切圖片的
    x = tf.split(axis=0, num_or_size_splits=n_chunks, value=x)

    # 定義要被 loop 的基本單元
    lstm_cell = BasicLSTMCell(rnn_size)


    # 選一個把 cell 串起來的 model
    outputs, states = static_rnn(lstm_cell, x, dtype= tf.float32)

    # 用一個 full connection layer 輸出預測
    output = tf.matmul(outputs[-1], layer['weights']) + layer['biases']

    return output

transpose、reshape 和 split,我認為是非常重要的資料前處理工具

大家可以開一個 test.py,去玩玩看這幾個方法,熟悉之後,要做事情比較快。

------------------------------------------------------------------------------------------------------------

三、最佳化遞歸神經網路

定義完神經網路後,我們就可以進行最後一步,最佳化

這一步編寫的方式和基礎神經網路還有卷積神經網路,一模一樣


def train_neural_network(x):

    prediction = recurrent_neural_network(x)

    cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))

    optimizer = tf.train.AdamOptimizer().minimize(cost)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(hm_epochs):
            epoch_loss = 0
            for _ in range(int(mnist.train.num_examples/batch_size)):
                epoch_x, epoch_y = mnist.train.next_batch(batch_size)
                epoch_x = epoch_x.reshape((batch_size, n_chunks, chunk_size))

                _, c = sess.run([optimizer,cost], feed_dict = {x: epoch_x, y: epoch_y})
                epoch_loss += c
            print ('Epoch completed:', epoch, hm_epochs, 'loss:', epoch_loss)

        correct = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))

        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
        print ('Accuracy:' ,accuracy.eval({x:mnist.test.images.reshape((-1, n_chunks, chunk_size)), y:mnist.test.labels}))

train_neural_network(x)

大功告成!

接著在終端機執行此程式碼,





在 3 個 epochs 底下,還是有 0.97 的準確率!




RNN 的基礎練習就到這邊,希望大家對定義 RNN 比較有概念了。

這邊附上這篇的 Github

rnn_090.py 適用於 tf 1.0.1 之前的所有版本,

rnn_101.py 則是適用於 tf 1.0.1。




接下來,我們將使用 LSTM 來玩 台股分K 的資料,

不要再玩每日每日的,交易頻率太低,資料量太少,沒有趣味

再來的 LSTM X 台指 實作中,還會學到 TFLearn,是一個很簡潔的高階 TF API!

------------------------------------------------------------------------------------------------------------

Reference

[1] Practical Machine Learning Problem

[2] 圖解機器學習

[3] Coursera - Machine Leanring 

[4] A tour of machine learning algorithms

沒有留言:

張貼留言