一个简单的例子测试numpy和Jax的性能对比

发布时间 2024-01-03 23:21:25作者: Angry_Panda

参考:
https://baijiahao.baidu.com/s?id=1725356123619612187&wfr=spider&for=pc



个人认为如果把Jax作为一款深度学习框架来学习,那么就没有多大的必要性,因为pytorch就够了。可以说,Jax可以做到的,pytorch也可以做到,Jax做不到的,Pytorch依旧可以做到,在深度学习的计算框架来说,pytorch确实有着Jax无法短期超越的优势(个人目前长期也不太可能超越)。


不过,如果把Jax作为一个科学计算框架,或者说把Jax看做是一个类似于numpy功能的加速框架,一种普适应用的矩阵计算框架,或者是传统机器学习的计算框架,那么Jax还是具备一定优势,毕竟在不考虑深度学习计算的前提下Jax要比pytorch更加的轻量,而唯一的不足就是Jax的稳定性有待观察,毕竟不是作为成熟产品推出的,而只是作为一个实验品推出的。





给出一个numpy和Jax的对比性能的小Demo:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import numpy as np

import time


def fn(x):
    return x+x*x+x*x*x

x=np.random.randn(10000, 10000)
a=time.time();fn(x);b=time.time() 
print(b-a)


jax_fn=jit(fn)
x2=jnp.array(x)
a=time.time();jax_fn(x2).block_until_ready();b=time.time()
print(b-a)

image


image