jax中对单步操作的缓存对性能造成的影响

发布时间 2024-01-09 11:03:14作者: Angry_Panda

代码:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

def selu(x, alpha=1.65, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000005,))
%timeit selu(x).block_until_ready()

运行结果:
image


再次运行:

image



修改array的shape:

代码:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

def selu(x, alpha=1.65, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000003,))
%timeit selu(x).block_until_ready()

运行结果:

image


再次运行:

image






PS. 由此可以看出,jax对单步运行其实也是使用缓存操作的,对单步操作也可以通过缓存来进行多次调用的速度提升的。