Skip to content

stands

Spatial Transcriptomics ANomaly Detection and Subtyping (STANDS) is an innovative computational method to detect anomalous tissue domains from multi-sample spatial transcriptomics (ST) data and reveal their biologically heterogeneous subdomains, which can be individual-specific or shared by all individuals.

Detecting and characterizing anomalous anatomic regions from tissue samples from affected individuals are crucial for clinical and biomedical research. This procedure, which we refer to as Detection and Dissection of Anomalous Tissue Domains (DDATD), serves as the first and foremost step in the analysis of clinical tissues because it reveals factors, such as pathogenic or differentiated cell types, associated with the development of diseases or biological traits. Traditionally, DDATD has relied on either laborious expert visual inspection or computer vision algorithms applied to histology images. ST provides an unprecedent opportunity to enhance DDATD by incorporating spatial gene expression information. However, to the best of our knowledge, no existing methods can perform de novo DDATD from ST datasets.

STANDS is built on state-of-the-art generative models for de novo DDATD from multi-sample ST by integrating multimodal information including spatial gene expression, histology image, and single cell gene expression. STANDS concurrently fulfills DDATD's three sequential core tasks: detecting, aligning, and subtyping anomalous tissue domains across multiple samples. STANDS first integrates and harnesses multimodal information from spatial transcriptomics and associated histology images to pinpoint anomalous tissue regions across multiple target datasets. Next, STANDS aligns anomalies identified from target datasets in a common data space via style-transfer learning to mitigate their non-biological variations. Finally, STANDS dissects aligned anomalies into biologically heterogenous subtypes that are either common or unique to the target datasets. STANDS combines these processes into a unified framework that maintains the methodological coherence, which leads to its unparallel performances in DDATD from multi-sample ST.

Modules:

Name Description
read

Read single spatial data and preprocess if required.

read_cross

Read spatial data from two sources and preprocess if required.

read_multi

Read multiple spatial datasets and preprocess if required.

pretrain

Pretrain STANDS using spatial data.

evaluate

Calculate various metrics (including SGD).

AnomalyDetect

AnomalyDetect(
    n_epochs: int = 10,
    batch_size: int = 128,
    learning_rate: float = 0.0003,
    n_dis: int = 2,
    GPU: Union[bool, str] = True,
    random_state: Optional[int] = None,
    weight: Optional[Dict[str, float]] = None,
)
Source code in src\stands\anomaly.py
def __init__(self, 
             n_epochs: int = 10, 
             batch_size: int = 128,
             learning_rate: float = 3e-4,
             n_dis: int = 2,
             GPU: Union[bool, str] = True,
             random_state: Optional[int] = None,
             weight: Optional[Dict[str, float]] = None):

    self.n_epochs = n_epochs
    self.batch_size = batch_size
    self.lr = learning_rate
    self.n_dis = n_dis
    self.device = select_device(GPU)

    if random_state is not None:
        seed_everything(random_state)

    if weight is None:
        weight = {'w_rec': 30, 'w_adv': 1, 'w_gp': 10}
    self.weight = weight

UpdateD

UpdateD(blocks)

Updating discriminator

Source code in src\stands\anomaly.py
def UpdateD(self, blocks):
    '''Updating discriminator'''
    self.opt_D.zero_grad()

    if self.only_ST:
        # generate fake data
        _, fake_g = self.G.STforward(blocks, blocks[0].srcdata['gene'])

        # get real data from blocks
        real_g = blocks[1].dstdata['gene']

        d1 = torch.mean(self.D.SCforward(real_g))
        d2 = torch.mean(self.D.SCforward(fake_g.detach()))
        gp = calculate_gradient_penalty(self.D, real_g, fake_g.detach())

    else:
        _, fake_g, fake_p = self.G.Fullforward(
            blocks, blocks[0].srcdata['gene'], blocks[1].srcdata['patch']
        )

        # get real data from blocks
        real_g = blocks[1].dstdata['gene']
        real_p = blocks[1].dstdata['patch']

        d1 = torch.mean(self.D.Fullforward(real_g, real_p))
        d2 = torch.mean(self.D.Fullforward(fake_g.detach(), fake_p.detach()))
        gp = calculate_gradient_penalty(
            self.D, real_g, fake_g.detach(), real_p, fake_p.detach()
        )            

    # store discriminator loss for printing training information
    self.D_loss = - d1 + d2 + gp * self.weight['w_gp']
    self.D_loss.backward()
    self.opt_D.step()

