pytorch F.grid_sample

发布时间 2023-10-07 19:22:19作者: 无左无右
import torch
from torch.nn import functional as F

inp = torch.ones(1, 1, 4, 4)
inp = torch.randint(1, 10, (1, 1, 4, 4)).float()

# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
 # grid的生成方式等价于用mesh_grid
new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
grid = grid.unsqueeze(0) #[1, 20, 20, 2]

outp = F.grid_sample(inp, grid=grid, mode='bilinear')
print(outp.shape)  #torch.Size([1, 1, 20, 20])

print(inp)

print(outp)
torch.Size([1, 1, 20, 20])
tensor([[[[3., 1., 4., 6.],
          [4., 8., 6., 1.],
          [8., 4., 1., 9.],
          [6., 6., 8., 5.]]]])
tensor([[[[0.7500, 1.0658, 1.3816, 1.5658, 1.6711, 1.7763, 1.8816, 1.9868,
           2.3684, 2.7895, 3.2105, 3.6316, 3.9737, 3.7632, 3.5526, 3.3421,
           3.1316, 2.7632, 2.1316, 1.5000],
          [1.0658, 1.5145, 1.9633, 2.2251, 2.3747, 2.5242, 2.6738, 2.8234,
           3.3657, 3.9640, 4.5623, 5.1607, 5.6468, 5.3476, 5.0485, 4.7493,
           4.4501, 3.9266, 3.0291, 2.1316],
          [1.3816, 1.9633, 2.5450, 2.8843, 3.0783, 3.2722, 3.4661, 3.6600,
           4.3629, 5.1385, 5.9141, 6.6898, 7.3199, 6.9321, 6.5443, 6.1565,
           5.7687, 5.0900, 3.9266, 2.7632],
          [1.3684, 1.9446, 2.5208, 2.9723, 3.3490, 3.7258, 4.1025, 4.4792,
           5.0693, 5.6898, 6.3102, 6.9307, 7.4349, 7.1247, 6.8144, 6.5042,
           6.1939, 5.5263, 4.2632, 3.0000],
          [1.1579, 1.6454, 2.1330, 2.7175, 3.3601, 4.0028, 4.6454, 5.2881,
           5.6011, 5.8670, 6.1330, 6.3989, 6.6150, 6.4820, 6.3490, 6.2161,
           6.0831, 5.5263, 4.2632, 3.0000],
          [0.9474, 1.3463, 1.7452, 2.4626, 3.3712, 4.2798, 5.1884, 6.0970,
           6.1330, 6.0443, 5.9557, 5.8670, 5.7950, 5.8393, 5.8837, 5.9280,
           5.9723, 5.5263, 4.2632, 3.0000],
          [0.7368, 1.0471, 1.3573, 2.2078, 3.3823, 4.5568, 5.7313, 6.9058,
           6.6648, 6.2216, 5.7784, 5.3352, 4.9751, 5.1967, 5.4183, 5.6399,
           5.8615, 5.5263, 4.2632, 3.0000],
          [0.5263, 0.7479, 0.9695, 1.9529, 3.3934, 4.8338, 6.2742, 7.7147,
           7.1967, 6.3989, 5.6011, 4.8033, 4.1551, 4.5540, 4.9529, 5.3518,
           5.7507, 5.5263, 4.2632, 3.0000],
          [0.7763, 1.1032, 1.4301, 2.3525, 3.6323, 4.9121, 6.1918, 7.4716,
           6.8608, 5.9799, 5.0990, 4.2181, 3.5242, 4.1392, 4.7542, 5.3691,
           5.9841, 5.8657, 4.5249, 3.1842],
          [1.0921, 1.5519, 2.0118, 2.8456, 3.9037, 4.9619, 6.0201, 7.0783,
           6.4010, 5.4758, 4.5506, 3.6253, 2.9204, 3.7569, 4.5935, 5.4301,
           6.2666, 6.2535, 4.8241, 3.3947],
          [1.4079, 2.0007, 2.5935, 3.3386, 4.1752, 5.0118, 5.8483, 6.6849,
           5.9411, 4.9716, 4.0021, 3.0325, 2.3165, 3.3747, 4.4328, 5.4910,
           6.5492, 6.6413, 5.1233, 3.6053],
          [1.7237, 2.4494, 3.1752, 3.8317, 4.4467, 5.0616, 5.6766, 6.2916,
           5.4813, 4.4675, 3.4536, 2.4398, 1.7126, 2.9924, 4.2722, 5.5519,
           6.8317, 7.0291, 5.4224, 3.8158],
          [2.0263, 2.8795, 3.7327, 4.2916, 4.6738, 5.0561, 5.4384, 5.8206,
           5.0104, 4.0298, 3.0492, 2.0686, 1.3871, 2.7999, 4.2126, 5.6253,
           7.0381, 7.2957, 5.6281, 3.9605],
          [2.2368, 3.1787, 4.1205, 4.5187, 4.5907, 4.6627, 4.7348, 4.8068,
           4.4619, 4.0575, 3.6530, 3.2486, 3.0104, 3.9356, 4.8608, 5.7860,
           6.7112, 6.7140, 5.1794, 3.6447],
          [2.4474, 3.4778, 4.5083, 4.7458, 4.5076, 4.2694, 4.0312, 3.7929,
           3.9134, 4.0852, 4.2569, 4.4287, 4.6337, 5.0713, 5.5090, 5.9467,
           6.3843, 6.1323, 4.7306, 3.3289],
          [2.6579, 3.7770, 4.8961, 4.9730, 4.4245, 3.8760, 3.3276, 2.7791,
           3.3650, 4.1129, 4.8608, 5.6087, 6.2569, 6.2071, 6.1572, 6.1073,
           6.0575, 5.5506, 4.2819, 3.0132],
          [2.8684, 4.0762, 5.2839, 5.2001, 4.3414, 3.4827, 2.6240, 1.7652,
           2.8165, 4.1406, 5.4647, 6.7888, 7.8802, 7.3428, 6.8054, 6.2680,
           5.7306, 4.9688, 3.8331, 2.6974],
          [2.7632, 3.9266, 5.0900, 4.9204, 3.9508, 2.9813, 2.0118, 1.0422,
           2.2784, 3.8296, 5.3809, 6.9321, 8.1925, 7.4169, 6.6413, 5.8657,
           5.0900, 4.2417, 3.2722, 2.3026],
          [2.1316, 3.0291, 3.9266, 3.7957, 3.0478, 2.2999, 1.5519, 0.8040,
           1.7576, 2.9543, 4.1510, 5.3476, 6.3199, 5.7216, 5.1233, 4.5249,
           3.9266, 3.2722, 2.5242, 1.7763],
          [1.5000, 2.1316, 2.7632, 2.6711, 2.1447, 1.6184, 1.0921, 0.5658,
           1.2368, 2.0789, 2.9211, 3.7632, 4.4474, 4.0263, 3.6053, 3.1842,
           2.7632, 2.3026, 1.7763, 1.2500]]]])

Process finished with exit code 0

Pytorch grid_sample解析
https://blog.csdn.net/xingye_fan/article/details/121852084

PyTorch中grid_sample的使用及说明_python
https://www.ab62.cn/article/35103.html