numpy代码:
import numpy as np
import time
x = np.random.random([10000, 10000]).astype(np.float32)
try:
st = time.time()
y = np.matmul(x, x)
print(time.time() - st)
print(y)
except Exception as e:
print(f"error: {e}")
Jax代码:
import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
y = np.matmul(x, x)
print(time.time() - st)
print(y)
except Exception as e:
print(f"error: {e}")
可以说,在这个例子里面,Jax和numpy的性能基本持平。