题意:
求 $$\sum_{k=1}^{N}((kM)\& M)mod(10^9+7)$$
其中 $1 \le N \le 10^{18}$ , $1 \le M \le 10^{11}$
题解:
按位来做,考虑每一个二进制位i, 假设pi是 M 中从小到大第i位上的0/1), 然后我们又可以用取模和取整来取代某一位上的运算,比如当前处理的是第i位 ,就可以写成 $$\sum_{k=1}^{N}(\lfloor \frac{k*M}{2^i} \rfloor)-2*\sum_{k=1}^{N}(\lfloor \frac{k*M}{2^{i+1}} \rfloor)$$
这显然是一个类欧几里得的模板式子,我们只需要套用一下模板即可。最后把所有位的答案累加起来即可。
wa点:
- M太大了,要用__int128才可以。。。wa了n发才发现这个板子会爆long long
#include<bits/stdc++.h> #define inf 0x3f3f3f3f #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 using namespace std; typedef long long ll; const ll MAXN=200005; const ll INF=1e18; const ll P = 1e9+7; ll i2 = 500000004, i6 = 166666668; /* P=1e9+7,inv2=500000004,inv6=166666668; */ struct data { data() { f = g = h = 0; } __int128 f, g, h; }; // 三个函数打包 data calc(__int128 n, __int128 a, __int128 b, __int128 c) { __int128 ac = a / c, bc = b / c, m = (a * n + b) / c, n1 = n + 1, n21 = n * 2 + 1; data d; if (a == 0) // 迭代到最底层 { d.f = bc * n1 % P; d.g = bc * n % P * n1 % P * i2 % P; d.h = bc * bc % P * n1 % P; return d; } if (a >= c || b >= c) // 取模 { d.f = n * n1 % P * i2 % P * ac % P + bc * n1 % P; d.g = ac * n % P * n1 % P * n21 % P * i6 % P + bc * n % P * n1 % P * i2 % P; d.h = ac * ac % P * n % P * n1 % P * n21 % P * i6 % P + bc * bc % P * n1 % P + ac * bc % P * n % P * n1 % P; d.f %= P, d.g %= P, d.h %= P; data e = calc(n, a % c, b % c, c); // 迭代 d.h += e.h + 2 * bc % P * e.f % P + 2 * ac % P * e.g % P; d.g += e.g, d.f += e.f; d.f %= P, d.g %= P, d.h %= P; return d; } data e = calc(m - 1, c, c - b - 1, a); d.f = n * m % P - e.f, d.f = (d.f % P + P) % P; d.g = m * n % P * n1 % P - e.h - e.f, d.g = (d.g * i2 % P + P) % P; d.h = n * m % P * (m + 1) % P - 2 * e.g - 2 * e.f - d.f; d.h = (d.h % P + P) % P; return d; } __int128 f(__int128 a,__int128 b,__int128 c,__int128 n) { if(a==0)return (n+1)*(b/c)%P; if(a>=c||b>=c)return (f(a%c,b%c,c,n)+(n%P)*((n+1)%P)%P*i2%P*(a/c%P)+(n+1)%P*(b/c%P))%P; __int128 m=(a*n+b)/c; return ((n%P)*(m%P)%P-f(c,c-b-1,a,m-1))%P; } int main(){ ll n,m,ans=0; scanf("%lld%lld",&n,&m); for(int i=0;i<=50;++i){ if((m>>i)&1){ ll tmp=calc(n,m,0,(1ll<<i)).f%P; tmp=((tmp-2*(calc(n,m,0,(1ll<<(i+1))).f%P))%P+P)%P; ans=(ans+(1ll<<i)%P*tmp)%P; } } printf("%lld\n",(ans%P+P)%P); return 0; }