参考:
https://baijiahao.baidu.com/s?id=1725356123619612187&wfr=spider&for=pc
个人认为如果把Jax作为一款深度学习框架来学习,那么就没有多大的必要性,因为pytorch就够了。可以说,Jax可以做到的,pytorch也可以做到,Jax做不到的,Pytorch依旧可以做到,在深度学习的计算框架来说,pytorch确实有着Jax无法短期超越的优势(个人目前长期也不太可能超越)。
不过,如果把Jax作为一个科学计算框架,或者说把Jax看做是一个类似于numpy功能的加速框架,一种普适应用的矩阵计算框架,或者是传统机器学习的计算框架,那么Jax还是具备一定优势,毕竟在不考虑深度学习计算的前提下Jax要比pytorch更加的轻量,而唯一的不足就是Jax的稳定性有待观察,毕竟不是作为成熟产品推出的,而只是作为一个实验品推出的。
给出一个numpy和Jax的对比性能的小Demo:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import time
def fn(x):
return x+x*x+x*x*x
x=np.random.randn(10000, 10000)
a=time.time();fn(x);b=time.time()
print(b-a)
jax_fn=jit(fn)
x2=jnp.array(x)
a=time.time();jax_fn(x2).block_until_ready();b=time.time()
print(b-a)