LSTM股票預測

      在〈LSTM股票預測〉中尚無留言

LSTM除了用在 NLP 外,亦可應用在語音辨識,語言翻譯等。每次在做資料分析之前,資料收集都是困難且費時難。

本篇說明台灣股市的預測,而台股的預測需要收盤價及交易量,可以由 tejapi 及 yfinance 二個套件取得。tejapi 需要付費,而 yfinance 則是免費的,所以建議使用 yfinance。

tejapi

tejapi 是台灣經濟新報 TEJ 推出的套件,需要付費,但可以申請試用。試用可下載的資料有限,所以不建議使用,本段只是作為記錄而以。

請到 TEJAPI 網站 https://api.tej.com.tw/trial.html 申請試用金鑰,申請時會傳送驗証碼到手機中,然後會將金鑰傳送到 eMail 中。

這組金鑰只能下載試用資料庫,資料庫名稱及欄位請由 https://api.tej.com.tw/datatables.html?db=TRAIL&t=%E8%A9%A6%E7%94%A8%E8%B3%87%E6%96%99%E5%BA%AB 查詢。我們選擇 “上市(櫃)未調整股價(日)” 查詢台積電的股價,試用金鑰只能查到 2022 一整年的資料。

請先安裝tejapi套件

pip install tejapi

程式碼使用 tejapi.get 取得資料,傳回的資料為 panda 格式,代碼如下

import tejapi
import pandas as pd
def getData(coid):
    tejapi.ApiConfig.api_key = "your key"
    tejapi.ApiConfig.ignoretz = True
    mdate={'gte':'2020-01-01', 'lte':'2023-12-31'}
    opts={'columns': ['coid','mdate','open_d','high_d','low_d','close_d','amount']}
    data = tejapi.get('TRAIL/TAPRCD',
        coid = coid,
        mdate=mdate,
        opts=opts,
        paginate=True)
    return data
display=pd.options.display
display.max_columns=None
display.max_rows=None
display.width=None
display.max_colwidth=None
print(getData(2330))

結果:
      coid      mdate  open_d  high_d  low_d  close_d        amount
None                                                               
0     2330 2022-01-03   619.0   632.0  618.0    631.0  4.624972e+07
1     2330 2022-01-04   645.0   656.0  644.0    656.0  5.918820e+07
2     2330 2022-01-05   669.0   669.0  646.0    650.0  4.758283e+07
3     2330 2022-01-06   638.0   646.0  636.0    644.0  3.681764e+07
4     2330 2022-01-07   643.0   646.0  632.0    634.0  2.535824e+07
5     2330 2022-01-10   628.0   645.0  627.0    643.0  2.505214e+07
6     2330 2022-01-11   646.0   651.0  639.0    651.0  2.336159e+07
.......

Yahoo Finance

Yahoo Finance 提供 yfinance 套件讓 Python 可以免費下載台股資訊,這是最好用的套件。

安裝套件指令如下

pip install yfinance

yf.Ticker() 函數可以下載美股的資訊。

yf.download() 函數可下載台股資訊,函數中的 tickers 參數需給定要下載的股票代碼,比如
大盤    : “^TWII”
台積電 : “2330.TW”
電子股 : ” ^TELI”
相關股票代號可以由 https://tw.stock.yahoo.com/t/idx.php 查詢

下載代碼如下

import yfinance as yf
import pandas as pd
display=pd.options.display
display.max_columns=None
display.max_rows=None
display.width=None
display.max_colwidth=None
#df = yf.Ticker('AAPL').history(period = 'max')
df=yf.download("^TWII", start="2023-10-01", end="2023-10-22")
print(df)

結果:
[*********************100%%**********************]  1 of 1 completed
Date               Open          High           Low         Close     Adj Close   Volume
                                                                                    
