【网信中山】Misc Pytorch 神经网络复原

题目描述

中山杯打到后期,放出了一道 misc 题,下载附件拿到一个压缩包。解压获得了如下文件:

  • .npy 文件 : 用于存储 numpy 数组
  • label.json 文件 : 用于存储模型标签
  • MyLeNet.pt 文件 : PyTorch 的序列化文件格式,用于保存和加载 PyTorch 模型的参数和状态。这种格式方便用户在训练过程中保存检查点,以及在后续的推理或继续训练过程中加载模型。

思考过程

在这里思路其实很明显: 先根据 pt 文件复原模型,然后将 npy 文件全部喂给模型,得到的输出根据 label.json文件做标签分类,然后再拼接获得flag。

这里最难的就是模型的复原,完全不知道他使用的哪个非线性方法,我试了很久也没试出来,跑出来的结果对称性非常强,一度让我以为是还要解码,最后遗憾告终,对这类题还是太陌生了。

题解

比赛结束后看了唯一做出来的那队的 WP,才发现居然要逆MyLeNet.pt文件,直接把 MyLeNet 文件放到 CyberChef 里,看内容,发现里面有 sigmoid 和 relu 非线性方法,由于只有两种非线性方法,一共四个层,做一个排列组合就能够拿到 flag 了,

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import json

# 定义模型结构,与提供的 MyLeNet.pt 结构匹配
class MyLeNet(nn.Module):
def __init__(self):
super(MyLeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 62)

def forward(self, x):
# 这个地方要排列组合去尝试
x = torch.sigmoid(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(-1, 256)
x = torch.sigmoid(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

# 加载模型参数
model = MyLeNet()
model.load_state_dict(torch.load('MyLeNet.pt', map_location=torch.device('cpu')))
model.eval()
# 定义一个空字符串用于拼接结果
predicted_string = ""

# 加载标签映射
with open("label.json", "r") as f:
label_map = json.load(f)

# 加载并预测样本
for i in range(56):
sample = np.load(f"{i}.npy")
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0).unsqueeze(0) # 添加 batch 和 channel 维度
prediction = model(sample)
predicted_label = torch.argmax(prediction, dim=1).item()

# 根据标签映射找到预测的字符
predicted_character = [char for char, label in label_map.items() if label == predicted_label][0]

# 将预测的字符添加到字符串中
predicted_string += predicted_character

# 打印拼接后的完整字符串
print(predicted_string)

然后试的话发现其中 sigmoid+relu+sigmoid+relu 这个组合的结果为:

ZmxhZ3thNzkzZjI1Ny01Nzg4LWYwZjYtY2E5Zi00YTgyZWE3MmUwYzZ9

眼熟的 base64,拿去 base64 解码即可获得 flag:

flag{a793f257-5788-f0f6-ca9f-4a82ea72e0c6}