import torch
from torch import nn
import torch.optim as optim
from torch.autograd import grad
import higher
from copy import deepcopy


class Model(nn.Module):

    def __init__(self, n_in, n_out):
        super().__init__()
        self.x = nn.Linear(n_in, n_out, bias=False)

    def forward(self, input):
        return self.x(input)


class ComBModel(nn.Module):

    def __init__(self, model_1, model_2, model_3):
        super().__init__()
        self.policy_params = [sum(p.numel() for p in self.parameters() if p.requires_grad), 0]
        self.policy = model_1
        self.policy_params[1] = sum(p.numel() for p in self.parameters() if p.requires_grad)

        self.critic_params = [sum(p.numel() for p in self.parameters() if p.requires_grad), 0]
        self.critic = model_2
        self.critic_params[1] = sum(p.numel() for p in self.parameters() if p.requires_grad)

        self.meta_network = model_3

        self.optimizer = optim.SGD(list(self.policy.parameters()) + list(self.critic.parameters()), lr=1.0)

        self.optimizer = optim.SGD([{'params': self.policy.parameters(), 'lr': 1.0},
                                    {'params': self.critic.parameters(), 'lr': 1.0}], lr=1.0)

        self.meta_optimizer = optim.SGD(self.meta_network.parameters(), lr=1.0)
        #self.optimizer1 = optim.SGD(self.policy.parameters(), lr=1.0)
        #self.optimizer2 = optim.SGD(self.critic.parameters(), lr=1.0)

    def forward(self, input):
        policy_loss = self.policy(input)
        critic_loss = self.critic(input)
        meta_loss = self.meta_network(input)
        return meta_loss * policy_loss, critic_loss

    def get_loss(self, input):
        policy_loss = self.policy(input)
        critic_loss = self.critic(input)
        return policy_loss, critic_loss

    def forward_only_1(self, input):
        return self.model1(input)


x = Model(2, 1)
y = Model(2, 1)
meta_model = Model(2, 1)
z = ComBModel(x, y, meta_model)
z_copy = deepcopy(z)


def clip_grad_callback_old(grads, max_norm=1):

    total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), 2.0) for grad in grads]), 2.0)
    print("GRADS", grads)
    print("total norm", total_norm)
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    new_grads = []
    for grad in grads:
        #p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
        new_grads.append(grad * clip_coef_clamped)
    print("Clipped grads", new_grads)
    return tuple(new_grads)


def clip_grad_callback(grads, max_norm=10):

    policy_se = z.policy_params
    value_se = z.critic_params
    sorted_grads = {"policy_grads": [], "value_grads": []}
    parameter_count = 0

    grads_cleaned = []
    for grad in grads:
        if grad is not None:
            grad[grad != grad] = 0.0
            grads_cleaned.append(grad)
    grads = grads_cleaned

    for grad in grads:
        num_parameters = grad.view(-1).shape[0]
        new_param_count = parameter_count + num_parameters
        if parameter_count >= policy_se[0] and new_param_count <= policy_se[1]:
            sorted_grads["policy_grads"].append(grad)
        elif parameter_count >= value_se[0] and new_param_count <= value_se[1]:
            sorted_grads["value_grads"].append(grad)
        parameter_count = new_param_count

    sorted_factors = dict()
    for key in sorted_grads:
        total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), 2.0) for grad in sorted_grads[key]]), 2.0)
        clip_coef = max_norm / (total_norm + 1e-6)
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        sorted_factors[key] = clip_coef_clamped

    new_grads = []
    parameter_count = 0
    for grad in grads:
        num_parameters = grad.view(-1).shape[0]
        new_param_count = parameter_count + num_parameters
        if parameter_count >= policy_se[0] and new_param_count <= policy_se[1]:
            new_grads.append(grad * sorted_factors['policy_grads'])
        elif parameter_count >= value_se[0] and new_param_count <= value_se[1]:
            new_grads.append(grad * sorted_factors['value_grads'])
        else:
            new_grads.append(grad)
        parameter_count = new_param_count
    print("Old grads", grads)
    print("Clip values", sorted_factors)
    print("Clipped grads", new_grads)
    return tuple(new_grads)


print("policy parameters", list(x.parameters()))
print("critic parameters", list(y.parameters()))
print("meta parameters", list(meta_model.parameters()))


in_ = torch.Tensor([1, 1])

for k in range(3):
    print("outer loop", k)
    with higher.innerloop_ctx(z, z.optimizer, copy_initial_weights=False) as (fnet, diffopt):

        for i in range(2):
            print("INNER LOOP", i)
            # using get loss does not work even though it is the same method
            #policy_loss, critic_loss = fnet.get_loss(in_)

            # using the forward method works
            policy_loss, critic_loss = fnet(in_)
            diffopt.step(policy_loss + critic_loss, grad_callback=clip_grad_callback)

            print("fnet params", fnet.parameters())

            for p in fnet.policy.parameters():
                print("POLICY PARAMS", p)
            for p in fnet.critic.parameters():
                print("CRITIC PARAMS", p)
            for p in z.meta_network.parameters():
                print("META PARAMS", p)

        policy_loss, _ = fnet(in_)
        z.meta_optimizer.zero_grad()
        policy_loss.backward()
        z.meta_optimizer.step()

        z.policy.load_state_dict(fnet.policy.state_dict())
        z.critic.load_state_dict(fnet.critic.state_dict())

        #for p in z.model1.parameters():
        #    print("MODEL 1 PARAMS", p)

'''

print("NOW WITHOUT CONTEXT")

print("model 1 parameters", list(z_copy.model1.parameters()))


for k in range(3):
    print("outer loop", k)

    for i in range(2):
        print("INNER LOOP", i)
        loss = z_copy.model1(in_)
        z_copy_optimizer.zero_grad()
        loss.backward()
        z_copy_optimizer.step()

        for p in z_copy.model1.parameters():
            print("Z COPY MODEL 1 PARAMS", p)
            print("Z COPY MODEL 1 GRADS", p.grad)

    #z_copy.model1.load_state_dict(fnet.model1.state_dict())

    for p in z_copy.model1.parameters():
        print("Z COPY MODEL 1 PARAMS", p)
        
'''