2021 ICPC EC 网络预选赛第二场 K.Meal

$n$ 个人选 $n$ 个物品, 每个人选一个. 每个人对物品有一个喜爱值 $a_{ij}$. 从第一个人开始, 在没有被选的物品中随机选择一个物品 $j$, 概率为 $\frac{a_{ij}}{\sum_{k\in S} a_{ik}}$, 其中 $S$ 是剩余物品集合. 问每个人选到每个物品的概率, 答案对 $998244353$ 取模.

$1 \le n \le 20, 1 \le a_{ij} \le 100$

学好概率论, 走遍天下都不怕(雾)

根据全概率公式, 一个很自然的思想是: 设 $A_{ij}$ 为事件 “第 $i$ 个人选到第 $j$ 个物品”, 那么有:

$$\begin{aligned} P(A_{2j}) &= \sum_{k \ne j} P(A_{2j} | A_{1k}) \cdot P(A_{1k}) \\ P(A_{3j}) &= \sum_{k \ne l \ne j} P(A_{3j} | A_{2k}A_{1l}) \cdot P(A_{2k} | A_{1l}) \cdot P(A_{1l}) \\ \cdots \end{aligned}$$

然后这玩意相当于枚举全排列……

考察 $P(A_{3j})$ 的某些项, 可以发现 $P(A_{3j} | A_{2k}A_{1l})$ 和 $P(A_{3j} | A_{2l}A_{1k})$ 是一样的, 这是重复计算. 想象更后面的人, 重复项会越来越多, 累计到阶乘的复杂度. 所以, 这是优化算法的一个突破口.

于是尝试重新对事件做一个划分: 设 $B_{is}$ 为随机事件 “前 $i$ 个人选择的物品集合为 $s$”. 那么有:

$$P(A_{ij}) = \sum_{x \in s} P(A_{ij} | B_{i-1, s - \{x\}}) \cdot P(B_{i-1, s - \{x\}})$$

那么计算所有 $P(A_{ij})$ 的总枚举量, 就是 $O(n \cdot 2^n)$.

然后考虑怎么计算 $P(B_{is})$.

这玩意显然不能直接算, 所以大抵能猜是个 dp, 有了上一步的经验, 来看看如何对事件 $B_{is}$ 进行一个划分, 得到有关联的全概率公式, 然后 dp 求解. 可以发现:

$$P(B_{is}) = \sum_{x \in s} P(B_{is} | B_{i-1, s - \{x\}}) \cdot P(B_{i-1, s - \{x\}})$$

这个公式和前面那个一样, 总枚举量为 $O(n \cdot 2^n)$.

公式里的 $P(A_{ij} | B_{i-1, s - \{x\}})$ 和 $P(B_{is} | B_{i-1, s - \{x\}})$ 其实是已知的, 就是第 $i$ 个人在 $s$ 被选的条件下, 选到 $j$ 的概率, 也就是 $\frac{a_{ij}}{\sum_{k \not \in s} a_{ik}}$.

然后两个方程一起 dp. 基本上就没了.

细节是要先计算一下之前那个式子的分母, 在 $O(n \cdot 2^n)$ 的时间内可预处理. 然后因为要算分数取模, 逆元不能每次求都用费马小, 否则复杂度稳定多一个 $O(\log P)$. 注意到分母最大也就 $n$ 个 $a$, 而 $a$ 只有 $100$. 所以可以预处理出 $20 \cdot 100$ 的逆元. 这里费马小或者线性推都可, 复杂度不会叠上去.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

int count(int x) {
  int cnt = 0;
  while (x) {
    cnt++;
    x ^= lowbit(x);
  }
  return cnt;
}

int inv[MAXM];
void init_inv(int n) {
  inv[0] = inv[1] = 1;
  for (int i = 2; i <= n; i++)
    inv[i] = mult(P - P/i, inv[P%i]);
}

int n, a[MAXN][MAXN], sum[MAXN][MAXS], dp[MAXS], ans[MAXN][MAXN];
VI S[MAXN];

int main() {
  scanf("%d", &n);
  init_inv(2005);
  for (int i = 1; i <= n; i++)
    for (int j = 0; j < n; j++)
      scanf("%d", &a[i][j]);

  for (int s = 0; s < 1 << n; s++)
    S[count(s)].emplace_back(s);

  for (int i = 1; i <= n; i++) {
    sum[i][0] = 0;
    for (int c = 1; c <= n; c++)
      for (auto s : S[c]) {
        int s0 = s ^ lowbit(s);
        sum[i][s] = sum[i][s0] + a[i][__lg(lowbit(s))];
      }
  }

  dp[0] = 1;
  for (int i = 1; i <= n; i++)
    for (int s : S[i])
      for (int j = 0; j < n; j++)
        if (s & (1 << j)) {
          int s0 = s ^ (1 << j);
          int inv_s0 = (~s0) & ((1 << n) - 1);
          int cur = mult(dp[s0], mult(a[i][j], inv[sum[i][inv_s0]]));
          dp[s] = pls(dp[s], cur);
          ans[i][j] = pls(ans[i][j], cur);

  for (int i = 1; i <= n; i++)
    for (int j = 0; j < n; j++)
      printf("%d%c", ans[i][j], " \n"[j == n-1]);

  return 0;
}