diff --git a/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStat.scala b/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStat.scala new file mode 100644 index 00000000..e1c916fc --- /dev/null +++ b/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStat.scala @@ -0,0 +1,165 @@ +package org.biodatageeks.sequila.flagStat + +import htsjdk.samtools.SAMRecord +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.biodatageeks.sequila.datasources.BAM.{BAMFileReader, BAMTableReader} +import org.seqdoop.hadoop_bam.{BAMBDGInputFormat, CRAMBDGInputFormat} +import org.slf4j.LoggerFactory +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.biodatageeks.sequila.datasources.InputDataType +import org.biodatageeks.sequila.inputformats.BDGAlignInputFormat +import org.biodatageeks.sequila.pileup.PileupMethods +import org.biodatageeks.sequila.pileup.conf.Conf +import org.biodatageeks.sequila.utils.{FileFuncs, TableFuncs} + +import scala.collection.mutable.ListBuffer + +case class FlagStatRow( + RCount: Long, + QCFail: Long, + DUPES: Long, + MAPPED: Long, + UNMAPPED: Long, + PiSEQ: Long, + Read1: Long, + Read2: Long, + PPaired: Long, + WIaMM: Long, + Singletons: Long +); + +case class FlagStat(spark:SparkSession) { + val Logger = LoggerFactory.getLogger(this.getClass.getCanonicalName); + + def processFile(bamFilePath: String) : DataFrame = { + val tableReader = new BAMFileReader[BAMBDGInputFormat](spark, bamFilePath, null); + val records = tableReader.readFile + processDF(processRows(records)) + } + + def processRows(records: RDD[SAMRecord]) : RDD[(String, Long)] = { + records.mapPartitions((m) => { + var RCount = 0L; + var QCFail = 0L; + var DUPES = 0L; + var UNMAPPED = 0L; + var MAPPED = 0L; + var PiSEQ = 0L; + var Read1 = 0L; + var Read2 = 0L; + var PPaired = 0L; + var WIaMM = 0L; + var Singletons = 0L; + + while (m.hasNext) { + val iter = m.next; + + if (iter.getReadFailsVendorQualityCheckFlag()) { + QCFail += 1; + } + if (iter.getDuplicateReadFlag()) { + DUPES += 1; + } + if (iter.getReadUnmappedFlag()) { + UNMAPPED += 1; + } else { + MAPPED += 1; + } + if (iter.getReadPairedFlag()) { + PiSEQ += 1; + if (iter.getSecondOfPairFlag()) { + Read2 += 1; + } else if(iter.getFirstOfPairFlag()) { + Read1 += 1; + } + if (iter.getProperPairFlag()) { + PPaired += 1; + } + if (!iter.getReadUnmappedFlag() && !iter.getMateUnmappedFlag()) { + WIaMM += 1; + } + if (!iter.getReadUnmappedFlag() && iter.getMateUnmappedFlag()) { + Singletons += 1; + } + } + RCount += 1; + } + + //Iterator(Row(RCount, QCFail, DUPES, MAPPED, UNMAPPED, PiSEQ, Read1, Read2, PPaired, WIaMM, Singletons)) + Iterator( + ("RCount", RCount), + ("QCFail", QCFail), + ("DUPES", DUPES), + ("MAPPED", MAPPED), + ("UNMAPPED", UNMAPPED), + ("PiSEQ", PiSEQ), + ("Read1", Read1), + ("Read2", Read2), + ("PPaired", PPaired), + ("WIaMM", WIaMM), + ("Singletons", Singletons) + ) + }).reduceByKey((v1, v2) => v1 + v2) + } + def processDF(rows: RDD[(String, Long)]): DataFrame = { + var mapping = rows.collectAsMap(); + var sequence = new ListBuffer[Long]; + FlagStat.Schema.fieldNames.foreach(x => { + sequence += mapping.get(x).get; + }) + val result = Row.fromSeq(sequence); + val rdd = spark.sparkContext.parallelize(Seq(result)); + spark.createDataFrame(rdd, FlagStat.Schema); + } + + + def handleFlagStat(tableNameOrPath: String, sampleId: String): RDD[(String, Long)] = { + if(sampleId != null) + Logger.info(s"Calculating flagStat on table: $tableNameOrPath") + else + Logger.info(s"Calculating flagStat using file: $tableNameOrPath") + + val (records) = { + if (sampleId != null) { + val metadata = TableFuncs.getTableMetadata(spark, tableNameOrPath) + val tableReader = metadata.provider match { + case Some(f) if sampleId != null => + if (f == InputDataType.BAMInputDataType) + new BAMTableReader[BAMBDGInputFormat](spark, tableNameOrPath, sampleId, "bam", None) + else throw new Exception("Only BAM file format is supported.") + case None => throw new Exception("Empty file extension - BAM file format is supported..") + } + tableReader + .readFile + } + else { + val fileReader = FileFuncs.getFileExtension(tableNameOrPath) match { + case "bam" => new BAMFileReader[BAMBDGInputFormat](spark, tableNameOrPath, None) + } + fileReader + .readFile + } + } + + processRows(records); + } +} + +object FlagStat { + val Schema = StructType(Array( + StructField("RCount", LongType, false), + StructField("QCFail", LongType, false), + StructField("DUPES", LongType, false), + StructField("MAPPED", LongType, false), + StructField("UNMAPPED", LongType, false), + StructField("PiSEQ", LongType, false), + StructField("Read1", LongType, false), + StructField("Read2", LongType, false), + StructField("PPaired", LongType, false), + StructField("WIaMM", LongType, false), + StructField("Singletons", LongType, false) + )); +} \ No newline at end of file diff --git a/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStatStrategy.scala b/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStatStrategy.scala new file mode 100644 index 00000000..aa1306d2 --- /dev/null +++ b/src/main/scala/org/biodatageeks/sequila/flagStat/FlagStatStrategy.scala @@ -0,0 +1,83 @@ +package org.biodatageeks.sequila.flagStat + +import okhttp3.logging.HttpLoggingInterceptor.Logger +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Encoders, FlagStatTemplate, PileupTemplate, Row, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.biodatageeks.sequila.datasources.BAM.BDGAlignFileReaderWriter +import org.biodatageeks.sequila.datasources.InputDataType +import org.biodatageeks.sequila.inputformats.BDGAlignInputFormat +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.biodatageeks.sequila.pileup.Pileup +import org.biodatageeks.sequila.pileup.conf.Conf +import org.biodatageeks.sequila.pileup.conf.QualityConstants.{DEFAULT_BIN_SIZE, DEFAULT_MAX_QUAL} +import org.biodatageeks.sequila.utils.{FileFuncs, InternalParams, TableFuncs} +import org.seqdoop.hadoop_bam.{BAMBDGInputFormat, CRAMBDGInputFormat} + +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag + +class FlagStatStrategy (spark:SparkSession) extends Strategy with Serializable { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan match { + case CreateDataSourceTableAsSelectCommand(table, mode, query, outputColumnNames) => { + table.storage.locationUri match { + case None => None + } + Nil + } + case InsertIntoHadoopFsRelationCommand(outputPath, staticPartitions, ifPartitionNotExists, partitionColumns, bucketSpec, fileFormat, options, query, mode, catalogTable, fileIndex, outputColumnNames) => { + Nil + } + case FlagStatTemplate(tableNameOrPath, sampleId, output) => { + val inputFormat = { + if (sampleId != null) + TableFuncs.getTableMetadata(spark, tableNameOrPath).provider + else if (FileFuncs.getFileExtension(tableNameOrPath) == "bam") Some(InputDataType.BAMInputDataType) + else None + } + inputFormat match { + case Some(f) => + if (f == InputDataType.BAMInputDataType) + FlagStatPlan[BAMBDGInputFormat](plan, spark, tableNameOrPath, sampleId, output) :: Nil + else Nil + case None => throw new RuntimeException("Only BAM file format is supported in flagStat function.") + } + } + case _ => Nil + } + } +} + +object FlagStatPlan extends Serializable { + +} +case class FlagStatPlan [T<:BDGAlignInputFormat](plan:LogicalPlan, spark:SparkSession, + tableNameOrPath:String, + sampleId:String, + output:Seq[Attribute] // ,directOrcWritePath: String = null)(implicit c: ClassTag[T] + ) + extends SparkPlan with Serializable with BDGAlignFileReaderWriter [T]{ + + override protected def otherCopyArgs: Seq[AnyRef] = Seq() + + override def children: Seq[SparkPlan] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val fs = new FlagStat(spark); + val rows = fs.handleFlagStat(tableNameOrPath, sampleId); + val mapping = rows.collectAsMap; + val fields = new ListBuffer[Long]; + FlagStat.Schema.fieldNames.foreach(x => { + fields += mapping(x); + }) + val result = InternalRow.fromSeq(fields); + spark.sparkContext.parallelize(Seq(result)); + } +} diff --git a/src/main/scala/org/biodatageeks/sequila/flagStat/flagStatDebugger.scala b/src/main/scala/org/biodatageeks/sequila/flagStat/flagStatDebugger.scala new file mode 100644 index 00000000..b0dd3ba0 --- /dev/null +++ b/src/main/scala/org/biodatageeks/sequila/flagStat/flagStatDebugger.scala @@ -0,0 +1,34 @@ +package org.biodatageeks.sequila.flagStat + +import org.apache.spark.sql.{DataFrame, SequilaSession, SparkSession} +import org.apache.spark.storage.StorageLevel +import org.biodatageeks.sequila.utils.InternalParams +import org.slf4j.LoggerFactory + +object FlagStatDebuggerEntryPoint { + val sampleId = "NA12878.proper.wes.md" + val bamFilePath: String = s"/Users/mwiewior/research/data/WES/${sampleId}.bam"; + + + + def main(args: Array[String]): Unit = { + performance(bamFilePath, null); + } + + def performance(tableNameOrPath: String, sampleId: String): Unit = { + System.setSecurityManager(null); + val spark = SparkSession + .builder() + .master("local[4]") + .config("spark.driver.memory","8g") + .config("spark.biodatageeks.bam.validation", "SILENT") + .config("spark.biodatageeks.readAligment.method", "hadoopBAM") + .config("spark.biodatageeks.bam.useGKLInflate", "true") + .getOrCreate(); + val ss = SequilaSession(spark); + ss.time { + ss.flagStat(tableNameOrPath, sampleId).show(); + } + ss.stop() + } +} \ No newline at end of file diff --git a/src/main/scala/org/biodatageeks/sequila/utvf/ResolveTableValuedFunctionsSeq.scala b/src/main/scala/org/biodatageeks/sequila/utvf/ResolveTableValuedFunctionsSeq.scala index 062f6ca6..c81db035 100644 --- a/src/main/scala/org/biodatageeks/sequila/utvf/ResolveTableValuedFunctionsSeq.scala +++ b/src/main/scala/org/biodatageeks/sequila/utvf/ResolveTableValuedFunctionsSeq.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.util.Locale - import org.apache.spark.sql.ResolveTableValuedFunctionsSeq.tvf import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, TypeCoercion, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression} @@ -28,10 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, _} +import org.biodatageeks.sequila.flagStat.FlagStat import org.biodatageeks.sequila.pileup.conf.QualityConstants import org.biodatageeks.sequila.utils.Columns - /** * Rule that resolves table-valued function references. */ @@ -109,6 +108,19 @@ object ResolveTableValuedFunctionsSeq extends Rule[LogicalPlan] { } ), + "flagstat" -> Map( + /* flagStat(tableName) */ + tvf("tableName" -> StringType) + { case Seq(tableName: Any) => + FlagStatTemplate(tableName.toString, null) + }, + /* flagStat(tableNameOrPath, sampleId) */ + tvf("tableNameOrPath" -> StringType, "sampleId" -> StringType) + { case Seq(tableNameOrPath: Any, sampleId: Any) => + FlagStatTemplate(tableNameOrPath.toString, sampleId.toString) + } + ), + "coverage" -> Map( /* coverage(tableName) */ tvf("table" -> StringType, "sampleId" -> StringType, "refPath" -> StringType) @@ -170,6 +182,35 @@ object ResolveTableValuedFunctionsSeq extends Rule[LogicalPlan] { } } +case class FlagStatTemplate(tableNameOrPath: String, sampleId: String, output: Seq[Attribute]) + extends LeafNode with MultiInstanceRelation { + + override def newInstance(): FlagStatTemplate = copy(output = output.map(_.newInstance())) + + def toSQL(): String = { + s""" + SELECT RCount, QCFail, DUPES, MAPPED, UNMAPPED, PiSEQ, Read1, Read2, PPAired, WIaMM, Singletons + AS `${output.head.name}` + FROM flagStat('$tableNameOrPath')""" + } + + override def toString: String = { + s"FlagStatFunction ('$tableNameOrPath')" + } +} + +object FlagStatTemplate { + private def output() = { + FlagStat.Schema.toAttributes; + } + + def apply(tableNameOrPath: String, sampleId: String) = { + new FlagStatTemplate(tableNameOrPath, sampleId, output()); + } +} + + + object PileupTemplate { diff --git a/src/main/scala/org/biodatageeks/sequila/utvf/SequilaSession.scala b/src/main/scala/org/biodatageeks/sequila/utvf/SequilaSession.scala index 931cb6c7..96e62e00 100644 --- a/src/main/scala/org/biodatageeks/sequila/utvf/SequilaSession.scala +++ b/src/main/scala/org/biodatageeks/sequila/utvf/SequilaSession.scala @@ -10,6 +10,7 @@ import org.apache.spark.sql.execution.datasources.SequilaDataSourceStrategy import org.apache.spark.sql.functions.{lit, typedLit} import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.{ArrayType, ByteType, MapType, ShortType} +import org.biodatageeks.sequila.flagStat.{FlagStatRow, FlagStatStrategy} import org.biodatageeks.sequila.pileup.PileupStrategy import org.biodatageeks.sequila.rangejoins.IntervalTree.IntervalTreeJoinStrategyOptim import org.biodatageeks.sequila.utils.{InternalParams, UDFRegister} @@ -33,8 +34,8 @@ object SequilaSession { new SequilaDataSourceStrategy(spark), new IntervalTreeJoinStrategyOptim(spark), new PileupStrategy(spark), + new FlagStatStrategy(spark), new GenomicIntervalStrategy(spark) - ) /*Set params*/ spark @@ -111,6 +112,16 @@ case class SequilaSession(sparkSession: SparkSession) extends SparkSession(spark new Dataset(sparkSession, PileupTemplate(path, refPath, true, quals), Encoders.kryo[Row]) }.as[Pileup] + /** + * Calculate flagStat + * + * @param tableNameOrPath BAM file path or Table + * @param sampleId sample id + * @return pileup as Dataset[FlagStatRow] + */ + def flagStat(tableNameOrPath: String, sampleId: String) : Dataset[FlagStatRow] ={ + new Dataset(sparkSession, FlagStatTemplate(tableNameOrPath, sampleId), Encoders.kryo[Row]) + }.as[FlagStatRow] } diff --git a/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestBase.scala b/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestBase.scala new file mode 100644 index 00000000..df9bcbbd --- /dev/null +++ b/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestBase.scala @@ -0,0 +1,50 @@ +package org.biodatageeks.sequila.tests.flagStat +import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext} +import org.apache.spark.sql.{SequilaSession, SparkSession} +import org.apache.spark.storage.StorageLevel +import org.biodatageeks.sequila.utils.InternalParams +import org.scalatest.{BeforeAndAfter, FunSuite} + +import java.io.File +import scala.reflect.io.Directory + +class FlagStatTestBase extends FunSuite + with DataFrameSuiteBase + with BeforeAndAfter + with SharedSparkContext{ + + System.setSecurityManager(null) + + val sampleId = "NA12878.multichrom.md" + val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath; + val tableName = "reads_bam" + var ss : SequilaSession = null; + + def cleanup(dir: String) = { + val directory = new Directory(new File(dir)) + directory.deleteRecursively() + } + + val flagStatQuery = + s""" + |SELECT * + |FROM flagstat('$tableName', '${sampleId}') + """.stripMargin + + before { + spark.sqlContext.setConf(InternalParams.SerializationMode, StorageLevel.DISK_ONLY.toString()); + spark.conf.set("spark.sql.shuffle.partitions", 1); + ss = SequilaSession(spark); + ss.sparkContext.setLogLevel("ERROR"); + ss.sqlContext.setConf(InternalParams.BAMValidationStringency, "SILENT"); + ss.sqlContext.setConf(InternalParams.UseIntelGKL, "true"); + ss.sqlContext.setConf(InternalParams.IOReadAlignmentMethod, "hadoopBAM"); + ss.sql(s"DROP TABLE IF EXISTS $tableName"); + ss.sql(s""" + |CREATE TABLE $tableName + |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource + |OPTIONS(path "$bamPath") + | + """.stripMargin); + } +} diff --git a/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestSuite.scala b/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestSuite.scala new file mode 100644 index 00000000..9d62a1c5 --- /dev/null +++ b/src/test/scala/org/biodatageeks/sequila/tests/flagStat/FlagStatTestSuite.scala @@ -0,0 +1,57 @@ +package org.biodatageeks.sequila.tests.flagStat + +import org.apache.avro.io.Encoder +import org.apache.spark.sql.{DataFrame, SequilaSession} +import org.biodatageeks.sequila.flagStat.{FlagStat, FlagStatRow} +import org.biodatageeks.sequila.utils.Columns +import org.biodatageeks.sequila.tests.flagStat.FlagStatTestBase; + +class FlagStatTestSuite extends FlagStatTestBase { + test("Against a stable implementation (samtools)") { + val result = ss.sql(flagStatQuery); + performAssertions(result); + println("Base Test (against known-working implementation) passed"); + } + + test("Dataframe API testing") { + val testDF = ss.flagStat(tableName, sampleId).toDF(); + val rdd = ss.sparkContext.parallelize(Expected.toSeq); + val exptDF = FlagStat(ss).processDF(rdd); + assertDataFrameEquals(exptDF, testDF); + println("DF API Test passed"); + } + + test("SQL DF against API testing") { + val result = ss.sql(flagStatQuery); + val rdd = ss.sparkContext.parallelize(Expected.toSeq); + val exptDF = FlagStat(ss).processDF(rdd); + assertDataFrameEquals(exptDF, result); + println("SQL DF against API Test passed"); + } + + val Expected = Map[String, Long]( + "RCount" -> 22607, + "QCFail" -> 0, + "DUPES" -> 1532, + "MAPPED" -> 22277, + "UNMAPPED" -> 330, + "PiSEQ" -> 22607, + "Read1" -> 11309, + "Read2" -> 11298, + "PPaired" -> 21647, + "WIaMM" -> 21924, + "Singletons" -> 353 + ); + + private def performAssertions(df:DataFrame):Unit ={ + assert(df.count() == 1); + val obtained = df.where("RCount > 0"); + assert(obtained.count(), 1); + val results = obtained.first(); + Expected.foreach(kv => { + println(s"Checking: ${kv._1} == ${kv._2}..."); + val index = results.fieldIndex(kv._1); + assert(results.getLong(index), kv._2); + }) + } +}