UpdateG

UpdateG(blocks)

Updating generator

Source code in src\stands\anomaly.py
def UpdateG(self, blocks):
    '''Updating generator'''
    self.opt_G.zero_grad()

    if self.only_ST:
        # generate fake data
        z, fake_g = self.G.STforward(blocks, blocks[0].srcdata['gene'])

        # get real data from blocks
        real_g = blocks[1].dstdata['gene']

        # discriminator provides feedback
        d = self.D.SCforward(fake_g)

        Loss_rec = self.L1(real_g, fake_g)
        Loss_adv = - torch.mean(d)

    else:
        z, fake_g, fake_p = self.G.Fullforward(
            blocks, blocks[0].srcdata['gene'], blocks[1].srcdata['patch']
        )

        # get real data from blocks
        real_g = blocks[1].dstdata['gene']
        real_p = blocks[1].dstdata['patch']

        # discriminator provides feedback
        d = self.D.Fullforward(fake_g, fake_p)

        Loss_rec = (self.L1(real_g, fake_g)+self.L1(real_p, fake_p))/2
        Loss_adv = - torch.mean(d)

    # store generator loss for printing training information and backward
    self.G_loss = self.weight['w_rec'] * Loss_rec + self.weight['w_adv'] * Loss_adv
    self.G_loss.backward()
    self.opt_G.step()

    # updating memory block with generated embeddings, fake_z
    self.G.Memory.update_mem(z)

fit

fit(
    ref: Dict[str, Any],
    only_ST: bool = False,
    weight_dir: Optional[str] = None,
)

Train STANDS on reference graph

Source code in src\stands\anomaly.py
def fit(self, ref: Dict[str, Any], only_ST: bool = False, weight_dir: Optional[str] = None):
    '''Train STANDS on reference graph'''
    tqdm.write('Begin to train the model on reference datasets...')

    # dataset provides subgraph for training
    ref_g = ref['graph']
    self.sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
    self.dataset = dgl.dataloading.DataLoader(
        ref_g, ref_g.nodes(), self.sampler, batch_size=self.batch_size, 
        shuffle=True, drop_last=True, num_workers=0, device=self.device
    )

    self.only_ST = only_ST
    self.init_model(ref, weight_dir)

    self.G.train()
    self.D.train()
    with tqdm(total=self.n_epochs) as t:
        for _ in range(self.n_epochs):
            t.set_description(f'Train Epochs')

            for _, _, blocks in self.dataset:

                # Update discriminator for n_dis times
                for _ in range(self.n_dis):
                    self.UpdateD(blocks)

                # Update generator for one time
                self.UpdateG(blocks)

            # Update learning rate for G and D
            self.D_sch.step()
            self.G_sch.step()

            t.set_postfix(G_Loss = self.G_loss.item(),
                          D_Loss = self.D_loss.item())
            t.update(1)

    tqdm.write('Training has been finished.')

init_weight

init_weight(weight_dir)

Initial stage for pretrained weights and memory block

Source code in src\stands\anomaly.py
@torch.no_grad()
def init_weight(self, weight_dir):
    '''Initial stage for pretrained weights and memory block'''
    self.G.extract.load_weight(weight_dir)

    # Initial the memory block with the normal embeddings
    sum_t = self.G.Memory.mem_dim/self.batch_size
    t = 0
    while t < sum_t:
        for _, _, blocks in self.dataset:
            if self.only_ST:
                real_g = blocks[0].srcdata['gene']
                z, _ = self.G.STforward(blocks, real_g)
            else:
                real_g = blocks[0].srcdata['gene']
                real_p = blocks[1].srcdata['patch']
                z, _, _ = self.G.Fullforward(blocks, real_g, real_p)

            self.G.Memory.update_mem(z)
            t += 1

predict

predict(tgt: Dict[str, Any], run_gmm: bool = True)

Detect anomalous spots on target graph

