Shell 流程控制
1160
2022-05-30
人像摳圖:算法概述及工程實現(xiàn)(一)
Pytorch->Caffe模型轉(zhuǎn)換
上一篇概述了人像摳圖現(xiàn)有的算法,平衡速度與精度選擇了MODNet作為baseline,本文將重點闡述工程實現(xiàn)、優(yōu)化改進的細節(jié)。
本項目的最終目的是在HiLens Kit硬件上落地實現(xiàn)實時視頻讀入與背景替換,開發(fā)環(huán)境為HiLens配套在線開發(fā)環(huán)境HiLens Studio,先上一下對比baseline的改進效果:
使用modnet預訓練模型modnet_photographic_portrait_matting.ckpt進行測試結(jié)果如下:
可以看到由于場景較為陌生、逆光等原因會導致?lián)笀D結(jié)果有些閃爍,雖然modnet可以針對特定視頻進行自監(jiān)督finetune,但我們的目的是在普遍意義上效果更好,因此沒有對本視頻進行自監(jiān)督學習。
優(yōu)化后的模型效果如下:
注:原視頻來自Human-centric video matting發(fā)布的數(shù)據(jù)集
本視頻并沒有作為訓練數(shù)據(jù)。可以看到,摳圖的閃爍情況減少了很多,毛發(fā)等細節(jié)也基本沒有損失。
工程落地
為了測試baseline效果,首先我們要在使用場景下對baseline進行工程落地。根據(jù)文檔導入/轉(zhuǎn)換本地開發(fā)模型可知
昇騰310 AI處理器支持模型格式為".om",對于Pytorch模型來說可以通過"Pytorch->Caffe->om"或"Pytorch->onnx->om"(新版本)的轉(zhuǎn)換方式得到,這里我選擇的是第一種。Pytorch->Caffe模型轉(zhuǎn)換方法與注意事項在之前的博客中有具體闡述過,這里不贅述。轉(zhuǎn)換得到Caffe模型后,可以在HiLens Studio中直接轉(zhuǎn)為om模型,非常方便。
首先在HiLens Studio中新建一個技能,此處選擇了空模板,只需要修改一下技能名稱就可以。
將Caffe模型上傳到model文件夾下:
在控制臺中運行模型轉(zhuǎn)換命令即可得到可以運行的om模型:
/opt/ddk/bin/aarch64-linux-gcc7.3.0/omg --model=./modnet_portrait_320.prototxt --weight=./modnet_portrait_320.caffemodel --framework=0 --output=./modnet_portrait_320 --insert_op_conf=./aipp.cfg
接下來完善demo代碼。在測試時HiLens Studio可以在工具欄選擇使用視頻模擬攝像頭輸入,或連接手機使用手機進行測試:
具體的demo代碼如下:
# -*- coding: utf-8 -*- # !/usr/bin/python3 # HiLens Framework 0.2.2 python demo import cv2 import os import hilens import numpy as np from utils import preprocess import time def run(work_path): hilens.init("hello") # 與創(chuàng)建技能時的校驗值一致 camera = hilens.VideoCapture('test/camera0_2.mp4') # 模擬輸入的視頻路徑 display = hilens.Display(hilens.HDMI) # 初始化模型 model_path = os.path.join(work_path, 'model/modnet_portrait_320.om') # 模型路徑 model = hilens.Model(model_path) while True: try: input_yuv = camera.read() input_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_YUV2RGB_NV21) # 摳圖后替換的背景 bg_img = cv2.cvtColor(cv2.imread('data/tiantan.jpg'), cv2.COLOR_BGR2RGB) crop_img, input_img = preprocess(input_rgb) # 預處理 s = time.time() matte_tensor = model.infer([input_img.flatten()])[0] print('infer time:', time.time() - s) matte_tensor = matte_tensor.reshape(1, 1, 384, 384) alpha_t = matte_tensor[0].transpose(1, 2, 0) matte_np = cv2.resize(np.tile(alpha_t, (1, 1, 3)), (640, 640)) fg_np = matte_np * crop_img + (1 - matte_np) * bg_img # 替換背景 view_np = np.uint8(np.concatenate((crop_img, fg_np), axis=1)) print('all time:', time.time() - s) output_nv21 = hilens.cvt_color(view_np, hilens.RGB2YUV_NV21) display.show(output_nv21) except Exception as e: print(e) break hilens.terminate()
其中預處理部分的代碼為:
import cv2 import numpy as np TARGET_SIZE = 640 MODEL_SIZE = 384 def preprocess(ori_img): ori_img = cv2.flip(ori_img, 1) H, W, C = ori_img.shape x_start = max((W - min(H, W)) // 2, 0) y_start = max((H - min(H, W)) // 2, 0) crop_img = ori_img[y_start: y_start + min(H, W), x_start: x_start + min(H, W)] crop_img = cv2.resize(crop_img, (TARGET_SIZE, TARGET_SIZE)) input_img = cv2.resize(crop_img, (MODEL_SIZE, MODEL_SIZE)) return crop_img, input_img
demo部分的代碼非常簡單,點擊運行即可在模擬器中看到效果:
模型推理耗時44ms左右,端到端運行耗時60ms左右,達到了我們想要的實時的效果。
效果改進
預訓練模型在工程上存在著時序閃爍的問題,原論文中提出了一種使視頻結(jié)果在時間上更平滑的后處理方式OFD,即用前后兩幀平均誤差大的中間幀。但這種辦法只適合慢速運動,同時會導致一幀延遲,而我們希望可以對攝像頭輸入進行實時、普適的時序處理,因此OFD不適合我們的應用場景。
在Video Object Segmentation任務中有一些基于Memory Network的方法(如STM),摳圖領(lǐng)域也有新論文如DVM考慮引入時序記憶單元使摳圖結(jié)果在時序上更穩(wěn)定,但這些方法普遍需要前后n幀信息,在資源占用、推理實時性、適用場景上都與我們希望的場景不符合。
考慮到資源消耗與效果的平衡,我們采用將前一幀的alpha結(jié)果cat到當前幀RGB圖像后共同作為輸入的方法來使網(wǎng)絡(luò)在時序上更穩(wěn)定。
網(wǎng)絡(luò)上的修改非常簡單,只需在模型初始化時指定in_channels = 4:
modnet = MODNet(in_channels=4, backbone_pretrained=False)
訓練數(shù)據(jù)方面,我們選擇一些VideoMatting的數(shù)據(jù)集:VideoMatte240K、ConferenceVideoSegmentationDataset。
最初,我們嘗試將前一幀alpha作為輸入、缺失前幀時補零這種簡單的策略對模型進行訓練:
if os.path.exists(os.path.join(self.alpha_path, alpha_pre_path)): alpha_pre = cv2.imread(os.path.join(self.alpha_path, alpha_pre_path)) else: alpha_pre = np.zeros_like(alpha) net_input = torch.cat([image, alpha_pre], dim=0)
收斂部署后發(fā)現(xiàn),在場景比較穩(wěn)定時模型效果提升較大,而在人進、出畫面時模型適應較差,同時如果某一幀結(jié)果較差,將對后續(xù)幀產(chǎn)生很大影響。針對這些問題,考慮制定相應的數(shù)據(jù)增強的策略來解決問題。
人進、出畫面時模型適應較差:數(shù)據(jù)集中空白幀較少,對人物入畫出畫學習不夠,因此在數(shù)據(jù)處理時增加空白幀概率:
if os.path.exists(os.path.join(self.alpha_path, alpha_pre_path)) and random.random() < 0.7: alpha_pre = cv2.imread(os.path.join(self.alpha_path, alpha_pre_path)) else: alpha_pre = np.zeros_like(alpha)
某一幀結(jié)果較差,將對后續(xù)幀產(chǎn)生很大影響:目前的結(jié)果較為依賴前一幀alpha,沒有學會拋棄錯誤結(jié)果,因此在數(shù)據(jù)處理時對alpha_pre進行一定概率的仿射變換,使網(wǎng)絡(luò)學會忽略偏差較大的結(jié)果;
此外,光照問題仍然存在,在背光或光線較強處摳圖效果較差:對圖像進行光照增強,具體的,一定概率情況下模擬點光源或線光源疊加到原圖中,使網(wǎng)絡(luò)對光照更魯棒。光照數(shù)據(jù)增強有兩種比較常用的方式,一種是通過opencv進行簡單的模擬,具體可以參考augmentation.py,另外還有通過GAN生成數(shù)據(jù),我們使用opencv進行模擬。
重新訓練后,我們的模型效果已經(jīng)可以達到前文展示的效果,在16T算力的HiLens Kit上完全達到了實時、優(yōu)雅的效果。進一步的,我還想要模型成為耗時更少、效果更好的優(yōu)秀模型~目前在做的提升方向是:
更換backbone:針對應用硬件選擇合適的backbone一向是提升模型性價比最高的方法,直接根據(jù)耗時與資源消耗針對硬件搜一個模型出來最不錯,目前搜出來的模型轉(zhuǎn)為onnx測試結(jié)果(輸入192x192):
GPU: Average Performance excluding first iteration. Iterations 2 to 300. (Iterations greater than 1 only bind and evaluate) Average Bind: 0.124713 ms Average Evaluate: 16.0683 ms Average Working Set Memory usage (bind): 6.53219e-05 MB Average Working Set Memory usage (evaluate): 0.546117 MB Average Dedicated Memory usage (bind): 0 MB Average Dedicated Memory usage (evaluate): 0 MB Average Shared Memory usage (bind): 0 MB Average Shared Memory usage (evaluate): 0.000483382 MB CPU: Average Performance excluding first iteration. Iterations 2 to 300. (Iterations greater than 1 only bind and evaluate) Average Bind: 0.150212 ms Average Evaluate: 13.7656 ms Average Working Set Memory usage (bind): 9.14507e-05 MB Average Working Set Memory usage (evaluate): 0.566746 MB Average Dedicated Memory usage (bind): 0 MB Average Dedicated Memory usage (evaluate): 0 MB Average Shared Memory usage (bind): 0 MB Average Shared Memory usage (evaluate): 0 MB
模型分支:在使用的觀察中發(fā)現(xiàn),大部分較為穩(wěn)定的場景可以使用較小的模型得到不錯的結(jié)果,所有考慮finetune LRBranch處理簡單場景,HRBranch與FusionBranch依舊用來處理復雜場景,這項工作還在進行中。
后續(xù)還會進行一些量化蒸餾的優(yōu)化嘗試,期待更好更快的結(jié)果。
華為HiLens 視頻
版權(quán)聲明:本文內(nèi)容由網(wǎng)絡(luò)用戶投稿,版權(quán)歸原作者所有,本站不擁有其著作權(quán),亦不承擔相應法律責任。如果您發(fā)現(xiàn)本站中有涉嫌抄襲或描述失實的內(nèi)容,請聯(lián)系我們jiasou666@gmail.com 處理,核實后本網(wǎng)站將在24小時內(nèi)刪除侵權(quán)內(nèi)容。