DP 套 DP

考试的时候碰见了这玩意,Kewth 也说挺常见的。
学的时候不太顺利,怕自己以后又忘了,写篇博客合影留念讲讲自己的理解。


概述

我们经常会碰到这样的问题:求有多少种满足某个限制的元素。
一般的 DP 题目中,是否满足这个限制会比较容易判断。但有的时候,即使我们给定了元素 xx ,判定它是否满足限制也需要 DP ,而不能直接得到。
比如限制「字符串 x 与字符串 y 的最长公共子序列长度恰好为 kk」。这个时候就比较难计数。
事实上,我们不要忘记了,DP 最重要的就是「状态」,而状态是可以压缩的。
压缩旧 DP 状态成新状态,添加新元素时先判断旧状态如何转移,再将更新后的旧状态重新压缩成新状态,计数原来的新状态对更新的新状态的贡献,这就是 DP 套 DP 。

例题

T1 [TJOI2018]游园会

Description

给定一个字符串 ss 和正整数 nn ,对于任意的 i[0,s]i \in [0,|s|] ,问有多少个字符串 tt ,满足 t=n|t|=nsstt 的最长公共子序列长度为 iisstt 均由三个字母 N O I 组成, tt 中不能有子串 NOI
n103,k15n\leq 10^3,k\leq 15

Solution

先考虑内层状态。
给定字符串 tts 的最长公共子序列(LCS)是可以 DP 出来的。
f[i][j] 表示 tt 的前 ii 位和 ss 的前 jj 位的 LCS 即可,这个 DP 很基础。
可以发现,当 ii+1i\rightarrow i+1 时, f[i][j] 最多加 11 。我们将 f[i][j] 的第二维差分,那么第二维是一个长度为 15150/10/1 数组。可以状压。
F[i][k] 表示枚举到 t 的前 ii 位,此时 f[i][j] 中的第二维状态的差分数组为 kk 的方案数。转移时枚举 t 的下一位字母,暴力将 kk 展开成内层 DP 式,在内层 DP 式上转移,完成后再压缩内层 DP 式即可。
避免子串 NOI 出现只需要再加一维就可以了。

Code

const int K = 30, N = 1e3 + 5, S = 1 << 15, mod = 1e9 + 7;
using namespace std;
int n, k;
char s[K];
int f[2][S][3], ans[N];
int g[K], h[K], sz[S];

void decode(int v){
    for(int i = 1; i <= k; ++i)
        g[i] = v >> (i - 1) & 1;
    for(int i = 1; i <= k; ++i)
        g[i] += g[i-1];
}
int  encode(){
    int ret = 0;
    for(int i = 1; i <= k; ++i)
        ret |= (h[i] - h[i - 1]) << (i - 1);
    return ret;
}
void trans(int a, int b, char c, int st, int v){
    decode(st);
    for(int i = 1; i <= k; ++i){
        h[i] = max(g[i], h[i - 1]);
        if(c == s[i])
            h[i] = max(h[i], g[i - 1] + 1);
    }
    int now = encode();
    (f[a][now][b] += v) %= mod;
}
int main(){

    scanf("%d%d", &n, &k);
    scanf("%s", s + 1);
    int tot = 1 << k;
    f[0][0][0] = 1;
    for(int i = 0; i < n; ++i){
        int p = i & 1, q = p ^ 1;
        memset(f[q], 0, sizeof f[q]);
        for(int j = 0; j < tot; ++j){
            int *t = f[p][j];
            if(t[0]){
                trans(q, 1, 'N', j, t[0]);
                trans(q, 0, 'O', j, t[0]);
                trans(q, 0, 'I', j, t[0]);
            }
            if(t[1]){
                trans(q, 1, 'N', j, t[1]);
                trans(q, 2, 'O', j, t[1]);
                trans(q, 0, 'I', j, t[1]);
            }
            if(t[2]){
                trans(q, 1, 'N', j, t[2]);
                trans(q, 0, 'O', j, t[2]);
            }
        }
    }
    for(int i = 0; i < tot; ++i) sz[i] = sz[i >> 1] + (i & 1);
    for(int i = 0; i < tot; ++i)
    for(int j = 0; j < 3; ++j)
        (ans[sz[i]] += f[n & 1][i][j]) %= mod;
    for(int i = 0; i <= k; ++i)
        w(ans[i]);
    return 0;
}

T2 count

Description

给定 Lx,Rx,Ly,Ry,TL_x,R_x,L_y,R_y,T ,求 {x and yLxxRxLyyRyx or y=T}\{x\text{ and }y |L_x\leq x\leq R_x\cap L_y\leq y\leq R_y\cap x\text{ or } y = T\} 的大小。

Solution

先考虑内层状态。
给定 kk ,判定能否找到 (x,y)(x,y) 满足 x and y=kx\text{ and }y =k 。这个东西可以通过数位 DP 来实现。
fi,0/1,0/1,0/1,0/1f_{i,0/1,0/1,0/1,0/1} 表示到第 ii 位时,某种压上下界的情况是否存在(即 ff 的值只有 0/10/1 )。ff 的转移可以通过枚举 xxyy 的当前位来实现(因为交和与都是位独立的,可以逐位判断)。ff 的后四维状态显然可以状压成一个 0150\sim 15 的数来表示。
我们令 aj=fi,ja_j=f_{i,j} ,显然 aa 序列是 16160/10/1 变量,它们的所有情况囊括了枚举到第 ii 位时 fi,jf_{i,j} 的所有情况。所以这个 aa 也是可以状压的。
于是我们对 kk 进行计数。设 Fi,sF_{i,s} 表示到第 ii 位时,所有满足第 ssaa 序列的合法的 kk 的数量。转移时分第 ii 位是 0/10/1 分类即可。
总结一下:先设计好内层 DP 在做什么,然后对内层 DP 的状态进行压缩,使得若干次内层 DP 同时进行。

Code

const int K = 60, S = (1 << 16);
using namespace std;
bool g[K], h[2][K];
long long f[K + 1][S];
long long t, lx, rx, ly, ry;
int main(){

    r(t, lx, rx, ly, ry);

    f[K][1 << 15] = 1;
    
    for(int i = K - 1; i >= 0; --i)
    for(int s = 0; s < S; ++s){
        if(!f[i + 1][s]) continue;
        for(int j = 0; j < 16; ++j){
            g[j] = s >> j & 1;
            h[0][j] = 0;
            h[1][j] = 0;
        }
        for(int x = 0; x < 2; ++x)
        for(int y = 0; y < 2; ++y){
            if((x | y) != (t >> i & 1)) continue;
            int test = 
                (x < (lx >> i & 1)) << 0 |
                (x > (rx >> i & 1)) << 1 |
                (y < (ly >> i & 1)) << 2 |
                (y > (ry >> i & 1)) << 3 ;
            for(int j = 0; j < 16; ++j){
                if(!g[j]) continue;
                if((test & j)) continue;
                int k =
                    (x == (lx >> i & 1)) << 0 |
                    (x == (rx >> i & 1)) << 1 |
                    (y == (ly >> i & 1)) << 2 |
                    (y == (ry >> i & 1)) << 3 ;
                h[x & y][k & j] |= 1;
            }
        }
        int T0 = 0, T1 = 0;
        for(int j = 0; j < 16; ++j){
            T0 |= h[0][j] << j;
            T1 |= h[1][j] << j;
        }
        f[i][T0] += f[i + 1][s];
        f[i][T1] += f[i + 1][s];
    }
    long long ans = 0;
    for(int s = 1; s < S; ++s)
        ans += f[0][s];
    w(ans);
    return 0;
}