#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May  9 23:55:59 2020

@author: mario
"""

import connected_k_dominating_set
import lp_to_nx_graph
import networkx as nx
import matplotlib.pyplot as plt
import gurobipy as gp
import sys
import datetime
import math


def add_constraints(G, m, nodes, root):
    m.addConstr(nodes[root] >= 1)

def add_path_constraints(G, m, nodes, root):
    m.addConstrs((nodes[v] * len(nx.algorithms.shortest_path(G, root, v))) <= gp.quicksum(nodes) for v in G.nodes)

def add_path_constraints2(G, m, nodes):
    m.addConstrs((nodes[v] * nodes[w] * len(nx.algorithms.shortest_path(G, v, w))) <= gp.quicksum(nodes) for v in G.nodes for w in G.nodes)

def add_path_constraints3(G, m, nodes, root):
    m.addConstr(gp.quicksum(nodes[v] * len(nx.algorithms.shortest_path(G, root, v)) for v in G.nodes) <= (gp.quicksum(nodes)+1)*gp.quicksum(nodes)/2)
    
def add_path_constraints4(G, m, nodes):
    m.addConstr(gp.quicksum(nodes[v] * nodes[w] * len(nx.algorithms.shortest_path(G, v, w)) for v in G.nodes for w in G.nodes) <= (gp.quicksum(nodes)+1)*gp.quicksum(nodes)/2)

def add_vertex_separator_degree_constraints(G, m, nodes):
    for i in G.nodes:
        if(G.degree[i] < 6):
            for j in G.nodes:
                if i != j and j not in G.neighbors(i):
                    min_ij_sep = connected_k_dominating_set.min_ij_separator(G, i, j, {i})
                    m.addConstr(gp.quicksum(nodes[s] for s in min_ij_sep) >= nodes[i] + nodes[j] - 1)
                    
def add_all_vertex_separator_constraaints(G, m, nodes):
    for i in G.nodes:
        for j in G.nodes:
            if i != j and j not in G.neighbors(i):
                min_ij_sep = connected_k_dominating_set.min_ij_separator(G, i, j, {i})
                m.addConstr(gp.quicksum(nodes[s] for s in min_ij_sep) >= nodes[i] + nodes[j] - 1)
                    

def model(G, k, root):
    m, nodes = connected_k_dominating_set.model(G, k, "MINkRCDS")
    add_constraints(G, m, nodes, root)
    # add_path_constraints(G, m, nodes, root)
    # add_path_constraints2(G, m, nodes)
    # add_path_constraints3(G, m, nodes, root)
    # add_path_constraints4(G, m, nodes)
    # add_vertex_separator_degree_constraints(G, m, nodes)
    # add_all_vertex_separator_constraaints(G, m, nodes)
    return m, nodes

def solve(G, k, root, maxIterations):
    m, nodes = model(G, k, root)
    return connected_k_dominating_set.solve_iteratively(G, k, m, nodes, maxIterations)

if __name__ == '__main__':
    # G = nx.Graph()
    # G.add_nodes_from(range(16))
    # G.add_edges_from([(0,1), (0,2), (1,2), (1,3), (1,4), (1,7), (2,4), (2,5), (2,8), (3,6), (3,7), (3,4), (3,10), (4,7), (4,8), (4, 5), (4,11), (5,8), (5,9), (5,12),
    #                   (6,7), (6,10), (7,8), (7,10), (7,13), (7,11), (8,9), (8,11), (8,12), (8,14), (10,11), (10,13), (11,13), (11,14), (11,12), (13,14), (13,15), (14,15)])
    
    # G.add_edges_from([(0,1), (0,2), (1,3), (1,4), (2,4), (2,5), (3,6), (3,7), (4,7), (4,8), (5,8), (5,9),
    #                   (6,10), (7,10), (7,11), (8,11), (8,12), (9,12), (10,13), (11,13), (11,14), (12,14), (13,15), (14,15)])
    
    # maxIterations = float("inf")
    # maxIterations = 5
    
    G = lp_to_nx_graph.read(sys.argv[1])
    
    if(len(sys.argv) > 2):
        k = int(sys.argv[2])
    else:
        k = 1
    if(len(sys.argv) > 3):
        maxIterations = int(sys.argv[3])
    else:
        maxIterations = float("inf")
        
    starttime = datetime.datetime.now()
    ds, iterations = solve(G, k, 0, maxIterations)
    endtime = datetime.datetime.now()
    duration = endtime- starttime
    duration_sec = duration.total_seconds()
    print(f"iterations: {iterations}, duration(s): {duration_sec}")
    color_map = ['red' if i in ds else 'green' for i in G.nodes]
    # nx.draw(G, node_color = color_map, with_labels = True)
    nx.draw_kamada_kawai(G, node_color = color_map)
    plt.show()