:)关于torch函数中dim的解释-读这篇就够了-|

发布时间 2023-04-08 11:21:35作者: lexn

关于torch函数中dim的解释-读这篇就够了

1 dim的取值范围

1)-1的作用

  0,1,2,-1.  其中-1 最后一维 即 2

  0,1,2,3,-1其中-1 最后一维 即3

2)维度

0,1,2,3表示 BCHW,常在CV任务中使用。

0,1,2 表示 CHW, 常在NLP任务中使用。

3)用图来说明

 

 2 NLP代码中实战dim

from torch import nn
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
check = "distilbert-base-uncased-finetuned-sst-2-english"
raw_eng = ["i like this video", "i hate the food"]

tokenizers = AutoTokenizer.from_pretrained(check)
model = AutoModelForSequenceClassification.from_pretrained(check)
# 打印模型结构
print(model)
inputs = tokenizers(raw_eng,
                    # 是否pad
                    padding=True,
                    # 是否截断
                    truncation=True,
                    # 返回torch.tensor
                    return_tensors="pt")
print(inputs)
# 使用toknizers.decode来解tok id 为 英文
eng_content = tokenizers.decode([101, 1045, 2066, 2023, 2678,  102])
print(eng_content)

# 开始推理
out = model(**inputs)
print(out)
# 输出为[2,2] 前面2 为batchsize,后面2为2分类
print(out.logits.shape)

nn.functional.softmax(out.logits, dim=-1)
print("---end---")

  

out 输出为2,2 

 

需要对第一行 两个数据求softmax,概率值(置信度)

需要对第二行(样本2) 两个数据求softmax。

 

所以 softmax函数dim 应该取CHW中w, 也就是2, 为了统一方便,取-1最后一维。