U2-Net 预测函数

发布时间 2023-04-28 13:21:35作者: ~逍遥子~

包含单个图片检测以及视频检测

import os
import time

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import subprocess
from torchvision.transforms import transforms

from src import u2net_full
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()

# 获取GPU相关信息
def get_gpu_info():
    try:
        cmd_out = subprocess.check_output('nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader',
                                          shell=True)
        gpu_info = cmd_out.decode().strip().split('\n')
        gpu_info = [info.split(', ') for info in gpu_info]
        return gpu_info
    except subprocess.CalledProcessError as e:
        print("Error while invoking nvidia-smi: ", e)
        return None

# 打印 GPU 型号及占用情况
def print_gpu_usage():
    gpu_info = get_gpu_info()
    if gpu_info:
        total_memory = 0
        used_memory = 0
        for name, used, total in gpu_info:
            used_memory += int(used.strip().split()[0])
            total_memory += int(total.strip().split()[0])
            memory_usage_percent = round(used_memory / total_memory * 100, 2)
            print(f"GPU: {name.strip()}, Memory used: {used.strip()}, Memory total: {total.strip()}"
                  f", Memory usage: {memory_usage_percent}%")

# 将原图像与分割后的图像混合
def Image_Blend(src, res):
    info = res.shape
    height = info[0]
    width = info[1]
    dst = np.zeros((height, width, 3), np.uint8)
    # 分割后的图换色
    mask = ~(res == [0, 0, 0]).all(axis=2)
    res[mask] = [0, 0, 255]
    dst = res
    # 2.图像混合
    img = cv2.addWeighted(src, 0.8, dst, 0.2, 0, dtype=cv2.CV_8UC3)
    return img

# 读取 Gpu 信息
def gpu_info() -> str:
    info = ''
    for id in range(torch.cuda.device_count()):
        p = torch.cuda.get_device_properties(id)
        info += f'CUDA:{id} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n'
    return info[:-1]

# 单单检测图片
def pic_predict(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)
        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img,seg_img)
    cv2.imwrite("result/pred_result11.png", cv2.cvtColor(img_res.astype(np.uint8), cv2.COLOR_RGB2BGR))

# 视频检测
def video_pre(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)

        # 打印GPU占用信息
        print_gpu_usage()

        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img, seg_img)
    return img_res

def main():
    weights_path = "model_best.pth"
    img_path = "test/video.mp4"
    threshold = 0.5

    # 判断图片路径是否正确
    assert os.path.exists(img_path), f"image file {img_path} dose not exists."
    # 判断 Gpu 是否可用
    if torch.cuda.is_available():
        print(gpu_info())
    # 设置硬件 根据Gpu 是否可用,来选择用GPU 还是 CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(320),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # 载入模型
    model = u2net_full()
    weights = torch.load(weights_path, map_location='cpu')
    if "model" in weights:
        model.load_state_dict(weights["model"])
    else:
        model.load_state_dict(weights)
    model.to(device)
    model.eval()

    str = os.path.splitext(img_path)[-1]
    if str == ".mp4":
        print("视频读入")
        # 获取视频部分
        cap = cv2.VideoCapture(img_path)
        i = 0
        # 2、获取图像的属性(宽和高),并将其转化为整数
        frame_width = int(cap.get(3))
        frame_height = int(cap.get(4))
        # 3、创建保存视频的对象,设置编码格式、帧率、图像的宽高等
        out = cv2.VideoWriter('result/OutPut2.avi', cv2.VideoWriter_fourcc(*'FFV1'), 30,
                              (frame_width, frame_height))

        while (cap.isOpened()):
            # 4、获取每一帧图像
            ret, frame = cap.read()
            img = frame
            i += 1
            start_time = time.time()  # 开始处理一帧图片的时间
            img_res = video_pre(threshold, device, data_transform, img, model)
            # 5、将每一帧图像写入到输出文件中
            if ret == True:
                out.write(img_res)
            else:
                break
            end_time = time.time()
            cost_time = end_time - start_time
            print("检测第 {} 帧花了 {:.8f}s 。".format(i, cost_time))
        cap.release()
        out.release()
        cv2.destroyAllWindows()
    elif str == ".jpg":
        print("图片读入")
        start_time = time.time()  # 开始处理一帧图片的时间
        origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        pic_predict(threshold, device, data_transform, origin_img, model)
        end_time = time.time()
        cost_time = end_time - start_time
        print("检测一张图片花了 {:.8f}s 。".format(cost_time))
    else:
        print("请重新读入图片或者视频")

if __name__ == '__main__':
    main()