考虑分别计算每个点的贡献
引理:一棵树仅有重心满足重儿子的大小小于整棵树的大小的一半
钦定以原树重心为根,如果一个在原树上非重心的点$x$在割掉一条边以后成为了重心,那么割掉的那个子树一定不在$x$的子树内,这个考虑重心的性质就可以理解了
考虑割掉的不包含$x$的子树大小为$S$,$x$的重儿子为$g$,若$x$为所在子树的重心,可以得到以下约束
$$ \begin{cases} 2 * (n - S - s(x)) \leq n - S\\ 2 * (n - s(g)) \leq n - S \end{cases} $$
等价于
$$ n - 2s(x) \leq S \leq n - 2s(g) $$
要计算的就是这样的$S$的数量,值域树状数组即可,初始把每个子树的$S$丢进去,然后dfs的时候用换根的思路,把割掉的边在该点到根的边集里的这一部分贡献算上
不在$x$的子树内的限制,其实有个很自然的记录dfn直接上主席树的做法,但实际上再开一个树状数组用树上差分计算出子树内的贡献减去就行了
最后考虑的是最初钦定的根节点的贡献,考虑和别的点一样,如果割掉的点在最大子树内,次大子树就不能超过整棵树大小的一半,反之,最大子树就不能超过整棵树大小的一半
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 1e6 + 10, M = 2e6 + 10, inf = 0x3f3f3f3f;
const long long Linf = 0x3f3f3f3f3f3f3f3f;
inline int read() {
bool sym = 0; int res = 0; char ch = getchar();
while (!isdigit(ch)) sym |= (ch == '-'), ch = getchar();
while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
return sym ? -res : res;
}
struct EDGE {
int u, v, nxt;
} edge[N];
bool vis[N];
int n, m, cnt, head[N], siz[N], son[N], rt, fir, sec;
long long ans;
struct BIT {
int sum[N];
#define lowbit(x) (x & -x)
void clear() {memset(sum, 0, sizeof(sum));}
void add(int x, int t) {for (int i = x; i <= n; i += lowbit(i)) sum[i] += t;}
int get(int x) {int res = 0; for (int i = x; i; i -= lowbit(i)) res += sum[i]; return res;}
int query(int l, int r) {return get(r) - get(l - 1);}
} t1, t2;
void add(int u, int v) {edge[++cnt] = (EDGE){u, v, head[u]}; head[u] = cnt;}
void dfs1(int u, int last) {
siz[u] = 1; son[u] = 0;
for (int e = head[u]; e; e = edge[e].nxt) {
int v = edge[e].v; if (v == last) continue; dfs1(v, u);
siz[u] += siz[v]; if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void find() {
for (int i = 1; i <= n; i++) {
if (siz[son[i]] <= n / 2 && n - siz[i] <= n / 2) {rt = i; return;}
}
}
void dfs2(int u, int last) {
int l, r, t; if (vis[last]) vis[u] = 1;
if (u != rt) {
t1.add(n - siz[u], 1); t1.add(siz[u], -1); t2.add(siz[u], 1);
l = n - 2 * siz[u]; r = n - 2 * siz[son[u]]; t = t2.query(l, r);
if (vis[u]) ans += 1ll * rt * (2 * siz[sec] <= n - siz[u]);
else ans += 1ll * rt * (2 * siz[fir] <= n - siz[u]);
}
for (int e = head[u]; e; e = edge[e].nxt) {
int v = edge[e].v; if (v == last) continue; dfs2(v, u);
}
if (u != rt) {
t = t2.query(l, r) - t; ans += 1ll * (t1.query(l, r) - t) * u;
t1.add(n - siz[u], -1); t1.add(siz[u], 1);
}
}
int main() {
int T = read();
while (T--) {
n = read(); cnt = 0; ans = 0; memset(head, 0, sizeof(head));
for (int i = 2; i <= n; i++) {
int u = read(), v = read(); add(u, v); add(v, u);
}
dfs1(1, 0); find(); dfs1(rt, 0); t1.clear(); t2.clear();
for (int i = 1; i <= n; i++) if (i != rt) t1.add(siz[i], 1);
fir = sec = 0;
for (int e = head[rt]; e; e = edge[e].nxt) {
int v = edge[e].v;
if (siz[v] > siz[fir]) sec = fir, fir = v;
else if (siz[v] > siz[sec]) sec = v;
}
memset(vis, 0, sizeof(vis)); vis[fir] = 1;
dfs2(rt, 0); printf("%lld\n", ans);
}
return 0;
}