分享

画pytorch模型图,以及参数计算

 LibraryPKU 2019-09-30

    刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

    首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。话不多说,上代码吧。

  1. import torch
  2. from torch.autograd import Variable
  3. import torch.nn as nn
  4. from graphviz import Digraph
  5. class CNN(nn.Module):
  6. def __init__(self):
  7. super(CNN, self).__init__()
  8. self.conv1 = nn.Sequential(
  9. nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
  10. nn.ReLU(),
  11. nn.MaxPool2d(kernel_size=2)
  12. )
  13. self.conv2 = nn.Sequential(
  14. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
  15. nn.ReLU(),
  16. nn.MaxPool2d(kernel_size=2)
  17. )
  18. self.out = nn.Linear(32*7*7, 10)
  19. def forward(self, x):
  20. x = self.conv1(x)
  21. x = self.conv2(x)
  22. x = x.view(x.size(0), -1) # (batch, 32*7*7)
  23. out = self.out(x)
  24. return out
  25. def make_dot(var, params=None):
  26. """ Produces Graphviz representation of PyTorch autograd graph
  27. Blue nodes are the Variables that require grad, orange are Tensors
  28. saved for backward in torch.autograd.Function
  29. Args:
  30. var: output Variable
  31. params: dict of (name, Variable) to add names to node that
  32. require grad (TODO: make optional)
  33. """
  34. if params is not None:
  35. assert isinstance(params.values()[0], Variable)
  36. param_map = {id(v): k for k, v in params.items()}
  37. node_attr = dict(style='filled',
  38. shape='box',
  39. align='left',
  40. fontsize='12',
  41. ranksep='0.1',
  42. height='0.2')
  43. dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  44. seen = set()
  45. def size_to_str(size):
  46. return '('+(', ').join(['%d' % v for v in size])+')'
  47. def add_nodes(var):
  48. if var not in seen:
  49. if torch.is_tensor(var):
  50. dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
  51. elif hasattr(var, 'variable'):
  52. u = var.variable
  53. name = param_map[id(u)] if params is not None else ''
  54. node_name = '%s\n %s' % (name, size_to_str(u.size()))
  55. dot.node(str(id(var)), node_name, fillcolor='lightblue')
  56. else:
  57. dot.node(str(id(var)), str(type(var).__name__))
  58. seen.add(var)
  59. if hasattr(var, 'next_functions'):
  60. for u in var.next_functions:
  61. if u[0] is not None:
  62. dot.edge(str(id(u[0])), str(id(var)))
  63. add_nodes(u[0])
  64. if hasattr(var, 'saved_tensors'):
  65. for t in var.saved_tensors:
  66. dot.edge(str(id(t)), str(id(var)))
  67. add_nodes(t)
  68. add_nodes(var.grad_fn)
  69. return dot
  70. if __name__ == '__main__':
  71. net = CNN()
  72. x = Variable(torch.randn(1, 1, 28, 28))
  73. y = net(x)
  74. g = make_dot(y)
  75. g.view()
  76. params = list(net.parameters())
  77. k = 0
  78. for i in params:
  79. l = 1
  80. print("该层的结构:" + str(list(i.size())))
  81. for j in i.size():
  82. l *= j
  83. print("该层参数和:" + str(l))
  84. k = k + l
  85. print("总参数数量和:" + str(k))
    模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

    大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以    这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

  1. net = CNN()
  2. x = Variable(torch.randn(1, 1, 28, 28))
  3. y = net(x)
  4. g = make_dot(y)
  5. g.view()
 


    最后输出该网络的各种参数。

  1. 该层的结构:[16, 1, 5, 5]
  2. 该层参数和:400
  3. 该层的结构:[16]
  4. 该层参数和:16
  5. 该层的结构:[32, 16, 5, 5]
  6. 该层参数和:12800
  7. 该层的结构:[32]
  8. 该层参数和:32
  9. 该层的结构:[10, 1568]
  10. 该层参数和:15680
  11. 该层的结构:[10]
  12. 该层参数和:10
  13. 总参数数量和:28938

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多