CF1508D 题解

发布时间 2023-07-03 09:49:15作者: yuhang-ren

CF1580D Subsequence 题解

Luogu

Codeforces

Description

从长度为 \(n\) 的序列中按顺序选择 \(m\) 个值,定义其权值为所选数和的 \(m\) 倍减去在原序列中所选的数两两之间的最小值。

首先我们看一下要求的这一个东西,对它进行一下化简。

\(\sum_{i = 1}^m (m \cdot a_{b_i}) - \sum_{i = 1}^m \sum_{j = 1}^m f(\min(b_i, b_j), \max(b_i, b_j))\)

\[\begin{aligned} \sum_{i = 1}^m (m \cdot a_{b_i}) - \sum_{i = 1}^m \sum_{j = 1}^m f(\min(b_i, b_j), \max(b_i, b_j)) &= \sum_{i = 1}^m (m \cdot a_{b_i}) - 2 \times \sum_{i = 1}^{m -1} \sum_{j = i + 1}^m f(b_{i},b_{j}) - \sum_{i = 1}^{m} a_{b_{i}} \\ &= (m - 1) \times \sum_{i = 1}^m a_{b_i} - 2 \times \sum_{i = 1}^{m - 1} \sum_{j = i + 1}^m f(b_{i},b_{j}) \end{aligned} \]

考虑到后面 \(f(i,j)\) 的部分为取区间最小值,因此我们可以构建出笛卡尔树。

利用栈构建笛卡尔树的代码:

for (int i = 1; i <= n; i++)
{
    read(nums[i]);
    st[st[0] + 1] = 0;
    while (st[0] && nums[st[st[0]]] > nums[i])
    {
        --st[0];
    }
    son[st[st[0]]][1] = i;
    son[i][0] = st[st[0] + 1];
    st[++st[0]] = i;
}

考虑利用树上 dp 求解,设 \(dp_{i,j}\) 表示在以 \(i\) 为根的子树中选取 \(j\) 个值的最大权值。\(lc\) 为树上当前节点的左孩子,\(rc\) 为树上当前节点的右孩子。

则有如下转移(没有左右节点时特判处理即可):

不选当前节点时:

\[dp_{x,i + j} = dp_{lc,i} + dp_{rc,j} - 2 \times i \times j \times nums_{x} \]

即左节点最大贡献加右节点最大贡献减去最小值的贡献。显然若左子树选了 \(i\) 个,右子树选了 \(j\) 个,则根据题意有 \(i \times j\) 个区间最小值为 \(nums_{x}\),减去即可。

选择当前节点时:

\[dp_{x,i + j + 1} = dp_{lc,i} + dp_{rc,j} - 2 \times i \times j \times nums_{x} + \left (m - 1 \right) \times nums_{x} - 2 \times \left (i + j \right) \times nums_{x} \]

前面的部分不变,在加上最小值单独的贡献即可。

该部分代码:

for (int i = 0; i <= siz[lc]; i++)
{
    for (int j = 0; j <= siz[rc]; j++)
    {
        dp[u][i + j] = max(dp[u][i + j], dp[lc][i] + dp[rc][j] - 2 * i * j * nums[u]);
    }
}
for (int i = 0; i <= siz[lc]; i++)
{
    for (int j = 0; j <= siz[rc]; j++)
    {
        dp[u][i + j + 1] = max(dp[u][i + j + 1], dp[lc][i] + dp[rc][j] + (m - 1) * nums[u] - 2 * (i * j + i + j) * nums[u]);
    }
}

dp 方程满足在 dfs 条件下的无后效性,利用 dfs 求解即可。

最后答案即为 \(dp_{root,m}\)

Codes

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define max_n 5101
void read(int &p)
{
    p = 0;
    int k = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
        {
            k = -1;
        }
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        p = p * 10 + c - '0';
        c = getchar();
    }
    p *= k;
    return;
}
void write_(int x)
{
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x > 9)
    {
        write_(x / 10);
    }
    putchar(x % 10 + '0');
}
void writesp(int x)
{
    write_(x);
    putchar(' ');
}
void writeln(int x)
{
    write_(x);
    putchar('\n');
}
int n, m, nums[max_n];
int root, siz[max_n], st[max_n], son[max_n][2], fa[max_n]; // 分别为根、子树大小、构建用的栈、左/右孩子、父亲节点

int dp[max_n][max_n];
void dfs(int u)
{
    siz[u] = 1;
    int lc = son[u][0], rc = son[u][1];
    // 没有左右孩子(即区间长度为 1),不取贡献为 0,取贡献为 (m - 1) * nums[u];
    if (!lc && !rc)
    {
        dp[u][0] = 0;
        dp[u][1] = nums[u] * (m - 1);
        return;
    }
    // 只有右孩子
    if (!lc && rc)
    {
        dfs(rc);
        siz[u] += siz[rc];
        memcpy(dp[u], dp[rc], sizeof(dp[rc]));//不取当前点直接复制过来
        // 取当前点
        for (int i = 0; i <= siz[rc]; i++)
        {
            dp[u][i + 1] = max(dp[u][i + 1], dp[rc][i] + (m - 1) * nums[u] - 2 * i * nums[u]);
        }
        return;
    }
    // 只有左孩子同上
    if (lc && !rc)
    {
        dfs(lc);
        siz[u] += siz[lc];
        memcpy(dp[u], dp[lc], sizeof(dp[lc]));
        for (int i = 0; i <= siz[lc]; i++)
        {
            dp[u][i + 1] = max(dp[u][i + 1], dp[lc][i] + (m - 1) * nums[u] - 2 * i * nums[u]);
        }
        return;
    }

    // 左右孩子都有

    dfs(lc), dfs(rc);
    siz[u] += siz[lc];
    siz[u] += siz[rc];
    for (int i = 0; i <= siz[lc]; i++)
    {
        for (int j = 0; j <= siz[rc]; j++)
        {
            dp[u][i + j] = max(dp[u][i + j], dp[lc][i] + dp[rc][j] - 2 * i * j * nums[u]);
        }
    }
    for (int i = 0; i <= siz[lc]; i++)
    {
        for (int j = 0; j <= siz[rc]; j++)
        {
            dp[u][i + j + 1] = max(dp[u][i + j + 1], dp[lc][i] + dp[rc][j] + (m - 1) * nums[u] - 2 * (i * j + i + j) * nums[u]);
        }
    }
}
signed main()
{
#if _clang_
    freopen("1.in", "r", stdin);
    freopen("1.out", "w", stdout);
#endif
    read(n), read(m);
    for (int i = 1; i <= n; i++)
    {
        read(nums[i]);
        st[st[0] + 1] = 0;
        while (st[0] && nums[st[st[0]]] > nums[i])
        {
            --st[0];
        }
        son[st[st[0]]][1] = i;
        son[i][0] = st[st[0] + 1];
        st[++st[0]] = i;
    }
    for (int i = 1; i <= n; i++)
    {
        fa[son[i][0]] = fa[son[i][1]] = i;
    }
    root = min_element(fa + 1, fa + n + 1) - fa; // 找到根节点
    dfs(root);
    writeln(dp[root][m]);
    return 0;
}