[类欧几里得][位运算] 2019牛客暑期多校训练营(第九场)I KM and M

题目链接

题意:

求 $$\sum_{k=1}^{N}((kM)\& M)mod(10^9+7)$$

其中 $1 \le N \le 10^{18}$ , $1 \le M \le 10^{11}$

题解:

按位来做,考虑每一个二进制位i, 假设pi是 M 中从小到大第i位上的0/1), 然后我们又可以用取模和取整来取代某一位上的运算,比如当前处理的是第i位 ,就可以写成 $$\sum_{k=1}^{N}(\lfloor \frac{k*M}{2^i} \rfloor)-2*\sum_{k=1}^{N}(\lfloor \frac{k*M}{2^{i+1}} \rfloor)$$
这显然是一个类欧几里得的模板式子,我们只需要套用一下模板即可。最后把所有位的答案累加起来即可。

wa点:

  • M太大了,要用__int128才可以。。。wa了n发才发现这个板子会爆long long
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
typedef long long ll;
const ll MAXN=200005;
const ll INF=1e18;
 
const ll P = 1e9+7;
ll i2 = 500000004, i6 = 166666668;
/*
P=1e9+7,inv2=500000004,inv6=166666668;
*/
struct data {
    data() { f = g = h = 0; }
    __int128 f, g, h;
};  // 三个函数打包
data calc(__int128 n, __int128 a, __int128 b, __int128 c) {
    __int128 ac = a / c, bc = b / c, m = (a * n + b) / c, n1 = n + 1, n21 = n * 2 + 1;
    data d;
    if (a == 0)  // 迭代到最底层
    {
        d.f = bc * n1 % P;
        d.g = bc * n % P * n1 % P * i2 % P;
        d.h = bc * bc % P * n1 % P;
        return d;
    }
    if (a >= c || b >= c)  // 取模
    {
        d.f = n * n1 % P * i2 % P * ac % P + bc * n1 % P;
        d.g = ac * n % P * n1 % P * n21 % P * i6 % P + bc * n % P * n1 % P * i2 % P;
        d.h = ac * ac % P * n % P * n1 % P * n21 % P * i6 % P +
              bc * bc % P * n1 % P + ac * bc % P * n % P * n1 % P;
        d.f %= P, d.g %= P, d.h %= P;
 
        data e = calc(n, a % c, b % c, c);  // 迭代
 
        d.h += e.h + 2 * bc % P * e.f % P + 2 * ac % P * e.g % P;
        d.g += e.g, d.f += e.f;
        d.f %= P, d.g %= P, d.h %= P;
        return d;
    }
    data e = calc(m - 1, c, c - b - 1, a);
    d.f = n * m % P - e.f, d.f = (d.f % P + P) % P;
    d.g = m * n % P * n1 % P - e.h - e.f, d.g = (d.g * i2 % P + P) % P;
    d.h = n * m % P * (m + 1) % P - 2 * e.g - 2 * e.f - d.f;
    d.h = (d.h % P + P) % P;
    return d;
}
 
__int128 f(__int128 a,__int128 b,__int128 c,__int128 n)
{
    if(a==0)return (n+1)*(b/c)%P;
    if(a>=c||b>=c)return (f(a%c,b%c,c,n)+(n%P)*((n+1)%P)%P*i2%P*(a/c%P)+(n+1)%P*(b/c%P))%P;
    __int128 m=(a*n+b)/c;
    return ((n%P)*(m%P)%P-f(c,c-b-1,a,m-1))%P;
}
 
int main(){
    ll n,m,ans=0;
    scanf("%lld%lld",&n,&m);
    for(int i=0;i<=50;++i){
        if((m>>i)&1){
            ll tmp=calc(n,m,0,(1ll<<i)).f%P;
            tmp=((tmp-2*(calc(n,m,0,(1ll<<(i+1))).f%P))%P+P)%P;
            ans=(ans+(1ll<<i)%P*tmp)%P;
        }
    }
    printf("%lld\n",(ans%P+P)%P);
    return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据