Objective:
The goal of this project is to fine-tune a state-of-the-art, pre-trained Vision Transformer model (CT-CLIP) to classify the presence of 18 different abnormalities in 3D CT scans.
Methodology & Pipeline
The process is broken down into five key steps, from raw data to a trained model. Here’s what the code is doing at each stage:
1. Data Cleaning & Preparation:
- Process: We start by loading the raw CSV metadata. A cleaning function (
clean_and_prepare_data) is applied to filter out low-quality or irrelevant records. This involves removing entries with missing text, filtering for specific scan types (1.nii.gz), and removing duplicate patient studies.
- Status (from logs):
- Training Data: Started with 47,149 records, which were cleaned and filtered down to a high-quality set of 20,000 scans.
- Validation Data: Started with 3,039 records, cleaned down to 1,304 scans.
2. Data Sampling (for Efficient Prototyping):
- Process: To speed up experiments, we are currently working with a smaller, representative subset of the data. We are sampling 15% of the cleaned data for both training and validation.
- Method: A specialized technique called Iterative Stratification is used. This is crucial because our dataset is multi-label (a single scan can have multiple diseases). This method ensures that the proportion of each disease (and combinations of diseases) in the small sample is the same as in the full dataset, preventing statistical biases.
- Status (from logs):
- Final Training Set: 3,000 scans.
- Final Validation Set: 196 scans.
3. Model Loading & Fine-Tuning (Transfer Learning):
- Process: We are using a powerful pre-trained model,
CTViT, which has already learned to understand features from a vast number of CT scans. Our approach is to:
- Load the pre-trained weights from the local
CT_LiPro_v2.pt file.
- Freeze the main body of the model (the
image_encoder). This locks in all the powerful, general-purpose knowledge it already has.
- Add a brand new, very small classification layer (
classification_head) on top.
- Status (from logs):
- The pre-trained weights were loaded successfully.
- The main model body is frozen.
- The optimizer is set up to train only 4 parameters. These are the weights and biases of our new, small classification layer. This makes the training process highly efficient and targeted to our specific 18 diseases.
4. Handling Severe Class Imbalance:
- Problem: In medical imaging, most scans are normal for any given finding. A naive model would learn to just predict "no disease" and achieve high accuracy without being useful.
- Solution: We are using Focal Loss combined with class weighting (
pos_weight). This does two things:
- It gives significantly more importance to the rare, positive disease cases, forcing the model to pay attention to them.
- It focuses the training effort on "hard" examples that the model gets wrong.
5. Training & Validation Loop:
- Process: The model will be trained for 5 epochs.
- Evaluation: After each epoch of training on the 3,000 training images, the script will immediately run an evaluation on the 196 validation images. This is crucial as it shows us how well the model is generalizing to new data it hasn't been trained on.
- Metrics: For each epoch, we will see a full report of the key classification metrics: AUROC, accuracy, sensitivity, and specificity. This will allow us to track if the model is genuinely learning or just memorizing the training data.