【HITCTF2023】 Network in network 神经网络图片恢复

题目描述

拿到题目后, 有三个文件, 一个是模型源码, 一个是跑模型编码后的图片, 还有模型的 pt 参数文件.

源码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

file = Image.open('origin.jpg')

trans = transforms.Compose([
transforms.ToTensor(),
])

m = trans(file)

torch.manual_seed(0x2daa1a1)

net = nn.Sequential(
nn.Conv2d(3, 5, 3),
nn.ReLU(),
nn.Conv2d(5, 10, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Linear(317, 800),
nn.Conv2d(10, 1, 1),
nn.Sigmoid()
)

torchvision.utils.save_image(net(m).squeeze(), 'enc.png')

torch.save(net, 'net.pt')

解题思路

通过观察源码, 整个过程是经过一个卷积层变为 5 个通道数, 然后 Relu 非线性变化, 在经过一个卷积层变为 10 个通道数, 然后经过 Relu 与池化后经过一个全连接层, 再做卷积和 sigmoid 得到编码后的图片.

注意到整个过程会对图片产生影响的只有线性层 nn.Linear , 其他层由于 kernel 很小, 做特征值处理后能够肉眼识别出 flag. 所以这里考虑复原 sigmoid, Conv2d, Linear.

sigmoid 复原

sigmoid 函数

$$y = \frac{1}{1 + e^{-x}}$$

逆运算,求得 $x$:

$$x = -\ln\left(\frac{1}{y} - 1\right)$$

卷积层复原

卷积层的函数

$$Y_{i,j,k} = \sum_{m=1}^{C_{in}} W_{m,k} \cdot X_{i,j,m} + b_k$$

其中 $W$ 为每个通道的卷积核权重, $b_k$ 为每个通道的偏置

做复原, 仅需

$$X = \frac{Y_{i,j,k}-b_k}{\sum_{m=1}^{C_{in}} W_{m,k}}$$

全连接层复原

全连接层的函数

$$Y = XW^T + b$$

复原

$$X = \frac{Y-b}{W^T}$$

以上这些层的权重参数都可以从 net.pt 文件中提取出来, 使用 pytorch 的 detach() 函数, 图像复原代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# 加载模型参数
net = torch.load('net.pt')
# 将图片转换为灰度图并将像素值归一化到 0~1
image = np.array(Image.open('enc.png').convert('L')) / 255
# 转换为 pytorch 张量
image = torch.Tensor(image).reshape([1, 197, 800])
# 获取卷积层的权重
w_conv_sum = net[-2].weight.detach().sum()
b_conv = net[-2].bias.detach()
# 获取线性变化层的矩阵
w_linear = net[-3].weight.detach()
b_linear = net[-3].bias.detach()
# sigmoid 逆变换
image = -torch.log((1/image)-1)
recover_image = ((image-b_conv)/w_conv_sum - b_linear) @ w_linear.T.pinverse()
plt.imshow(recover_image[0,:,:],cmap='gray')
plt.show()

复原后的效果为: