from collections import defaultdict, Counter
from dataclasses import dataclass, field
from time import strftime
from tqdm import tqdm
from multiprocessing import Pool
from LAFITE.utils import loc_distance, Vividict
@ dataclass
class RawAttributeCollection:
name: str
start: int
end: int
fsm: bool = False
multi_exon: bool = True
correct_site: list = field(default_factory=list)
merge_gap: list = field(default_factory=list)
polyaed: bool = False
lowCredit_junction: dict = field(default_factory=dict)
splicing_tag: list = field(default_factory=list)
rss_dis: int = None
res_dis: int = None
collapsed_ID: str = None
class ReadCorrectionColappse:
def __init__(self, chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, polya_dict, corExcept_dis=0):
self.chrom = chrom
self.strand = strand
self.chrand_processed_read = chrand_processed_read
self.chrand_ref_exon = chrand_ref_exon
self.chrand_ref_junction = chrand_ref_junction
self.chrand_ref_single_exon_trans = chrand_ref_single_exon_trans
self.chrand_ref_mutple_exon_trans = chrand_ref_mutple_exon_trans
self.chrand_left_sj_set = chrand_left_sj_set
self.chrand_right_sj_set = chrand_right_sj_set
self.chrand_junction_dict = chrand_junction_dict
self.sj_correction_window = sj_correction_window
self.mis_intron_length = mis_intron_length
self.polya_dict = polya_dict
self.corExcept_dis = corExcept_dis
def single_exon_read_collapse(self, read, single_exon_read):
"""
remove the single-exon reads overlaped with exon from reference multi-exon transcript and collapse"""
start, end = read
if self.chrand_ref_exon:
overlapped_ref_exon = tuple(self.chrand_ref_exon.find(read))
if overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[0] <= start+self.sj_correction_window and exon[1] >= end-self.sj_correction_window:
read = []
break
if read:
counter=Counter(range(start,end+1))
for i in counter:
if i in single_exon_read:
single_exon_read[i] += counter[i]
else:
single_exon_read[i] = counter[i]
return single_exon_read
def RTS_refrence_distance(self, start, end, read_splicing):
"""calculate the distance between read start/end site to the reference transcript start/end site for FSM reads
Args:
strand (str): strand information
start (int): genomic start position regardless of strand
end (int): genomic end position regardless of strand
read_splicing (tuple): corrected read splicing
chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
"""
if self.strand == '+':
rss_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][0]-start)
res_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][1]-end)
else:
rss_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][0]-end)
res_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][1]-start)
return rss_dis, res_dis
def multi_exon_read_correction(self, read_id, full_block):
"""
splicing junction correction and collaspsing for multi-exon read"""
raw_splicing= tuple(full_block[1:-1])
raw_read_attribute = RawAttributeCollection(read_id, full_block[0], full_block[-1])
corrected_read_splicing = []
# polya event checking
try:
if self.polya_dict[read_id]: raw_read_attribute.polyaed = True
except:
pass
if raw_splicing in self.chrand_ref_mutple_exon_trans: # raw read_splicing matching with the reference
raw_read_attribute.fsm = True
raw_read_attribute.splicing_tag = 'FSM'
corrected_read_splicing = raw_splicing
raw_read_attribute.rss_dis, raw_read_attribute.res_dis = self.RTS_refrence_distance(raw_read_attribute.start, raw_read_attribute.end, corrected_read_splicing)
else: # splicing site correction
itered_raw_splicing = iter(raw_splicing)
for idx, (left_sj, right_sj) in enumerate(zip(itered_raw_splicing, itered_raw_splicing)):
# check splicing coverage and motif in raw data
sj_pos = full_block.index(left_sj)
tmp_sj = (self.chrom, self.strand, left_sj, right_sj)
junction_coverage = self.chrand_junction_dict[tmp_sj][0]
junction_motif = self.chrand_junction_dict[tmp_sj][2]
if left_sj in self.chrand_left_sj_set and right_sj in self.chrand_right_sj_set:
if (left_sj, right_sj) in self.chrand_ref_junction:
raw_read_attribute.splicing_tag.append('M')
else:
raw_read_attribute.splicing_tag.append('KC')
elif self.chrand_left_sj_set:
left_dis, left_ref_sj = loc_distance(self.chrand_left_sj_set, left_sj)
right_dis, right_ref_sj = loc_distance(self.chrand_right_sj_set, right_sj)
# do not correct the splicing site once the edit distance > sj_correction_window
if left_dis > self.sj_correction_window: left_ref_sj = left_sj
if right_dis > self.sj_correction_window: right_ref_sj = right_sj
# correction exception, splicing junction with edit distance less than the given value for both sides
if self.corExcept_dis and left_dis <= self.corExcept_dis and right_dis <= self.corExcept_dis and junction_coverage > 1 and junction_motif == 'canonical':
raw_read_attribute.splicing_tag.append('EXC')
# splicing site correction
elif left_dis <= self.sj_correction_window or right_dis <= self.sj_correction_window:
if [left_sj, right_sj] == [left_ref_sj, right_ref_sj]:
pass
elif full_block[sj_pos-1] < left_ref_sj < right_ref_sj < full_block[sj_pos+2]:
raw_read_attribute.correct_site.append([left_sj, right_sj])
left_sj, right_sj = left_ref_sj, right_ref_sj
if left_sj in self.chrand_left_sj_set and right_sj in self.chrand_right_sj_set:
if (left_sj, right_sj) in self.chrand_ref_junction:
raw_read_attribute.splicing_tag.append('CM')
else:
raw_read_attribute.splicing_tag.append('CKC')
#checking unintended small intron overlap with reference exons
if len(raw_read_attribute.splicing_tag) == idx and right_sj - left_sj <= self.mis_intron_length:
if self.chrand_ref_exon:
overlapped_ref_exon = tuple(self.chrand_ref_exon.find((left_sj, right_sj)))
else:
overlapped_ref_exon = ()
# compare the unintended intron with the ref_exon from multi-exon transcripts
if idx == 0 and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if (len(full_block) == 4 and exon[1] >= right_sj and exon[0] <= left_sj):
raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
raw_read_attribute.multi_exon = False
break
elif (exon[1] == full_block[sj_pos+2] and exon[0] <= left_sj):
raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
break
elif idx == int(len(raw_splicing)/2 - 1) and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] >= right_sj and exon[0] == full_block[sj_pos-1]:
raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
break
elif 0 < idx < int(len(raw_splicing)/2 -1) and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] == full_block[sj_pos+2] and exon[0] == full_block[sj_pos-1]:
raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
break
# compare the unintended intron with the exon from single-exon transcripts
elif not overlapped_ref_exon and len(full_block) == 4:
if self.chrand_ref_single_exon_trans:
overlapped_ref_exon = tuple(self.chrand_ref_single_exon_trans.find((left_sj, right_sj)))
if overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] >= right_sj and exon[0] <= left_sj:
raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
raw_read_attribute.multi_exon = False
break
corrected_read_splicing.extend([left_sj, right_sj])
if [full_block[sj_pos], full_block[sj_pos+1]] in raw_read_attribute.merge_gap:
raw_read_attribute.splicing_tag.append('UI')
del corrected_read_splicing[-2:]
elif len(raw_read_attribute.splicing_tag) == idx:
raw_read_attribute.splicing_tag.append('NC')
if junction_coverage == 1:
raw_read_attribute.lowCredit_junction[idx+1] = junction_motif
corrected_read_splicing = tuple(corrected_read_splicing)
if corrected_read_splicing in self.chrand_ref_mutple_exon_trans:
raw_read_attribute.fsm = True
raw_read_attribute.rss_dis, raw_read_attribute.res_dis = self.RTS_refrence_distance(raw_read_attribute.start, raw_read_attribute.end, corrected_read_splicing)
return corrected_read_splicing, raw_read_attribute
def multi_exon_read_collapse(self, corrected_read_splicing, raw_read_attribute, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx):
"""collapsing multi-exon read
"""
prefix = 'POS' if self.strand == '+' else 'NEG'
if raw_read_attribute.lowCredit_junction:
pass
elif corrected_read_splicing:
if corrected_read_splicing not in multi_exon_read:
collapsed_idx += 1
raw_read_attribute.collapsed_ID = f'{self.chrom}_{prefix}.{collapsed_idx}'
multi_exon_read[corrected_read_splicing] = [[raw_read_attribute.start], [raw_read_attribute.end], [raw_read_attribute.polyaed], 1, raw_read_attribute.fsm, raw_read_attribute.collapsed_ID]
else:
multi_exon_read[corrected_read_splicing][0].insert(0, raw_read_attribute.start)
multi_exon_read[corrected_read_splicing][1].insert(0, raw_read_attribute.end)
multi_exon_read[corrected_read_splicing][2].insert(0,raw_read_attribute.polyaed)
multi_exon_read[corrected_read_splicing][3] += 1
raw_read_attribute.collapsed_ID = multi_exon_read[corrected_read_splicing][5]
if raw_read_attribute.rss_dis:
rss_dis_lst.append(raw_read_attribute.rss_dis)
res_dis_lst.append(raw_read_attribute.res_dis)
return multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx, raw_read_attribute
def coco_operation (self, collapsed_idx = 0):
"""main function for correcting splicing junction and collpasing reads
"""
corrected_read = defaultdict(dict)
correction_log = defaultdict(dict)
single_exon_read = defaultdict(dict)
multi_exon_read = defaultdict(dict)
rss_dis_lst = []
res_dis_lst = []
for read_id, full_block in tqdm(self.chrand_processed_read.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {self.chrom} {self.strand}'):
# single-exon read collapsing
if len(full_block) == 2:
single_exon_read = self.single_exon_read_collapse(full_block, single_exon_read)
corrected_read[read_id] = full_block
# multi-exon read correction and collapsing
else:
corrected_read_splicing, raw_read_attribute = self.multi_exon_read_correction(read_id, full_block)
multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx, raw_read_attribute = self.multi_exon_read_collapse(corrected_read_splicing, raw_read_attribute, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx)
if corrected_read_splicing:
corrected_read[read_id] = [raw_read_attribute.start, *list(corrected_read_splicing), raw_read_attribute.end]
else:
corrected_read[read_id] = [raw_read_attribute.start, raw_read_attribute.end]
correction_log[read_id] = raw_read_attribute
return self.chrom, self.strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst
gtf = '/expt/zjzace/Nanopore_subcellular/Reference/gencode.v38.primary_assembly.annotation.sorted.gtf'
ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = RefProcessWrapper(gtf, 16).result_collection()
bed = '/expt/zjzace/Nanopore_subcellular/Analysis/Assembly/LAFITE/GLRA_tmp/A549_Cyto_LAFITE_tmp/bam.bed'
fa = '/NFS/mnemosyne2/expt/zjzace/GenomeRef/GRCh38.primary_assembly.genome.fa'
junction_dict, processed_read = read_grouping(bed, fa)
polya_dict = polya_signal_import('/expt/zjzace/Nanopore_subcellular/Analysis/Nanopolish/A549_Cyto_PolyA.res')
chrom, strand = 'chr19', '-'
chrand_processed_read = processed_read[(chrom, strand)]
chrand_ref_exon = ref_exon[(chrom, strand)]
chrand_ref_junction = ref_junction[(chrom, strand)]
chrand_ref_single_exon_trans = ref_single_exon_trans[(chrom, strand)]
chrand_ref_mutple_exon_trans = ref_mutple_exon_trans[(chrom,strand)]
chrand_left_sj_set = left_sj_set[(chrom,strand)]
chrand_right_sj_set = right_sj_set[(chrom,strand)]
chrand_junction_dict = junction_dict[(chrom,strand)]
sj_correction_window = 40
mis_intron_length = 150
corExcept_dis=4
polya_dict=polya_dict
chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = ReadCorrectionColappse(chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, polya_dict, corExcept_dis).coco_operation()
class CoCoWrapper:
def __init__(self, thread, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, sj_correction_window, polya_dict, mis_intron_length, tmp_dir,corExcept_dis=0):
self.thread = thread
self.processed_read = processed_read
self.ref_exon = ref_exon
self.ref_junction = ref_junction
self.ref_single_exon_trans = ref_single_exon_trans
self.ref_mutple_exon_trans = ref_mutple_exon_trans
self.left_sj_set = left_sj_set
self.right_sj_set = right_sj_set
self.junction_dict = junction_dict
self.sj_correction_window = sj_correction_window
self.polya_dict = polya_dict
self.mis_intron_length = mis_intron_length
self.tmp_dir = tmp_dir
self.corExcept_dis = corExcept_dis
def job_compute(self):
job = []
for branch in self.processed_read:
chrom, strand = branch
chrand_processed_read = self.processed_read[branch]
chrand_ref_exon = self.ref_exon[branch]
chrand_ref_junction = self.ref_junction[branch]
chrand_ref_single_exon_trans = self.ref_single_exon_trans[branch]
chrand_ref_mutple_exon_trans = self.ref_mutple_exon_trans[branch]
chrand_left_sj_set = self.left_sj_set[branch]
chrand_right_sj_set = self.right_sj_set[branch]
chrand_junction_dict = self.junction_dict[branch]
job.append(ReadCorrectionColappse(chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, self.sj_correction_window, self.mis_intron_length, self.polya_dict, self.corExcept_dis))
p = Pool(processes = self.thread)
result = [p.apply_async(i.coco_operation, args=()) for i in job]
p.close()
p.join()
return result
def result_collection(self):
collected_single_exon_read = Vividict()
collected_multi_exon_read = Vividict()
collected_rss = []
collected_res = []
path_to_log = f'{self.tmp_dir}/read_correction.log'
path_to_corrected_bed = f'{self.tmp_dir}/Corrected_reads.bed'
result = self.job_compute()
with open(path_to_log, 'w') as flog, open(path_to_corrected_bed, 'w') as fbed:
for res in result:
chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = res.get()
collected_single_exon_read[(chrom, strand)] = single_exon_read
collected_multi_exon_read[(chrom, strand)] = multi_exon_read
collected_rss.extend(rss_dis_lst)
collected_res.extend(res_dis_lst)
for read_id, raw_read_attribute in correction_log.items():
raw_read_attribute.name = raw_read_attribute.name.split('_', 1)[1]
attributes = '\t'.join('{}: {}'.format(key, value) for key, value in raw_read_attribute.__dict__.items())
flog.write(f'{attributes}\n')
for read_id, full_block in corrected_read.items():
read_name = read_id.split('_', 1)[1]
bed_block = splicing_to_bed_block(chrom, strand, read_name, full_block)
fbed.write(f'{bed_block}\n')
return collected_single_exon_read, collected_multi_exon_read, collected_rss, collected_res
from LAFITE.reference_processing import RefProcessWrapper, short_reads_sj_import
from LAFITE.preprocessing import read_grouping, polya_signal_import, PolyAFinder
from LAFITE.utils import temp_dir_creation, bam2bed, keep_tmp_file
gtf = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/Raw_data/SIRV_isoforms_multi-fasta-annotation_C_170612a.gtf'
bed = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/bam/SRR6058584.sorted.bed'
fa = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/Raw_data/SIRV_isoforms_multi-fasta_170612a.fasta'
junction_dict, processed_read = read_grouping(bed, fa)
ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = RefProcessWrapper(gtf, 16).result_collection()
# polya_dict = polya_signal_import('/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/bam/SRR6058584.polya.res')
polya_dict = PolyAFinder(processed_read, fa, '/home/zjzace/software/SQANTI3-4.1/data/polyA_motifs/mouse_and_human.polyA_motif.txt').polya_estimation()
collected_single_exon_read, collected_multi_exon_read, tss, tes = CoCoWrapper(16, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, sj_correction_window=40, polya_dict=polya_dict, mis_intron_length = 150, tmp_dir='.',corExcept_dis=4).result_collection()
def single_exon_read_collapse(read, chrand_ref_exon, overhang, single_exon_read):
"""remove the single-exon reads overlaped with exon from reference multi-exon transcript and collapse
Args:
read (list): start and end position of the single-exon read (1 base, [1,100])
chrand_ref_exon (interlap data): exon from reference multi-exon transcript
overhang (int): tolerance distance
single_exon_read (dict): returned collapsed single-exon reads
"""
start, end = read
if chrand_ref_exon:
overlapped_ref_exon = tuple(chrand_ref_exon.find(read))
if overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[0] <= start+overhang and exon[1] >= end-overhang:
read = []
break
if read:
counter=Counter(range(start,end+1))
for i in counter:
if i in single_exon_read:
single_exon_read[i] += counter[i]
else:
single_exon_read[i] = counter[i]
return single_exon_read
def RTS_refrence_distance(strand, start, end, read_splicing, chrand_ref_mutple_exon_trans):
"""calculate the distance between read start/end site to the reference transcript start/end site for FSM reads
Args:
strand (str): strand information
start (int): genomic start position regardless of strand
end (int): genomic end position regardless of strand
read_splicing (tuple): corrected read splicing
chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
"""
if strand == '+':
rss_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][0]-start)
res_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][1]-end)
else:
rss_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][0]-end)
res_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][1]-start)
return rss_dis, res_dis
def multi_exon_read_correction(chrom, strand, name, full_block, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, corExcept_dis, polya_dict):
"""splicing junction correction and collaspsing for multi-exon read
Args:
chrom (str): chromosome
strand (str): strand information
name (str): read name
full_block (list): start, end position and all splicing site of multi-exon read (1 base, [1,20,40, 100])
chrand_ref_exon (interlap data): exon from reference multi-exon transcript
chrand_ref_junction (list): reference splicing junction
chrand_ref_single_exon_trans (interlap data): reference single-exon transcript
chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
chrand_left_sj_set (list): reference left splicing site
chrand_right_sj_set (list): reference right splicing site
chrand_junction_dict (dict): splicing junction detected from long read
sj_correction_window (int): tolerance distance for splicing site correction
mis_intron_length (int): unintended small intron gap that should be filled
corExcept_dis (int, optional): edit distance to the reference splicing site
polya_dict (dict, optional): raw long reads Polyadenylation event
"""
raw_splicing= tuple(full_block[1:-1])
start, end = full_block[0], full_block[-1]
corrected_read_splicing = []
tag_dict={'reference_match':False, 'multi-exon':True, 'correct_site':[], 'merge_site':[], 'polya_signal':False , 'lowCredit_junction':{}, 'splicing_tag':[], 'rss_dis':None, 'res_dis':None}
# polya event checking
try:
if polya_dict[name]: tag_dict['polya_signal'] = True
except:
pass
if raw_splicing in chrand_ref_mutple_exon_trans: # uncorrected read_splicing matching with the reference
tag_dict['reference_match'] = True
tag_dict['splicing_tag'] = 'FSM'
corrected_read_splicing = raw_splicing
tag_dict['rss_dis'], tag_dict['res_dis'] = RTS_refrence_distance(strand, start, end, corrected_read_splicing, chrand_ref_mutple_exon_trans)
else: # splicing site correction
itered_raw_splicing = iter(raw_splicing)
for idx, (left_sj, right_sj) in enumerate(zip(itered_raw_splicing, itered_raw_splicing)):
# check splicing coverage and motif in raw data
sj_pos = full_block.index(left_sj)
tmp_sj = (chrom,strand,left_sj,right_sj)
junction_coverage = chrand_junction_dict[tmp_sj][0]
junction_motif = chrand_junction_dict[tmp_sj][2]
if left_sj in chrand_left_sj_set and right_sj in chrand_right_sj_set:
if (left_sj, right_sj) in chrand_ref_junction:
tag_dict['splicing_tag'].append('M')
else:
tag_dict['splicing_tag'].append('KC')
elif chrand_left_sj_set:
left_dis, left_ref_sj = loc_distance(chrand_left_sj_set, left_sj)
right_dis, right_ref_sj = loc_distance(chrand_right_sj_set, right_sj)
# do not correct the splicing site once the edit distance > sj_correction_window
if left_dis > sj_correction_window: left_ref_sj = left_sj
if right_dis > sj_correction_window: right_ref_sj = right_sj
# correction exception, splicing junction with edit distance less than 4 for both sides
if left_dis <= corExcept_dis and right_dis <= corExcept_dis and junction_coverage > 1 and junction_motif == 'canonical':
tag_dict['splicing_tag'].append('EXC')
# splicing site correction
elif left_dis <= sj_correction_window or right_dis <= sj_correction_window:
if [left_sj, right_sj] == [left_ref_sj, right_ref_sj]:
pass
elif full_block[sj_pos-1] < left_ref_sj < right_ref_sj < full_block[sj_pos+2]:
tag_dict['correct_site'].append([left_sj, right_sj])
left_sj, right_sj = left_ref_sj, right_ref_sj
if left_sj in chrand_left_sj_set and right_sj in chrand_right_sj_set:
if (left_sj, right_sj) in chrand_ref_junction:
tag_dict['splicing_tag'].append('CM')
else:
tag_dict['splicing_tag'].append('CKC')
#checking unintended small intron overlap with reference exons
if len(tag_dict['splicing_tag']) == idx and right_sj - left_sj <= mis_intron_length:
if chrand_ref_exon:
overlapped_ref_exon = tuple(chrand_ref_exon.find((left_sj, right_sj)))
else:
overlapped_ref_exon = ()
# compare the unintended intron with the ref_exon from multi-exon transcripts
if idx == 0 and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if (len(full_block) == 4 and exon[1] >= right_sj and exon[0] <= left_sj):
tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
tag_dict['multi-exon'] = False
break
elif (exon[1] == full_block[sj_pos+2] and exon[0] <= left_sj):
tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
break
elif idx == int(len(raw_splicing)/2 - 1) and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] >= right_sj and exon[0] == full_block[sj_pos-1]:
tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
break
elif 0 < idx < int(len(raw_splicing)/2 -1) and overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] == full_block[sj_pos+2] and exon[0] == full_block[sj_pos-1]:
tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
break
# compare the unintended intron with the exon from single-exon transcripts
elif not overlapped_ref_exon and len(full_block) == 4:
if chrand_ref_single_exon_trans:
overlapped_ref_exon = tuple(chrand_ref_single_exon_trans.find((left_sj, right_sj)))
if overlapped_ref_exon:
for exon in overlapped_ref_exon:
if exon[1] >= right_sj and exon[0] <= left_sj:
tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
tag_dict['multi-exon'] = False
break
corrected_read_splicing.extend([left_sj, right_sj])
if [full_block[sj_pos], full_block[sj_pos+1]] in tag_dict['merge_site']:
tag_dict['splicing_tag'].append('UI')
del corrected_read_splicing[-2:]
elif len(tag_dict['splicing_tag']) == idx:
tag_dict['splicing_tag'].append('NC')
if junction_coverage == 1:
tag_dict['lowCredit_junction'][idx+1] = junction_motif
corrected_read_splicing = tuple(corrected_read_splicing)
if corrected_read_splicing in chrand_ref_mutple_exon_trans:
tag_dict['reference_match'] = True
tag_dict['rss_dis'], tag_dict['res_dis'] = RTS_refrence_distance(strand, start, end, corrected_read_splicing, chrand_ref_mutple_exon_trans)
return start, end, corrected_read_splicing, tag_dict
def multi_exon_read_collapse(read_id, start, end, corrected_read_splicing, tag_dict, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx):
"""collapsing multi-exon read
"""
read_name = read_id.split('_', 1)[1]
if tag_dict['lowCredit_junction']:
pass
elif corrected_read_splicing:
if corrected_read_splicing not in multi_exon_read:
collapsed_idx += 1
multi_exon_read[corrected_read_splicing] = [[start], [end], [tag_dict['polya_signal']], 1, [read_name], tag_dict['reference_match'], f'collapsed.{collapsed_idx}']
else:
multi_exon_read[corrected_read_splicing][0].insert(0, start)
multi_exon_read[corrected_read_splicing][1].insert(0, end)
multi_exon_read[corrected_read_splicing][2].insert(0,tag_dict['polya_signal'])
multi_exon_read[corrected_read_splicing][3] += 1
multi_exon_read[corrected_read_splicing][4].insert(0,read_name)
if tag_dict['rss_dis']:
rss_dis_lst.append(tag_dict['rss_dis'])
res_dis_lst.append(tag_dict['res_dis'])
return multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx
def three_prime_exon_extraction(strand, multi_exon_read):
three_prime_exon = defaultdict(set)
for corrected_read_splicing, read_info in multi_exon_read.items():
if strand == '+':
max_end = max(read_info[1])
exon = (corrected_read_splicing[-1], max_end)
else:
max_end = max(read_info[0])
exon = (max_end, corrected_read_splicing[0])
three_prime_exon.add(exon)
return three_prime_exon
def coco_operation (chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window = 40, mis_intron_length = 150, corExcept_dis = 0, polya_dict=None, collapsed_idx = 0):
"""main function for correcting splicing junction and collpasing reads
"""
corrected_read = defaultdict(dict)
correction_log = defaultdict(dict)
single_exon_read = defaultdict(dict)
multi_exon_read = defaultdict(dict)
rss_dis_lst = res_dis_lst = []
for read_id, full_block in tqdm(chrand_processed_read.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {chrom} {strand}'):
# single-exon read collapsing
if len(full_block) == 2:
single_exon_read = single_exon_read_collapse(full_block, chrand_ref_exon, sj_correction_window, single_exon_read)
corrected_read[read_id] = full_block
# multi-exon read correction and collapsing
else:
start, end, corrected_read_splicing, tag_dict = multi_exon_read_correction(chrom, strand, read_id, full_block, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, corExcept_dis, polya_dict)
multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx = multi_exon_read_collapse(read_id, start, end, corrected_read_splicing, tag_dict, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx)
if corrected_read_splicing:
corrected_read[read_id] = [start, *list(corrected_read_splicing), end]
else:
corrected_read[read_id] = [start, end]
correction_log[read_id] = tag_dict
return chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst
def splicing_to_bed_block(chrom, strand, name, full_block):
start, end = full_block[0], full_block[-1]
full_block = iter(full_block)
block_sizes = []
block_starts = []
for left_end, right_end in zip(full_block, full_block):
block_starts.append(left_end - start)
block_sizes.append(right_end - left_end + 1)
block_count = len(block_sizes)
block_sizes = ','.join([str(i) for i in block_sizes])
block_starts = ','.join([str(i) for i in block_starts])
bed_block = [chrom, start-1, end, name, '-', strand, start-1, end, '255,0,0', block_count, block_sizes, block_starts]
bed_block = '\t'.join([str(i) for i in bed_block])
return bed_block
class CoCoWrapper:
def __init__(self, thread, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, tmp_dir, sj_correction_window, polya_dict, mis_intron_length, corExcept_dis=4):
self.thread = thread
self.processed_read = processed_read
self.ref_exon = ref_exon
self.ref_junction = ref_junction
self.ref_single_exon_trans = ref_single_exon_trans
self.ref_mutple_exon_trans = ref_mutple_exon_trans
self.left_sj_set = left_sj_set
self.right_sj_set = right_sj_set
self.junction_dict = junction_dict
self.sj_correction_window = sj_correction_window
self.tmp_dir = tmp_dir
self.polya_dict = polya_dict
self.mis_intron_length = mis_intron_length
self.corExcept_dis = corExcept_dis
def job_compute(self):
p = Pool(processes = self.thread)
result = []
for branch in self.processed_read:
chrom, strand = branch
chrand_processed_read = self.processed_read[branch]
chrand_ref_exon = self.ref_exon[branch]
chrand_ref_junction = self.ref_junction[branch]
chrand_ref_single_exon_trans = self.ref_single_exon_trans[branch]
chrand_ref_mutple_exon_trans = self.ref_mutple_exon_trans[branch]
chrand_left_sj_set = self.left_sj_set[branch]
chrand_right_sj_set = self.right_sj_set[branch]
chrand_junction_dict = self.junction_dict[branch]
result.append(p.apply_async(coco_operation, (chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, self.sj_correction_window, self.mis_intron_length, self.corExcept_dis, self.polya_dict,)))
p.close()
p.join()
return result
def result_collection(self):
collected_single_exon_read = Vividict()
collected_multi_exon_read = Vividict()
collected_rss = []
collected_res = []
path_to_log = f'{self.tmp_dir}/read_correction.log'
path_to_corrected_bed = f'{self.tmp_dir}/Corrected_reads.bed'
result = self.job_compute()
with open(path_to_log, 'w') as flog, open(path_to_corrected_bed, 'w') as fbed:
for res in result:
chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = res.get()
collected_single_exon_read[(chrom, strand)] = single_exon_read
collected_multi_exon_read[(chrom, strand)] = multi_exon_read
collected_rss.extend(rss_dis_lst)
collected_res.extend(res_dis_lst)
for read_id, tag_dict in correction_log.items():
read_name = read_id.split('_', 1)[1]
attributes = '\t'.join('{}: {}'.format(key, value) for key, value in tag_dict.items())
flog.write(f'{read_name}\t{attributes}\n')
for read_id, full_block in corrected_read.items():
read_name = read_id.split('_', 1)[1]
bed_block = splicing_to_bed_block(chrom, strand, read_name, full_block)
fbed.write(f'{bed_block}\n')
return collected_single_exon_read, collected_multi_exon_read, collected_rss, collected_res
# chrand_read_list1['adcc854a-2c55-432c-9d65-11ca1d8c9eb4'] = chrand_read_list['adcc854a-2c55-432c-9d65-11ca1d8c9eb4']
# # chrand_read_list1['adcc854a-2c55-432c-9d65-11ca1d8c9eb4'] = [167500,167519, 167584, 167610]
# for (chrom, strand) in processed_read:
# chrand_read_list = processed_read[(chrom, strand)]
# chrand_ref_exon = ref_exon[(chrom, strand)]
# chrand_ref_junction = ref_junction[(chrom, strand)]
# chrand_ref_single_exon_trans = ref_single_exon_trans[(chrom, strand)]
# chrand_ref_mutple_exon_trans = ref_mutple_exon_trans[(chrom,strand)]
# chrand_left_sj_set = left_sj_set[(chrom,strand)]
# chrand_right_sj_set = right_sj_set[(chrom,strand)]
# chrand_junction_dict = junction_dict[(chrom,strand)]
# sj_correction_window = 40
# mis_intron_length = 150
# corExcept_dis=4
# polya_dict=None
# single_exon_read = {}
# tmp_dir='./'
# # with open ('test.out', 'w') as fw:
# # for read, splicing in tqdm(chrand_read_list.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {chrom} {strand}'):
# # if len(splicing) == 2:
# # single_exon_read_collapse(splicing, chrand_ref_exon, 40, single_exon_ead)
# # if len(splicing) > 2:
# single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = coco_operation(chrom, strand, chrand_read_list, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window)
collected_single_exon_read, collected_multi_exon_read = CoCoWrapper(32, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, tmp_dir='.', sj_correction_window=40, mis_intron_length = 150, corExcept_dis=4, polya_dict=None).result_collection()