Idea & Solution : shailesh_2004
Prepared By : R.i.s.h.i_99
Required Knowledge : Bit Manipulation, Combinatorics, Digit DP
Time Complexity :
Editorialists : shailesh_2004, R.i.s.h.i_99
As we scan the number from left to right (most significant bit to least), every time we see a , we actually have two choices. We can either keep it as and continue matching the number, or we can turn it into .
If we turn it into , the remaining bits are no longer restricted. Now we can fill them in any way we want, as long as we end up with the required number of set bits. Instead of trying all possibilities one by one, we directly count how many such arrangements are possible using combinations.
So the idea is simple: while traversing the bits, whenever we decide to place instead of , we immediately add the number of valid ways to fill the remaining positions using . By doing this carefully at every step and keeping track of how many set bits we still need, we efficiently count all valid numbers without actually generating them.
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MOD = 1e9 + 7;
ll power(ll base, ll exp) {
ll res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
ll modInverse(ll n) {
return power(n, MOD - 2);
}
ll nCr(int n, int r) {
if (r < 0 || r > n) return 0;
if (r == 0 || r == n) return 1;
if (r > n / 2) r = n - r;
ll num = 1;
ll den = 1;
for (int i = 0; i < r; i++) {
num = (num * (n - i)) % MOD;
den = (den * (i + 1)) % MOD;
}
return (num * modInverse(den)) % MOD;
}
ll countValid(ll n, int k) {
if (n < 0 || k < 0) return 0;
if (n == 0) return k == 0 ? 1 : 0;
ll ans = 0;
int ones = 0;
for (int i = 61; i >= 0; i--) {
if ((n >> i) & 1) {
if (k >= ones) {
ans = (ans + nCr(i, k - ones)) % MOD;
}
ones++;
}
}
if (ones == k) {
ans = (ans + 1) % MOD;
}
return ans;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t;
if (cin >> t) {
while (t--) {
ll l, r;
int k;
cin >> l >> r >> k;
ll ans = (countValid(r, k) - countValid(l - 1, k) + MOD) % MOD;
cout << ans <<endl;
}
}
return 0;
}
We process the number bit by bit from the most significant bit to the least and try to build all valid numbers. At each position, we decide whether to place or , while keeping track of how many set bits are still needed.
A key idea is the tight condition. If we are still matching the given number, our choices are restricted by its current bit. But once we place a smaller bit, we are no longer restricted, and we can freely choose or for the remaining positions.
We define a DP state with the current index, remaining set bits, and whether we are tight or not. For each state, we try all valid choices, update the remaining set bits, and move to the next position. To make it efficient, we store results in a DP table so that each state is computed only once.
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define double long double
#define endl '\n'
#define all(v) v.begin(),v.end()
#define rall(v) v.rbegin(),v.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
template<class T>
void input(vector<T> &a) {
for(auto &e:a)
cin >> e;
}
int M=1e9+7;
int dp[61][61][2];
int fun(int idx,int k,bool tight,int n) {
if(k<0) return 0;
if(idx==-1) return k==0;
int &res=dp[idx][k][tight];
if(res!=-1) return res;
res=0;
int ub=tight?((n>>idx)&1):1;
for(int i=0;i<=ub;i++) {
res=(res+fun(idx-1,k-i,(tight&(ub==i)),n))%M;
}
return res;
}
void solve() {
int l,r,k;
cin >> l >> r >> k;
memset(dp,-1,sizeof(dp));
int res=fun(60,k,true,r);
memset(dp,-1,sizeof(dp));
if(l>0) res=(res-fun(60,k,true,l-1)+M)%M;
cout << res << endl;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int t=1;
cin >> t;
while(t--)
solve();
return 0;
}
import sys
MOD = 10**9 + 7
def modInverse(n):
return pow(n, MOD - 2, MOD)
def nCr(n, r):
if r < 0 or r > n:
return 0
if r == 0 or r == n:
return 1
if r > n // 2:
r = n - r
num = 1
den = 1
for i in range(r):
num = (num * (n - i)) % MOD
den = (den * (i + 1)) % MOD
return (num * modInverse(den)) % MOD
def countValid(n, k):
if n < 0 or k < 0:
return 0
if n == 0:
return 1 if k == 0 else 0
ans = 0
ones = 0
for i in range(61, -1, -1):
if (n >> i) & 1:
if k >= ones:
ans = (ans + nCr(i, k - ones)) % MOD
ones += 1
if ones == k:
ans = (ans + 1) % MOD
return ans
def solve():
input_data = sys.stdin.read().split()
if not input_data:
return
t = int(input_data[0])
out = []
idx = 1
for _ in range(t):
l = int(input_data[idx])
r = int(input_data[idx+1])
k = int(input_data[idx+2])
idx += 3
ans = (countValid(r, k) - countValid(l - 1, k) + MOD) % MOD
out.append(str(ans))
print('\n'.join(out))
if __name__ == '__main__':
solve()