#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May  9 23:55:40 2020

@author: mario
"""

import k_transitive_closure 
import k_dominating_set
import networkx as nx
import matplotlib.pyplot as plt
import gurobipy as gp
from gurobipy import GRB

def solve(G, k):
    m, nodes = model(G, k, "MINkCDS")
    add_constraints(G, k, m)
#
    return solve_iteratively(G, k, m, nodes)

def min_ij_separator(G, i, j, C_i):
    N_ci = {v for c in C_i for v in G.neighbors(c)}
    G_prime = nx.Graph(G)
    C_i_prime = C_i.copy()
    C_i_prime.update(N_ci)
    G_prime.remove_edges_from(G.subgraph(C_i_prime).edges)
    # dijkstra
    R_j = nx.algorithms.dag.descendants(G_prime, j)
    return R_j.intersection(N_ci)
    
def model(G, k, name):
    m, nodes = k_dominating_set.model(G,k,name)
    add_constraints(G, m, nodes)
    return m, nodes

def add_base_connectivity_constraint(G, m, nodes):
    m.addConstrs(nodes[v] <= gp.quicksum(nodes[w] for w in G.neighbors(v)) for v in G.nodes)

def add_constraints(G, m, nodes):
    add_base_connectivity_constraint(G, m, nodes)
 
def solve_iteratively(G, k, m, nodes, maxIterations):
    iterations = 0
    m.optimize()
    
    ds = {i for i,x_i in enumerate(m.getVars()) if x_i.x == 1}
    
    G_prime_prime = G.subgraph(ds)
    while(not nx.is_connected(G_prime_prime)) and iterations < maxIterations:
        iterations+=1
        C = [c for c in nx.algorithms.components.connected_components(G_prime_prime)]
        for i in range(len(C)-1):
            C_i = C[i]
            for j in range(i+1, len(C)):
                C_j = C[j]
                h = next(iter(C_i))
                l = next(iter(C_j))
                min_ij_sep = min_ij_separator(G, h, l, C_i)
                m.addConstr(gp.quicksum(nodes[s] for s in min_ij_sep) >= nodes[h] + nodes[l] - 1)
        
        m.optimize()
        ds = {i for i,x_i in enumerate(m.getVars()) if x_i.x == 1}
        G_prime_prime = G.subgraph(ds)

    return ds, iterations