Source code in src\stands\anomaly.py
@torch.no_grad()
def predict(self, tgt: Dict[str, Any], run_gmm: bool = True):
    '''Detect anomalous spots on target graph'''

    tgt_g = tgt['graph']
    dataset = dgl.dataloading.DataLoader(
        tgt_g, tgt_g.nodes(), self.sampler, batch_size=self.batch_size, 
        shuffle=False, drop_last=False, num_workers=0, device=self.device
    )

    self.G.eval()
    self.D.eval()
    tqdm.write('Detect anomalous spots on target dataset...')

    ref_score = self.score(self.dataset)
    tgt_score = self.score(dataset)

    tqdm.write('Anomalous spots have been detected.\n')

    if run_gmm:
        gmm = GMMWithPrior(ref_score)
        threshold = gmm.fit(tgt_score=tgt_score)
        tgt_label = [1 if s >= threshold else 0 for s in tgt_score]
        return tgt_score, tgt_label
    else:
        return tgt_score

BatchAlign

BatchAlign(
    n_epochs: int = 10,
    batch_size: int = 128,
    learning_rate: float = 0.0003,
    n_dis: int = 3,
    GPU: Union[bool, str] = True,
    random_state: Optional[int] = None,
    weight: Optional[Dict[str, float]] = None,
)
Source code in src\stands\align.py
def __init__(self, 
             n_epochs: int = 10, 
             batch_size: int = 128,
             learning_rate: float = 3e-4, 
             n_dis: int = 3,
             GPU: Union[bool, str] = True, 
             random_state: Optional[int] = None,
             weight: Optional[Dict[str, float]] = None):

    self.n_epochs = n_epochs
    self.batch_size = batch_size
    self.lr = learning_rate
    self.n_dis = n_dis
    self.device = select_device(GPU)
    self.GPU = GPU

    self.seed = random_state
    if random_state is not None:
        seed_everything(random_state)

    if weight is None:
        weight = {'w_rec': 30, 'w_adv': 1, 'w_gp': 10}
    self.weight = weight

UpdateD

UpdateD(blocks)

Updating discriminator

Source code in src\stands\align.py
def UpdateD(self, blocks):
    '''Updating discriminator'''
    self.opt_D.zero_grad()

    # generate fake data
    batchid = blocks[-1].dstdata['batch']
    fake_g = self.G.STforward(blocks, blocks[0].srcdata['gene'], batchid)

    # get real data from blocks
    real_g = blocks[1].dstdata['gene']

    d1 = torch.mean(self.D.SCforward(real_g))
    d2 = torch.mean(self.D.SCforward(fake_g.detach()))
    gp = calculate_gradient_penalty(self.D, real_g, fake_g.detach())         

    # store discriminator loss for printing training information
    self.D_loss = - d1 + d2 + gp * self.weight['w_gp']
    self.D_loss.backward()
    self.opt_D.step()

UpdateG

UpdateG(blocks)

Updating generator

Source code in src\stands\align.py
def UpdateG(self, blocks):
    '''Updating generator'''
    self.opt_G.zero_grad()

    # generate fake data
    batchid = blocks[-1].dstdata['batch']
    fake_g = self.G.STforward(blocks, blocks[0].srcdata['gene'], batchid)

    # get real data from blocks
    real_g = blocks[1].dstdata['gene']

    # discriminator provides feedback
    d = self.D.SCforward(fake_g)

    Loss_rec = self.L1(real_g, fake_g)
    Loss_adv = - torch.mean(d)

    # store generator loss for printing training information and backward
    self.G_loss = self.weight['w_rec']*Loss_rec + self.weight['w_adv']*Loss_adv
    self.G_loss.backward()
    self.opt_G.step()

fit

fit(
    raw: Dict[str, Any],
    generator: GeneratorAD,
    **alignerkwargs
)

Remove batch effects

