CF Round 963-Div.2 | E. Xor-Grid Problem(最短路径/floyd/位运算/状压dp)

题目链接:https://codeforces.com/contest/1993/problem/E
题目
f47d778b82a5551e45d38453edd770d3.png
简单题意
给定一个n*m的二维矩阵现可以执行两种操作:

  • 选择第i行,将第i行上,第j列元素,替换为它所在列的元素异或和。即a[i][j] = xor(a[k][j]),1<=k<=n
  • 选择第i列,将第i列上,第j行元素,替换为它所在行的元素异或和。即a[j][i] = xor(a[j][k]), 1<=k<=m
    可以执行上述操作任意次。问,最终能得到,二维矩阵,相邻元素绝对差值之和,最小值,是多少。即sum(min(a[i][j]-a[k][l])),其中(abs(i-k)+abs(j-l)) == 1。

考虑变形,抽象问题

我们观察

  • 选择第i列,将第i列上,第j行元素,替换为它所在行的元素异或和。即a[j][i] = xor(a[j][k]), 1<=k<=m

涉及了m个元素,不方便观察,进一步抽象,只关注元素a[i][j]

  • 选择第i列,将第i列上,第j行元素,替换为它所在行的元素异或和。即a[j][i] = xor(a[j][k]), 1<=k<=m

修改前,第j行元素异或和为
xor_old = a[j][1]^a[j][2]^…^a[j][m],a[j][i]_old = a[j][i]
修改后,
a[j][i]_new = xor_old
第j行元素异或和为
xor_new = a[j][1]^a[j][2]^…^a[j][i]_new ^…^a[j][m],
进一步化简,
xor_new = a[j][1]^a[j][2]^…^xor_old ^…^a[j][m],
xor_new = a[j][1]^a[j][2]^…^(a[j][1]^a[j][2]^…^a[j][m]) ^…^a[j][m],
xor_new = a[j][i]
到这里,有没有发现!

  • 选择第i列,将第i列上,第j行元素,替换为它所在行的元素异或和。即a[j][i] = xor(a[j][k]), 1<=k<=m

上述操作,实际上,就是做了一件事情

1
swap(a[j][i],xor)

再扩展到第1,2…,m个元素,单次的列操作,实际上就是做了n次元素交换,被交换的对象为第j行(1<=j<=n)以及它所在行的原始异或总和xor_row[j]

1
swap(a[j][i], xor_row[j]), 1<=j<=n

同理,对于另一个操作选择

  • 第i列,将第i列上,第j行元素,替换为它所在行的元素异或和。即a[j][i] = xor(a[j][k]), 1<=k<=m

实际上就是做了m次元素交换,被交换的对象为第j列(1<=j<=m)以及它所在列的原始异或总和xor_col[j]

1
swap(a[i][j], xor_col[j]), 1<=j<=m

因此,我们可以抽象出一个单独第n+1列,第m+1行

比如对于原始数组是
9238ae35c4afe97c6cbb6c06d0a564f3.png
我们可以再脑补出额外的行、列,分别存储对应行、列的异或和。
a[n+1][m+1](实际上就是原始数组所有元素的异或和)也要补充出来。
67446c3f40af4.png
原来的操作转换为下述两种操作:

  • 选择第i行,将其与第n+1行进行,行交换。
  • 选择第i列,将其与第m+1列进行,列交换。

再进一步抽象

我们要最小的是beauty(a)=R+C,其中R,C分别为
67446c3eda3f5.png
又因为,我们可以自由交换任意两行,任意两列。
那么上述问题,实际上就是旅行商问题。我们只看第i行:

  • 从第任意一列出发,要求经过所有其他的列,需要的最短路径。

求任意两个点之间的最短路径,根据该题的数据范围,我们很容易想到,用floyd算法。

1
2
3
4
5
// floyd伪代码
for (i...)
 for (k ...)
  for (j ...)
    dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);

我们单独求出每一行、遍历所有节点的最短路径,每一列、遍历所有节点的的最短路径。
以哪个为起点,哪个为终点,途径点顺序又是怎么样的?我们可以用二进制状态压缩(俗称,状压),暴力枚举所有可能的情况。
最后,再考虑整合成二维。
最终,行上和列上,分别有两个结束节点j,i
以j列做为行上的最终结束节点,以i行做为列上的最终结束节点
我们分别计算行上的最短路径fr[i][j],以及列上的最短路径fc[i][j]
最终答案即为

1
min(fr[i][j] + fc[i][j])

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <bits/stdc++.h>
using namespace std;
 
const int N = 16;
 
int n, m;
int a[N][N];
 
int fr[N][N], fc[N][N];
int w[N][N], dp[N][1<<N];
 
int main() {
    cin.tie(0)->sync_with_stdio(0);
    
    int t;
    cin >> t;
 
    while (t--) {
        cin >> n >> m;
 
        for (int i = 0; i <= n; i++) a[i][m] = 0;
        for (int j = 0; j <= m; j++) a[n][j] = 0;
 
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                cin >> a[i][j];
                // 计算第n+1行,第m+1列,包括a[n][m]
                a[i][m] ^= a[i][j];
                a[n][j] ^= a[i][j];
                a[n][m] ^= a[i][j];
            }
        }
 
    // 状压常见表达方式,用 111...111 表示最大值
        int fullmask_n = (1 << (n+1)) - 1;
        int fullmask_m = (1 << (m+1)) - 1;
 
        // 枚举 rmv做为行上的终点
        for (int rmv = 0; rmv <= m; rmv++) {
         // floyd算法
            for (int i = 0; i <= n; i++) {
                for (int j = i + 1; j <= n; j++) {
                    w[i][j] = 0;
                    for (int l = 0; l <= m; l++) {
                        if (rmv == l) continue;
                        w[i][j] += abs(a[i][l] - a[j][l]);
                    }
                    w[j][i] = w[i][j]; // 双向图,所以成立
                }
            }
 
     // dp[i][mask]表示以i为当前 出发点,
     // 且图上已经遍历了mask个节点(用1表示已遍历)
     // 需要的最短路径
            for (int i = 0; i <= n; i++) {
                fill(dp[i], dp[i] + fullmask_n, INT_MAX);
                dp[i][1 << i] = 0;
            }
     // 状压
            for (int mask = 0; mask <= fullmask_n; mask++) {
                for (int last = 0; last <= n; last++) {
                  // last在mask对应的bit必须为1
                    if (~mask >> last & 1) continue; 
                    // mask已经有n个1的情况 已经到达结束条件
                    if (__builtin_popcount(mask) == n) continue;
 
                    for (int next = 0; next <= n; next++) {
                     // next在mask对应的bit必须为0
                        if (mask >> next & 1) continue;
 
                        int new_mask = mask | 1 << next;
                        dp[next][new_mask] = min(
                            dp[next][new_mask],
                            dp[last][mask] + w[last][next]
                        );
                    }
                }
            }
     // 枚举 第i点做为 列上终点
            for (int i = 0; i <= n; i++) {
                fr[i][rmv] = INT_MAX;
                // 我们总共有n+1个点,而我们实际只需要走n个点
                // 最后一个点i,不需要走
                // 所以通过异或,去掉该点
                int mask = fullmask_n ^ 1 << i;
 
                for (int last = 0; last <= n; last++) {
                    fr[i][rmv] = min(fr[i][rmv], dp[last][mask]);
                }
            }
        }
 
    // 枚举 rmv做为列上的终点,思路同上
        for (int rmv = 0; rmv <= n; rmv++) {
            for (int i = 0; i <= m; i++) {
                for (int j = i + 1; j <= m; j++) {
                    w[i][j] = 0;
                    for (int l = 0; l <= n; l++) {
                        if (rmv == l) continue;
                        w[i][j] += abs(a[l][i] - a[l][j]);
                    }
                    w[j][i] = w[i][j];
                }
            }
 
            for (int i = 0; i <= m; i++) {
                fill(dp[i], dp[i] + fullmask_m, INT_MAX);
                dp[i][1 << i] = 0;
            }
 
            for (int mask = 0; mask <= fullmask_m; mask++) {
                for (int last = 0; last <= m; last++) {
                    if (~mask >> last & 1) continue;
                    if (__builtin_popcount(mask) == m) continue;
 
                    for (int next = 0; next <= m; next++) {
                        if (mask >> next & 1) continue;
 
                        int new_mask = mask | 1 << next;
                        dp[next][new_mask] = min(
                            dp[next][new_mask],
                            dp[last][mask] + w[last][next]
                        );
                    }
                }
            }
 
            for (int i = 0; i <= m; i++) {
                fc[rmv][i] = INT_MAX;
                int mask = fullmask_m ^ 1 << i;
 
                for (int last = 0; last <= m; last++) {
                    fc[rmv][i] = min(fc[rmv][i], dp[last][mask]);
                }
            }
        }
 
        int ans = INT_MAX; // 求最终答案
        for (int i = 0; i <= n; i++) {
            for (int j = 0; j <= m; j++) {
                ans = min(ans, fr[i][j] + fc[i][j]);
            }
        }
 
        cout << ans << '\n';
    }
}