64) conv2d/bias:0 (64,) conv2d_1/kernel:0 (3, 3, 64, 64) conv2d_1/bias:0 (64,) conv2d_2/kernel:0 (3, 3, 64, 128) conv2d_2/bias:0 (128,) … dense/kernel:0 (25088, 4096) dense/bias:0 (4096,) dense_1/kernel:0 (4096, 4096) dense_1/bias:0 (4096,) dense_2/kernel:0 (4096, 1000) dense_2/bias:0 (1000,) torch_params = torch_model.state_dict() torch_keys = list(torch_params.keys()) for layer in tf_model.layers: for var in (layer.weights): torch_key = torch_keys.pop(0) torch_param = torch_params[torch_key].numpy() if len(torch_param.shape) == 4: # Conv2d.weight var.assign(torch_param.transpose(2,3,1,0)) elif len(torch_param.shape) == 2: # Linear.weight var.assign(torch_param.transpose(1,0)) else: var.assign(torch_param) features.0.weight (64, 3, 3, 3) features.0.bias (64,) features.2.weight (64, 64, 3, 3) features.2.bias (64,) features.5.weight (128, 64, 3, 3) features.5.bias (128,) … classifier.0.weight (4096, 25088) classifier.0.bias (4096,) classifier.3.weight (4096, 4096) classifier.3.bias (4096,) classifier.6.weight (1000, 4096) classifier.6.bias (1000,) TFVGG() torchvision.models.vgg16()