水果模型訓練,比 Dollar bill detection 多了圖片的標示,train/valid 的分割。
安裝套件
請記得安裝 ultralytics 時,會自動安裝無法使用 GPU 的 torch 版本,所以需先安裝 torch GPU 版,再安裝 ultralytics。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
pip install pip install ultralytics labelimg
圖片下載
請下載本站的 fruits.zip,儲存在專案下的 fruits 目錄,然後解壓縮至此,就會在 fruits 下產生 data 目錄
標示圖片
請在命令提示字元視窗輸入 pip install labelimg,然後直接執行 labelimg 即可開始標識圖片。標示完記得將 .txt 檔存放在 fruits/data/labels 之下
pip install labelimg
切割 train/valid
使用下面代碼,將 fruits/data 裏的資料,分割成 80% 存放在 datasets/train,20% 存放 datasets/valid。執行完後,將 fruits 下的 train 及 valid 二個目錄 copy 到專案下的 datasets 目錄。
import os
import random
import shutil
data_path='./data'
train_path='./train'
valid_path='./valid'
if os.path.exists(train_path):
shutil.rmtree(train_path)
if os.path.exists(valid_path):
shutil.rmtree(valid_path)
os.makedirs(os.path.join(train_path, 'images'))
os.makedirs(os.path.join(train_path, 'labels'))
os.makedirs(os.path.join(valid_path, 'images'))
os.makedirs(os.path.join(valid_path, 'labels'))
files=[os.path.splitext(file)[0]
for file in os.listdir(os.path.join(data_path, "images"))]
random.shuffle(files)
mid=int(len(files)*0.8)
for file in files[:mid]:
source=os.path.join(data_path, "images", f'{file}.jpg')
target=os.path.join(train_path,"images", f'{file}.jpg')
print(source, target)
shutil.copy(source, target)
source=os.path.join(data_path, "labels", f'{file}.txt')
target=os.path.join(train_path,"labels", f'{file}.txt')
print(source, target)
shutil.copy(source, target)
for file in files[mid:]:
source=os.path.join(data_path, "images", f'{file}.jpg')
target=os.path.join(valid_path,"images", f'{file}.jpg')
print(source, target)
shutil.copy(source, target)
source=os.path.join(data_path, "labels", f'{file}.txt')
target=os.path.join(valid_path,"labels", f'{file}.txt')
print(source, target)
shutil.copy(source, target)
設定 data.yaml
在專案根目錄下新增 data.yaml,內容如下。請注意一定要寫絕對路徑
train: e:/python_ai/yolov8_fruit/train/images val: e:/python_ai/yolov8_fruit/valid/images nc: 4 names: ['guava', 'lemon', 'pitaya', 'wax']
下載預訓練模型
到 https://github.com/ultralytics/ultralytics 網站,往下拉下載 YOLOv8s.pt ,儲存在專案根目錄下。
yolo.exe 訓練權重
在 Terminal 執行如下指令,epochs 設定為 200 次。不過 V8很聰明,在 epochs 約 98 次後,發現無法再逼近,就會自動停止
yolo task=detect mode=train model=./yolov8x.pt data=./data.yaml epochs=200 imgsz=640
本例圖片並不多,RTX 3080Ti 訓練時間約 6 分鐘。無法訓練的人,請下載 best.pt,然後儲存在 ./runs/detect/train/weights 之下。
使用 Python 訓練模型
在專案下新增 train.py 檔,由 YOLO 戴入預訓練模型 yolov8x.pt 產生 model 物件,再由 model.train() 即可開始訓練。
請注意這段程式碼一定要寫在 if __name__ 的區塊中,否則會出現需使用 fork 執行子行程的錯誤。
import os
import shutil
import time
from ultralytics import YOLO
#訓練模型時,一定要放在 __name__ 區塊內
#否則會出現需使用 fork來執行子行程的錯誤
if __name__=='__main__':
train_path="./runs/detect/train"
if os.path.exists(train_path):
shutil.rmtree(train_path)
model = YOLO("yolov8x.pt")#因圖片數少,所以使用v8x比較準
print("開始訓練 .........")
t1=time.time()
model.train(data="./data.yaml", epochs=200, imgsz=640)
t2=time.time()
print(f'訓練花費時間 : {t2-t1}秒')
path=model.export()
print(f'模型匯出路徑 : {path}')
偵測圖片
請使用如下代碼偵測圖片
import os
import platform
import pylab as plt
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from ultralytics import YOLO
def text(img, text, xy=(0, 0), color=(0, 0, 0), size=12):
pil = Image.fromarray(img)
s = platform.system()
if s == "Linux":
font =ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', size)
elif s == "Darwin":
font = ImageFont.truetype('....', size)
else:
font = ImageFont.truetype('simsun.ttc', size)
ImageDraw.Draw(pil).text(xy, text, font=font, fill=color)
return np.asarray(pil)
model=YOLO('./runs/detect/train/weights/best.pt')
img_path="./valid/images"
plt.figure(figsize=(12,9))
for i,file in enumerate(os.listdir(img_path)[-6:]):
full=os.path.join(img_path, file)
img=cv2.imdecode(
np.fromfile(full, dtype=np.uint8),
cv2.IMREAD_UNCHANGED
)[:,:,::-1].copy()
results=model.predict(img, save=False)
boxes=results[0].boxes.xyxy
names = [results[0].names[int(idx.cpu().numpy())] for idx in results[0].boxes.cls]
for box, name in zip(boxes, names):
print(name)
box=box.cpu().numpy()
x1 = int(box[0]);y1 = int(box[1]);x2 = int(box[2]);y2 = int(box[3])
print(img.shape, x1, y1, x2, y2)
try:
img=cv2.rectangle(img,(x1, y1), (x2, y2), (0,255,0) , 2)
img=text(img, name, (x1, y1), (0,255,0),25)
except:
pass
plt.subplot(2,3,i+1)
plt.axis("off")
plt.imshow(img)
plt.show()
