遥感影像去除nodata值

发布时间 2023-04-25 17:26:55作者: 行走的蓑衣客

 

import cv2

import scipy.interpolate
import numpy as np
from osgeo import gdal


def read_img(filename):
    dataset = gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)

    del dataset
    return im_proj, im_geotrans, im_width, im_height, im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del dataset


def NoData_kill(in_path, out_path):
    im_proj, im_geotrans, im_width, im_height, im_data = read_img(in_path)
    mask = np.isnan(im_data)
    if len(mask.shape)==2:
        c=1
    else:
        c, w, h = mask.shape
    mask_list = []
    for i in range(c):
        if mask[i].__contains__(True):
            mask_list.append(mask[i])

    for m in mask_list:
        m = m + 0
        m = np.uint8(m)

        inpainted_img = cv2.inpaint(im_data, m, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
        im_data = inpainted_img
    write_img(out_path, im_proj, im_geotrans, im_data)



if __name__ == "__main__":
    in_path =r'E:\jpg_test\image\test.tif'
    out_path = r'E:\cq_test2.tif'
    NoData_kill(in_path, out_path)