Source code for satay.da_analysis

from pydeseq2.ds import DeseqStats
from pydeseq2.default_inference import DefaultInference
from pydeseq2.dds import DeseqDataSet
import pandas as pd
import pyranges as pr
from pathlib import Path
from datetime import datetime
import logging
import os
import sys


[docs] def run_deseq(count_df, sample_data, condition_col, baseline, a=0.01, n_cpus=None): """ Run DESeq2 analysis on count data. Args: count_df: DataFrame with count data (samples as rows, genes as columns) sample_data: DataFrame with sample metadata condition_col: Column name for condition/treatment in sample_data baseline: Baseline condition for comparisons a: Alpha value for significance testing Returns: tuple: (vst_counts DataFrame, results DataFrame) """ logging.info(f"Running DESeq2 analysis with baseline: {baseline}") sample_data[condition_col] = sample_data[condition_col].astype(str) # Validate inputs if baseline not in sample_data[condition_col].values: raise ValueError( f"Baseline '{baseline}' not found in column '{condition_col}', available: {sample_data[condition_col].unique()}") if not count_df.index.equals(sample_data.index): logging.warning( "Count data and sample data indices don't match exactly") if n_cpus is None: n_cpus = min(8, os.cpu_count() or 1) inference = DefaultInference(n_cpus=n_cpus) dds = DeseqDataSet( counts=count_df, metadata=sample_data, design_factors=condition_col, refit_cooks=True, inference=inference, ) dds.deseq2() dds.vst() vst_counts = pd.DataFrame( dds.layers["vst_counts"], index=count_df.index, columns=count_df.columns ).T res_l = [] comps = list(sample_data[condition_col].unique()) comps.remove(baseline) logging.info( f"Performing {len(comps)} comparisons against baseline '{baseline}'") for comp in comps: res = DeseqStats( dds, contrast=[condition_col, comp, baseline], inference=inference, alpha=a ) res.summary() res_l.append(res.results_df.assign(contrast=f"{comp}_vs_{baseline}")) res_df = pd.concat(res_l) return vst_counts, res_df
[docs] def load_filter_deseq(counts_file, sample_data_file, filter=100, comp_col='conc', baseline="0nMaF", a=0.01, sample_id_col='sample_id', n_cpus=None ): """Load count data and sample metadata, filter low-count genes, run DESeq2.""" try: logging.info(f"Loading sample data from {sample_data_file}") sample_data = pd.read_csv(sample_data_file) logging.info(f"Loading count data from {counts_file}") df = pd.read_csv(counts_file, index_col=0) # Validate required columns exist if sample_id_col not in sample_data.columns: raise ValueError( f"Sample ID column '{sample_id_col}' not found in sample data") if comp_col not in sample_data.columns: raise ValueError( f"Comparison column '{comp_col}' not found in sample data") # Filter to samples present in both files metadata_samples = set(sample_data[sample_id_col].unique()) count_samples = set(df.columns) # Find samples missing from each file missing_from_counts = metadata_samples - count_samples missing_from_metadata = count_samples - metadata_samples if missing_from_counts: logging.warning( f"Samples in metadata but not in count data: {sorted(missing_from_counts)}") if missing_from_metadata: logging.warning( f"Samples in count data but not in metadata: {sorted(missing_from_metadata)}") # Use only samples present in both files common_samples = metadata_samples & count_samples if not common_samples: raise ValueError( "No common samples found between count data and metadata") logging.info( f"Using {len(common_samples)} samples present in both files") df = df[list(common_samples)] sample_data = sample_data[sample_data[sample_id_col].isin( common_samples)] df = df.T.sort_index() sample_data = sample_data.set_index(sample_id_col).sort_index() # Filter low-count genes initial_gene_count = df.shape[1] genes_to_keep = df.columns[df.sum(axis=0) >= filter] df = df[genes_to_keep] logging.info( f"Filtered genes: {initial_gene_count} -> {df.shape[1]} (removed {initial_gene_count - df.shape[1]} low-count genes)") if df.empty: raise ValueError("No genes remain after filtering") vst_counts, res_df = run_deseq( df, sample_data, comp_col, baseline, a=a, n_cpus=n_cpus) return vst_counts, res_df except FileNotFoundError as e: logging.error(f"File not found: {e}") raise except Exception as e: logging.error(f"Error in load_filter_deseq: {e}") raise
[docs] def merge_annotations(res_df, gff_file, ids_to_keep=["locus_tag", "gene", "product"]): """Merge DESeq2 results with gene annotations from GFF file.""" try: logging.info(f"Loading annotations from {gff_file}") gff = pr.read_gff3(gff_file).as_df() # Drop requested columns that aren't actually present in the GFF missing_cols = set(ids_to_keep) - set(gff.columns) if missing_cols: logging.warning(f"Missing annotation columns: {missing_cols}") ids_to_keep = [col for col in ids_to_keep if col in gff.columns] # 'locus_tag' is the key used to join annotations to results, so it must # be present in the GFF and always selected, regardless of --ids if 'locus_tag' not in gff.columns: logging.warning( "GFF has no 'locus_tag' field to merge on; " "returning results without annotations") return res_df if 'locus_tag' not in ids_to_keep: ids_to_keep = ['locus_tag'] + ids_to_keep gff = gff[ids_to_keep].drop_duplicates() merged = res_df.merge( gff, left_on='ID', right_on='locus_tag', how='left') logging.info( f"Merged annotations: {len(res_df)} -> {len(merged)} rows") return merged except Exception as e: logging.error(f"Error merging annotations: {e}") return res_df
[docs] def gene_da(counts_file, sample_data_file, output_dir, filter, comp_col, baseline, a, gff_file='', ids_to_keep=["locus_tag", "gene", "product"], sample_id_col='sample_id', n_cpus=None): """Run complete differential analysis pipeline.""" # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') try: # Convert string paths to Path objects counts_file = Path(counts_file) sample_data_file = Path(sample_data_file) output_dir = Path(output_dir) # Ensure output directory exists output_dir.mkdir(parents=True, exist_ok=True) logging.info("Starting differential analysis pipeline") today_str = datetime.today().strftime("%y-%m-%d") vst_counts, res_df = load_filter_deseq(counts_file, sample_data_file, filter, comp_col, baseline, a, sample_id_col=sample_id_col, n_cpus=n_cpus) logging.info(f"DESeq2 results shape: {res_df.shape}") if gff_file: res_df = merge_annotations(res_df, gff_file, ids_to_keep) logging.info(f"Results with annotations shape: {res_df.shape}") # Save results results_file = output_dir / \ f"{today_str}_{counts_file.stem}_{comp_col}_l0a{a}.csv" vst_file = output_dir / f"{today_str}_{counts_file.stem}_vstcounts.csv" res_df.to_csv(results_file, index=False) vst_counts.to_csv(vst_file) logging.info(f"Results saved to: {results_file}") logging.info(f"VST counts saved to: {vst_file}") logging.info("Differential analysis complete") except Exception as e: logging.error(f"Pipeline failed: {e}") sys.exit(1)