用pytorch 2.1 加速 numpy 代码

发布时间 2023-10-23 19:29:15作者: bregman
brew install libomp
python test_np.py

image

  • test_np.py 代码如下
import time
import numpy as np

def kmeans(X, means):
    return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)

t1 = time.time()
npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape)  # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
np_pred = kmeans(X, means)
t2 = time.time()
import torch
compiled_fn = torch.compile(kmeans)
t3 = time.time()
compiled_pred = compiled_fn(X, means)
t4 = time.time()
assert np.allclose(np_pred, compiled_pred)
print(t2 - t1 , '时间torch = ' ,  t4 - t3 )