原题链接

题目描述

给出了一个由 nn 个节点组成的网络,用 n×nn × n 个邻接矩阵图 graph 表示。在节点网络中,当 graph[i][j] = 1 时,表示节点 ii 能够直接连接到另一个节点 jj

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

如果从 initial移除某一节点能够最小化 M(initial), 返回该节点。如果有多个节点满足条件,就返回索引最小的节点。

请注意,如果某个节点已从受感染节点的列表 initial 中删除,它以后仍有可能因恶意软件传播而受到感染。

示例 1:

输入:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输出:0

示例 2:

输入:graph = [[1,0,0],[0,1,0],[0,0,1]], initial = [0,2]
输出:0

示例 3:

输入:graph = [[1,1,1],[1,1,1],[1,1,1]], initial = [1,2]
输出:1

提示:

  • n==graph.lengthn == graph.length
  • n==graph[i].lengthn == graph[i].length
  • 2n3002 \le n \le 300
  • graph[i][j]==0graph[i][j] == 0graph[i][j]==1graph[i][j] == 1.
  • graph[i][j]==graph[j][i]graph[i][j] == graph[j][i]
  • graph[i][i]==1graph[i][i] == 1
  • 1initial.lengthn1 \le initial.length \le n
  • 0initial[i]n10 \le initial[i] \le n - 1
  • initial 中所有整数均不重复

并查集

给了一个图,图中有许多连通块,每个连通块中有 00 个或多个被感染的节点,现在可以对某个节点杀毒,杀毒完成后,要求最终未被感染的节点最多,如果有多个符合要求的节点,返回下标最小的那个节点。

可以使用并查集来维护这些连通块,连通块中被感染节点数量如下:

  1. 00 个:不需要管,不会扩撒。

  2. 11 个:杀毒后,这个连通块将不会被感染。

  3. 22 个及以上:即使杀毒了,其他被感染的节点最终也会感染这个连通块。

综上,找到所有符合第 22 种情况的连通块,且该连通块的节点数量最多。若有多个符合情况,返回节点索引最小的即可。

const int N = 305;

class Solution {
private:
    int p[N], cnt[N], sum[N]; // p并查集,cnt连通块中被感染的数量,sum连通块节点数量

    int find(int x)
    {
        return p[x] == x ? p[x] : p[x] = find(p[x]);
    }

public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        // 初始化并查集
        for ( int i = 0; i < n; i ++ ) 
        {
            p[i] = i;
            sum[i] = 1;
            cnt[i] = 0;
        }
        // 连通各个节点
        for ( int i = 0; i < n; i ++ )
            for ( int j = 0; j < n; j ++ )
                // i, j连通,且并查集中未连通
                if ( graph[i][j] && find(i) != find(j) )
                {
                    sum[find(i)] += sum[find(j)];
                    p[find(j)] = find(i);
                }
        // 初始化各个块被感染的数量
        for ( auto x : initial ) cnt[find(x)] ++;

        int res = INT_MAX, tmp_max = -1;

        for ( auto x : initial )
        {
            // 当没有第二种情况出现时
            if ( tmp_max == -1 ) 
                res = min(x, res);
            if ( cnt[find(x)] == 1 )
            {
                if ( sum[find(x)] > tmp_max ) 
                {
                    tmp_max = sum[find(x)];
                    res = x;
                }
                else if ( sum[find(x)] == tmp_max )
                    res = min(res, x);
            }
        }
        return res;
    }
};