Source code in src\stands\align.py
def fit(self, raw: Dict[str, Any], generator: GeneratorAD, **alignerkwargs):
    '''Remove batch effects'''
    adatas = raw['adata']
    adata_ref = adatas[0]
    adata_tgt = ad.concat(adatas[1:])

    # find Kin Pairs
    Aligner = FindPairs(GPU=self.GPU, random_state=self.seed, **alignerkwargs)
    _, tgt_g = Aligner.fit(generator, raw)

    self.sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
    self.dataset = dgl.dataloading.DataLoader(
        tgt_g, tgt_g.nodes(), self.sampler, batch_size=self.batch_size, 
        shuffle=True, drop_last=False, num_workers=0, device=self.device
    )

    self.init_model(raw, generator)

    tqdm.write('Begin to correct spatial transcriptomics datasets...')
    self.G.train()
    self.D.train()
    with tqdm(total=self.n_epochs) as t:
        for _ in range(self.n_epochs):
            t.set_description(f'Train Epochs')

            for _, _, blocks in self.dataset:

                # Update discriminator for n_dis times
                for _ in range(self.n_dis):
                    self.UpdateD(blocks)

                # Update generator for one time
                self.UpdateG(blocks)

            # Update learning rate for G and D
            self.D_sch.step()
            self.G_sch.step()
            t.set_postfix(G_Loss = self.G_loss.item(),
                          D_Loss = self.D_loss.item())
            t.update(1)

    self.dataset = dgl.dataloading.DataLoader(
        tgt_g, tgt_g.nodes(), self.sampler, batch_size=self.batch_size, 
        shuffle=False, drop_last=False, num_workers=0, device=self.device
    )

    self.G.eval()
    corrected = []
    with torch.no_grad():
        for _, _, blocks in self.dataset:
            fake_g = self.G.STforward(
                blocks, blocks[0].srcdata['gene'], blocks[-1].dstdata['batch']
            )
            corrected.append(fake_g.cpu().detach())

    corrected = torch.cat(corrected, dim=0).numpy()
    adata_tgt.X = corrected
    adata = ad.concat([adata_ref, adata_tgt])
    tqdm.write('Datasets have been corrected.\n')
    return adata

read_cross

read_cross(
    ref: AnnData,
    tgt: AnnData,
    spa_key: str = "spatial",
    preprocess: bool = True,
    n_genes: int = 3000,
    patch_size: Optional[int] = None,
    n_neighbors: int = 4,
    augment: bool = True,
    return_type: Literal["anndata", "graph"] = "graph",
)

Read spatial data from two sources and preprocess if required. The read data are transformed to reference and target graph.

Parameters:

Name Type Description Default
ref AnnData

Reference AnnData object.

required
tgt AnnData

Target AnnData object.

required
spa_key str

Key for spatial information in AnnData objects.

'spatial'
preprocess bool

Perform data preprocessing.

True
n_genes int

Number of genes for feature selection.

3000
patch_size Optional[int]

Patch size for H&E images.

None
n_neighbors int

Number of neighbors for spatial data reading.

4
augment bool

Whether to use the data augmentation.

True
return_type Literal['anndata', 'graph']

Type of data to return.

'graph'

Returns:

Type Description
Union[Tuple, Dict]

Depending on the 'return_type', returns either a tuple of AnnData objects or a dictionary of graph-related data.

Source code in src\stands\_read.py
@clear_warnings
def read_cross(ref: ad.AnnData, tgt: ad.AnnData, spa_key: str = 'spatial',
               preprocess: bool = True, n_genes: int = 3000, patch_size: Optional[int] = None,
               n_neighbors: int = 4, augment: bool = True, 
               return_type: Literal['anndata', 'graph'] = 'graph'):
    """
    Read spatial data from two sources and preprocess if required.
    The read data are transformed to reference and target graph.

    Parameters:
        ref (ad.AnnData): Reference AnnData object.
        tgt (ad.AnnData): Target AnnData object.
        spa_key (str): Key for spatial information in AnnData objects.
        preprocess (bool): Perform data preprocessing.
        n_genes (int): Number of genes for feature selection.
        patch_size (Optional[int]): Patch size for H&E images.
        n_neighbors (int): Number of neighbors for spatial data reading.
        augment (bool): Whether to use the data augmentation.
        return_type (Literal['anndata', 'graph']): Type of data to return.

    Returns:
        (Union[Tuple, Dict]): Depending on the 'return_type', returns either a tuple of AnnData objects or a dictionary of graph-related data.
    """
    seed_everything(0)

    ref, ref_img, ref_pos = read(ref, False, 'tuple', spa_key=spa_key, n_neighbors=n_neighbors)
    tgt, tgt_img, tgt_pos = read(tgt, False, 'tuple', spa_key=spa_key, n_neighbors=n_neighbors)
    overlap_gene = list(set(ref.var_names) & set(tgt.var_names))
    ref = ref[:, overlap_gene]
    tgt = tgt[:, overlap_gene]

    if preprocess:
        ref = preprocess_data(ref)
        tgt = preprocess_data(tgt)
        if len(overlap_gene) <= n_genes:
            warnings.warn(
                'There are too few overlapping genes to perform feature selection'
            )
        else:
            sc.pp.filter_genes(ref, min_cells=10)
            sc.pp.highly_variable_genes(ref, n_top_genes=n_genes, subset=True)
            tgt = tgt[:, ref.var_names]

    if return_type == 'anndata':
        return ref, tgt

    elif return_type == 'graph':
        if patch_size is None:
            patch_size = set_patch(ref)

        ref_b = BuildGraph(ref, ref_img, ref_pos, augment, n_neighbors, patch_size)
        tgt_b = BuildGraph(tgt, tgt_img, tgt_pos, augment, n_neighbors, patch_size)
        return ref_b.pack(), tgt_b.pack()

