Tutorial 2: Identify cancerous domains across multiple ST datasets concurrently¶
We use STANDS to identify anomalous tissue domains across multiple ST datasets concurrently. Specifically, this experiment involves three 10x Visium reference datasets (10x-hNB-v05, 10x-hNB-v06, 10x-hNB-v07), which are prepared from healthy human breast tissues and encompass four normal domain types, and two 10x Visium target datasets (10x-hBC-G2, 10x-hBC-H1), which are prepared from human breast cancer tissues and encompass two additional shared carcinogenic domain types including the cancer in situ (CIS) and the invasive cancer (IC) domains.
Loading package¶
import warnings
warnings.filterwarnings("ignore")
import torch
import stands
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
Reading ST data¶
We read the processed ST datasets. In the example, the demo datasets includes: 1) gene expression matrix in adata.X
; 2) spatial coordinates in adata.obsm['spatial']
; 3) histology image in adata.uns['spatial']
. To make the model can read the data sucessfully, please ensure the same anndata structure as example.
path = [
'./HumanBreast/process/V05.h5ad',
'./HumanBreast/process/V06.h5ad',
'./HumanBreast/process/V07.h5ad'
]
ref_list = []
for p in path:
ref = sc.read_h5ad(p)
ref_list.append(ref)
path = [
'./HumanBreast/process/G2.h5ad',
'./HumanBreast/process/H1.h5ad'
]
tgt_list = []
for p in path:
tgt = sc.read_h5ad(p)
tgt_list.append(tgt)
ref_list
[AnnData object with n_obs × n_vars = 2224 × 3000 obs: 'cell_type', 'batch', 'disease' uns: 'spatial' obsm: 'spatial', AnnData object with n_obs × n_vars = 3037 × 3000 obs: 'cell_type', 'batch', 'disease' uns: 'spatial' obsm: 'spatial', AnnData object with n_obs × n_vars = 2086 × 3000 obs: 'cell_type', 'batch', 'disease' uns: 'spatial' obsm: 'spatial']
tgt_list
[AnnData object with n_obs × n_vars = 467 × 3000 obs: 'cell_type', 'batch', 'disease' uns: 'spatial' obsm: 'spatial', AnnData object with n_obs × n_vars = 613 × 3000 obs: 'cell_type', 'batch', 'disease' uns: 'spatial' obsm: 'spatial']
Converting data¶
For ST input, STANDS first needs to convert the anndata data into a graph, where nodes represent each spot and edges represent the adjacency relationship between two spots. In the example, the node features of the converted graph include the gene expression vector and image patch. Additionally, if the data has been preprocessed, you should set preprocess=False
. It is worth noting that STANDS converts data in the case of multiple datasets in a slightly different way than a single dataset.
ref_g = stands.read_multi(ref_list, patch_size=64, n_genes=3000, preprocess=False)
tgt_g = stands.read_multi(tgt_list, patch_size=64, n_genes=3000, preprocess=False)
Training the model¶
After inputting the converted reference data into fit
, STANDS starts to train the multimodal GAN. After training and inputting the converted target data into predict
, STANDS conveniently detects the anomalies and outputs anomaly scores for each spot, where a higher score indicates a more likely anomaly. In addition, if run_gmm=True
is specified in predict
, STANDS will also use the GMM algorithm to determine the thresholds for anomalous and normal, and return the binary classification results.
model = stands.AnomalyDetect()
model.fit(ref_g)
Begin to train the model on reference datasets...
Train Epochs: 100%|██████████| 10/10 [09:46<00:00, 58.70s/it, D_Loss=-1.58, G_Loss=4.03]
Training has been finished.
scores, labels = model.predict(tgt_g)
Detect anomalous spots on target dataset... Anomalous spots have been detected.
# store the results
tgt1, tgt2 = tgt_list
tgt1.obs['score'] = scores[0]
tgt1.obs['pred'] = labels[0]
tgt2.obs['score'] = scores[1]
tgt2.obs['pred'] = labels[1]
Saving the weight¶
The model trained in the anomaly detection phase will be used as the extractor for the subsequent tasks. Thus saving the weights will help to improve the performance of the subsequent tasks.
torch.save(model.G.state_dict(), 'generator.pth')
Evaluation¶
STANDS integrates several evaluation metrics in stands.evaluate
for anomaly detection tasks, which can be used very easily and directly.
metrics = ['Accuracy', 'F1', 'SGD_degree', 'SGD_cc']
result = stands.evaluate(metrics, adata=tgt1, spaid='spatial', y_true=tgt1.obs['disease'],
y_pred=tgt1.obs['pred'], y_score=tgt1.obs['score'])
pd.DataFrame(zip(metrics, result))
0 | 1 | |
---|---|---|
0 | Accuracy | 0.845824 |
1 | F1 | 0.775000 |
2 | SGD_degree | 0.591077 |
3 | SGD_cc | 0.299607 |
metrics = ['Accuracy', 'F1', 'SGD_degree', 'SGD_cc']
result = stands.evaluate(metrics, adata=tgt2, spaid='spatial', y_true=tgt2.obs['disease'],
y_pred=tgt2.obs['pred'], y_score=tgt2.obs['score'])
pd.DataFrame(zip(metrics, result))
0 | 1 | |
---|---|---|
0 | Accuracy | 0.833605 |
1 | F1 | 0.727273 |
2 | SGD_degree | 0.844450 |
3 | SGD_cc | 0.310134 |
Visualization¶
We use the spatial map to visualise the results of anomaly detection.
tgt1.obs['pred'] = tgt1.obs['pred'].astype('category')
tgt1.obs['disease'] = tgt1.obs['disease'].astype('category')
ax = sc.pl.spatial(tgt1, color=['pred', 'disease'], s=90, show=False, crop_coord=(0, 8700, 800, 7600))
ax[0].legend(['Normal', 'Anomaly'], fontsize=12)
ax[0].set_title('STANDS', fontsize=18)
ax[0].set_xlabel('Saptial 1', fontsize=14)
ax[0].set_ylabel('Saptial 2', fontsize=14)
ax[1].legend(['Normal', 'Anomaly'], fontsize=12)
ax[1].set_title('Ground Truth', fontsize=18)
ax[1].set_xlabel('Saptial 1', fontsize=14)
ax[1].set_ylabel('Saptial 2', fontsize=14)
plt.show()
ax = sc.pl.spatial(tgt1, color=['score'], s=90, show=False, crop_coord=(0, 8700, 800, 7600))
ax[0].set_title('Anomaly Score', fontsize=18)
ax[0].set_xlabel('Saptial 1', fontsize=14)
ax[0].set_ylabel('Saptial 2', fontsize=14)
plt.show()
tgt2.obs['pred'] = tgt2.obs['pred'].astype('category')
tgt2.obs['disease'] = tgt2.obs['disease'].astype('category')
ax = sc.pl.spatial(tgt2, color=['pred', 'disease'], s=90, show=False, crop_coord=(100, 9200, 2000, 9800))
ax[0].legend(['Normal', 'Anomaly'], fontsize=12)
ax[0].set_title('STANDS', fontsize=18)
ax[0].set_xlabel('Saptial 1', fontsize=14)
ax[0].set_ylabel('Saptial 2', fontsize=14)
ax[1].legend(['Normal', 'Anomaly'], fontsize=12)
ax[1].set_title('Ground Truth', fontsize=18)
ax[1].set_xlabel('Saptial 1', fontsize=14)
ax[1].set_ylabel('Saptial 2', fontsize=14)
plt.show()
ax = sc.pl.spatial(tgt2, color=['score'], s=90, show=False, crop_coord=(100, 9200, 2000, 9800))
ax[0].set_title('Anomaly Score', fontsize=18)
ax[0].set_xlabel('Saptial 1', fontsize=14)
ax[0].set_ylabel('Saptial 2', fontsize=14)
plt.show()