接下來我們會用 LSTM 來玩台股分 K 預測,
不過!在這之前我們要先介紹一下 TFLearn 這個超好用的工具。
------------------------------------------------------------------------------------------------------------
一、TFLearn 介紹
TFLearn is a modular and transparent deep learning library built on top of Tensorflow.
基於 tensorflow 之上的 API,稍微講一下我跟 TFLearn 的機緣,
應該會對他的方便性更加有感。
去年十月在玩 CNN X TAIFEX 的時候,因為我自己是認為輸入資料比較重要的人,
所以我對 tune Model 並沒有下太多功夫,那 CNN 因為有 filter 的關係,
每一層每一層之間的 tensor 和 ops 定義很麻煩,
所以我就寫了一個小工具,讓自己一行就可以寫出複雜的 model ,
然後用 for loop 去亂試,哈哈哈!
猜想 TFLearn 的團隊一定也是和我一樣苟且,所以寫了這個 API。
不知道大家是否記得,當初在寫 CNN 有多煎熬。
但現在!只要靠著 TFLearn,每一層神經層和 activation func 甚至是 filter,
都只要一行!就可以完成!
不僅如此! TFLearn 還提供很多很方便的功能,大家可以在 Get Started 的地方查找,
從資料前處理的方法, Ops、神經層的定義以及最後視覺化,
TFLearn 都提供很好的解決方案,如果是快速開發,真的是很方便的 API!
------------------------------------------------------------------------------------------------------------
二、LSTM model 用於台指分K
其實,這一次要做的事情是很直觀的,
就使用前 15 分鐘的資料去預測未來 15 分鐘的漲跌,
把第一根 K 棒的資訊丟到第二根,第二根丟到第三根這樣,
累積丟十五根,之後預測未來 15 分鐘(灰色區塊)的漲跌。
K 棒的資訊就是 Open、High、Low、Close,
我們再額外加上五分均線和 20 分均線以及 K9、D9 作為 features,
K9、D9 技術指標的介紹,在這邊有。
這邊稍微想一下,以這樣的假設,
對 LSTM 來說,input size 以及 time step 分別是多少呢?
input size = 8
time step = 15
現在可以決定一下,我們 output 是要長什麼樣子,
我目前就是隨便設一個未來 15 分鐘內,
漲 10 點就是上漲,跌 10 點就是下跌,
上下 10 點或是震幅小於 10 點,就是平盤。
OK!決定好了之後,我們就可以開始處理訓練資料了!
------------------------------------------------------------------------------------------------------------
三、資料前處理
首先一樣 import 我們會使用到的 Library,
然後定義一些參數以及用 pickle 載入分 K 的資料。
(結果我完全忘記給大家分 K 資料,連結在這。)
import random
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
from tflearn.layers.recurrent import lstm
import pickle
from datetime import date, timedelta
Learning_rate = 1e-3
pickle_in_min = open('TWfuture_min_data_2014_adjusted.pickle', 'rb')
df_min = pickle.load(pickle_in_min)
start_day = date(2014,1,3)
end_day = date(2016,12,31)
interval = timedelta(days=1)
基本設定做完之後,我們就可以在以下,寫下資料前處理的方法,
def initTrData(df,start_day,end_day):
training_data = []
span_day = end_day-start_day
span = int(span_day.days)
last_day = start_day
for _ in range(span):
y_price = df[df["date"]==last_day]["close"][-1]
start_day += interval
df_today = df[df["date"]==start_day]
# 把資料標準化
df_today.loc[:,'open_adjusted'] =
(df_today["open"] - y_price)*100/y_price
df_today.loc[:,"high_adjusted"] =
(df_today["high"] - y_price)*100/y_price
df_today.loc[:,"low_adjusted"] =
(df_today["low"] - y_price)*100/y_price
df_today.loc[:,"close_adjusted"] =
(df_today["close"] - y_price)*100/y_price
df_today.loc[:,"close_mvag5_adjusted"] =
(df_today["close_mvag5"] - y_price)*100/y_price
df_today.loc[:,"close_mvag20_adjusted"] =
(df_today["close_mvag20"] - y_price)*100/y_price
if len(df_today) > 0:
last_day = start_day
print (start_day)
for i in range(120):
df_today_min = df_today[i:i+15][["open_adjusted","high_adjusted",
"low_adjusted","close_adjusted","close_mvag5_adjusted","close_mvag20_adjusted","K9","D9"]]
X = np.array(df_today_min)
call = False
put = False
for j in range(15):
if (df_today.close.iat[i+15+j] -
df_today.open.iat[i+15]) > 10:
call = True
elif (df_today.close.iat[i+15+j] -
df_today.open.iat[i+15]) < -10:
put = True
if call == True and put == False:
y = [0,0,1] # 這定義為多訊號
elif call == False and put == True:
y = [1,0,0] # 這定義為空訊號
else:
y = [0,1,0] # 這定義為平盤訊號
training_data.append([X,y])
training_data_save = np.array(training_data)
# 用 np 儲存資料
np.save('preprocessing_15to15.npy',training_data_save)
initTrData(df_min,start_day,end_day)
執行這一個 script 後,就可以獲得 2014 ~ 2016 年整理完的資料了!
2017 到時候用來跑回測。
這邊附上 Github 連結與分 K 資料連結,
跑完之後,下一章就可以寫 LSTM model 來學習這些資料。
------------------------------------------------------------------------------------------------------------
Reference
[1] Practical Machine Learning Problem
[2] 圖解機器學習
[3] Coursera - Machine Leanring
[4] A tour of machine learning algorithms
pickle_in_min = open('TWfuture_min_data_2014_adjusted.pickle', 'rb')
回覆刪除df_min = pickle.load(pickle_in_min)
這兩行在python3.5的環境下
好像會有import module的error
可能要加入:
import pandas as pd
另外
修改為
df_min = pd.read_pickle(pickle_in_min)
不知是不是漏打
還是在python3.5的環境下會不支援pickle.open
我這邊不是用 pandas 去 read_pickle,
刪除剛剛試了一下用 3.5.1 版本 import pickle 讀取和寫入都沒問題。
應該是我當初開發時都習慣用 2.7.13 版本,所以 pickle 版本問題。
用 pandas 確實很方便,但我現在都傾向把東西寫成 csv, yaml, txt, json or database
pickle 其實是滿快,檔案也小一點,但就是編碼過,很容易有版本不相容問題。
學習到了!
刪除因為還是初學者,
不清楚使用上的方式><
感謝講解~