快捷搜索:  汽车  科技

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心查看数据集80020010一条数据,三、定义数据集#数据增强 class ImgTransforms(object): """ 图像预处理工具 并对图像的维度进行转换 从HWC变为CHW """ def __init__(self fmt): self.format = fmt def __call__(self img): img = img.transpose(self.format) return img #定义数据集 class GIDataset(Dataset): def __init__(self data_path mode="train" val_split=0.2): super()._

一、准备数据集

根据数学知识自动生成三角形关键点(重心和内心)坐标

import numpy as np import pandas as pd #三角形内心 # ix=(aX1 bX2 cX3)/(a b c) # iy=(aY1 bY2 cY3)/(a b c) #三角形重心 # gx=(x1 x2 x3)/3 # gy=(y1 y2 y3)/3 datadir="data" header=["x1" "y1" "x2" "y2" "x3" "y3" "gx" "gy" "ix" "iy"] #构造三角形的三个顶点 y1=0.0 x2=0.0 y3=224 #数据列表 lst=[] #返回三角开三边长 def dist(xy): ''' xy:三角形三个顶点坐标A(x1 y1) B(x2 y2) C(x3 y3) np.array[[x1 y1] [x2 y3] [x3 y3]] ''' #d=np.sqrt(np.sum(np.square(a-b))) a=np.linalg.norm(xy[1]-xy[2]) b=np.linalg.norm(xy[2]-xy[0]) c=np.linalg.norm(xy[0]-xy[1]) return a b c #返回三角形重心或内心坐标 def gixy(xy abc=np.array([1 1 1])): ''' abc:三角形三边np.array([a b c]) xy:三角形三个顶点坐标np.array[[x1 y1] [x2 y3] [x3 y3]] ''' gixy=np.dot(abc xy)/np.sum(abc) return gixy for i in range(1000): x1=np.random.rand()*224 y2=np.random.rand()*224 x3=np.random.rand()*224 #三角形顶点坐标 xy1=[x1 y1] xy2=[x2 y2] xy3=[x3 y3] xy=np.array([xy1 xy2 xy3]) a b c=dist(xy) #三角形的重心、内心坐标 gxy=gixy(xy) ixy=gixy(xy np.array([a b c])) v=(x1 y1 x2 y2 x3 y3 gxy[0] gxy[1] ixy[0] ixy[1]) lst.append(v) #写入csv文件 生成训练集和测试集 df=pd.DataFrame(lst columns=header) df.to_csv("data/train-gi.csv" index=False) # df.to_csv("data/test-gi.csv" index=False)

二、准备环境,查看数据集

import numpy as np import matplotlib.pyplot as plt import pandas as pd import os import cv2 import paddle from paddle.io import Dataset from paddle.vision import transforms from paddle.vision.models import resnet18 from paddle.nn import functional as F

#查看一下数据集 train_Dir = 'data/train-gi.csv' test_Dir = 'data/test-gi.csv'

data_source=pd.read_csv(train_Dir) img=data_source.iloc[3 :-4] img=img.to_numpy() img=img.reshape((3 2)) print(img) base=np.zeros((224 224 3) np.uint8) cv2.polylines(base [img.astype(int)] True (255 255 255) 1) plt.imshow(base) plt.show()

查看数据结果:

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心(1)

一条数据,

三、定义数据集

