对比全分解和先验分解

发布时间 2023-12-07 20:47:08作者: 浪矢-CL

单线的

def forward(self, x):
    y = self.g_a(x)
    y_hat, y_likelihoods = self.entropy_bottleneck(y)
    x_hat = self.g_s(y_hat)

    return {
        "x_hat": x_hat,
        "likelihoods": {
            "y": y_likelihoods,
        },
    }

@classmethod
def from_state_dict(cls, state_dict):
    """Return a new model instance from `state_dict`."""
    N = state_dict["g_a.0.weight"].size(0)
    M = state_dict["g_a.6.weight"].size(0)
    net = cls(N, M)
    net.load_state_dict(state_dict)
    return net

def compress(self, x):
    y = self.g_a(x)
    y_strings = self.entropy_bottleneck.compress(y)
    return {"strings": [y_strings], "shape": y.size()[-2:]}

def decompress(self, strings, shape):
    assert isinstance(strings, list) and len(strings) == 1
    y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
    x_hat = self.g_s(y_hat).clamp_(0, 1)
    return {"x_hat": x_hat}

image

def forward(self, x):
    y = self.g_a(x)
    z = self.h_a(torch.abs(y))
    z_hat, z_likelihoods = self.entropy_bottleneck(z)
    scales_hat = self.h_s(z_hat)
    y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat)
    x_hat = self.g_s(y_hat)

    return {
        "x_hat": x_hat,
        "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
    }

@classmethod
def from_state_dict(cls, state_dict):
    """Return a new model instance from `state_dict`."""
    N = state_dict["g_a.0.weight"].size(0)
    M = state_dict["g_a.6.weight"].size(0)
    net = cls(N, M)
    net.load_state_dict(state_dict)
    return net

def compress(self, x):
    y = self.g_a(x)
    z = self.h_a(torch.abs(y))

    z_strings = self.entropy_bottleneck.compress(z)
    z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])

    scales_hat = self.h_s(z_hat)
    indexes = self.gaussian_conditional.build_indexes(scales_hat)
    y_strings = self.gaussian_conditional.compress(y, indexes)
    return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}

def decompress(self, strings, shape):
    assert isinstance(strings, list) and len(strings) == 2
    z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
    scales_hat = self.h_s(z_hat)
    indexes = self.gaussian_conditional.build_indexes(scales_hat)
    y_hat = self.gaussian_conditional.decompress(strings[0], indexes, z_hat.dtype)
    x_hat = self.g_s(y_hat).clamp_(0, 1)
    return {"x_hat": x_hat}

其区别为: