SSRS
#include <bits/stdc++.h>
using namespace std;
const int LOG = 18;
const long long MOD = 998244353;
vector<long long> xor_fwt(vector<long long> A, bool inv){
int N = A.size();
for (int i = 1; i < N; i <<= 1){
for (int j = 0; j < N; j++){
if ((j & i) == 0){
long long x = A[j];
long long y = A[j | i];
A[j] = (x + y) % MOD;
A[j | i] = (x + MOD - y) % MOD;
if (inv){
A[j] *= (MOD + 1) / 2;
A[j] %= MOD;
A[j | i] *= (MOD + 1) / 2;
A[j | i] %= MOD;
}
}
}
}
return A;
}
vector<long long> xor_convolution(vector<long long> A, vector<long long> B){
int N = A.size();
A = xor_fwt(A, false);
B = xor_fwt(B, false);
vector<long long> C(N);
for (int i = 0; i < N; i++){
C[i] = A[i] * B[i] % MOD;
}
C = xor_fwt(C, true);
return C;
}
int main(){
int N, K;
cin >> N >> K;
vector<int> A(K);
for (int i = 0; i < K; i++){
cin >> A[i];
}
vector<long long> f(1 << 16, 0);
for (int i = 0; i < K; i++){
f[A[i]]++;
}
vector<vector<long long>> dp(LOG);
dp[0] = f;
for (int i = 0; i < LOG - 1; i++){
dp[i + 1] = xor_convolution(dp[i], dp[i]);
}
vector<vector<long long>> dp2(LOG);
dp2[0] = vector<long long>(1 << 16, 0);
dp2[0][0] = 1;
for (int i = 0; i < LOG - 1; i++){
vector<long long> tmp = dp[i];
tmp[0]++;
dp2[i + 1] = xor_convolution(tmp, dp2[i]);
}
vector<long long> ans(1 << 16, 0);
vector<long long> curr(1 << 16, 0);
curr[0] = 1;
N++;
for (int i = LOG - 1; i >= 0; i--){
if ((N >> i & 1) == 1){
vector<long long> tmp = xor_convolution(dp2[i], curr);
for (int j = 0; j < (1 << 16); j++){
ans[j] += tmp[j];
ans[j] %= MOD;
}
curr = xor_convolution(curr, dp[i]);
}
}
long long S = 0;
for (int i = 1; i < (1 << 16); i++){
S += ans[i];
}
S %= MOD;
cout << S << endl;
}
MoRanSky
// Skyqwq
#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long LL;
template <typename T> void chkMax(T &x, T y) { if (y > x) x = y; }
template <typename T> void chkMin(T &x, T y) { if (y < x) x = y; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N = 2e5 + 5, P = 998244353;
int n, m, k, a[N];
int inline power(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = (LL)res * a % P;
a = (LL)a * a % P;
b >>= 1;
}
return res;
}
void inline XOR(int n, int a[], int o) {
for (int w = 1; w < n; w <<= 1)
for (int i = 0; i < n; i += (w << 1))
for (int j = 0; j < w; j++) {
int u = a[i + j], v = a[i + j + w];
a[i + j] = ((LL)u + v + P) * o % P;
a[i + j + w] = ((LL)u - v + P) * o % P;
}
}
int inline sum(int q, int n) {
if ((q - 1) % P == 0) return n;
return 1ll * q * (power(q, n) - 1 + P) % P * power(q - 1, P - 2) % P;
}
int inv2 = power(2, P - 2);
int main() {
read(m), read(k); n = 1 << 16;
while (k--) {
int x; read(x);
a[x]++;
}
XOR(n, a, 1);
for (int i = 0; i < n; i++) {
a[i] = sum(a[i], m);
}
XOR(n, a, inv2);
int ans = 0;
for (int i = 1; i < n; i++) {
(ans += a[i]) %= P;
}
printf("%d\n", ans);
return 0;
}
Huah
#include <bits/stdc++.h>
typedef unsigned long long ull;
typedef long long ll;
#define inf 0x3f3f3f3f
#define rep(i, l, r) for (int i = l; i <= r; i++)
#define nep(i, r, l) for (int i = r; i >= l; i--)
void sc(int &x) { scanf("%d", &x); }
void sc(int &x, int &y) { scanf("%d%d", &x, &y); }
void sc(int &x, int &y, int &z) { scanf("%d%d%d", &x, &y, &z); }
void sc(ll &x) { scanf("%lld", &x); }
void sc(ll &x, ll &y) { scanf("%lld%lld", &x, &y); }
void sc(ll &x, ll &y, ll &z) { scanf("%lld%lld%lld", &x, &y, &z); }
void sc(char *x) { scanf("%s", x); }
void sc(char *x, char *y) { scanf("%s%s", x, y); }
void sc(char *x, char *y, char *z) { scanf("%s%s%s", x, y, z); }
void out(int x) { printf("%d\n", x); }
void out(ll x) { printf("%lld\n", x); }
void out(int x, int y) { printf("%d %d\n", x, y); }
void out(ll x, ll y) { printf("%lld %lld\n", x, y); }
void out(int x, int y, int z) { printf("%d %d %d\n", x, y, z); }
void out(ll x, ll y, ll z) { printf("%lld %lld %lld\n", x, y, z); }
void ast(ll x,ll l,ll r){assert(x>=l&&x<=r);}
using namespace std;
const int N=4e5+5,mod=998244353;
int n,k;
ll a[N],c[N],b[N],d[N];
ll qpow(ll a,ll n)
{
ll ans=1;
for(;n;n>>=1,a=a*a%mod)
if(n&1) ans=ans*a%mod;
return ans;
}
void fwt_xor(ll f[N],ll opt,int n)
{
for(int l=2,k=1;l<=n;l<<=1,k<<=1)
for(int i=0;i<n;i+=l)
for(int j=0;j<k;j++)
{
f[i+j]=(f[i+j]+f[i+j+k])%mod;
f[i+j+k]=(f[i+j]+mod-f[i+j+k]+mod-f[i+j+k])%mod;
f[i+j]=f[i+j]*opt%mod;
f[i+j+k]=f[i+j+k]*opt%mod;
}
}
void sol(int cas)
{
sc(n,k);
rep(i,1,k)
{
int x;sc(x);
a[x]++;
}
fwt_xor(a,1,1<<16);
rep(i,0,(1<<16)-1) c[i]=1;
rep(i,0,(1<<16)-1)
if(a[i]!=1) b[i]=1ll*a[i]*(mod+1-qpow(a[i],n))%mod*qpow(mod+1-a[i],mod-2)%mod;
else b[i]=1ll*a[i]*n%mod;
// rep(j,1,n)
// {
// rep(i,0,(1<<16)-1) c[i]=1ll*c[i]*a[i]%mod,d[i]=(d[i]+c[i])%mod;
// }
// rep(i,0,(1<<16)-1) if(b[i]!=d[i]) cout<<i<<' '<<b[i]<<' '<<d[i]<<endl;
fwt_xor(b,mod-mod/2,1<<16);
ll ans=0;
for(int i=1,j=k;i<=n;i++)
{
ans=(ans+j)%mod;
j=1ll*j*k%mod;
}
ans=(ans+mod-b[0])%mod;
out(ans);
}
int main()
{
// freopen("1.in", "r",stdin);
// freopen("1.out", "w", stdout);
srand(time(0));
int t=1,cas=0;
// scanf("%d",&t);
while(t--)
{
sol(++cas);
}
}
/*
befor submit code check:
freopen
size of N
mod
debug output
*/