X2CT
跑通X2CT
我们初始的目标就是跑通X2CT,传入一套X光的正侧面,让该网络为我们生成对应的CT 3d模型
test运行命令分析
由于我们使用的是正侧两面,即Multi-view,官网给的命令是
1 | python test.py --ymlpath=./experiment/multiview2500/d2_multiview2500.yml --gpu=0 --dataroot=./data/LIDC-HDF5-256 --dataset=test --tag=d2_multiview2500 --data=LIDC256 --dataset_class=align_ct_xray_views_std --model_class=MultiViewCTGAN --datasetfile=./data/test.txt --resultdir=./multiview --check_point=90 --how_many=3 |
- –ymlpath指定模型参数
- –gpu指定0号gpu
- –dataroot指定数据源位置(需要改)
- –dataset指定我们目标是test
- –tag是包含模型的实验名
- –model_class指定我们使用的模型名:MultiViewCTGAN
- –data指定输入数据的前缀,用于保存和加载,如LIDC256(需要改)
- –dataset_class输入数据的格式,如单X光或双X光
- –resultdir: 算法的输出路径
- –how_many: 可视化会跑多少个样本 (只用在可视化模式下)
- –datasetfile:测试需要的data列表
test.py/evaluate函数
基础配置
首先进行了一些基础配置,如gpu, check point等
然后将args中的参数和我们指定的yml文件中的配置进行整合,记作opt
然后重点来了,调用了一个get_dataset函数,参数我们传入的dataset-class。
Dataset
经过查看factory.py中的get_dataset函数,可知道有这样五类dataset格式:
不同的数据格式决定了我们使用不同的数据加载类
DataLoader
然后用上述得到的dataset数据集创建了一个dataloader,便于一会遍历
model: init_process
测试环境下,即等同于init_network
训练环境下,会多一个init_loss
预测部分
tqdm是进度条库,起美观效果,应该没啥用
enumerate dataloader,以epoch为单位取出数据
对于每组数据,model.set_input(data)
DataSet
BaseDataSet
- get_item:pull_item
- len:num_samples
AlignDataSet
- pull_item(self, item):找出第item个filepath(是一个h5文件),取出第0维的ct,第1维的xray1,并给xray1扩充一维,
Model部分
- 继承于Base_model
init_network
按照传入的opt进行了初始化,没啥特别的
set_input
- G_real设为第0维度
- G_input1和G_input2分别为第1维度的第0维,第1维度的第1维
- 第2维及以后是file_path
forwardd
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Ando's blog!