[补题][数位dp][位运算] 2019牛客第七场 H Pair

题目链接

题意:给定ABC三个数,大小1e9,要求求出来1<=x<=A,1<=y<=B范围内满足(x&y)>C或(x^y)<C的(x,y)对数。

由于题目给的是或条件,所以还要考虑一个容斥的问题,先算出来单独满足两个条件的数目,再减去同时满足两个条件的数目,即可得到答案。先考虑单独计算两个条件的情况:

一般来说这种数位dp的框架都是类似的,在第i位时考虑前面的约束有没有“顶满”,如果顶满的话考虑当前这一位的限制,否则继续递归下去。这样的框架使得最后的递归边界处理变成了最有思维含量的一部分。这个题不同寻常的地方在于它是强制要求你同时考虑多个边界的,这是值得注意的一点。在这个题目里比较友好的一点是他对于与C的关系比对的时候是不等号,没有等号,这样的话,我们只要在dp计数的过程中一直保持可行性就可以了,到最后一步(pos=-1)只需要判断

1)前面是不是顶到C的头了,如果是的话就不行
2)前面是不是顶到了全0,因为题目要求x,y至少为1

当两个条件同时为否时返回1,否则返回0. 可以看出来这样的约束就要求我们在dp的过程中还要计入C是否顶到头和前缀是否为全0这两个flag。后者是分a和b约束的,所以总共多了3个flag,另外在本身的递归过程中要考虑a和b的范围约束,这又是两个flag,所以最后的dp状态写出来就是dp[pos][a1][b1][c1][za][zb]这样了,表示在第pos位的时候a,b,c,是否顶到头的flag为a1 b1 c1 ,在走的过程中两个边界前缀是否为全0的flag分别为za zb的时候,有多少种答案。(其实在考虑x&y>C的dp中这个全为0的flag是不需要的,因为不可能出现这种情况,他走的过程一定是大于嘛不可能是0的,也就是上面说的判定2一定是否定的)

然后就是两个条件都满足的情况。这时候走的时候要考虑双边界和双约束。。。。类比上面的分析就会有4个“顶头”flag,和两个前缀0flag。。。dp[pos][a1][b1][c1][c2][za][zb],然后边界也是同时考虑两个条件和前缀0,就变成了if (c1 == 0 && c2 == 0 && zero1 == 0 && zero2 == 0) return 1;了…

总的来说自己的数位dp还是有待很大的提高,emmmmmm心情复杂

#include<bits/stdc++.h>
using namespace std;
int a[31];
int b[31];
int c[31];
int d[31];
long long dp[33][2][2][2][2][2];
//在pos处,
long long dfs(int pos,int a1,int b1,int c1, int zero1, int zero2)
{
    if(pos == -1)//important boundary
    {
        if(c1 == 0)// && zero1 == 0 && zero2 == 0)
            return 1;
        else return 0;
    }
    if(dp[pos][a1][b1][c1][zero1][zero2] != -1)return dp[pos][a1][b1][c1][zero1][zero2];
    long long ans = 0;
    int enda = a1?a[pos]:1;
    int endb = b1?b[pos]:1;
    for(int i = 0;i <= enda;i++)
        for(int j = 0;j <= endb;j++)
        {
            if(c1 == 1 && (i&j) < c[pos])// c is full, impossible
                continue;
            ans += dfs(pos-1,a1 && i == enda,b1 && j==endb,c1 && (i&j) == c[pos], zero1 && i == 0, zero2 && j == 0);
        }
    return dp[pos][a1][b1][c1][zero1][zero2] = ans;
}
 
long long dfs2(int pos, int a1, int b1, int c1, int zero1, int zero2) {
    if (pos == -1) {
        if (c1 == 0 && zero1 == 0 && zero2 == 0) return 1;
        else return 0;
    }
    if (dp[pos][a1][b1][c1][zero1][zero2] != -1) return dp[pos][a1][b1][c1][zero1][zero2];
    long long ans = 0;
    int enda = a1?a[pos]:1;
    int endb = b1?b[pos]:1;
    for (int i = 0; i <= enda; i++)
        for (int j = 0; j <= endb; j++) {
            if (c1 == 1 && (i^j) > c[pos])continue;
            ans += dfs2(pos-1, a1 && i == enda, b1 && j == endb, c1 && (i^j) == c[pos], zero1 && i == 0, zero2 && j == 0);
        }
    return dp[pos][a1][b1][c1][zero1][zero2] = ans;
}
long long dp2[33][2][2][2][2][2][2];
long long dfs3(int pos, int a1, int b1, int c1, int c2, int zero1, int zero2) {
    if (pos == -1) {
        if (c1 == 0 && c2 == 0 && zero1 == 0 && zero2 == 0)return 1;
        else return 0;
    }
    if (dp2[pos][a1][b1][c1][c2][zero1][zero2] != -1) return dp2[pos][a1][b1][c1][c2][zero1][zero2];
    long long ans = 0;
    int enda = a1?a[pos]:1;
    int endb = b1?b[pos]:1;
    for (int i  = 0; i <= enda; i++)
        for (int j = 0; j <= endb; j++) {
            if (c1 == 1 && (i&j) < c[pos])continue;
            if (c2 == 1 && (i^j) > c[pos])continue;
            ans += dfs3(pos-1, a1 && i == enda, b1 && j == endb, c1 && (i&j) == c[pos], c2 && (i^j) == c[pos], zero1 && i == 0, zero2 && j == 0);
        }
    return dp2[pos][a1][b1][c1][c2][zero1][zero2] = ans;
}
 
 
int main()
{
    int T;
    int iCase = 0;
    int A,B,K;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d%d",&A,&B,&K);
        iCase++;
        for(int i = 0;i <= 30;i++)
        {
            if(A & (1<<i))
                a[i] = 1;
            else a[i] = 0;
            if(B & (1<<i))
                b[i] = 1;
            else b[i] = 0;
            if(K & (1<<i))
                c[i] = 1;
            else c[i] = 0;
        }
        memset(dp,-1,sizeof(dp));
        long long tmp1 = dfs(30, 1, 1, 1, 1, 1);
 
        memset(dp,-1,sizeof(dp));
        long long tmp2 = dfs2(30, 1, 1, 1, 1, 1);
 
        memset(dp2, -1, sizeof(dp2));
        long long tmp3 = dfs3(30, 1, 1, 1, 1, 1, 1);

        long long ans = tmp1+tmp2-tmp3;
        cout<<ans<<endl;
    }
    return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据