#数据增强 class ImgTransforms(object): """ 图像预处理工具 并对图像的维度进行转换 从HWC变为CHW """ def __init__(self fmt): self.format = fmt def __call__(self img): img = img.transpose(self.format) return img #定义数据集 class GIDataset(Dataset): def __init__(self data_path mode="train" val_split=0.2): super().__init__() self.mode=mode self.data_source=pd.read_csv(data_path) self.data_source.dropna(how="any" inplace=True) self.data_label_all=self.data_source.drop(["x1" "y1" "x2" "y2" "x3" "y3"] axis=1) if self.mode in ["train" "val"]: np.random.seed(99) data_len=len(self.data_source) shuffled_ind=np.random.permutation(data_len) self.shuffled_ind=shuffled_ind val_set_size=int(data_len*val_split) if self.mode=="val": val_ind=shuffled_ind[:val_set_size] self.data_img=self.data_source.reindex().iloc[val_ind :-4] self.data_label=self.data_label_all.reindex().iloc[val_ind] elif self.mode=="train": train_ind=shuffled_ind[val_set_size:] self.data_img=self.data_source.reindex().iloc[train_ind :-4] self.data_label=self.data_label_all.reindex().iloc[train_ind] elif self.mode=="test": self.data_img=self.data_source.drop(["gx" "gy" "ix" "iy"] axis=1) self.data_label=self.data_label_all self.transforms = transforms.Compose([ ImgTransforms((2 0 1)) ]) def __getitem__(self idx): img=np.zeros((224 224 3) np.float32) #数据类型 HWC imgxy=self.data_img.iloc[idx] imgxy=imgxy.to_numpy() imgxy=np.reshape(imgxy (3 2)) cv2.polylines(img [imgxy.astype(int)] True (255 255 255) 1) img = self.transforms(img) #CHW label=np.array(self.data_label.iloc[idx :] np.float32)/224 return img label def __len__(self): return len(self.data_img)

train_dataset=GIDataset(train_Dir mode="train") val_dataset=GIDataset(train_Dir mode="val") test_dataset=GIDataset(test_Dir mode="test")

print(train_dataset.__len__()) print(val_dataset.__len__()) print(test_dataset.__len__())

运行结果:

800
200
10

查看数据集

#查看数据集 idx=np.random.randint(test_dataset.__len__()) print(idx) img label=test_dataset[idx] img=np.transpose(img (1 2 0)) plt.imshow(img) plt.show()

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心(2)

for i in range(4): plt.subplot(2 2 i 1) idx=np.random.randint(train_dataset.__len__()) img label=train_dataset[idx] img=np.transpose(img (1 2 0)) label=label*224 label=label.astype(int) print(label) cv2.circle(img (label[0] label[1]) 2 (255 0 0) 2) cv2.circle(img (label[2] label[3]) 2 (0 255 0) 2) plt.imshow(img) plt.show()

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心(3)

四、定义网络模型

#定义模型 #resnet18 ImageNet分类任务:1000类,在模型后接一个全连接层,将输出的1000维向量映射成4维,对应2个关键点的坐标。 class GINet(paddle.nn.Layer): def __init__(self num_keypoints pretrained=False): super(GINet self).__init__() self.backbone=resnet18(pretrained) self.outLayer1=paddle.nn.Sequential( paddle.nn.Linear(1000 512) paddle.nn.ReLU() paddle.nn.Dropout(0.1)) self.outLayer2=paddle.nn.Linear(512 num_keypoints*2) def forward(self inputs): out=self.backbone(inputs) out=self.outLayer1(out) out=self.outLayer2(out) return out

五、查看网络模型

#模型可视化 num_keypoints=2 model=paddle.Model(GINet(num_keypoints)) model.summary((1 3 224 224))

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心(4)

六、模型训练

#训练模型 #坐标回归 model=paddle.Model(GINet(num_keypoints=2)) optim=paddle.optimizer.Adam(learning_rate=0.001 parameters=model.parameters()) model.prepare(optim paddle.nn.MSELoss()) model.fit(train_dataset val_dataset epochs=8 batch_size=16)

七、保存模型参数和模型

#保存模型参数 paddle.save(model.state_dict() 'kpmodels/kp.pdparams') #保存模型 model.save("kpmodels/best" training=False)

八、模型预测

#模型预测 model=GINet(num_keypoints=2) model_dict=paddle.load("kpmodels/kp.pdparams") model.load_dict(model_dict) model.eval() idx=np.random.randint(test_dataset.__len__()) print(idx) img label=test_dataset[idx] inputimg=img.copy() inputimg=inputimg.reshape((1 3 224 224)) inputimg=paddle.to_tensor(inputimg) result=model(inputimg) result=result[0]*224 result=np.array(result np.uint8) print(result) img=np.transpose(img (1 2 0)) label=label*224 label=label.astype(int) print(label) plt.imshow(img) plt.show()

如何预测向量回归:18.坐标回归-三角形关键点 重心和内心(5)

预测结果对比

本文主要一个实例来展示,如何通过坐标回归方式,来实现图像关键检测点。

猜您喜欢: