D - 神经网络

题目链接

这是一道经典题目改编,原题是dreamoon出的NTU-WF选拔赛D题(题目链接)。

不难发现,修复时所产生的花费,可以分解为一系列有向简单路径上结点权值之和。因此由期望的线性性可知,要求总花费的期望值,我们只需分别计算$n^2$条有向简单路径对总花费产生贡献的期望值。

对于树上一条有向简单路径$u \rightarrow v$,设该路径长度为$d \ (0 \leq d \leq n-1)$(也就是说有$d$条边,$d+1$个结点),路径上所有结点权值为$S$,那么结论就是:这条简单路径对总花费产生贡献的期望值为$\frac{S}{d}$ 。

为什么这个结论是正确的呢?首先要注意到,对于任意一个边修复顺序,交换除有向路径$u \rightarrow v$上的$d$条边以外的任意边的修复顺序,不影响有向路径$u \rightarrow v$对总花费的贡献。进一步分析我们发现,对于$u \rightarrow v$上的$d$条边,只有当连接着结点$u$的那条边是这$d$条边中最后一个被修复的边时,有向路径$u \rightarrow v$才会对总花费产生大小恰为$S$的贡献。不难证明“连接$u$的那条边在路径上的$d$条边中最后被修复”这一事件发生的概率为$\frac{1}{d}$,因此有向路径$u \rightarrow v$对最终总花费产生贡献的期望值就是$\frac{S}{d}$。

于是问题转化成,求树上所有有向路径的$\frac{\text{结点权值和}}{\text{路径长度}}$的和,而这可以通过树分治NTT计算。简单来说就是:我们每次找一个重心,对于重心的每个儿子对应的子树处理出两个值$cnt_i$和$val_i$,分别表示该子树中深度为$i$的结点个数,以及所有“深度为$i$的结点到该子树跟结点路径上的所有结点的权值之和”的和(如果没听懂就去看代码,反正我也应该听不懂)。于是我们对每棵子树都搞出了两个多项式。在每次计算跨越重心的贡献时,我们只需要合理利用$cnt_i$和$val_i$对应的多项式进行一波奥妙重重的NTT计数就统计出相应答案了,这之后再把这两棵子树的$cnt_i, val_i$分别合并即可。

需要注意的是,在计算某个分治出来的连通块的答案时,必须将子树按照对应多项式的大小排序,然后从小到大依次计算,否则会TLE(专门构造了卡未排序做法的数据,不过dreamoon出的那题没有卡掉未排序做法)。在有排序的做法下,处理每个连通块的时间复杂度为$O(size \log n)$,因此整个做法的复杂度为$O(n \log n \log n)$。

#include<bits/stdc++.h>

using namespace std;

#define mem(a,b) memset(a,b,sizeof(a))
#define REP(i,a,b) for(int i=a; i<=b; ++i)
#define PER(i,a,b) for(int i=a; i>=b; --i)
#define MP make_pair
#define PB push_back
#define fi first
#define se second
typedef long long LL;
typedef double DB;

const int maxn = 1e5;
const int maxa = 1e5;

const LL P = ((7 * 17) << 23) + 1, gg = 3;
const int maxL = 1<<17;

inline LL PowMod(LL a, LL b) { LL r=1; while(b) { if(b&1) r=r*a%P; a=a*a%P, b>>=1; } return r; }
LL A[maxL+5], B[maxL+5];
void NTT(LL *a, int len, int type) {
    int i, j, k, l;
    for(i=1, j=len>>1; i<len-1; ++i) {
        if(i<j) swap(a[i], a[j]);
        k = len>>1;
        while(j>=k)
            j -= k, k >>= 1;
        j += k;
    }
    LL var, step, u, v;
    for(l=2; l<=len; l<<=1) {
        step = PowMod(gg, (P-1)/l);
        for(k=0; k<len; k+=l) {
            var = 1;
            for(i=k; i<k+l/2; ++i) {
                u = a[i], v = var*a[i+l/2] % P;
                a[i] = (u+v) % P;
                a[i+l/2] = (u-v+P) % P;
                var = var*step % P;
            }
        }
    }
    if(type == -1) {
        for(i=1; i<len/2; ++i) swap(a[i], a[len-i]);
        LL inv = PowMod(len, P-2);
        for(i=0; i<len; ++i) a[i] = a[i]*inv % P;
    }
}

