[CF][图论][最短路树] codeforces 1076D Edge Deletion

这个题倒是不难,但是我当时对于题目的理解导致写出来很复杂,所以记一下踩的坑。其实就是最短路树的应用,我的想法是先删掉不在树上的边,然后对于树上的边先删除连叶子的边,以此类推。这样写倒是没有错也可以过。

但是!这个题题面里说的是最多k个,所以其实根本不用考虑去掉非树边什么的,假如k>n-1的话只输出n-1条树边也是对的!这就省了好多代码了。其次,根本不用使用拓扑排序那样的处理过程(就是先去掉叶子这种),因为想一下dijkstra的过程就可以发现,每次从优先队列中取出来一个节点后,这个节点的前驱最短路其实已经确定了,只要在这时候把其前驱记录下来,根据dijkstra的每一步往外扩展一步的操作,按照这个顺序搞下去最后的记录表中的顺序其实就是拓扑序。。。。依次输出即可。

只能说自己对算法理解还是太浅吧,而且对题目也没有认真考虑,写了又臭又长的代码emmmm

自己的代码:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <set>
#include <map>
#include <string>
#include<cmath>
#include<queue>

using namespace std;

typedef long long LL;
const int maxn = 6e5 + 5;
const int mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;

LL dis[maxn];
struct Edge
{
    int u,v,id;LL w;
}edges[maxn*2];

struct node
{
    int id;LL dis;
    node(int id,LL dis):id(id),dis(dis){}
    bool operator<(const struct node& rhs)const{
        return dis>rhs.dis;
    }
};

vector<int > G[maxn];
priority_queue<node> Q;
int deg[maxn],pa[maxn],me,pe[maxn],fgs[maxn*2];//pe and fgs store id
set<int> S;

void addedge(int u,int v,int w,int id)
{
    me++;
    edges[me].u=u,edges[me].v=v,edges[me].w=w,edges[me].id=id;
    G[u].push_back(me);
    me++;
    edges[me].u=v,edges[me].v=u,edges[me].w=w,edges[me].id=id;
    G[v].push_back(me);
}

void sssp(int s,int n,int k)
{
    memset(dis,INF,sizeof(dis));
    dis[s]=0;
    Q.push(node(s,0));
    while(!Q.empty())
    {
        node cn=Q.top();
        Q.pop();
        if(cn.dis!=dis[cn.id])continue;
        int cd=cn.dis,u=cn.id,len=G[cn.id].size();
        for(int i=0;i<len;i++)
        {
            Edge & e=edges[G[u][i]];
            if(dis[e.v]>dis[u]+e.w)
            {
                if(pa[e.v])
                {
                    deg[pa[e.v]]--;
                    fgs[pe[e.v]]=0;
                    
                }
                deg[u]++;
                pa[e.v]=u;
                pe[e.v]=e.id;
                if(e.id==0)
                    while(1);
                fgs[e.id]=1;
                dis[e.v]=dis[u]+e.w;
                Q.push(node(e.v,dis[e.v]));
            }
        }
    }
    if(k>=n-1)//其实是不用考虑的
    {
        for(int i=1;i<=n;i++)
                if(i!=s)
                {
                    cout<<pe[i]<<" ";
                }
        k-=n-1;
        for(int i=1;(i<<1)<=me&&k>0;i++)
        {
            if(fgs[i]==0)
            {
                k--;
                cout<<i<<" ";
            }
        }
        return ;
    }

    for(int i=1;i<=n;i++)
    {
        if(deg[i]==0)
        {
            S.insert(i);
        }
    }
    int cc=n-1;
    while(!S.empty()&&k<cc)
    {
        cc--;
        int cur=(*S.begin());
        int pre=pa[cur],pid=pe[cur];
        fgs[pid]=0;
        deg[pre]--;
        S.erase(S.begin());
        if(deg[pre]==0)
        {
            S.insert(pre);
        }
    }
    for(int i=1;(i<<1)<=me;i++)
        {
            if(fgs[i]!=0)
                cout<<i<<" ";
        }
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n,m,k,u,v,w;
    cin>>n>>m>>k;
    for(int i=1;i<=m;i++)
    {
        cin>>u>>v>>w;
        addedge(u,v,w,i);
    }
    cout<<k<<"\n";//其实没必要输出k个
    sssp(1,n,k);
    return 0;
}

官方代码:

#include <bits/stdc++.h>

using namespace std;

vector<pair<int, pair<int, int>>> g[300043];

int main()
{
  int n, m, k;
  scanf("%d %d %d", &n, &m, &k);
  for (int i = 0; i < m; i++)
  {
    int x, y, w;
    scanf("%d %d %d", &x, &y, &w);
    --x;
    --y;
    g[x].push_back(make_pair(y, make_pair(w, i)));
    g[y].push_back(make_pair(x, make_pair(w, i)));
  }
  set<pair<long long, int>> q;
  vector<long long> d(n, (long long)(1e18));
  d[0] = 0;
  q.insert(make_pair(0, 0));
  vector<int> last(n, -1);
  int cnt = 0;
  vector<int> ans;
  while (!q.empty() && cnt < k)
  {
    auto z = *q.begin();
    q.erase(q.begin());
    int k = z.second;
    if (last[k] != -1)
    {
      cnt++;
      ans.push_back(last[k]);
    }
    for (auto y : g[k])
    {
      int to = y.first;
      int w = y.second.first;
      int idx = y.second.second;
      if (d[to] > d[k] + w)
      {
        q.erase(make_pair(d[to], to));
        d[to] = d[k] + w;
        last[to] = idx;
        q.insert(make_pair(d[to], to));
      }
    }
  }
  printf("%d\n", ans.size());
  for (auto x : ans)
    printf("%d ", x + 1);
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据