跑通X2CT

我们初始的目标就是跑通X2CT,传入一套X光的正侧面,让该网络为我们生成对应的CT 3d模型

test运行命令分析

由于我们使用的是正侧两面,即Multi-view,官网给的命令是

1
python test.py --ymlpath=./experiment/multiview2500/d2_multiview2500.yml --gpu=0 --dataroot=./data/LIDC-HDF5-256 --dataset=test --tag=d2_multiview2500 --data=LIDC256 --dataset_class=align_ct_xray_views_std --model_class=MultiViewCTGAN --datasetfile=./data/test.txt --resultdir=./multiview --check_point=90 --how_many=3
  • –ymlpath指定模型参数
  • –gpu指定0号gpu
  • –dataroot指定数据源位置(需要改)
  • –dataset指定我们目标是test
  • –tag是包含模型的实验名
  • –model_class指定我们使用的模型名:MultiViewCTGAN
  • –data指定输入数据的前缀,用于保存和加载,如LIDC256(需要改)
  • –dataset_class输入数据的格式,如单X光或双X光
  • –resultdir: 算法的输出路径
  • –how_many: 可视化会跑多少个样本 (只用在可视化模式下)
  • –datasetfile:测试需要的data列表

test.py/evaluate函数

基础配置

首先进行了一些基础配置,如gpu, check point等

然后将args中的参数和我们指定的yml文件中的配置进行整合,记作opt

然后重点来了,调用了一个get_dataset函数,参数我们传入的dataset-class。

Dataset

经过查看factory.py中的get_dataset函数,可知道有这样五类dataset格式:

image-20230315191930025

不同的数据格式决定了我们使用不同的数据加载类

DataLoader

然后用上述得到的dataset数据集创建了一个dataloader,便于一会遍历

model: init_process

测试环境下,即等同于init_network

训练环境下,会多一个init_loss

预测部分

tqdm是进度条库,起美观效果,应该没啥用

enumerate dataloader,以epoch为单位取出数据

对于每组数据,model.set_input(data)

DataSet

BaseDataSet

  • get_item:pull_item
  • len:num_samples

AlignDataSet

  • pull_item(self, item):找出第item个filepath(是一个h5文件),取出第0维的ct,第1维的xray1,并给xray1扩充一维,

Model部分

  • 继承于Base_model

init_network

按照传入的opt进行了初始化,没啥特别的

set_input

  • G_real设为第0维度
  • G_input1和G_input2分别为第1维度的第0维,第1维度的第1维
  • 第2维及以后是file_path

forwardd