Jax计算框架的NamedSharding的reshape —— namedsharding-gives-a-way-to-express-shardings-with-names

发布时间 2024-01-07 19:26:09作者: Angry_Panda

官方文档参考:
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names



本篇post的主要讲解的是:
jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))

jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))
的不同:


主机的四个CPU情况:

代码:

import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,)))

# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(2, 2))
jax.debug.visualize_array_sharding(y)

运行结果:

image




代码:

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))

from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1))

devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)


y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:
image



代码:

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))

from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1))

devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)


y = jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:

image