NOI2021 庆典题解

发布时间 2023-10-12 15:30:52作者: _hjy

又是一道锻炼代码能力的题目。

首先遇到这种求经过多少个节点的题可以先缩点,然后我们考虑那个特殊限制怎么用。

如果对于两个强联通分量 \(x\) 能到 \(z\)\(y\) 能到 \(z\),则 \(x,y\) 之间一定有一个限制,假设这个限制是 \(x\) 能到 \(y\),那么我们可以只记录 \(x\)\(y\)\(y\)\(z\),就把一个 \(DAG\) 变成了一颗叶向树。

考虑 \(k=0\) 的时候怎么做。容易发现就是你去判断 \(t\) 是否在 \(s\) 的子树中,如果在的话就是他们链上的边权和。

显然 \(k = 1,k = 2\) 的时候可以通过分类讨论的方法做,但是这样太逆天了,我们考虑一个更简洁的做法。

首先我们有 \(O(n)\) 的暴力,对起点 \(bfs\) ,在反图上以终点为起点 \(bfs\),对到达的点算贡献,那么我们可以对起点与终点和加入的边的端点建出一颗虚树,在虚树上跑这个暴力即可。由于虚树上的点最多只有 \(O(k)\) 个,所以总时间复杂度不高,可以通过此题。

代码:

#include<bits/stdc++.h>
using namespace std;

const int MX = 3e5;

#define pii pair<int,int >
#define mk make_pair
#define fir first
#define sec second

int n,m,q,k,s,t;

int u[15],v[15];
vector<int > g[MX + 7],r[MX + 7];
int scc[MX + 7],low[MX + 7],dfn[MX + 7],stk[MX + 7],num[MX + 7],in[MX + 7];
int topp = 0,tot = 0,col = 0,rt = 0;
bool instk[MX + 7];

queue<int > que;

int read(){
	int x = 0,ch = getchar();
	while(!(ch >= 48 && ch <= 57))ch = getchar();
	while(ch >= 48 && ch <= 57)x = x * 10 + ch - 48,ch = getchar();
	return x;
}

void print(int x){
	if(x>9)print(x / 10);
	putchar(x % 10 + 48);
}

void tarjan(int x){
	dfn[x] = low[x] = ++tot;
	stk[++topp] = x;
	instk[x] = true;
	for(int i = 0;i < g[x].size();i++){
		int to = g[x][i];
		if(!dfn[to]){
			tarjan(to);
			low[x] = min(low[x],low[to]);
		}
		else if(instk[to])low[x] = min(low[x],dfn[to]);
	}
	if(low[x] == dfn[x]){
		scc[x] = ++col;
		num[col]++;
		instk[x] = false;
		while(stk[topp] != x){
			scc[stk[topp]] = col;
			num[col]++;
			instk[stk[topp]] = false;
			topp--;
		}
		topp--;
	}
}

int heavy[MX + 7],dep[MX + 7],sz[MX + 7],top[MX + 7],fa[MX + 7],sum[MX + 7];

void dfs1(int x,int f){
    sz[x] = 1;
    fa[x] = f;
    int maxn = 0;
    for(int i = 0;i < g[x].size();i++){
        int to = g[x][i];
        if(to == f)continue;
        dep[to] = dep[x] + 1;
        sum[to] = sum[x] + num[to];
        dfs1(to,x);
        if(maxn < sz[to]){
            maxn = sz[to];
            heavy[x] = to;
        }
        sz[x] += sz[to];
    }
}

void dfs2(int x,int f,int t){
    dfn[x] = ++tot;
    top[x] = t;
    if(g[x].size() == 0)return;
    dfs2(heavy[x],x,t);
    for(int i = 0;i < g[x].size();i++){
        int to = g[x][i];
        if(to != heavy[x] && to != f)dfs2(to,x,to);
    }
}

void build_scc(){
	for(int i = 1;i <= n;i++)
		if(!dfn[i])tarjan(i);
	for(int i = 1;i <= n;i++){
		for(int j = 0;j < g[i].size();j++){
			int to = g[i][j];
			if(scc[to] != scc[i]){
				in[scc[to]]++;
				r[scc[i]].push_back(scc[to]);
			}
		}
		g[i].clear();
	}
}

void topsort(){
	for(int i = 1;i <= col;i++)
		if(in[i] == 0){
			que.push(i);
			rt = i;
			sum[rt] = sz[rt];
		}
	while(!que.empty()){
		int x = que.front();
		//cout << x << ' ' << scc[x] << '\n';
		que.pop();
		for(int i = 0;i < r[x].size();i++){
			int to = r[x][i];
			in[to]--;
			if(in[to] == 0){
				g[x].push_back(to);
				que.push(to);
			}
		}
	}
	for(int i = 1;i <= col;i++)r[i].clear();
}

void build_heavy(){
	tot = 0;
	dfs1(rt,rt);
	dfs2(rt,rt,rt);
}

void init(){
	build_scc();
	topsort();
	build_heavy();
}

int lca(int x,int y){
	while(top[x] != top[y]){
		if(dep[top[x]] > dep[top[y]])
			x = fa[top[x]];
		else
			y = fa[top[y]];
	}
	return dep[x] < dep[y]? x : y;
}

bool cmp(int x,int y){
	return dfn[x] < dfn[y];
}

struct Edge{
	int from,to,w;
	int nxt;
}edges[MX + 7],edget[MX + 7];

int heads[MX + 7],headt[MX + 7],tNode = 0;
int a[2 * MX + 7],h[MX + 7],len = 0;
bool viss[MX + 7],vist[MX + 7],visE[MX + 7];

void addEs(int x,int y,int w){
	edges[tNode].from = x;
	edges[tNode].to = y;
	edges[tNode].w = w;
	edges[tNode].nxt = heads[x];
	heads[x] = tNode;
}

void addEt(int x,int y,int w){
	edget[tNode].from = x;
	edget[tNode].to = y;
	edget[tNode].w = w;
	edget[tNode].nxt = headt[x];
	headt[x] = tNode;
}

void addE(int x,int y,int w){
	++tNode;
	visE[tNode] = false;
	addEs(x,y,w);
	addEt(y,x,w);
}

void build_vtree(int l){ 
	sort(h + 1,h + 1 + l,cmp);
	len = 0;
	for(int i = 1;i < l;i++){
		a[++len] = h[i];
		a[++len] = lca(h[i],h[i + 1]);
	}
	a[++len] = h[l];
	sort(a + 1,a + 1 + len,cmp);
	len = unique(a + 1,a + len + 1) - a - 1;
	for(int i = 1;i < len;i++){
		int lc = lca(a[i],a[i + 1]);
		addE(lc,a[i + 1],sum[a[i + 1]] - sum[lc] - num[a[i + 1]]);
	}
}

void build_query(){
	s = read();
	t = read();
	s = scc[s];
	t = scc[t];
	int tmp = 0;
	h[++tmp] = s;
	h[++tmp] = t;
	for(int i = 1;i <= k;i++){
		u[i] = read();
		v[i] = read();
		u[i] = scc[u[i]];
		v[i] = scc[v[i]];
		h[++tmp] = u[i];
		h[++tmp] = v[i];
	}
	build_vtree(tmp);
	for(int i = 1;i <= k;i++)addE(u[i],v[i],0);
}

void bfs1(int x){
	que.push(x);
	while(!que.empty()){
		int p = que.front();
		que.pop();
		viss[p] = true;
		for(int i = heads[p];i;i = edges[i].nxt){
			int to = edges[i].to;
			visE[i] = true;
			if(!viss[to]){
				viss[to] = true;
				que.push(to);
			}		
		}
	}
}

void bfs2(int x){
	int ans = 0;
	que.push(x);
	while(!que.empty()){
		int p = que.front();
		que.pop();
		if(viss[p]){
			ans += num[p];
			viss[p] = false;
		}
		vist[p] = true;
		for(int i = headt[p];i;i = edget[i].nxt){
			int to = edget[i].to,w = edget[i].w;
			if(visE[i]){
				ans += w;
				visE[i] = false;
			}
			if(!vist[to]){
				vist[to] = true;
				que.push(to);
			}
		}
	}
	print(ans);
	putchar('\n');
}

void clear_node(int x){
	heads[x] = headt[x] = 0;
	viss[x] = vist[x] = false;
}

void clear_vtree(){
	tNode = 0;
	for(int i = 1;i < len;i++){
		int lc = lca(a[i],a[i + 1]);
		clear_node(lc);
		clear_node(a[i + 1]);
	}
}

void solve(){
	build_query();
	bfs1(s);
	bfs2(t);
	clear_vtree();
}

int main(){
	n = read();
	m = read();
	q = read();
	k = read();
	for(int i = 1;i <= m;i++){
		s = read();
		t = read();
		g[s].push_back(t);
	}
	init();
	for(int i = 1;i <= q;i++)solve();
	return 0;
}