2 solutions

  • 2
    @ 2024-2-6 19:33:47
    #include<bits/stdc++.h>
    using namespace std;
    const int N = 4e5 + 10, mod = 998244353;
    int T, n, k, p[N], q, tr[N];
    int lowbit(int x){return x&(-x);}
    void add(int x, int sup){for(int i = x; i <= sup; i += lowbit(i)) tr[i] += 1;}
    int query(int x){
    	int res = 0;
    	for(int i = x; i > 0; i -= lowbit(i)) res += tr[i];
    	return res;
    }
    void init(){memset(tr,0,sizeof(tr));}
    int abs(int x){return x < 0 ? -x : x;}
    int del_sum(int st, int ed, int del){
    	return 1ll * (st + ed) * (abs(st - ed) / del + 1) / 2 % mod;
    }
    void solve(){
    	scanf("%d %d", &n, &k);
    	for(int i = 1; i <= n; i += 1) scanf("%d", &p[i]);
    	for(int i = 1; i <= k; i += 1) tr[i] = 0;
    	int ans = 0;
    	for(int i = 1; i <= k; i += 1){
    		scanf("%d", &q);
    		ans = (ans + query(k) - query(q + 1)) % mod;
    		add(q + 1, k);
    	}
    	for(int i = 1; i <= 2 * n - 1; i += 1) tr[i] = 0;
    	ans = 1ll * ans * n % mod;
    	for(int i = 1; i <= n; i += 1){
    		int res = 0;
    		if(p[i] < 2 * n - 1)
    			res = (query(2 * n - 1) - query(p[i])) * 1ll * k % mod;
    		int l = 1, x = p[i] >> 1;
    		while(x > 0 && l < k){
    			res = (res + (query(2 * n - 1) - query(x)) * 1ll * (k - l)) % mod;
    			l += 1;
    			x >>= 1;
    		}
    		res = (res + query(2 * n - 1) * 1ll * del_sum(0, k - l, 1) % mod) % mod;
    		l = 1, x = (p[i] << 1);
    		while(x < 2 * n - 1 && l < k){
    			res = (res + (query(2 * n - 1) - query(x)) * 1ll * (k - l)) % mod;
    			l += 1;
    			x <<= 1;
    		}
    		ans = (ans + res) % mod;
    		add(p[i], 2 * n - 1);
    	}
    	printf("%d\n",ans);
    }
    int main(){
    	//freopen("count.in","r",stdin);
    	//freopen("count.out","w",stdout);
    	scanf("%d",&T);
    	while(T--) solve();
    	return 0;
    }
    

    Information

    ID
    309
    Time
    1000ms
    Memory
    256MiB
    Difficulty
    6
    Tags
    (None)
    # Submissions
    37
    Accepted
    13
    Uploaded By