int n, ai[maxn+5];
LL inv[maxn+5];
LL ans = 0;

vector<int> G[maxn+5];
int que[maxn+5], fa[maxn+5], sz[maxn+5], msz[maxn+5];
bool ban[maxn+5];
int FindRoot(int x) {
    int s = 1, t = 1;
    que[1] = x, fa[x] = 0;
    while(s <= t) {
        x = que[s++], sz[x] = 1, msz[x] = 0;
        for(auto v : G[x])
            if(!ban[v] && v!=fa[x])
                que[++t] = v, fa[v] = x;
    }
    for(int i = t; i >= 1; --i) {
        x = que[i], msz[x] = max(msz[x], t-sz[x]);
        if((msz[x]<<1) <= t) return x;
        sz[fa[x]] += sz[x], msz[fa[x]] = max(msz[fa[x]], sz[x]);
    }
    assert(false);
    return 0;
}

int id[maxn+5];
vector<int> cnt[maxn+5], val[maxn+5];
void Mul(int a, int b) {
    int l1 = cnt[a].size();
    int l2 = val[b].size();
    int len = 1;
    while(len < l1+l2-1) len <<= 1;
    assert(len<=maxL);
    REP(i,0,l1-1) A[i] = cnt[a][i];
    REP(i,l1,len-1) A[i] = 0;
    REP(i,0,l2-1) B[i] = val[b][i];
    REP(i,l2,len-1) B[i] = 0;
    NTT(A,len,1), NTT(B,len,1);
    REP(i,0,len-1) A[i] = A[i]*B[i] % P;
    NTT(A,len,-1);
    REP(i,1,len-1) ans = (ans + A[i] * inv[i]) % P;
}

void dfs(int x, int p, int d, int sum, vector<int> &c, vector<int> &v) {
    sum = (sum + ai[x]) % P;
    for(auto y : G[x]) if(y!=p && !ban[y]) {
        dfs(y, x, d+1, sum, c, v);
    }
    if(d+1 > c.size()) {
        c.resize(d+1);
        v.resize(d+1);
    }
    ++c[d], v[d] = (v[d]+sum) % P;
}

void Solve(int x) {
    x = FindRoot(x);
    int tot = 1;
    cnt[0].resize(1), val[0].resize(1);
    cnt[0][0] = 1, val[0][0] = ai[x];
    id[0] = 0;
    for(auto y : G[x]) if(!ban[y]) {
        cnt[tot].clear(), val[tot].clear();
        dfs(y, x, 1, 0, cnt[tot], val[tot]);
        id[tot] = tot;
        ++tot;
    }
    sort(id, id+tot, [&](const int &a, const int &b){ return cnt[a].size() < cnt[b].size(); });
    for(int i = 0; i+1 < tot; ++i) {
        Mul(id[i], id[i+1]);
        Mul(id[i+1], id[i]);
        for(int k = 0; k < cnt[id[i+1]].size(); ++k)
            val[id[i+1]][k] = (val[id[i+1]][k] + LL(ai[x])*cnt[id[i+1]][k]) % P;
        for(int k = 0; k < cnt[id[i]].size(); ++k) {
            cnt[id[i+1]][k] = (cnt[id[i+1]][k] + cnt[id[i]][k]) % P;
            val[id[i+1]][k] = (val[id[i+1]][k] + val[id[i]][k]) % P;
        }
        cnt[id[i]].clear(), val[id[i]].clear();
    }
    cnt[id[tot-1]].clear(), val[id[tot-1]].clear();
    ban[x] = 1;
    for(auto y : G[x]) if(!ban[y]) Solve(y);
}

int main() {
    inv[1] = 1;
    REP(i,2,maxn) inv[i] = (P - P/i) * inv[P%i] % P;
    scanf("%d", &n);
    REP(i,1,n) scanf("%d", ai+i);
    for(int i = 1, u,v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        G[u].PB(v);
        G[v].PB(u);
    }
    Solve(1);
    ans = ans*2 % P;
    printf("%lld\n", ans);
    return 0;
}

G - 铁树开花

题目链接

本题是防AK题,原因不在于思维度高,而在于码量大……(标程在压过代码量的情况下200+行,不过牛客同步赛上杭电在比赛刚结束不久就通过了这题,并且他们准AK,非常强悍)

这题是把两个idea强行拼在一起的产物 (于是也间接导致数据极其难造)。第一个idea来自Google Kick Start 2019 Round A P3 Contention,也就是计算$\max( \min_{1 \leq i \leq m}(cnt_i) )$的方法。不妨考虑倒着染色,假设现在还有$i$种花的顺序没有确定,并且他们一定都是在前$i$步开花,那么我们先让他们都在树上开花染色,并且对于每朵花记录一个$temp_k$值,表示树上只由这一种花染过色的结点的个数。之后,我们只要选择$temp_k$最大的那朵花,让它在第$i$步开花,并用$temp_k$更新一下$\min_{1 \leq i \leq m}(cnt_i)$即可。如此反复做$n$次后就可以算出$\max( \min_{1 \leq i \leq m}(cnt_i) )$的值,这是一个贪心,仔细思考一下就会发现是对的。

下一个问题是,如何在每一轮快速计算$temp_k$。注意到开花半径最大不超过$2$,这给了我们乱搞的余地。我们考虑对树定根,然后按照深度建线段树,即把所有深度相同的结点放在连续的一段区间当中。然后对每个结点,我们记录以下它的所有深度为$1$的子结点对应的区间,然后在记录一下它的所有深度为$2$的子结点对应的区间,于是经过一番仔细的分类讨论后,你会发现每次开花染色或撤销操作只需要涉及$O(\log n)$个区间。由于需要记录当某个结点只被一朵花染色时,这朵花具体是哪个,所以可能还需要对结点多维护一个set,于是单次染色或撤销操作就变成了$O(\log n \log m)$。因此你就可以倒着做,让每朵花都在树上染色,然后每一轮找出最大的$temp_k$及其对应的花,然后把这朵花撤掉即可,具体实现看代码。

但是这样还没完,因为题目要求字典序最小的操作顺序 (主要是因为数据不好造才强行加了这个特技增加题目复杂度),不过有了上面的一系列操作后,求字典序最小其实很简单。我们正着开花,每次选择没开过的花里面满足$temp_k \geq \max( \min_{1 \leq i \leq m}(cnt_i) )$且编号最小的那朵花开花即可。时间复杂度自己算算吧。。。

#include<bits/stdc++.h>

using namespace std;

#define mem(a,b) memset(a,b,sizeof(a))
#define REP(i,a,b) for(int i=a; i<=b; ++i)
#define PER(i,a,b) for(int i=a; i>=b; --i)
#define MP make_pair
#define PB push_back
#define fi first
#define se second
#define lson (x<<1)
#define rson (x<<1|1)
#define mid ((l+r)>>1)
typedef long long LL;
typedef double DB;
typedef pair<int,int> pii;

const int maxn = 1e5;
const int maxM = 2e4;
const int maxnode = maxn<<2;

int n, m;
int DFN = 0, dfn[maxn+5], fa[maxn+5];
vector<int> vec[maxn+5];
int L1[maxn+5], R1[maxn+5], L2[maxn+5], R2[maxn+5];
int pt[maxM+5], dd[maxM+5];
int cnt[maxM+5];
bool dead[maxM+5];
vector<int> G[maxn+5];
priority_queue<pair<int,int>> q;

int tr[maxnode+5], lz[maxnode+5];
set<int> cov[maxnode+5];

void dfs(int x, int p, int dis) {
    fa[x] = p;
    vec[dis].PB(x);
    for(auto y : G[x]) if(y!=p) {
        dfs(y,x, dis+1);
    }
}

void dfs2(int x, int p) {
    L1[x] = L2[x] = n+1;
    R1[x] = R2[x] = 0;
    for(auto y : G[x]) if(y!=p) {
        dfs2(y,x);
        L1[x] = min(L1[x], dfn[y]);
        R1[x] = max(R1[x], dfn[y]);
        L2[x] = min(L2[x], L1[y]);
        R2[x] = max(R2[x], R1[y]);
    }
}

void reBuild(int x, int l, int r) {
    tr[x] = lz[x] = 0;
    assert(cov[x].size()==0);
    if(l<r) {
        reBuild(lson, l, mid);
        reBuild(rson, mid+1, r);
    }
}

void PushDown(int x) {
    lz[lson]+=lz[x], tr[lson]+=lz[x];
    lz[rson]+=lz[x], tr[rson]+=lz[x];
    lz[x] = 0;
}

void Upd(int x, int l, int r, int v) {
    if(tr[x]>1) return;
    if(cov[x].size()) v = *cov[x].begin();
    if(l==r) {
        if(tr[x]==0) tr[x] = n+10;
        else if(tr[x]==1) {
            assert(1<=v && v<=n);
            ++cnt[v];
            q.push( pair<int,int>(cnt[v],v) );
            tr[x] = n+10;
        }
    }
    else {
        PushDown(x);
        Upd(lson, l, mid, v);
        Upd(rson, mid+1, r, v);
        tr[x] = min(tr[lson], tr[rson]);
    }
}

void Add(int x, int l, int r, int ll, int rr, int v) {
    if(ll<=l && r<=rr) {
        ++tr[x], ++lz[x];
        cov[x].insert(v);
    }
    else {
        PushDown(x);
        if(ll<=mid) Add(lson, l, mid, ll, rr, v);
        if(mid<rr) Add(rson, mid+1, r, ll, rr, v);
        tr[x] = min(tr[lson], tr[rson]);
    }
}

void Sub(int x, int l, int r, int ll, int rr, int v) {
    if(ll<=l && r<=rr) {
        --tr[x], --lz[x];
        auto it = cov[x].find(v);
        assert(it != cov[x].end());
        cov[x].erase(it);
    }
    else {
        PushDown(x);
        if(ll<=mid) Sub(lson,l, mid, ll, rr, v);
        if(mid<rr) Sub(rson, mid+1, r, ll, rr, v);
        tr[x] = min(tr[lson], tr[rson]) + lz[x];
    }
}

void Mod(int l, int r, int v, int o) {
    if(l>r) return;
    if(o==1) Add(1, 1, n, l, r, v);
    else Sub(1, 1, n, l, r, v);
}

void Draw(int id, int o) {
    assert(o==1 || o==-1);
    int x = pt[id], kk = dd[id];
    if(kk==0) Mod(dfn[x], dfn[x], id, o);
    else if(kk==1) {
        Mod(L1[x], R1[x], id, o);
        Mod(dfn[x], dfn[x], id, o);
        if(fa[x]!=0) Mod(dfn[fa[x]], dfn[fa[x]], id, o);
    }
    else {
        Mod(L1[x], R1[x], id, o);
        Mod(L2[x], R2[x], id, o);
        if(fa[x]!=0) {
            int y = fa[x];
            Mod(L1[y], R1[y], id, o);
            Mod(dfn[y], dfn[y], id, o);
            if(fa[y]!=0) Mod(dfn[fa[y]], dfn[fa[y]], id, o);
        }
        else Mod(dfn[x], dfn[x], id, o);
    }
}

priority_queue<int> q2;

int main() {
    scanf("%d%d", &n, &m);
    for(int i = 1, u,v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        G[u].PB(v); G[v].PB(u);
    }
    dfs(1,0,1);
    REP(i,1,n) {
        for(auto x : vec[i]) dfn[x] = ++DFN;
    }
    dfs2(1,0);
    REP(i,1,m) {
        scanf("%d%d", pt+i, dd+i);
        Draw(i,1);
    }
    REP(i,1,m) q.push( pair<int,int>(0,i) );
    Upd(1, 1, n, -1);
    int ans = n+10;
    REP(i,1,m) {
        while(dead[q.top().se] || cnt[q.top().se]!=q.top().fi) q.pop();
        ans = min(ans, q.top().fi);
        int x = q.top().se; q.pop();
        assert(1<=x && x<=m);
        dead[x] = 1;
        Draw(x,-1);
        Upd(1, 1, n, -1);
    }

    reBuild(1,1,n);
    while(!q.empty()) q.pop();
    REP(i,1,m) cnt[i] = 0, dead[i] = 0, Draw(i,1), q.push( pair<int,int>(0,i) );
    Upd(1, 1, n, -1);
    int gao = 0;
    REP(i,1,m) {
        while(gao < m) {
            while(dead[q.top().se] || cnt[q.top().se]!=q.top().fi) q.pop();
            if(q.top().fi < ans) break;
            q2.push( -q.top().se );
            dead[q.top().se] = 1;
            q.pop(); ++gao;
        }
        int x = -q2.top();
        q2.pop();
        Draw(x,-1);
        Upd(1, 1, n, -1);
    }

    printf("%d\n", ans);
    int tmp_sum = 0;
    REP(i,1,m) printf("%d\n", cnt[i]), tmp_sum += cnt[i];

    return 0;
}

