Skip to content
代码片段 群组 项目
CollapsedCellOptimizer.cpp 35.7 KB
Newer Older
#include "CollapsedCellOptimizer.hpp"
#include "EMUtils.hpp"
#include <assert.h>

CollapsedCellOptimizer::CollapsedCellOptimizer() {}
/*
 * Use the "relax" EM algorithm over gene equivalence
 * classes to estimate the latent variables (alphaOut)
 * given the current estimates (alphaIn).
 */
void CellEMUpdate_(std::vector<SalmonEqClass>& eqVec,
k3yavi's avatar
k3yavi 已提交
                   const CollapsedCellOptimizer::SerialVecType& alphaIn,
                   CollapsedCellOptimizer::SerialVecType& alphaOut) {
  assert(alphaIn.size() == alphaOut.size());

  for (size_t eqID=0; eqID < eqVec.size(); eqID++) {
    auto& kv = eqVec[eqID];

    uint32_t count = kv.count;
    // for each label in this class
    const std::vector<uint32_t>& genes = kv.labels;
    size_t groupSize = genes.size();

    if (BOOST_LIKELY(groupSize > 1)) {
      double denom = 0.0;
      for (size_t i = 0; i < groupSize; ++i) {
        auto gid = genes[i];
        denom += alphaIn[gid];
      }

      if (denom > 0.0) {
        double invDenom = count / denom;
        for (size_t i = 0; i < groupSize; ++i) {
          auto gid = genes[i];
          double v = alphaIn[gid];
          if (!std::isnan(v)) {
            alphaOut[gid] += v * invDenom;
          }
        }//end-for
      }//endif for denom>0
    }//end-if boost gsize>1
    else if (groupSize == 1){
      alphaOut[genes.front()] += count;
    }
    else{
      std::cerr<<"0 Group size for salmonEqclasses in EM\n"
               <<"Please report this on github";
      exit(1);

double truncateAlphas(VecT& alphas, double cutoff) {
  // Truncate tiny expression values
  double alphaSum = 0.0;

  for (size_t i = 0; i < alphas.size(); ++i) {
k3yavi's avatar
k3yavi 已提交
    if (alphas[i] < cutoff) {
      alphas[i] = 0.0;
    }
    alphaSum += alphas[i];
  }
  return alphaSum;
}

bool runPerCellEM(double& totalNumFrags, size_t numGenes,
                  CollapsedCellOptimizer::SerialVecType& alphas,
                  std::vector<SalmonEqClass>& salmonEqclasses,
                  std::shared_ptr<spdlog::logger>& jointlog,
                  bool initUniform){

  // An EM termination criterion, adopted from Bray et al. 2016
  uint32_t minIter {50};
  double relDiffTolerance {0.01};
  uint32_t maxIter {10000};
  size_t numClasses = salmonEqclasses.size();
  if ( initUniform ) {
    double uniformPrior = 1.0/numGenes;
    std::fill(alphas.begin(), alphas.end(), uniformPrior);
  }
  CollapsedCellOptimizer::SerialVecType alphasPrime(numGenes, 0.0);
  assert( numGenes == alphas.size() );
  for (size_t i = 0; i < numGenes; ++i) {
    alphas[i] += 0.5;
    alphas[i] *= 1e-3;
  }

  bool converged{false};
  double maxRelDiff = -std::numeric_limits<double>::max();
  size_t itNum = 0;

  // EM termination criteria, adopted from Bray et al. 2016
k3yavi's avatar
k3yavi 已提交
  double minAlpha = 1e-8;
  double alphaCheckCutoff = 1e-2;
  constexpr double minWeight = std::numeric_limits<double>::denorm_min();

  while (itNum < minIter or (itNum < maxIter and !converged)) {
    CellEMUpdate_(salmonEqclasses, alphas, alphasPrime);

    converged = true;
    maxRelDiff = -std::numeric_limits<double>::max();
      if (alphasPrime[i] > alphaCheckCutoff) {
        double relDiff =
          std::abs(alphas[i] - alphasPrime[i]) / alphasPrime[i];
        maxRelDiff = (relDiff > maxRelDiff) ? relDiff : maxRelDiff;
        if (relDiff > relDiffTolerance) {
          converged = false;
        }
      }
      alphas[i] = alphasPrime[i];
      alphasPrime[i] = 0.0;
    }

    ++itNum;
  }

  // Truncate tiny expression values
  totalNumFrags = truncateAlphas(alphas, minAlpha);
  if (totalNumFrags < minWeight) {
    jointlog->error("Total alpha weight was too small! "
k3yavi's avatar
k3yavi 已提交
                    "Make sure you ran salmon correctly.");
k3yavi's avatar
k3yavi 已提交
    jointlog->flush();
    return false;
  }

  return true;
}

bool runBootstraps(size_t numGenes,
k3yavi's avatar
k3yavi 已提交
                   CollapsedCellOptimizer::SerialVecType& geneAlphas,
                   std::vector<SalmonEqClass>& salmonEqclasses,
                   std::shared_ptr<spdlog::logger>& jointlog,
                   uint32_t numBootstraps,
k3yavi's avatar
k3yavi 已提交
                   CollapsedCellOptimizer::SerialVecType& variance,
                   bool useAllBootstraps,
                   std::vector<std::vector<double>>& sampleEstimates,
                   bool initUniform){
k3yavi's avatar
k3yavi 已提交

  // An EM termination criterion, adopted from Bray et al. 2016
  uint32_t minIter {50};
  double relDiffTolerance {0.01};
  uint32_t maxIter {10000};
  size_t numClasses = salmonEqclasses.size();

  CollapsedCellOptimizer::SerialVecType mean(numGenes, 0.0);
  CollapsedCellOptimizer::SerialVecType squareMean(numGenes, 0.0);
  CollapsedCellOptimizer::SerialVecType alphas(numGenes, 0.0);
  CollapsedCellOptimizer::SerialVecType alphasPrime(numGenes, 0.0);

  //extracting weight of eqclasses for making discrete distribution
k3yavi's avatar
k3yavi 已提交
  std::vector<uint64_t> eqCounts;
k3yavi's avatar
k3yavi 已提交
  for (auto& eqclass: salmonEqclasses) {
    totalNumFrags += eqclass.count;
    eqCounts.emplace_back(eqclass.count);
  }

k3yavi's avatar
k3yavi 已提交
  std::random_device rd;
  std::mt19937 gen(rd());
  std::discrete_distribution<uint64_t> csamp(eqCounts.begin(),
                                             eqCounts.end());

k3yavi's avatar
k3yavi 已提交
  while ( bsNum++ < numBootstraps) {
    csamp.reset();

k3yavi's avatar
k3yavi 已提交
    for (size_t sc = 0; sc < numClasses; ++sc) {
      salmonEqclasses[sc].count = 0;
k3yavi's avatar
k3yavi 已提交
    }

    for (size_t fn = 0; fn < totalNumFrags; ++fn) {
k3yavi's avatar
k3yavi 已提交
      salmonEqclasses[csamp(gen)].count += 1;
k3yavi's avatar
k3yavi 已提交
    }

    for (size_t i = 0; i < numGenes; ++i) {
      if ( initUniform ) {
        alphas[i] = 1.0 / numGenes;
      } else {
        alphas[i] = (geneAlphas[i] + 0.5) * 1e-3;
      }
k3yavi's avatar
k3yavi 已提交
    }

    bool converged{false};
    double maxRelDiff = -std::numeric_limits<double>::max();
    size_t itNum = 0;

    // EM termination criteria, adopted from Bray et al. 2016
    double minAlpha = 1e-8;
    double alphaCheckCutoff = 1e-2;
    constexpr double minWeight = std::numeric_limits<double>::denorm_min();

    while (itNum < minIter or (itNum < maxIter and !converged)) {
      CellEMUpdate_(salmonEqclasses, alphas, alphasPrime);

      converged = true;
      maxRelDiff = -std::numeric_limits<double>::max();
      for (size_t i = 0; i < numGenes; ++i) {
        if (alphasPrime[i] > alphaCheckCutoff) {
          double relDiff =
            std::abs(alphas[i] - alphasPrime[i]) / alphasPrime[i];
          maxRelDiff = (relDiff > maxRelDiff) ? relDiff : maxRelDiff;
          if (relDiff > relDiffTolerance) {
            converged = false;
          }
        }
        alphas[i] = alphasPrime[i];
        alphasPrime[i] = 0.0;
      }

      ++itNum;
k3yavi's avatar
k3yavi 已提交

    // Truncate tiny expression values
    double alphaSum = 0.0;
    // Truncate tiny expression values
    alphaSum = truncateAlphas(alphas, minAlpha);

    if (alphaSum < minWeight) {
      jointlog->error("Total alpha weight was too small! "
                      "Make sure you ran salmon correclty.");
      jointlog->flush();
      return false;
    }

    for(size_t i=0; i<numGenes; i++) {
      double alpha = alphas[i];
      mean[i] += alpha;
      squareMean[i] += alpha * alpha;
    }
k3yavi's avatar
k3yavi 已提交

    if (useAllBootstraps) {
      sampleEstimates.emplace_back(alphas);
    }
k3yavi's avatar
k3yavi 已提交

  // calculate mean and variance of the values
  for(size_t i=0; i<numGenes; i++) {
    double meanAlpha = mean[i] / numBootstraps;
    geneAlphas[i] = meanAlpha;
    variance[i] = (squareMean[i]/numBootstraps) - (meanAlpha*meanAlpha);
k3yavi's avatar
k3yavi 已提交
void optimizeCell(std::vector<std::string>& trueBarcodes,
                  std::atomic<uint32_t>& barcode,
k3yavi's avatar
k3yavi 已提交
                  size_t totalCells, eqMapT& eqMap,
k3yavi's avatar
k3yavi 已提交
                  std::deque<std::pair<TranscriptGroup, uint32_t>>& orderedTgroup,
                  std::shared_ptr<spdlog::logger>& jointlog,
                  bfs::path& outDir, std::vector<uint32_t>& umiCount,
                  std::vector<CellState>& skippedCB,
                  bool verbose, GZipWriter& gzw, size_t umiLength, bool noEM,
                  bool quiet, tbb::atomic<double>& totalDedupCounts,
k3yavi's avatar
k3yavi 已提交
                  tbb::atomic<uint32_t>& totalExpGeneCounts,
                  spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                  uint32_t numGenes, uint32_t numBootstraps,
                  bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps,
                  bool initUniform, CFreqMapT& freqCounter,
                  spp::sparse_hash_set<uint32_t>& mRnaGenes,
                  spp::sparse_hash_set<uint32_t>& rRnaGenes,
                  std::atomic<uint64_t>& totalUniEdgesCounts,
k3yavi's avatar
k3yavi 已提交
                  std::atomic<uint64_t>& totalBiEdgesCounts){
  size_t numCells {trueBarcodes.size()};
  size_t trueBarcodeIdx;
Rob Patro's avatar
Rob Patro 已提交

k3yavi's avatar
k3yavi 已提交
  // looping over until all the cells
Rob Patro's avatar
Rob Patro 已提交
  while((trueBarcodeIdx = barcode++) < totalCells) {
    // per-cell level optimization
    if ( umiCount[trueBarcodeIdx] == 0 ) {
      //skip the barcode if no mapped UMI
      skippedCB[trueBarcodeIdx].inActive = true;
k3yavi's avatar
k3yavi 已提交
    // extracting the sequence of the barcode
    auto& trueBarcodeStr = trueBarcodes[trueBarcodeIdx];

    //extracting per-cell level eq class information
k3yavi's avatar
k3yavi 已提交
    double totalExpGenes{0};
    std::vector<uint32_t> eqIDs;
k3yavi's avatar
k3yavi 已提交
    std::vector<uint32_t> eqCounts;
    std::vector<UGroupT> umiGroups;
    std::vector<tgrouplabelt> txpGroups;
    std::vector<double> geneAlphas(numGenes, 0.0);
k3yavi's avatar
k3yavi 已提交
    std::vector<uint8_t> tiers (numGenes, 0);

    for (auto& key : orderedTgroup) {
      //traversing each class and copying relevant data.
k3yavi's avatar
k3yavi 已提交
      bool isKeyPresent = eqMap.find_fn(key.first, [&](const SCTGValue& val){
          auto& bg = val.barcodeGroup;
          auto bcIt = bg.find(trueBarcodeIdx);

k3yavi's avatar
k3yavi 已提交
          // sub-selecting bgroup of this barcode only
          if (bcIt != bg.end()){
k3yavi's avatar
k3yavi 已提交
            // extracting txp labels
k3yavi's avatar
k3yavi 已提交
            const std::vector<uint32_t>& txps = key.first.txps;
k3yavi's avatar
k3yavi 已提交

            // original counts of the UMI
k3yavi's avatar
k3yavi 已提交
            uint32_t eqCount {0};
            for(auto& ugroup: bcIt->second){
              eqCount += ugroup.second;
k3yavi's avatar
k3yavi 已提交
            }
k3yavi's avatar
k3yavi 已提交

k3yavi's avatar
k3yavi 已提交
            txpGroups.emplace_back(txps);
            umiGroups.emplace_back(bcIt->second);

            // for dumping per-cell eqclass vector
k3yavi's avatar
k3yavi 已提交
            if(verbose){
              eqIDs.push_back(static_cast<uint32_t>(key.second));
k3yavi's avatar
k3yavi 已提交
            }
          }
        });

      if(!isKeyPresent){
        jointlog->error("Not able to find key in Cuckoo hash map."
                        "Please Report this issue on github");
k3yavi's avatar
k3yavi 已提交
        jointlog->flush();
k3yavi's avatar
k3yavi 已提交
    if ( !naiveEqclass ) {
      // perform the UMI deduplication step
      std::vector<SalmonEqClass> salmonEqclasses;
      spp::sparse_hash_map<uint16_t, uint32_t> numMolHash;
k3yavi's avatar
k3yavi 已提交
      bool dedupOk = dedupClasses(geneAlphas, totalCount, txpGroups,
                                  umiGroups, salmonEqclasses,
k3yavi's avatar
k3yavi 已提交
                                  txpToGeneMap, tiers, gzw,
                                  dumpUmiGraph, trueBarcodeStr, numMolHash,
k3yavi's avatar
k3yavi 已提交
                                  totalUniEdgesCounts, totalBiEdgesCounts);
k3yavi's avatar
k3yavi 已提交
      if( !dedupOk ){
        jointlog->error("Deduplication for cell {} failed \n"
                        "Please Report this on github.", trueBarcodeStr);
        jointlog->flush();
        std::exit(74);
k3yavi's avatar
k3yavi 已提交
      if ( numBootstraps and noEM ) {
        jointlog->error("Cannot perform bootstrapping with noEM");
k3yavi's avatar
k3yavi 已提交
        exit(1);
      }

      // perform EM for resolving ambiguity
      if ( !noEM ) {
        bool isEMok = runPerCellEM(totalCount,
                                   numGenes,
                                   geneAlphas,
                                   salmonEqclasses,
                                   jointlog,
                                   initUniform);
k3yavi's avatar
k3yavi 已提交
        if( !isEMok ){
          jointlog->error("EM iteration for cell {} failed \n"
                          "Please Report this on github.", trueBarcodeStr);
          jointlog->flush();
          std::exit(74);
      uint8_t featureCode {0};
      {
        std::stringstream featuresStream;
        featuresStream << trueBarcodeStr;

        // Making features
        double totalUmiCount {0.0};
        double maxNumUmi {0.0};
        for (auto count: geneAlphas) {
k3yavi's avatar
k3yavi 已提交
          if (count>0.0) {
            totalUmiCount += count;
            totalExpGenes += 1;
            if (count > maxNumUmi) { maxNumUmi = count; }
          }
        }

        uint32_t numGenesOverMean {0};
        double mitoCount {0.0}, riboCount {0.0};
        double meanNumUmi {totalUmiCount / totalExpGenes};
        double meanByMax = maxNumUmi ? meanNumUmi / maxNumUmi : 0.0;
        for (size_t j=0; j<geneAlphas.size(); j++){
          auto count = geneAlphas[j];
          if (count > meanNumUmi) { ++numGenesOverMean; }

          if (mRnaGenes.contains(j)){
            mitoCount += count;
          }
          if (rRnaGenes.contains(j)){
            riboCount += count;
          }
        }

        auto indexIt = freqCounter.find(trueBarcodeStr);
        bool indexOk = indexIt != freqCounter.end();
        if ( not indexOk ){
          jointlog->error("Error: index {} not found in freq Counter\n"
                          "Please Report the issue on github", trueBarcodeStr);
        }

        uint64_t numRawReads = *indexIt;
        uint64_t numMappedReads { umiCount[trueBarcodeIdx] };
        double mappingRate = numRawReads ?
          numMappedReads / static_cast<double>(numRawReads) : 0.0;
        double deduplicationRate = numMappedReads ?
          1.0 - (totalUmiCount / numMappedReads) : 0.0;

        // Feature created after discussion with Mehrtash
        double averageNumMolPerArbo {0.0};
        size_t totalNumArborescence {0};
        std::stringstream arboString ;
        for (auto& it: numMolHash) {
          totalNumArborescence += it.second;
          averageNumMolPerArbo += (it.first * it.second);
          if (dumpUmiGraph) {
            arboString << "\t" << it.first << ":" << it.second;
          }
        }
        averageNumMolPerArbo /= totalNumArborescence;

        featuresStream << "\t" << numRawReads
                       << "\t" << numMappedReads
                       << "\t" << totalUmiCount
                       << "\t" << mappingRate
                       << "\t" << deduplicationRate
                       << "\t" << meanByMax
                       << "\t" << totalExpGenes
                       << "\t" << numGenesOverMean;

        if (mRnaGenes.size() > 1) {
          featureCode += 1;
          featuresStream << "\t" << mitoCount / totalUmiCount;
        }

        if (rRnaGenes.size() > 1) {
          featureCode += 2;
          featuresStream << "\t" << riboCount / totalUmiCount;
        }
k3yavi's avatar
k3yavi 已提交
        if (dumpUmiGraph) {
          featuresStream << arboString.rdbuf();
        } else {
          featuresStream << "\t" << averageNumMolPerArbo;
        }

        features = featuresStream.str();
      } // end making features


      // write the abundance for the cell
      gzw.writeSparseAbundances( trueBarcodeStr,
                                 features,
                                 featureCode,
                                 geneAlphas,
                                 tiers,
                                 dumpUmiGraph );
k3yavi's avatar
k3yavi 已提交

k3yavi's avatar
k3yavi 已提交
      // maintaining count for total number of predicted UMI
      salmon::utils::incLoop(totalDedupCounts, totalCount);
k3yavi's avatar
k3yavi 已提交
      totalExpGeneCounts += totalExpGenes;
k3yavi's avatar
k3yavi 已提交

      if ( numBootstraps > 0 ){
k3yavi's avatar
k3yavi 已提交
        std::vector<std::vector<double>> sampleEstimates;
k3yavi's avatar
k3yavi 已提交
        std::vector<double> bootVariance(numGenes, 0.0);
k3yavi's avatar
k3yavi 已提交
        bool isBootstrappingOk = runBootstraps(numGenes,
                                               geneAlphas,
                                               salmonEqclasses,
                                               jointlog,
                                               numBootstraps,
k3yavi's avatar
k3yavi 已提交
                                               bootVariance,
                                               useAllBootstraps,
                                               sampleEstimates,
                                               initUniform);
k3yavi's avatar
k3yavi 已提交
        if( not isBootstrappingOk or
            (useAllBootstraps and sampleEstimates.size()!=numBootstraps)
            ){
k3yavi's avatar
k3yavi 已提交
          jointlog->error("Bootstrapping failed \n"
                          "Please Report this on github.");
          jointlog->flush();
          std::exit(74);
k3yavi's avatar
k3yavi 已提交
        }

        // write the abundance for the cell
        gzw.writeSparseBootstraps( trueBarcodeStr,
                                   geneAlphas, bootVariance,
                                   useAllBootstraps, sampleEstimates);
k3yavi's avatar
k3yavi 已提交
      }//end-if
    }
    else {
      // doing per eqclass level naive deduplication
      for (size_t eqId=0; eqId<umiGroups.size(); eqId++) {
        spp::sparse_hash_set<uint64_t> umis;

        for(auto& it: umiGroups[eqId]) {
          umis.insert( it.first );
        }
        totalCount += umis.size();

        // filling in the eqclass level deduplicated counts
        if (verbose) {
          eqCounts[eqId] = umis.size();
        }
      }

      // maintaining count for total number of predicted UMI
      salmon::utils::incLoop(totalDedupCounts, totalCount);
    }

    if (verbose) {
      gzw.writeCellEQVec(trueBarcodeIdx, eqIDs, eqCounts, true);
    }
k3yavi's avatar
k3yavi 已提交
    //printing on screen progress
    const char RESET_COLOR[] = "\x1b[0m";
    char green[] = "\x1b[30m";
    green[3] = '0' + static_cast<char>(fmt::GREEN);
    char red[] = "\x1b[30m";
    red[3] = '0' + static_cast<char>(fmt::RED);

rob-p's avatar
rob-p 已提交
    double cellCount {static_cast<double>(barcode)};//numCells-jqueue.size_approx()};
    if (cellCount > totalCells) { cellCount = totalCells; }
    double percentCompletion {cellCount*100/numCells};
k3yavi's avatar
k3yavi 已提交
    if (not quiet){
      fmt::print(stderr, "\033[A\r\r{}Analyzed {} cells ({}{}%{} of all).{}\n",
                 green, cellCount, red, round(percentCompletion), green, RESET_COLOR);
    }
  }
}

template <typename ProtocolT>
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,
                                      AlevinOpts<ProtocolT>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
k3yavi's avatar
k3yavi 已提交
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode){
  size_t numCells = trueBarcodes.size();
  size_t numGenes = geneIdxMap.size();
  size_t numWorkerThreads{1};
  bool hasWhitelist = boost::filesystem::exists(aopt.whitelistFile);

  if (aopt.numThreads > 1) {
    numWorkerThreads = aopt.numThreads - 1;
  }

  //get the keys of the map
k3yavi's avatar
k3yavi 已提交
  std::deque<std::pair<TranscriptGroup, uint32_t>> orderedTgroup;
k3yavi's avatar
k3yavi 已提交
  //spp::sparse_hash_set<uint64_t> uniqueUmisCounter;
k3yavi's avatar
k3yavi 已提交
  uint32_t eqId{0};
  for(const auto& kv : fullEqMap.lock_table()){
k3yavi's avatar
k3yavi 已提交
    // assuming the iteration through lock table is always same
    if(kv.first.txps.size() == 1){
k3yavi's avatar
k3yavi 已提交
      orderedTgroup.push_front(std::make_pair(kv.first, eqId));
k3yavi's avatar
k3yavi 已提交
      orderedTgroup.push_back(std::make_pair(kv.first, eqId));
k3yavi's avatar
k3yavi 已提交
    eqId++;
  if (aopt.noEM) {
    aopt.jointLog->warn("Not performing EM; this may result in discarding ambiguous reads\n");
    aopt.jointLog->flush();
  }

  if (aopt.initUniform) {
    aopt.jointLog->warn("Using uniform initialization for EM");
    aopt.jointLog->flush();
  }

  spp::sparse_hash_set<uint32_t> mRnaGenes, rRnaGenes;
  bool useMito {false}, useRibo {false};
  if(boost::filesystem::exists(aopt.mRnaFile)) {
    std::ifstream mRnaFile(aopt.mRnaFile.string());
    std::string gene;
    size_t skippedGenes {0};
    if(mRnaFile.is_open()) {
      while(getline(mRnaFile, gene)) {
        if (geneIdxMap.contains(gene)){
          mRnaGenes.insert(geneIdxMap[ gene ]);
        }
        else{
          skippedGenes += 1;
        }
      }
      mRnaFile.close();
    }
    if (skippedGenes > 0){
      aopt.jointLog->warn("{} mitorna gene(s) does not have transcript in the reference",
                          skippedGenes);
    }
    aopt.jointLog->info("Total {} usable mRna genes", mRnaGenes.size());
    if (mRnaGenes.size() > 0 ) { useMito = true; }
  else if (hasWhitelist) {
    aopt.jointLog->warn("mrna file not provided; using is 1 less feature for whitelisting");
  }

  if(boost::filesystem::exists(aopt.rRnaFile)){
    std::ifstream rRnaFile(aopt.rRnaFile.string());
    std::string gene;
    size_t skippedGenes {0};
    if(rRnaFile.is_open()) {
      while(getline(rRnaFile, gene)) {
        if (geneIdxMap.contains(gene)){
          rRnaGenes.insert(geneIdxMap[ gene ]);
        }
        else{
          skippedGenes += 1;
        }
      }
      rRnaFile.close();
    }
    if (skippedGenes > 0){
      aopt.jointLog->warn("{} ribosomal rna gene(s) does not have transcript in the reference",
                          skippedGenes);
    }
    aopt.jointLog->info("Total {} usable rRna genes", rRnaGenes.size());
    if (rRnaGenes.size() > 0 ) { useRibo = true; }
  else if (hasWhitelist) {
    aopt.jointLog->warn("rrna file not provided; using is 1 less feature for whitelisting");
  }


  std::vector<CellState> skippedCB (numCells);
  std::atomic<uint32_t> bcount{0};
  tbb::atomic<double> totalDedupCounts{0.0};
k3yavi's avatar
k3yavi 已提交
  tbb::atomic<uint32_t> totalExpGeneCounts{0};
k3yavi's avatar
k3yavi 已提交
  std::atomic<uint64_t> totalBiEdgesCounts{0};
  std::atomic<uint64_t> totalUniEdgesCounts{0};

  std::vector<std::thread> workerThreads;
  for (size_t tn = 0; tn < numWorkerThreads; ++tn) {
    workerThreads.emplace_back(optimizeCell,
                               std::ref(trueBarcodes),
                               std::ref(bcount),
                               numCells,
                               std::ref(fullEqMap),
                               std::ref(orderedTgroup),
                               std::ref(aopt.jointLog),
                               std::ref(aopt.outputDirectory),
                               std::ref(umiCount),
                               std::ref(skippedCB),
                               aopt.dumpBarcodeEq,
                               aopt.protocol.umiLength,
k3yavi's avatar
k3yavi 已提交
                               aopt.quiet,
k3yavi's avatar
k3yavi 已提交
                               std::ref(totalDedupCounts),
k3yavi's avatar
k3yavi 已提交
                               std::ref(totalExpGeneCounts),
                               std::ref(txpToGeneMap),
                               numGenes,
k3yavi's avatar
k3yavi 已提交
                               aopt.numBootstraps,
k3yavi's avatar
k3yavi 已提交
                               aopt.naiveEqclass,
k3yavi's avatar
k3yavi 已提交
                               aopt.dumpUmiGraph,
                               aopt.dumpfeatures,
k3yavi's avatar
k3yavi 已提交
                               aopt.initUniform,
                               std::ref(rRnaGenes),
                               std::ref(mRnaGenes),
k3yavi's avatar
k3yavi 已提交
                               std::ref(totalUniEdgesCounts),
                               std::ref(totalBiEdgesCounts));
  }

  for (auto& t : workerThreads) {
    t.join();
  }
  aopt.jointLog->info("Total {0:.2f} UMI after deduplicating.",
k3yavi's avatar
k3yavi 已提交
                      totalDedupCounts);
k3yavi's avatar
k3yavi 已提交
  aopt.jointLog->info("Total {} BiDirected Edges.",
                      totalBiEdgesCounts);
  aopt.jointLog->info("Total {} UniDirected Edges.",
                      totalUniEdgesCounts);

k3yavi's avatar
k3yavi 已提交
  //adjusting for float
  aopt.totalDedupUMIs = totalDedupCounts+1;
k3yavi's avatar
k3yavi 已提交
  aopt.totalExpGenes = totalExpGeneCounts;

  uint32_t skippedCBcount {0};
  for(auto cb: skippedCB){
    if (cb.inActive) {
      skippedCBcount += 1;
    }
  }

  if( skippedCBcount > 0 ) {
    aopt.jointLog->warn("Skipped {} barcodes due to No mapped read",
                        skippedCBcount);
k3yavi's avatar
k3yavi 已提交
    auto lowRegionCutoffIdx = numCells - numLowConfidentBarcode;
k3yavi's avatar
k3yavi 已提交
    std::vector<std::string> retainedTrueBarcodes ;
    for (size_t idx=0; idx < numCells; idx++){
k3yavi's avatar
k3yavi 已提交
      // not very efficient way but assuming the size is small enough
      if (skippedCB[idx].inActive) {
        if (idx > lowRegionCutoffIdx){
          numLowConfidentBarcode--;
        }
k3yavi's avatar
k3yavi 已提交
      } else {
        retainedTrueBarcodes.emplace_back(trueBarcodes[idx]);
k3yavi's avatar
k3yavi 已提交

    trueBarcodes = retainedTrueBarcodes;
k3yavi's avatar
k3yavi 已提交
    numCells = trueBarcodes.size();
k3yavi's avatar
k3yavi 已提交
  std::vector<std::string> geneNames(numGenes);
  for (auto geneIdx : geneIdxMap) {
    geneNames[geneIdx.second] = geneIdx.first;
k3yavi's avatar
k3yavi 已提交
  boost::filesystem::path gFilePath = aopt.outputDirectory / "quants_mat_cols.txt";
  std::ofstream gFile(gFilePath.string());
  std::ostream_iterator<std::string> giterator(gFile, "\n");
  std::copy(geneNames.begin(), geneNames.end(), giterator);
  gFile.close();
  if( not hasWhitelist ){
    aopt.jointLog->info("Clearing EqMap; Might take some time.");
    fullEqMap.clear();
    if ( numLowConfidentBarcode < aopt.lowRegionMinNumBarcodes ) {
      aopt.jointLog->warn("Num Low confidence barcodes too less {} < {}."
                          "Can't performing whitelisting; Skipping",
                          numLowConfidentBarcode,
                          aopt.lowRegionMinNumBarcodes);
    } else {
      aopt.jointLog->info("Starting white listing of {} cells", trueBarcodes.size());
      bool whitelistingSuccess = alevin::whitelist::performWhitelisting(aopt,
                                                                        umiCount,
                                                                        trueBarcodes,
                                                                        freqCounter,
                                                                        useRibo,
                                                                        useMito,
                                                                        numLowConfidentBarcode);
      if (!whitelistingSuccess) {
        aopt.jointLog->error(
                             "The white listing algorithm failed. This is likely the result of "
                             "bad input (or a bug). If you cannot track down the cause, please "
                             "report this issue on GitHub.");
        aopt.jointLog->flush();
        return false;
      }

      aopt.jointLog->info("Finished white listing");
      aopt.jointLog->flush();
k3yavi's avatar
k3yavi 已提交
    }
  } //end-if whitelisting
k3yavi's avatar
k3yavi 已提交
  if (aopt.dumpMtx){
    aopt.jointLog->info("Starting dumping cell v gene counts in mtx format");
    boost::filesystem::path qFilePath = aopt.outputDirectory / "quants_mat.mtx.gz";

    boost::iostreams::filtering_ostream qFile;
    qFile.push(boost::iostreams::gzip_compressor(6));
    qFile.push(boost::iostreams::file_sink(qFilePath.string(),
                                           std::ios_base::out | std::ios_base::binary));

    // mtx header
    qFile << "%%MatrixMarket\tmatrix\tcoordinate\treal\tgeneral" << std::endl
          << numCells << "\t" << numGenes << "\t" << totalExpGeneCounts << std::endl;

    {
k3yavi's avatar
k3yavi 已提交

      auto popcount = [](uint8_t n) {
        size_t count {0};
        while (n) {
          n &= n-1;
          ++count;
        }
        return count;
      };

      uint32_t zerod_cells {0};
k3yavi's avatar
k3yavi 已提交
      size_t numFlags = std::ceil(numGenes/8);
k3yavi's avatar
k3yavi 已提交
      std::vector<uint8_t> alphasFlag (numFlags, 0);
      size_t flagSize = sizeof(decltype(alphasFlag)::value_type);

      std::vector<float> alphasSparse;
      alphasSparse.reserve(numFlags/2);
      size_t elSize = sizeof(decltype(alphasSparse)::value_type);

      auto countMatFilename = aopt.outputDirectory / "quants_mat.gz";
      if(not boost::filesystem::exists(countMatFilename)){
k3yavi's avatar
k3yavi 已提交
        std::cout<<"ERROR: Can't import Binary file quants.mat.gz, it doesn't exist" << std::flush;
      }

      boost::iostreams::filtering_istream countMatrixStream;
      countMatrixStream.push(boost::iostreams::gzip_decompressor());
      countMatrixStream.push(boost::iostreams::file_source(countMatFilename.string(),
                                                           std::ios_base::in | std::ios_base::binary));

      for (size_t cellCount=0; cellCount<numCells; cellCount++){
k3yavi's avatar
k3yavi 已提交
        countMatrixStream.read(reinterpret_cast<char*>(alphasFlag.data()), flagSize * numFlags);

        size_t numExpGenes {0};
        std::vector<size_t> indices;
        for (size_t j=0; j<alphasFlag.size(); j++) {
          uint8_t flag = alphasFlag[j];
          size_t numNonZeros = popcount(flag);
          numExpGenes += numNonZeros;

          for (size_t i=0; i<8; i++){
            if (flag & (128 >> i)) {
k3yavi's avatar
k3yavi 已提交
              indices.emplace_back( i+(8*j) );
k3yavi's avatar
k3yavi 已提交
        if (indices.size() != numExpGenes) {
          aopt.jointLog->error("binary format reading error {}: {}: {}",
                               indices.size(), numExpGenes);
          aopt.jointLog->flush();
          exit(84);
k3yavi's avatar
k3yavi 已提交

        alphasSparse.clear();
        alphasSparse.resize(numExpGenes);
        countMatrixStream.read(reinterpret_cast<char*>(alphasSparse.data()), elSize * numExpGenes);

        float readCount {0.0};
        readCount += std::accumulate(alphasSparse.begin(), alphasSparse.end(), 0.0);

        for(size_t i=0; i<numExpGenes; i++) {
          qFile << std::fixed
                << cellCount + 1 << "\t"
                << indices[i] + 1 << "\t"
k3yavi's avatar
k3yavi 已提交
                << alphasSparse[i] <<  std::endl;
        }

        //size_t alphasSparseCounter {0};
        //for (size_t i=0; i<numGenes; i+=8) {
        //  uint8_t flag = alphasFlag[i];
        //  for (size_t j=0; j<8; j++) {
        //    size_t vectorIndex = i+j;
        //    if (vectorIndex >= numGenes) { break; }

        //    if ( flag & (1<<(7-j)) ) {
        //      if (alphasSparseCounter >= numExpGenes) {
        //        aopt.jointLog->error("binary format reading error {}: {}: {}",
        //                             alphasSparseCounter, numExpGenes, readCount);
        //        aopt.jointLog->flush();
        //        exit(84);
        //      }

        //      float count = alphasSparse[alphasSparseCounter];
        //      readCount += count;
        //      qFile << cellCount+1 << "\t"
        //            << vectorIndex+1 << "\t"
        //            << count << std::endl;

        //      ++alphasSparseCounter;
        //    }
        //  }
        //}

        if (readCount == 0.0){
          zerod_cells += 1;
        }
      } // end-for each cell

      if (zerod_cells > 0) {
        aopt.jointLog->warn("Found {} cells with 0 counts", zerod_cells);
    boost::iostreams::close(qFile);
k3yavi's avatar
k3yavi 已提交
    aopt.jointLog->info("Finished dumping counts into mtx");
  }
  return true;
} //end-optimize


namespace apt = alevin::protocols;
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,
                                      AlevinOpts<apt::DropSeq>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

                                      AlevinOpts<apt::InDrop>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

k3yavi's avatar
k3yavi 已提交
                                      AlevinOpts<apt::ChromiumV3>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

                                      AlevinOpts<apt::Chromium>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

k3yavi's avatar
k3yavi 已提交
                                      AlevinOpts<apt::Gemcode>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

                                      AlevinOpts<apt::CELSeq>& aopt,
k3yavi's avatar
k3yavi 已提交
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

k3yavi's avatar
k3yavi 已提交
                                      AlevinOpts<apt::CELSeq2>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
                                      CFreqMapT& freqCounter,
                                      size_t numLowConfidentBarcode);
template
k3yavi's avatar
k3yavi 已提交
bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap,
                                      spp::sparse_hash_map<uint32_t, uint32_t>& txpToGeneMap,
                                      spp::sparse_hash_map<std::string, uint32_t>& geneIdxMap,

                                      AlevinOpts<apt::Custom>& aopt,
                                      GZipWriter& gzw,
                                      std::vector<std::string>& trueBarcodes,
                                      std::vector<uint32_t>& umiCount,
k3yavi's avatar
k3yavi 已提交
                                      CFreqMapT& freqCounter,
k3yavi's avatar
k3yavi 已提交
                                      size_t numLowConfidentBarcode);