9.20模拟赛T3题解【限时公开,阅后即焚】

发布时间 2023-09-22 16:38:59作者: cool_milo

考场做法。
复杂度是优美的\(\Theta(n^2 \log n)\)
强烈谴责高复杂度碾标算行为
考虑一个观察:对于一个左上角 \((x, y)\) ,如果我们确定了它的边长一个区间 \([L,R]\),使得这个区间内 至少存在 \(k\)\(k\) 列1,(可能还有一些多余的1),那么我们就可以用二分确定哪一部分是合法的,单调性显然。
我们记点 \((x, y)\) 向下延伸的 \(1\) 的个数为 \(dn_{i,j}\), 同理,\(rt_{i,j}\)为向右延伸的 \(1\) 的个数。
当左上角为 \(0\) 的时候,直接找到 \((x, y)\) 右边,下边前 \(k\) 个1就可以了。L就是两个 1出现的位置和 \((x, y)\) 的距离 取\(max\),R就是区间非零 \(dn\) 的最小值和区间非零 \(rt\) 的最小值取 \(min\)
但是这个做法在 \((x, y)\)\(1\) 的时候会出问题,怎么会是呢?
我们发现,难点在于我们 不知道1往下延伸的部分是作为一列出现还是整体作为一行出现
image
但是我们可以同时发现,我们排除掉所有全1正方形之后,以向右为例,假设当前的正方形边长为 \(len\)\((x, [y, y + len - 1])\)\(dn\) 的最小值为 \(mindn\) ,那么 \(mindn\) 就是“紧贴 \((x, y)\)的全1行的数量”。

(原谅作者菜的补星的语文。。)

这样就好做了,维护一个指针 \(ptr\) ,表示 \((x, [y, ptr])\) 中非最小值的出现次数为 \(k\) 的最左边的列号。对向下也做一遍同样的事情,这样两个长度取 \(max\) 就是二分下界,区间次小值取min就是二分上界,实现时可以用map搞。
还有一个问题是 \(ptr\) 的单调性。但是注意到 \(dn_{x,y}\) 不是最小值时没有问题, \(dn_{x,y}\) 是最小值时因为 \(len \geq k + 1\) 也没有问题。那么就做完了(?)

丑陋的代码贴出来全网求hack,我自己也和hfy的代码写写对拍(?)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int P = 998244353;
const int N = 1005;
template<typename T>inline void read(T &a)
{
	a = 0;
	int f = 1;
	char c = getchar();
	while(!isdigit(c))
	{
		if(c == '-')	f = -1;
		c = getchar();
	}
	while(isdigit(c))
		a = a * 10 + c - 48,c = getchar();
	a *= f;
}

template<typename T,typename ...L> inline void read(T &a,L &...l)
{
	read(a),read(l...);
}

char matrix[N][N];
int n, m, k, sum[N][N], dn[N][N], rt[N][N], lw[N][N], hg[N][N];//下界,上界 

inline int f(int len) {
	return 2 * len * k - k * k;
}

inline int Sum(int x, int y, int xx, int yy) {
	return sum[xx][yy] - sum[xx][y - 1] - sum[x - 1][yy] + sum[x - 1][y - 1];
}

int main() {
	freopen("matrix.in", "r", stdin);
	freopen("matrix.out", "w", stdout);
	read(n, m, k);
	for(int i = 1; i <= n; i++) {
		scanf("%s", matrix[i] + 1);
	}
	for(int i = 1; i <= n; i++) {
		for(int j = 1; j <= m; j++) {
			sum[i][j] = sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + matrix[i][j] - '0';
		}
	}
	
	memset(hg, 0x3f, sizeof hg);
	for(int i = 1; i <= n; i++) {
		for(int j = 1; j <= m; j++) {
			lw[i][j] = k + 1;
		}
	} 
	
	for(int i = n; i >= 1; i--) {
		for(int j = 1; j <= m; j++) {
			dn[i][j] = (matrix[i][j] == '1') ? dn[i + 1][j] + 1 : 0; 
		} 
	}
	
	for(int j = m; j >= 1; j--) {
		for(int i = 1; i <= n; i++) {
			rt[i][j] = (matrix[i][j] == '1') ? rt[i][j + 1] + 1 : 0;
		} 
	}
	for(int i = 1; i <= n; i++) {
		int len = 0;
		map<int, int> mp;
		for(int j = m; j >= 1; j--) {
			len++;
			mp[dn[i][j]]++;
			while(len - mp.begin() -> second > k) {
				--mp[dn[i][j + len - 1]];
				if(!mp[dn[i][j + len - 1]]) {
					mp.erase(dn[i][j + len - 1]);
				}
				len--;
			}
			while(mp.begin() -> second != len && dn[i][j + len - 1] == mp.begin() -> first) {
				--mp.begin() -> second;
				if(!mp.begin() -> second) {
					mp.erase(mp.begin());
				}
				len--;
			} 
			lw[i][j] = max(lw[i][j], len);
			if(int(mp.size()) > 1 && len - mp.begin() -> second >= k) {
				hg[i][j] = min(hg[i][j], next(mp.begin()) -> first);
			}
			else {
				hg[i][j] = -1;
			}
		}
	}
	
	for(int j = 1; j <= m; j++) {
		int len = 0;
		map<int, int> mp;
		for(int i = n; i >= 1; i--) {
			len++;
			mp[rt[i][j]]++;
			while(len - mp.begin() -> second > k) {
				--mp[rt[i + len - 1][j]];
				if(!mp[rt[i + len - 1][j]]) {
					mp.erase(rt[i + len - 1][j]);
				}
				len--;
			}
			while(mp.begin() -> second != len && rt[i + len - 1][j] == mp.begin() -> first) {
				--mp.begin() -> second;
				if(!mp.begin() -> second) {
					mp.erase(mp.begin());
				}
				len--;
			} 
			lw[i][j] = max(lw[i][j], len);
			if(int(mp.size()) > 1 && len - mp.begin() -> second >= k) {
				hg[i][j] = min(hg[i][j], next(mp.begin()) -> first);
			}
			else {
				hg[i][j] = -1;
			}
		}
	}
	int ans = 0;
	for(int i = 1; i <= n; i++) {
		for(int j = 1; j <= m; j++) {
			if(i + k - 1 <= n && j + k - 1 <= m) {
				ans += (Sum(i, j, i + k - 1, j + k - 1) == k * k);
			}
			if(lw[i][j] <= hg[i][j] && Sum(i, j, i + lw[i][j] - 1, j + lw[i][j] - 1) == f(lw[i][j])) {
				int L = lw[i][j], R = hg[i][j];
				while(L < R) {
					int mid = (L + R + 1) >> 1;
					if(f(mid) == Sum(i, j, i + mid - 1, j + mid - 1)) {
						L = mid;
					}
					else {
						R = mid - 1;
					}
				}
				ans += L - lw[i][j] + 1;
			}
		}
	} 
	cout<<ans<<endl;
}

/*
	start coding at:2023/9/22 11:02
	finish debugging at:2023/9/22 11:58
	stubid mistakes:求前缀和没有减'0', mp.erase()时len没有-- 
*/

/*
  吾日三省吾身:
  你排序了吗?
  你取模了吗?
  你用%lld输出long long 了吗?
  1LL<<x写对了吗?
  判断=0用了abs()吗?
  算组合数判了a<b吗?
  线段树build()了吗?
  .size()加了(signed)吗?
  树链剖分的DFS序是在第二次更新的吗?
  修改在询问前面吗?
  线段树合并到叶子结点return了吗?
*/