【HITCTF2023】 Network in network 神经网络图片恢复 题目描述 拿到题目后, 有三个文件, 一个是模型源码, 一个是跑模型编码后的图片, 还有模型的 pt 参数文件.
源码为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 import torchimport torch.nn as nnimport torchvision.transforms as transformsimport torchvisionfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npfile = Image.open ('origin.jpg' ) trans = transforms.Compose([ transforms.ToTensor(), ]) m = trans(file) torch.manual_seed(0x2daa1a1 ) net = nn.Sequential( nn.Conv2d(3 , 5 , 3 ), nn.ReLU(), nn.Conv2d(5 , 10 , 5 ), nn.ReLU(), nn.MaxPool2d(2 ), nn.Linear(317 , 800 ), nn.Conv2d(10 , 1 , 1 ), nn.Sigmoid() ) torchvision.utils.save_image(net(m).squeeze(), 'enc.png' ) torch.save(net, 'net.pt' )
解题思路 通过观察源码, 整个过程是经过一个卷积层变为 5 个通道数, 然后 Relu 非线性变化, 在经过一个卷积层变为 10 个通道数, 然后经过 Relu 与池化后经过一个全连接层, 再做卷积和 sigmoid 得到编码后的图片.
注意到整个过程会对图片产生影响的只有线性层 nn.Linear
, 其他层由于 kernel 很小, 做特征值处理后能够肉眼识别出 flag. 所以这里考虑复原 sigmoid
, Conv2d
, Linear
.
sigmoid 复原 sigmoid 函数
$$y = \frac{1}{1 + e^{-x}}$$
逆运算,求得 $x$:
$$x = -\ln\left(\frac{1}{y} - 1\right)$$
卷积层复原 卷积层的函数
$$Y_{i,j,k} = \sum_{m=1}^{C_{in}} W_{m,k} \cdot X_{i,j,m} + b_k$$
其中 $W$ 为每个通道的卷积核权重, $b_k$ 为每个通道的偏置
做复原, 仅需
$$X = \frac{Y_{i,j,k}-b_k}{\sum_{m=1}^{C_{in}} W_{m,k}}$$
全连接层复原 全连接层的函数
$$Y = XW^T + b$$
复原
$$X = \frac{Y-b}{W^T}$$
以上这些层的权重参数都可以从 net.pt
文件中提取出来, 使用 pytorch 的 detach()
函数, 图像复原代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 import torchimport torch.nn as nnimport torchvision.transforms as transformsimport torchvisionfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npnet = torch.load('net.pt' ) image = np.array(Image.open ('enc.png' ).convert('L' )) / 255 image = torch.Tensor(image).reshape([1 , 197 , 800 ]) w_conv_sum = net[-2 ].weight.detach().sum () b_conv = net[-2 ].bias.detach() w_linear = net[-3 ].weight.detach() b_linear = net[-3 ].bias.detach() image = -torch.log((1 /image)-1 ) recover_image = ((image-b_conv)/w_conv_sum - b_linear) @ w_linear.T.pinverse() plt.imshow(recover_image[0 ,:,:],cmap='gray' ) plt.show()
复原后的效果为: