Creating a custom GATK Walker (GATK 3.6) : my notebook
This is my notebook for creating a custom engine in GATK.
Description
I want to read a VCF file and to get a table of category/count. Something like this:
HAVE_ID | TYPE | COUNT |
---|---|---|
YES | SNP | 123 |
NO | SNP | 3 |
NO | INDEL | 13 |
Class Category
I create a class Category describing each row in the table. It's just a List of Strings
static class Category
implements Comparable<Category>
{
private final List<String> labels ;
Category(final List<String> labels) {
this.labels=new ArrayList<>(labels);
}
As were going to use Category in an associative array / Map we need to implement hashCode
and equals
:
Implementing hashCode and equals
static class Category
implements Comparable<Category>
{
private final List<String> labels ;
Category(final List<String> labels) {
this.labels=new ArrayList<>(labels);
}
@Override
public int hashCode() {
return labels.hashCode();
}
@Override
public int compareTo(final Category o) {
for(int i=0;i< labels.size();++i)
{
int d = labels.get(i).compareTo(o.labels.get(i));
if(d!=0) return d;
}
return 0;
}
@Override
public boolean equals(Object o) {
if(o==this) return true;
if(o==null || !(o instanceof Category)) return false;
return compareTo(Category.class.cast(o))==0;
}
}
# The Main Walker:
The main engine is called CountPredictions it extends a RodWalker
. Why ? I don't really know, I've copied this from another engine. We re going to do some map/reduce some 'Map' so my class is declared as
public class CountPredictions
extends RodWalker< Map<CountPredictions.Category,Long>, Map<CountPredictions.Category,Long>>
{
...
}
# Documenting the Walker
An annotations 'DocumentedGATKFeature' is added so the GATKEngine will be able to find our walker
(...)
@DocumentedGATKFeature(
summary="Count Predictions",
groupName = HelpConstants.DOCS_CAT_VARMANIP,
extraDocs = {CommandLineGATK.class} )
public class CountPredictions extends RodWalker
(....)
Describing the input and the ouput
The input is one VCF file:
@Input(fullName="variant", shortName = "V", doc="Input VCF file", required=true)
public RodBinding<VariantContext> variants;
The output is a PrintStream (default is stdout) where we'll write the table:
@Output(doc="File to which result should be written")
protected PrintStream out = System.out;
The other arguments
The other arguments are also decorared using java annotations.. These are the switches to add some columns to the final Table:
(...)
@Argument(fullName="chrom",shortName="chrom",required=false,doc="Group by Chromosome/Contig")
public boolean bychrom = false;
@Argument(fullName="ID",shortName="ID",required=false,doc="Group by having/not-having ID")
public boolean byID = false;
@Argument(fullName="variantType",shortName="variantType",required=false,doc="Group by VariantType")
(...)
The initialize() method
this method is called after the arguments have been parsed. As we want to be able to parse the VEP annotation where going to get the VCF header and extract the components of the ANN attribute to get the indexes for 'Annotation_Impact' and 'Transcript_BioType'
@Override
public void initialize() {
if(byImpact || bybiotype) {
final VCFHeader vcfHeader = GATKVCFUtils.getVCFHeadersFromRods(getToolkit()).get(variants.getName());
final VCFInfoHeaderLine annInfo = vcfHeader.getInfoHeaderLine("ANN");
if(annInfo==null)
{
logger.warn("NO ANN in "+variants.getSource());
}
else
{
int q0=annInfo.getDescription().indexOf('\'');
int q1=annInfo.getDescription().lastIndexOf('\'');
if(q0==-1 || q1<=q0)
{
logger.warn("Cannot parse "+annInfo.getDescription());
}
else
{
final String fields[]=pipeRegex.split(annInfo.getDescription().substring(q0+1, q1));
for(int c=0;c<fields.length;++c)
{
final String column=fields[c].trim();
if(column.equals("Annotation_Impact"))
{
ann_impact_column=c;
}
else if(column.equals("Transcript_BioType"))
{
ann_transcript_biotype_column=c;
}
}
}
}
}
super.initialize();
}
reduceInit()
as far as I understand 'reduceInit()' is used to create a very first item during the map/reduce process. So we're creating an empty associative map:
@Override
public Map<CountPredictions.Category,Long> reduceInit() {
return Collections.emptyMap();
}
the reduce method
This method reduce two mapping processes, so we're creating a map combining both counts:
@Override
public Map<CountPredictions.Category,Long> reduce(Map<CountPredictions.Category,Long> value, Map<CountPredictions.Category,Long> sum) {
final Map<CountPredictions.Category,Long> newmap = new HashMap<>(sum);
for(Category cat:value.keySet()) {
Long sv = sum.get(cat);
Long vv = value.get(cat);
newmap.put(cat, sv==null?vv:sv+vv);
}
return newmap;
}
the map method
This is the workhorse of the engine. As far as I can see it is called for each Variant. The method 'tracker.getValues()' returns an array of all the variant (here, only one because we only have one input) at the current position. At the end of the method we must return a count of Category for those variants...
@Override
public Map<CountPredictions.Category,Long> map(final RefMetaDataTracker tracker,final ReferenceContext ref, final AlignmentContext context) {
if ( tracker == null )return Collections.emptyMap();
final Map<CountPredictions.Category,Long> count = new HashMap<>();
for(final VariantContext ctx: tracker.getValues(this.variants,context.getLocation()))
{
...
}
return count;
}
Filling the Category
for each variant, a new Category is filled:
(...)
final List<String> labels=new ArrayList<>();
if(bychrom) labels.add(ctx.getContig());
if(byID) labels.add(ctx.hasID()?"Y":".");
(...)
final Category cat=new Category(labels);
Long n=count.get(cat);
count.put(cat, n==null?1L:n+1);
(...)
At the end, onTraversalDone, printing the table
When the onTraversalDone is called, the table is printed using the class GATKReport:
@Override
public void onTraversalDone(final Map<CountPredictions.Category,Long> counts) {
GATKReportTable table=new GATKReportTable(
"Variants", "Variants "+variants.getSource(),0);
if(bychrom) table.addColumn("CONTIG");
if(byID) table.addColumn("IN_DBSNP");
if(byType) table.addColumn("TYPE");
(...)
table.addColumn("COUNT");
int nRows=0;
for(final Category cat: counts.keySet())
{
for(int x=0;x<cat.labels.size();++x)
{
table.set(nRows, x, cat.labels.get(x));
}
table.set(nRows, cat.labels.size(), counts.get(cat));
++nRows;
}
GATKReport report = new GATKReport();
report.addTable(table);
report.print(this.out);
out.flush();
}
All in one:
the full source code:
package mygatk;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.broadinstitute.gatk.engine.CommandLineGATK;
import org.broadinstitute.gatk.engine.GATKVCFUtils;
import org.broadinstitute.gatk.engine.walkers.RodWalker;
import org.broadinstitute.gatk.utils.commandline.Argument;
import org.broadinstitute.gatk.utils.commandline.Input;
import org.broadinstitute.gatk.utils.commandline.Output;
import org.broadinstitute.gatk.utils.commandline.RodBinding;
import org.broadinstitute.gatk.utils.contexts.AlignmentContext;
import org.broadinstitute.gatk.utils.contexts.ReferenceContext;
import org.broadinstitute.gatk.utils.help.DocumentedGATKFeature;
import org.broadinstitute.gatk.utils.help.HelpConstants;
import org.broadinstitute.gatk.utils.refdata.RefMetaDataTracker;
import org.broadinstitute.gatk.utils.report.GATKReport;
import org.broadinstitute.gatk.utils.report.GATKReportTable;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFInfoHeaderLine;
/**
* Test Documentation
*/
@DocumentedGATKFeature(
summary="Count Predictions",
groupName = HelpConstants.DOCS_CAT_VARMANIP,
extraDocs = {CommandLineGATK.class} )
public class CountPredictions extends RodWalker<Map<CountPredictions.Category,Long>, Map<CountPredictions.Category,Long>> {
private enum IMPACT { HIGH, MODERATE, MODIFIER,LOW};
@Input(fullName="variant", shortName = "V", doc="Input VCF file", required=true)
public RodBinding<VariantContext> variants;
@Output(doc="File to which result should be written")
protected PrintStream out = System.out;
@Argument(fullName="minQuality",shortName="mq",required=false,doc="Group by Quality. Set the treshold for Minimum Quality")
public double minQuality = -1;
@Argument(fullName="chrom",shortName="chrom",required=false,doc="Group by Chromosome/Contig")
public boolean bychrom = false;
@Argument(fullName="ID",shortName="ID",required=false,doc="Group by ID")
public boolean byID = false;
@Argument(fullName="variantType",shortName="variantType",required=false,doc="Group by VariantType")
public boolean byType = false;
@Argument(fullName="filter",shortName="filter",required=false,doc="Group by FILTER")
public boolean byFilter = false;
@Argument(fullName="impact",shortName="impact",required=false,doc="Group by ANN/IMPACT")
public boolean byImpact = false;
@Argument(fullName="biotype",shortName="biotype",required=false,doc="Group by ANN/biotype")
public boolean bybiotype = false;
@Argument(fullName="nalts",shortName="nalts",required=false,doc="Group by number of ALTS")
public boolean bynalts = false;
@Argument(fullName="affected",shortName="affected",required=false,doc="Group by number of Samples called and not HOMREF")
public boolean byAffected = false;
@Argument(fullName="called",shortName="called",required=false,doc="Group by number of Samples called")
public boolean byCalled = false;
@Argument(fullName="maxSamples",shortName="maxSamples",required=false,doc="if the number of samples affected is greater than --maxSamples use the label \"GT_MAX_SAMPLES\"")
public int maxSamples = Integer.MAX_VALUE;
private final Pattern pipeRegex=Pattern.compile("[\\|]");
static class Category
implements Comparable<Category>
{
private final List<String> labels ;
Category(final List<String> labels) {
this.labels=new ArrayList<>(labels);
}
@Override
public int hashCode() {
return labels.hashCode();
}
@Override
public int compareTo(final Category o) {
for(int i=0;i< labels.size();++i)
{
int d = labels.get(i).compareTo(o.labels.get(i));
if(d!=0) return d;
}
return 0;
}
@Override
public boolean equals(Object o) {
if(o==this) return true;
if(o==null || !(o instanceof Category)) return false;
return compareTo(Category.class.cast(o))==0;
}
}
private int ann_impact_column=-1;/* eg: MODIFIER / LOW */
private int ann_transcript_biotype_column=-1;/* eg: transcript /intergenic_region */
@Override
public void initialize() {
if(byImpact || bybiotype) {
final VCFHeader vcfHeader = GATKVCFUtils.getVCFHeadersFromRods(getToolkit()).get(variants.getName());
final VCFInfoHeaderLine annInfo = vcfHeader.getInfoHeaderLine("ANN");
if(annInfo==null)
{
logger.warn("NO ANN in "+variants.getSource());
}
else
{
int q0=annInfo.getDescription().indexOf('\'');
int q1=annInfo.getDescription().lastIndexOf('\'');
if(q0==-1 || q1<=q0)
{
logger.warn("Cannot parse "+annInfo.getDescription());
}
else
{
final String fields[]=pipeRegex.split(annInfo.getDescription().substring(q0+1, q1));
for(int c=0;c<fields.length;++c)
{
final String column=fields[c].trim();
if(column.equals("Annotation_Impact"))
{
ann_impact_column=c;
}
else if(column.equals("Transcript_BioType"))
{
ann_transcript_biotype_column=c;
}
}
}
}
}
super.initialize();
}
@Override
public Map<CountPredictions.Category,Long> map(final RefMetaDataTracker tracker,final ReferenceContext ref, final AlignmentContext context) {
if ( tracker == null )return Collections.emptyMap();
final Map<CountPredictions.Category,Long> count = new HashMap<>();
for(final VariantContext ctx: tracker.getValues(this.variants,context.getLocation()))
{
final List<String> labels=new ArrayList<>();
if(bychrom) labels.add(ctx.getContig());
if(byID) labels.add(ctx.hasID()?"Y":".");
if(byType) labels.add(ctx.getType().name());
if(byFilter) labels.add(ctx.isFiltered()?"F":".");
if(minQuality>=0) {
labels.add(ctx.hasLog10PError() && ctx.getPhredScaledQual()>=this.minQuality ?
".":"LOWQUAL"
);
}
if(byImpact || bybiotype)
{
String biotype=null;
IMPACT impact=null;
final List<Object> anns = ctx.getAttributeAsList("ANN");
for(final Object anno:anns) {
final String tokens[]=this.pipeRegex.split(anno.toString());
if(this.ann_impact_column==-1 ||
this.ann_impact_column >= tokens.length ||
tokens[this.ann_impact_column].isEmpty()
) continue;
IMPACT currImpact = IMPACT.valueOf(tokens[this.ann_impact_column]);
if(impact!=null && currImpact.compareTo(impact)<0) continue;
impact=currImpact;
biotype=null;
if(this.ann_transcript_biotype_column==-1 ||
this.ann_transcript_biotype_column >= tokens.length ||
tokens[this.ann_transcript_biotype_column].isEmpty()
) continue;
biotype=tokens[ann_transcript_biotype_column];
}
if(byImpact) labels.add(impact==null?".":impact.name());
if(bybiotype) labels.add(biotype==null?".":biotype);
}
if(bynalts)labels.add(String.valueOf(ctx.getAlternateAlleles().size()));
if(byAffected || byCalled)
{
int nc=0;
int ng=0;
for(int i=0;i< ctx.getNSamples();++i)
{
final Genotype g= ctx.getGenotype(i);
if(!(g.isNoCall() || g.isHomRef()))
{
ng++;
}
if(g.isCalled())
{
nc++;
}
}
if(byCalled) labels.add(nc< maxSamples?String.valueOf(nc):"GE_"+maxSamples);
if(byAffected) labels.add(ng< maxSamples?String.valueOf(ng):"GE_"+maxSamples);
}
final Category cat=new Category(labels);
Long n=count.get(cat);
count.put(cat, n==null?1L:n+1);
}
return count;
}
@Override
public Map<CountPredictions.Category,Long> reduce(Map<CountPredictions.Category,Long> value, Map<CountPredictions.Category,Long> sum) {
final Map<CountPredictions.Category,Long> newmap = new HashMap<>(sum);
for(Category cat:value.keySet()) {
Long sv = sum.get(cat);
Long vv = value.get(cat);
newmap.put(cat, sv==null?vv:sv+vv);
}
return newmap;
}
@Override
public Map<CountPredictions.Category,Long> reduceInit() {
return Collections.emptyMap();
}
@Override
public void onTraversalDone(final Map<CountPredictions.Category,Long> counts) {
GATKReportTable table=new GATKReportTable(
"Variants", "Variants "+variants.getSource(),0);
if(bychrom) table.addColumn("CONTIG");
if(byID) table.addColumn("IN_DBSNP");
if(byType) table.addColumn("TYPE");
if(byFilter) table.addColumn("FILTER");
if(minQuality>=0) table.addColumn("QUAL_GE_"+this.minQuality);
if(byImpact) table.addColumn("IMPACT");
if(bybiotype) table.addColumn("BIOTYPE");
if(bynalts) table.addColumn("N_ALT_ALLELES");
if(byCalled) table.addColumn("CALLED_SAMPLES");
if(byAffected) table.addColumn("AFFECTED_SAMPLES");
table.addColumn("COUNT");
int nRows=0;
for(final Category cat: counts.keySet())
{
for(int x=0;x<cat.labels.size();++x)
{
table.set(nRows, x, cat.labels.get(x));
}
table.set(nRows, cat.labels.size(), counts.get(cat));
++nRows;
}
GATKReport report = new GATKReport();
report.addTable(table);
report.print(this.out);
out.flush();
logger.info("TraversalDone");
}
}
Compiling and testing
gatk.jar=/path/to/GenomeAnalysisTK.jar
VCF=input.vcf.gz
test: dist/mygatk.jar
java -cp ${gatk.jar}:dist/mygatk.jar \
org.broadinstitute.gatk.engine.CommandLineGATK \
-R /path/to/ref.fasta \
-T CountPredictions -V ${VCF} -ID
dist/mygatk.jar : ${gatk.jar} ./mygatk/CountPredictions.java
mkdir -p tmp dist
javac -d tmp -cp .:${gatk.jar} ./mygatk/CountPredictions.java
jar cfv $@ -C tmp .
rm -rf tmp
Output:
INFO 16:47:16,545 HelpFormatter - ----------------------------------------------------------------------------------
INFO 16:47:16,547 HelpFormatter - The Genome Analysis Toolkit (GATK) v3.6-0-g89b7209, Compiled 2016/06/01 22:27:29
INFO 16:47:16,547 HelpFormatter - Copyright (c) 2010-2016 The Broad Institute
INFO 16:47:16,547 HelpFormatter - For support and documentation go to https://www.broadinstitute.org/gatk
INFO 16:47:16,714 HelpFormatter - [Fri Jan 13 16:47:16 CET 2017] Executing on Linux 3.10.0-327.36.3.el7.x86_64 amd64
INFO 16:47:16,714 HelpFormatter - Java HotSpot(TM) 64-Bit Server VM 1.8.0_102-b14 JdkDeflater
INFO 16:47:16,717 HelpFormatter - Program Args: -R ref.fasta -T CountPredictions -V input.vcf.gz -ID
INFO 16:47:16,724 HelpFormatter - Date/Time: 2017/01/13 16:47:16
INFO 16:47:16,725 HelpFormatter - ----------------------------------------------------------------------------------
INFO 16:47:16,725 HelpFormatter - ----------------------------------------------------------------------------------
INFO 16:47:16,740 GenomeAnalysisEngine - Strictness is SILENT
INFO 16:47:16,829 GenomeAnalysisEngine - Downsampling Settings: Method: BY_SAMPLE, Target Coverage: 1000
WARN 16:47:16,881 IndexDictionaryUtils - Track variant doesn't have a sequence dictionary built in, skipping dictionary validation
INFO 16:47:16,936 GenomeAnalysisEngine - Preparing for traversal
INFO 16:47:16,940 GenomeAnalysisEngine - Done preparing for traversal
INFO 16:47:16,940 ProgressMeter - [INITIALIZATION COMPLETE; STARTING PROCESSING]
INFO 16:47:16,941 ProgressMeter - | processed | time | per 1M | | total | remaining
INFO 16:47:16,941 ProgressMeter - Location | sites | elapsed | sites | completed | runtime | runtime
INFO 16:47:46,946 ProgressMeter - chr22:26185095 167649.0 30.0 s 3.0 m 91.0% 32.0 s 2.0 s
#:GATKReport.v1.1:1
#:GATKTable:2:1:%s:%s:;
#:GATKTable:Variants:Variants input.vcf.gz
IN_DBSNP COUNT
. 139984
INFO 16:47:52,501 CountPredictions - TraversalDone
INFO 16:47:52,502 ProgressMeter - done 202931.0 35.0 s 2.9 m 91.1% 38.0 s 3.0 s
INFO 16:47:52,502 ProgressMeter - Total runtime 35.56 secs, 0.59 min, 0.01 hours
------------------------------------------------------------------------------------------
Done. There were 1 WARN messages, the first 1 are repeated below.
WARN 16:47:16,881 IndexDictionaryUtils - Track variant doesn't have a sequence dictionary built in, skipping dictionary validation
------------------------------------------------------------------------------------------