自定模型5種花卉

      在〈自定模型5種花卉〉中尚無留言

VGG19可以分辨1000個種類,但如果我們要辨識的,並不在這1000個之內,就要自已訓練。訓練的步驟為 : 
收集圖片資料,訓練,偵測。開始之前,請先安裝如下套件

pip install tensorflow==2.10.1 matplotlib opencv-python

訓練圖片

收集圖片是一項大工程,還好 tensorflow 幫我們收集了一大堆圖片。請到 http://download.tensorflow.org/example_images/flower_photos.tgz 下載圖片。下載後置於專案之下並解開,就會多一個 flower_photos 目錄。

flower_photos 裏面 有 5 個資料夾,分別是 daisy(雛菊),dandelion(蒲公英),roses(玫瑰),sunflowers(向日葵),tulips(鬱金香)

測試圖片

要進行辨識測試的圖片,請由如下網址下載,存儲在專案下的 images 目錄

images.zip

Dense 層

Dense 層只會出現在全連接層,也就是輸出層。Dense 的作用就是由原本的 x 種狀況 (特徵),經過某種演算後,變成下一層的 y 種狀況。

剛剛講的某種演算,通常就是用捲積的方式來計算下一層的結果。

訓練模型

把每張圖片讀入縮小後放入 data 中,並同時由目錄名稱記錄每張圖的類別放在 labels 中。然後將VGG19前三個連接層抽出,再加入自已的全連接層。

一開始始用 GlobalAveragePooling2D 將 (長*寬*通道) 轉換成 (1*1*通道),其方法是將每個通道中的權重作平均值。如下圖所示,最前面的通道中,(1+5+4+5+6+5+3+9+4+2+5+2+8+6+8+7)/16 = 5

Dense 種類為 256 及 64 種,激活方式為 relu (線性整流, 將負值變為0),最後的輸出層激活含數為 softmax (轉換為機率,總合為1)。

在每個連接層( Dense) 之後,都需作 BatchNormalization。BN 層的作用是把一個 batch 內的所有數據,從不規則的分佈拉到常態分佈,將數字集中在平均為 0,標準差為 1 的範圍。這樣作的好處是使得數據能夠分佈在激活函數的敏感區域,敏感區域即為梯度較大的區域,因此在反向傳播時能夠較快反饋誤差傳播。BN的運作方式又是一篇論文,裏面有著複雜的數學運算,所以只需了解其功能即可。

模型中所加入的各層,不一定要按這個方式,也可以改用其它層,可自行測試看看。

訓練時要注意 batch_size,如果太大會造成顯卡記憶体不足,就需往下調整。訓練的時間依硬体等級有所不同。訓練好的模型會儲存在 flower 目錄中。

下述代碼若使用 CPU 訓練,每個 epoch 約要 2 分鐘,共需 100 分鐘。若使用 GPU 訓練,每個 epoch 約 4 秒,共需 200 秒。

#!./.venv/bin/python3
import os
import random
from keras import Sequential
from keras.src.applications.vgg19 import VGG19, preprocess_input
from keras.src.layers import GlobalAveragePooling2D, Dense, BatchNormalization, Dropout
from keras.src.optimizers import Adam
from MahalSdk.cv import cv
import cv2
import numpy as np
import pylab as plt
imgs=[]
labels=[]
path="./flower_photos"
flowers=[]
for flower in os.listdir(path):
    flowers.append(flower)
    for file in os.listdir(os.path.join(path, flower)):
        img=cv.read(os.path.join(path, flower, file))[:,:,::-1].copy()
        img=cv2.resize(img,(224, 224) , interpolation=cv2.INTER_LINEAR)
        # 作不作預處理都沒差
        #imgs.append(preprocess_input(img))
        imgs.append(img)
        labels.append(flower)

#Linux 讀取目錄的順序不會按字母排序,所以要手動排序
flowers=sorted(flowers)

upset=list(zip(imgs, labels))
random.seed(1)
random.shuffle(upset)
imgs, labels=zip(*upset)
train=int(len(imgs)*0.9)
train_img=np.array(imgs[:train])
train_label=np.array(labels[:train])
test_img=np.array(imgs[train:])
test_label=np.array(labels[train:])
train_onehot=np.zeros([len(train_label), 5])
test_onehot=np.zeros([len(test_label), 5])
for i in range(len(train_label)):
    train_onehot[i][flowers.index(train_label[i])] = 1
for i in range(len(test_label)):
    test_onehot[i][flowers.index(test_label[i])]=1
print(test_onehot)
#建立模型
model_base=VGG19(
    weights='imagenet',
    include_top=False,
    input_shape=(224,224,3)
)
for layer in model_base.layers:
    layer.trainable=False

model=Sequential()
model.add(model_base)
model.add(GlobalAveragePooling2D())
model.add(Dense(256,activation='relu'))
model.add(BatchNormalization())
model.add(Dense(64,activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))

#編譯模型
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

#開始訓練模型
history=model.fit(
    train_img,
    train_onehot,
    batch_size=128,
    epochs=50,
    validation_data=(test_img, test_onehot)
)
model.save("./flower.keras")

p1=plt.plot(history.history['accuracy'], label='training acc')#訓練時的準確度
p2=plt.plot(history.history['val_accuracy'], label='val acc')#測試時的準確度
p3=plt.plot(history.history['loss'], label='training loss')#訓練時的損失率
p4=plt.plot(history.history['val_loss'], label='val loss')#測試時的損失率
plt.legend()
plt.show()

辨識圖片

如果無法訓練模型,可下載人本已訓練好的模型 : flower5.zip

載入模型後,再載入圖片即可辨識。記得圖片需縮小為 224*224,然後擴展為4維。

#!./.venv/bin/python3
import os
from keras.src.applications.convnext import preprocess_input
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import keras
import pylab as plt
import cv2, numpy as np
model=keras.models.load_model('./flower.keras')
path="./images"
flowers=sorted(os.listdir("./flower_photos"))
for i, file in enumerate(os.listdir(path)):
    full=os.path.join(path, file)
    img=cv2.imdecode(np.fromfile(full, dtype=np.uint8), cv2.IMREAD_COLOR)
    img=img[:,:,::-1].copy()
    x=cv2.resize(img, (224,224), interpolation=cv2.INTER_LINEAR)
    x=np.expand_dims(x, axis=0)

    #不作預處理也沒差
    #x=preprocess_input(x)

    out=model.predict(x)
    idx=out[0].argmax()
    name=flowers[idx]
    ax=plt.subplot(3,5,i+1)
    ax.set_title(name)
    ax.imshow(img)
    ax.axis("off")
plt.show()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *