import torch from torch import nn import torch.optim as optim from torch.autograd import grad import higher from copy import deepcopy class Model(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.x = nn.Linear(n_in, n_out, bias=False) def forward(self, input): return self.x(input) class ComBModel(nn.Module): def __init__(self, model_1, model_2, model_3): super().__init__() self.policy_params = [sum(p.numel() for p in self.parameters() if p.requires_grad), 0] self.policy = model_1 self.policy_params[1] = sum(p.numel() for p in self.parameters() if p.requires_grad) self.critic_params = [sum(p.numel() for p in self.parameters() if p.requires_grad), 0] self.critic = model_2 self.critic_params[1] = sum(p.numel() for p in self.parameters() if p.requires_grad) self.meta_network = model_3 self.optimizer = optim.SGD(list(self.policy.parameters()) + list(self.critic.parameters()), lr=1.0) self.optimizer = optim.SGD([{'params': self.policy.parameters(), 'lr': 1.0}, {'params': self.critic.parameters(), 'lr': 1.0}], lr=1.0) self.meta_optimizer = optim.SGD(self.meta_network.parameters(), lr=1.0) #self.optimizer1 = optim.SGD(self.policy.parameters(), lr=1.0) #self.optimizer2 = optim.SGD(self.critic.parameters(), lr=1.0) def forward(self, input): policy_loss = self.policy(input) critic_loss = self.critic(input) meta_loss = self.meta_network(input) return meta_loss * policy_loss, critic_loss def get_loss(self, input): policy_loss = self.policy(input) critic_loss = self.critic(input) return policy_loss, critic_loss def forward_only_1(self, input): return self.model1(input) x = Model(2, 1) y = Model(2, 1) meta_model = Model(2, 1) z = ComBModel(x, y, meta_model) z_copy = deepcopy(z) def clip_grad_callback_old(grads, max_norm=1): total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), 2.0) for grad in grads]), 2.0) print("GRADS", grads) print("total norm", total_norm) clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) new_grads = [] for grad in grads: #p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device)) new_grads.append(grad * clip_coef_clamped) print("Clipped grads", new_grads) return tuple(new_grads) def clip_grad_callback(grads, max_norm=10): policy_se = z.policy_params value_se = z.critic_params sorted_grads = {"policy_grads": [], "value_grads": []} parameter_count = 0 grads_cleaned = [] for grad in grads: if grad is not None: grad[grad != grad] = 0.0 grads_cleaned.append(grad) grads = grads_cleaned for grad in grads: num_parameters = grad.view(-1).shape[0] new_param_count = parameter_count + num_parameters if parameter_count >= policy_se[0] and new_param_count <= policy_se[1]: sorted_grads["policy_grads"].append(grad) elif parameter_count >= value_se[0] and new_param_count <= value_se[1]: sorted_grads["value_grads"].append(grad) parameter_count = new_param_count sorted_factors = dict() for key in sorted_grads: total_norm = torch.norm(torch.stack([torch.norm(grad.detach(), 2.0) for grad in sorted_grads[key]]), 2.0) clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) sorted_factors[key] = clip_coef_clamped new_grads = [] parameter_count = 0 for grad in grads: num_parameters = grad.view(-1).shape[0] new_param_count = parameter_count + num_parameters if parameter_count >= policy_se[0] and new_param_count <= policy_se[1]: new_grads.append(grad * sorted_factors['policy_grads']) elif parameter_count >= value_se[0] and new_param_count <= value_se[1]: new_grads.append(grad * sorted_factors['value_grads']) else: new_grads.append(grad) parameter_count = new_param_count print("Old grads", grads) print("Clip values", sorted_factors) print("Clipped grads", new_grads) return tuple(new_grads) print("policy parameters", list(x.parameters())) print("critic parameters", list(y.parameters())) print("meta parameters", list(meta_model.parameters())) in_ = torch.Tensor([1, 1]) for k in range(3): print("outer loop", k) with higher.innerloop_ctx(z, z.optimizer, copy_initial_weights=False) as (fnet, diffopt): for i in range(2): print("INNER LOOP", i) # using get loss does not work even though it is the same method #policy_loss, critic_loss = fnet.get_loss(in_) # using the forward method works policy_loss, critic_loss = fnet(in_) diffopt.step(policy_loss + critic_loss, grad_callback=clip_grad_callback) print("fnet params", fnet.parameters()) for p in fnet.policy.parameters(): print("POLICY PARAMS", p) for p in fnet.critic.parameters(): print("CRITIC PARAMS", p) for p in z.meta_network.parameters(): print("META PARAMS", p) policy_loss, _ = fnet(in_) z.meta_optimizer.zero_grad() policy_loss.backward() z.meta_optimizer.step() z.policy.load_state_dict(fnet.policy.state_dict()) z.critic.load_state_dict(fnet.critic.state_dict()) #for p in z.model1.parameters(): # print("MODEL 1 PARAMS", p) ''' print("NOW WITHOUT CONTEXT") print("model 1 parameters", list(z_copy.model1.parameters())) for k in range(3): print("outer loop", k) for i in range(2): print("INNER LOOP", i) loss = z_copy.model1(in_) z_copy_optimizer.zero_grad() loss.backward() z_copy_optimizer.step() for p in z_copy.model1.parameters(): print("Z COPY MODEL 1 PARAMS", p) print("Z COPY MODEL 1 GRADS", p.grad) #z_copy.model1.load_state_dict(fnet.model1.state_dict()) for p in z_copy.model1.parameters(): print("Z COPY MODEL 1 PARAMS", p) '''