2023-10-02  16382.959961  16575.310547  16382.959961  16557.310547  16557.310547  2669600
2023-10-03  16520.240234  16571.660156  16453.800781  16454.339844  16454.339844  2697100
2023-10-04  16419.480469  16419.480469  16203.339844  16273.379883  16273.379883  2624800
2023-10-05  16313.889648  16477.609375  16313.889648  16453.519531  16453.519531  2503200
2023-10-06  16482.929688  16539.130859  16482.929688  16520.570312  16520.570312  2296800
2023-10-11  16567.429688  16729.500000  16567.429688  16672.029297  16672.029297  3635100
2023-10-12  16698.849609  16825.910156  16684.429688  16825.910156  16825.910156  2917100
2023-10-13  16815.910156  16815.910156  16726.429688  16782.570312  16782.570312  2955800
2023-10-16  16712.929688  16712.929688  16614.099609  16652.240234  16652.240234  2843500
2023-10-17  16678.279297  16770.640625  16618.300781  16642.550781  16642.550781  2965400
2023-10-18  16608.019531  16612.250000  16398.759766  16440.910156  16440.910156  3844900
2023-10-19  16416.539062  16479.349609  16382.400391  16452.730469  16452.730469  2608200
2023-10-20  16434.429688  16462.560547  16271.820312  16440.720703  16440.720703  3035200

yfinance 共有Date(日期)、Open(開盤)、High(最高)、Low(最低)、Close(收盤)、Adj Close(盤後交易)、Volume(交易量) 共 7 個欄位。

完整代碼

底下可繪出黃金期貨走勢圖及預測值,請先安裝如下套件

pip install plotly yfinance scikit-learn pandas

完整代碼如下

from datetime import datetime, timedelta
import pandas as pd
from dateutil.relativedelta import relativedelta
from sklearn.linear_model import LinearRegression

display=pd.options.display
display.max_columns=None
display.max_rows=None
display.width=None
display.max_colwidth=None
import yfinance as yf
import plotly.graph_objects as go
"""
大盤 : ^TWII
黃金期貨 : GC=F
"""
stock='GC=F'
df=yf.download(stock, start='2023-10-01', end='2024-05-12')

fig=go.Figure()
fig.add_trace(
    go.Scatter(
        x=df.index,
        y=df['Close'].values,
        mode='lines',
        name='實際價格',
        line=dict(color='royalblue', width=2)
    )
)

ma1=5#5日平均線
ma2=10#10日平均線
df=df.dropna()
df['s1']=df['Close'].rolling(window=ma1).mean()
df['s2']=df['Close'].rolling(window=ma2).mean()
df=df.dropna()

train=df[['Close','s1','s2']]
train['next_day_price']=train['Close'].shift(-1)
train=train.dropna()
x_train=train[['s1', 's2']]
y_train=train['next_day_price']
model=LinearRegression()
model.fit(x_train, y_train)
df['predict_price']=model.predict(df[['s1', 's2']])
pred=df[['predict_price']]
s=(pred.tail(1).index+timedelta(days=1))[0]
dates=pd.date_range(s, periods=1)
pred.loc[dates[0]]=[0]
pred['predict_price']=pred['predict_price'].shift(1)
print(pred)
fig.add_trace(
    go.Scatter(
        x=pred.index,
        y=pred['predict_price'].values,
        mode='lines',
        name='AI預測',
        line=dict(color='orange', width=1)
    )
)
current = datetime.now()
xrange = [(current - relativedelta(months=6)).strftime("%Y-%m-%d"), current.strftime("%Y-%m-%d")]
yrange = [df['Close'].tail(180).min(), df['Close'].tail(180).max()]
fig.update_layout(
    dragmode="pan",
    xaxis=go.layout.XAxis(
        range=xrange,
        rangeselector=dict(
            buttons=list([
                dict(count=1,
                     label="1 month",
                     step="month",
                     stepmode="backward"),
                dict(count=6,
                     label="6 month",
                     step="month",
                     stepmode="backward"),
                dict(count=1,
                     label="1 year",
                     step="year",
                     stepmode="backward"),
                dict(count=1,
                     label="1 day",
                     step="day",
                     stepmode="todate"),
                dict(step="all")
            ])
        ),
        rangeslider=dict(
            visible=True
        ),
        type="date"
    ),
    yaxis=dict(
        fixedrange=False,
        range=yrange
    )
)
fig.show()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *