[线段树+][LCA][dfs序][二分查找][×] codeforces 1062E

http://codeforces.com/contest/1062/problem/E

这题两个难点,第一个是如何快速找到一堆不一定在一起的节点的公共LCA,第二个是如何确定这一堆节点中需要剔除哪一个。

可以先想一下,设原集合的lca为p1,可以取得最优答案的点为x,其他点为vi,那么x一定不在lca(所有的vi)这个点的子树上。这时候,任意一个vi和x的lca都是p1。这时假如维护一个S[l,r]表示[l,r]的公共LCA,根据上面说的性质,这个S序列一定存在一个断点,从这个点开始所有的S==p1,这样的话就可以利用二分查找找到这个断点,可以确定这个断点一定是候选答案之一。另外一个在二分的过程中其实是所有的lca都是基于最左边那个点积累过来的,所以除了分界点之外另外一个可能影响到答案的就是最左边这个点。到这里就算是找到了唯二可能影响答案的点,只需要对他们分别验证就可以了。从dalao的代码里学到了lca也可以作为线段树维护的变量。。。

贴上dalao的代码,我写的太丑了

#include <bits/stdc++.h>
#define pb push_back
#define lc (id << 1)
#define rc (id << 1 ^ 1)
#define md (l + r >> 1)
using namespace std;
const int N = 1e5 + 10, LG = 18;
int n, q, S[N * 4], H[N], P[N][LG];
vector<int> Adj[N];
void DFS(int v)
{
    for (int i = 1; i < LG; i++)
        P[v][i] = P[P[v][i - 1]][i - 1];
    for (int &u : Adj[v])
        H[u] = H[v] + 1, DFS(u);
}
inline int LCA(int v, int u)
{
    if (H[v] < H[u])
        return (LCA(u, v));
    for (int i = 0; i < LG; i++)
        if ((H[v] - H[u]) & (1 << i))
            v = P[v][i];
    if (v == u)
        return (v);
    for (int i = LG - 1; ~i; i--)
        if (P[v][i] != P[u][i])
            v = P[v][i], u = P[u][i];
    return (P[v][0]);
}
void Build(int id = 1, int l = 1, int r = n + 1)
{
    if (r - l < 2)
    {
        S[id] = l;
        return;
    }
    Build(lc, l, md);
    Build(rc, md, r);
    S[id] = LCA(S[lc], S[rc]);
}
int Get(int le, int ri, int id = 1, int l = 1, int r = n + 1)
{
    if (le <= l && r <= ri)
        return (S[id]);
    if (ri <= md)
        return (Get(le, ri, lc, l, md));
    if (md <= le)
        return (Get(le, ri, rc, md, r));
    return (LCA(Get(le, ri, lc, l, md), Get(le, ri, rc, md, r)));
}
int main()
{
    scanf("%d%d", &n, &q);
    for (int i = 2; i <= n; i++)
        scanf("%d", &P[i][0]), Adj[P[i][0]].pb(i);
    DFS(1);
    Build();
    for (; q; q--)
    {
        int l, r, lca;
        scanf("%d%d", &l, &r);
        lca = Get(l, r + 1);
        int le = l, ri = r, mid;
        while (ri - le > 1)
        {
            mid = (le + ri) >> 1;
            if (Get(l, mid + 1) == lca)
                ri = mid;
            else
                le = mid;
        }
        int lca1 = Get(l + 1, r + 1);
        int lca2;
        if (ri != r)
            lca2 = LCA(Get(l, ri), Get(ri + 1, r + 1));
        else
            lca2 = Get(l, ri);
        if (H[lca1] < H[lca2])
            printf("%d %d\n", ri, H[lca2]);
        else
            printf("%d %d\n", l, H[lca1]);
    }
    return (0);
}

另一种思路是官方题解的那种,利用dfs的时间戳:在[l,r]中找到最大和最小的时间戳,这俩节点就是决定lca的点,具体证明在官方tutorial里有,不过其实直觉上挺好理解的。。按照dfs序画画就明白了。

底下这个巨丑的代码就是我写的

#include<bits/stdc++.h>

using namespace std;
const int maxn=1e5+15;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;

typedef long long LL;

vector<int> g[maxn];
int fa[maxn][18],stamp[maxn],ct,dep[maxn];

void dfs(int u,int p)
{
    stamp[u]=++ct;
    int len=g[u].size();
    for(int i=0;i<len;i++)
    {
        if(g[u][i]==p)continue;
        fa[g[u][i]][0]=u;dep[g[u][i]]=dep[u]+1;
        dfs(g[u][i],u);
    }
}

struct node
{
    int v,id;
    node(){v=id=INF;}
    node(int v,int id):v(v),id(id){}
    bool operator<(const struct node &rhs)const
    {
        return v<rhs.v;
    }
};

node minx[maxn<<2],maxx[maxn<<2];
void build(int o,int l,int r)
{
    if(l==r)
    {
        minx[o].v=stamp[l],minx[o].id=l;
        maxx[o].v=stamp[l],maxx[o].id=l;
    }
    else
    {
        int m=(l+r)>>1;
        build(o<<1,l,m);build(o<<1|1,m+1,r);
        minx[o]=min(minx[o<<1],minx[o<<1|1]),maxx[o]=max(maxx[o<<1],maxx[o<<1|1]);
    }
}
int minxid,maxxid;
int ql,qr;
node query1(int o,int l,int r)
{
    if(ql<=l&&r<=qr)
        return minx[o];
    else
    {
        int m=(l+r)>>1;
        node tmp;
        if(ql<=m)
            tmp=query1(o<<1,l,m);
        if(m<qr)
            tmp=min(tmp,query1(o<<1|1,m+1,r));
        return tmp;
    }
}
node query2(int o,int l,int r)
{
    if(ql<=l&&r<=qr)
        return maxx[o];
    else
    {
        int m=(l+r)>>1;
        node tmp(0,0);
        if(ql<=m)
            tmp=query2(o<<1,l,m);
        if(m<qr)
            tmp=max(tmp,query2(o<<1|1,m+1,r));
        return tmp;
    }
}
void initLca(int n)
{
    for(int i=1;i<=17;i++)
        for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];
}
int lca(int x,int y)
{
    if(dep[x]>dep[y])swap(x,y);
    int d=dep[y]-dep[x];
    for(int i=0;(1<<i)<=d;i++)
        if((1<<i)&d)y=fa[y][i];
    if(x==y)return x;
    else
    {
        for(int i=17;i>=0;i--)
        if(fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
        return fa[x][0];
    }
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    int n,q,u,v;
    cin>>n>>q;
    for(int i=2;i<=n;i++)
    {
        cin>>v;
        g[i].push_back(v),g[v].push_back(i);
    }
    dep[1]=0;
    dfs(1,0);
    initLca(n);
    build(1,1,n);
    node t1,t2;
    for(int i=1;i<=q;i++)
    {
        cin>>ql>>qr;
        t1=query1(1,1,n);t2=query2(1,1,n);
        int temp=0,lca1,lca2,alca,tql=ql,tqr=qr,tans;
        u=t1.id,v=t2.id;
        qr=u-1;
        if(qr<ql)
            lca1=qr+2;
        else
        {
            t1=query1(1,1,n);t2=query2(1,1,n);
            lca1=lca(t1.id,t2.id);
        }
        ql=u+1,qr=tqr;
        if(qr<ql)
        {
            lca2=ql-2;
        }
        else
        {
            t1=query1(1,1,n);t2=query2(1,1,n);
            lca2=lca(t1.id,t2.id);
        }
        
        tans=lca(lca1,lca2);
        
        ql=tql,qr=v-1;
        
        if(qr<ql)
            lca1=qr+2;
        else
        {
            t1=query1(1,1,n);t2=query2(1,1,n);
            lca1=lca(t1.id,t2.id);
        }
        ql=v+1,qr=tqr;
        if(qr<ql)
            lca2=ql-2;
        else
        {
            t1=query1(1,1,n);t2=query2(1,1,n);
            lca2=lca(t1.id,t2.id);
        }
        alca=lca(lca1,lca2);
        if(dep[alca]>dep[tans])
            cout<<v<<" "<<dep[alca]<<"\n";
        else
            cout<<u<<" "<<dep[tans]<<"\n";
    }
    return 0;
}

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

此站点使用 Akismet 来减少垃圾评论。了解我们如何处理您的评论数据