read_multi

read_multi(
    adata_list: List[AnnData],
    patch_size: Optional[int] = None,
    gene_list: Optional[List[str]] = None,
    preprocess: bool = True,
    n_genes: int = 3000,
    n_neighbors: int = 4,
    augment: bool = True,
    spa_key: str = "spatial",
    return_type: Literal["anndata", "graph"] = "graph",
)

Read multiple spatial datasets and preprocess if required. All the datasets are transformed to only one graph.

Parameters:

Name Type Description Default
adata_list List[AnnData]

List of AnnData objects.

required
patch_size Optional[int]

Patch size for H&E images.

None
gene_list Optional[List[str]]

Selected gene list.

None
preprocess bool

Perform data preprocessing.

True
n_genes int

Number of genes for feature selection.

3000
n_neighbors int

Number of neighbors for spatial data reading.

4
augment bool

Whether to use the data augmentation.

True
spa_key str

Key for spatial information in AnnData objects.

'spatial'
return_type Literal['anndata', 'graph']

Type of data to return.

'graph'

Returns:

Type Description
Union[List, Dict]

Depending on the 'return_type', returns either a list of AnnData objects or a dictionary of graph-related data.

Source code in src\stands\_read.py
@clear_warnings
def read_multi(adata_list: List[ad.AnnData], patch_size: Optional[int] = None,
               gene_list: Optional[List[str]] = None, preprocess: bool = True, 
               n_genes: int = 3000, n_neighbors: int = 4, augment: bool = True,
               spa_key: str = 'spatial', return_type: Literal['anndata', 'graph'] = 'graph'):
    """
    Read multiple spatial datasets and preprocess if required.
    All the datasets are transformed to only one graph.

    Parameters:
        adata_list (List[ad.AnnData]): List of AnnData objects.
        patch_size (Optional[int]): Patch size for H&E images.
        gene_list (Optional[List[str]]): Selected gene list.
        preprocess (bool): Perform data preprocessing.
        n_genes (int): Number of genes for feature selection.
        n_neighbors (int): Number of neighbors for spatial data reading.
        augment (bool): Whether to use the data augmentation.
        spa_key (str): Key for spatial information in AnnData objects.
        return_type (Literal['anndata', 'graph']): Type of data to return.

    Returns:
        (Union[List, Dict]): Depending on the 'return_type', returns either a list of AnnData objects or a dictionary of graph-related data.
    """
    seed_everything(0)

    adatas, images, positions = [], [], []
    for i in range(len(adata_list)):
        d, img, pos = read(adata_list[i], False, 'tuple', spa_key=spa_key, n_neighbors=n_neighbors)
        adatas.append(d)
        images.append(img)
        positions.append(pos)

    for img in images:
        if img is None:
            images = None
            break

    if preprocess:
        adatas = [preprocess_data(d) for d in adatas]
        if gene_list is None:
            ref = adatas[0]
            sc.pp.filter_genes(ref, min_cells=10)
            sc.pp.highly_variable_genes(ref, n_top_genes=n_genes, subset=True)
            adatas = [d[:, list(ref.var_names)] for d in adatas]
        else:
            adatas = [d[:, list(gene_list)] for d in adatas]

    if return_type == 'anndata':
        return adatas

    elif return_type == 'graph':
        if patch_size is None:
            patch_size = set_patch(adatas[0])
        builder = BuildMultiGraph(adatas, images, positions, augment, n_neighbors, patch_size)
        return builder.pack()