J - 单身狗救星

题目链接

把题目中的公式化简后,问题就转化成:给出二维平面内$n$个点,再给$m$个询问点,对每个询问点输出$n$个点中和它形成斜率最大的那个点的编号。

据出题人说这个idea去年jls在牛客多校上出过,做法是搞两遍凸包。我们把所有$n+m$点按照先$x$坐标后$y$坐标的顺序排个序,然后按照新顺序扫一遍,同时维护一个动态凸包。当我们扫到点是前$n$个定点之一,就把这个点加进凸包里;如果我们扫到的点是后$m$个询问点之一,我们就在凸包上二分找一个使得斜率最大的定点,并更新这个询问点的答案。

做完之后,再将所有点关于原点翻转,之后再做一遍上述操作即可。总时间复杂度$O((n+m) \log n)$。

为什么这样做是对的呢?简单口胡一下,当扫到某个询问点时,所有在其左侧的定点(也就是$x$座标小于它的点)一定都被包含在下凸包或者就在下凸包上。现在我们找到了下凸包上的斜率最大点,假设凸包内有一点能使得斜率更大,那么画画图就会发现一定能在凸包上找到另一个点使得斜率比这个凸包内点还要大,这样就出现了矛盾。所以这样操作一定能找到每个询问点左侧使其斜率最大的点,而对于右侧的情况,沿原点翻转再算一遍就出来了。

// hls tql
#include <bits/stdc++.h>
#define ll long long

using namespace std;

struct node {
    ll x, y;
    int ty, id;
    node() {}
    node(ll _x, ll _y, int _ty, int _id)
        : x(_x), y(_y), ty(_ty), id(_id) {}
    void init(int _ty, int _id) {
        scanf("%lld%lld", &x, &y);
        ty = _ty;
        id = _id;
    }
    node operator-(const node &b) const {
        return node(x - b.x, y - b.y, 0, 0);
    }
    ll operator^(const node &b) const {
        return x * b.y - y * b.x;
    }
    bool operator<(const node &b) const {
        return x == b.x ? y < b.y : x < b.x;
    }
};

const int MAXN = 1e5 + 10;
node a[MAXN], b[MAXN], cc[MAXN << 1];
int ans[MAXN], que[MAXN], lim;

bool judge(node a, node b, node c) {
    ll tmp = (b - a) ^ (a - c);
    return tmp == 0 ? c.id < b.id : tmp > 0;
}

void solve() {
    int las = 0;
    for (int i = 1; i <= lim; i++) {
        if (cc[i].ty == 1) {
            while (las > 1 && ((cc[i] - cc[que[las]]) ^ (cc[que[las]] - cc[que[las - 1]])) > 0)
                las--;
            que[++las] = i;
        } else {
            if (!las)
                continue;
            int l = 1, r = las;
            while (l < r) {
                int mid = (l + r) >> 1;
                if (((cc[i] - cc[que[mid + 1]]) ^ (cc[que[mid + 1]] - cc[que[mid]])) > 0)
                    r = mid;
                else
                    l = mid + 1;
            }
            if (ans[cc[i].id] == -1 || judge(b[cc[i].id], a[ans[cc[i].id]], a[cc[que[l]].id]))
                ans[cc[i].id] = cc[que[l]].id;
        }
    }
}

int main() {
    int n, q;
    scanf("%d%d", &n, &q);
    lim = n + q;
    for (int i = 1; i <= lim; i++) {
        if (i <= n) {
            a[i].init(1, i);
            cc[i] = a[i];
        } else {
            b[i - n].init(2, i - n);
            cc[i] = b[i - n];
        }
    }
    memset(ans, -1, sizeof(ans));
    sort(cc + 1, cc + lim + 1);
    solve();
    reverse(cc + 1, cc + lim + 1);
    for (int i = 1; i <= lim; i++) {
        cc[i].x = -cc[i].x;
        cc[i].y = -cc[i].y;
    }
    solve();
    for (int i = 1; i <= q; i++)
        printf("%d\n", ans[i]);
}