一个简单的例子测试numpy和Jax的性能对比 (续)

发布时间 2024-01-03 23:51:39作者: Angry_Panda

相关:
一个简单的例子测试numpy和Jax的性能对比






numpy代码:

import numpy as np
import time

x = np.random.random([10000, 10000]).astype(np.float32)
try:
    st = time.time()
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: {e}")

image



Jax代码:

import jax.numpy as np
from jax import random
import time

x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
    y = np.matmul(x, x)
    print(time.time() - st)
    print(y)
except Exception as e:
    print(f"error: {e}")

image



可以说,在这个例子里面,Jax和numpy的性能基本持平。