洛谷P3101 题解

发布时间 2024-01-03 21:00:45作者: wbw121124

输入格式

\(1\) 行,三个整数 \(m,n,t\)

\(2\)\(m+1\) 行,\(m\) 个整数,表示海拔高度。

\(2+m\)\(2m+1\) 行,\(m\) 个整数。第 \(i\) 行,第 \(j\) 个整数表示 \(i,j\) 是否为起点。

输出格式

所有起点的最小难度评级之和(请注意,这可能不适合32位整数,即使个别难度评级会适合)。

答案

方法 \(1\):二分答案

  1. 枚举每一个起点,并对其二分答案 \(d\)
  2. dfs 对起点搜索连通块,上下左右移动,海拔差不超过 \(d\),检查连通块的大小是否达到 \(t\)
  3. 对一个起点,搜索时间复杂度 \(O(n * m* \log(d))\)
  4. 最多 \(n * m\) 个起点,总时间复杂度 \(O(n^2 * m^2 * \log(d))\)

方法 \(2\):贪心

  1. 将格子视为点,相邻的格子之间建边,边权为海拔差。
  2. 按照边权,将 \(tot\) 条边从小到大排序。
  3. 枚举边 \(i\),第 \(i\) 条边若首次将 \(2\) 端点的连通块连接,则 \(2\) 个连通块,变大。
  4. 若合并后连通块的大小达到 \(t\),则该连通块的起点的 \(d\) 值即为 \(a[i].z\)\(ans+=e[i].w * u[i]\),清空连通块内的起点标记,继续枚举边。
  5. 答案即为 \(ans\)
  6. 时间复杂度 \(O(n * m * \log(n * m))\)

方法 \(2\) 代码

/*
作者:wbw_121124
*/
#include<bits/stdc++.h>
#define debug false
#define int long long
using namespace std;
const int N = 510;
int n, ans, fa[N * N], v[N * N], m, t, tot, cnt, h[N * N], u[N * N];
struct node {
	int x, y, z;
	bool operator< (node x)
	{
		return z < x.z;
	}
}a[N*N*2];
int get_dis(int x, int y)
{
	return (x - 1) * m + y;
}
int find(int x)
{
	if (fa[x] == x)
		return x;
	return fa[x] = find(fa[x]);
}
void unionn(int x, int y)
{
	x = find(x);
	y = find(y);
	if (x != y)
		fa[x] = y, v[y] += v[x], u[y] += u[x], u[x] = v[x] = 0;
	return;
}
signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	cin >> n >> m >> t;
	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= m; j++)
		{
			cin >> h[get_dis(i, j)];
			v[get_dis(i, j)] = 1;
			fa[get_dis(i, j)] = get_dis(i, j);
		}
	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= m; j++)
			cin >> u[get_dis(i, j)];
	for(int i=1;i<=n;i++)
		for (int j = 1; j <= m; j++)
		{
			if (i != n)
				a[++tot] = node{ get_dis(i,j),get_dis(i + 1,j),abs(h[get_dis(i,j)] - h[get_dis(i + 1,j)]) };
			if (j != m)
				a[++tot] = node{ get_dis(i,j),get_dis(i,j + 1),abs(h[get_dis(i,j)] - h[get_dis(i,j + 1)]) };
		}
	sort(a + 1, a + 1 + tot);
	for(int i=1;i<=tot;i++)
		if (find(a[i].x) != find(a[i].y))
		{
			unionn(a[i].x, a[i].y);
			if (v[find(a[i].y)] >= t)
				ans += u[find(a[i].y)] * a[i].z, u[find(a[i].y)] = 0;
		}
	cout << ans;
	return 0;
}