from collections import defaultdict, Counter
from dataclasses import dataclass, field
from time import strftime
from tqdm import tqdm
from multiprocessing import Pool
from LAFITE.utils import loc_distance, Vividict

class RawAttributeCollection[source]

RawAttributeCollection(name:str, start:int, end:int, fsm:bool=False, multi_exon:bool=True, correct_site:list=<factory>, merge_gap:list=<factory>, polyaed:bool=False, lowCredit_junction:dict=<factory>, splicing_tag:list=<factory>, rss_dis:int=None, res_dis:int=None, collapsed_ID:str=None)

RawAttributeCollection(name: str, start: int, end: int, fsm: bool = False, multi_exon: bool = True, correct_site: list = , merge_gap: list = , polyaed: bool = False, lowCredit_junction: dict = , splicing_tag: list = , rss_dis: int = None, res_dis: int = None, collapsed_ID: str = None)

@ dataclass
class RawAttributeCollection:
	name: str
	start: int
	end: int
	fsm: bool = False
	multi_exon: bool = True
	correct_site: list = field(default_factory=list)
	merge_gap: list = field(default_factory=list)
	polyaed: bool = False
	lowCredit_junction: dict = field(default_factory=dict)
	splicing_tag: list = field(default_factory=list)
	rss_dis: int = None
	res_dis: int = None
	collapsed_ID: str = None

class ReadCorrectionColappse[source]

ReadCorrectionColappse(chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, polya_dict, corExcept_dis=0)

class ReadCorrectionColappse:
	def __init__(self, chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, polya_dict, corExcept_dis=0):
		self.chrom = chrom
		self.strand = strand
		self.chrand_processed_read = chrand_processed_read
		self.chrand_ref_exon = chrand_ref_exon
		self.chrand_ref_junction = chrand_ref_junction
		self.chrand_ref_single_exon_trans = chrand_ref_single_exon_trans
		self.chrand_ref_mutple_exon_trans = chrand_ref_mutple_exon_trans
		self.chrand_left_sj_set = chrand_left_sj_set
		self.chrand_right_sj_set = chrand_right_sj_set
		self.chrand_junction_dict = chrand_junction_dict
		self.sj_correction_window = sj_correction_window
		self.mis_intron_length = mis_intron_length
		self.polya_dict = polya_dict
		self.corExcept_dis = corExcept_dis

	def single_exon_read_collapse(self, read, single_exon_read):
		"""
		remove the single-exon reads overlaped with exon from reference multi-exon transcript and collapse"""
		start, end = read
		if self.chrand_ref_exon:
			overlapped_ref_exon = tuple(self.chrand_ref_exon.find(read))
			if overlapped_ref_exon:
				for exon in overlapped_ref_exon:
					if exon[0] <= start+self.sj_correction_window and exon[1] >= end-self.sj_correction_window:
						read = []
						break
		if read:
			counter=Counter(range(start,end+1))
			for i in counter:
				if i in single_exon_read:
					single_exon_read[i] += counter[i]
				else:
					single_exon_read[i] = counter[i]

		return single_exon_read

	def RTS_refrence_distance(self, start, end, read_splicing):
		"""calculate the distance between read start/end site to the reference transcript start/end site for FSM reads

		Args:
			strand (str): strand information
			start (int): genomic start position regardless of strand
			end (int): genomic end position regardless of strand
			read_splicing (tuple): corrected read splicing
			chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
		"""
		if self.strand == '+':
			rss_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][0]-start)
			res_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][1]-end)
		else:
			rss_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][0]-end)
			res_dis = abs(self.chrand_ref_mutple_exon_trans[read_splicing][1]-start)

		return rss_dis, res_dis

	def multi_exon_read_correction(self, read_id, full_block):
		"""
		splicing junction correction and collaspsing for multi-exon read"""

		raw_splicing= tuple(full_block[1:-1])
		raw_read_attribute = RawAttributeCollection(read_id, full_block[0], full_block[-1])
		corrected_read_splicing = []

		# polya event checking
		try:
			if self.polya_dict[read_id]: raw_read_attribute.polyaed = True
		except:
			pass
		
		if raw_splicing in self.chrand_ref_mutple_exon_trans: # raw read_splicing matching with the reference
			raw_read_attribute.fsm = True
			raw_read_attribute.splicing_tag = 'FSM'
			corrected_read_splicing = raw_splicing
			raw_read_attribute.rss_dis, raw_read_attribute.res_dis = self.RTS_refrence_distance(raw_read_attribute.start, raw_read_attribute.end, corrected_read_splicing)

		else: # splicing site correction
			itered_raw_splicing = iter(raw_splicing)
			for idx, (left_sj, right_sj) in enumerate(zip(itered_raw_splicing, itered_raw_splicing)):

				# check splicing coverage and motif in raw data
				sj_pos = full_block.index(left_sj)
				tmp_sj = (self.chrom, self.strand, left_sj, right_sj)
				junction_coverage = self.chrand_junction_dict[tmp_sj][0]
				junction_motif = self.chrand_junction_dict[tmp_sj][2]

				if left_sj in self.chrand_left_sj_set and right_sj in self.chrand_right_sj_set:
					if (left_sj, right_sj) in self.chrand_ref_junction:
						raw_read_attribute.splicing_tag.append('M')
					else:
						raw_read_attribute.splicing_tag.append('KC')

				elif self.chrand_left_sj_set:
					left_dis, left_ref_sj = loc_distance(self.chrand_left_sj_set, left_sj)
					right_dis, right_ref_sj = loc_distance(self.chrand_right_sj_set, right_sj)

					# do not correct the splicing site once the edit distance > sj_correction_window
					if left_dis > self.sj_correction_window: left_ref_sj = left_sj
					if right_dis > self.sj_correction_window: right_ref_sj = right_sj

					# correction exception, splicing junction with edit distance less than the given value for both sides
					if self.corExcept_dis and left_dis <= self.corExcept_dis and right_dis <= self.corExcept_dis and junction_coverage > 1 and junction_motif == 'canonical':
						raw_read_attribute.splicing_tag.append('EXC')

					# splicing site correction
					elif left_dis <= self.sj_correction_window or right_dis <= self.sj_correction_window:
						if [left_sj, right_sj] == [left_ref_sj, right_ref_sj]:
							pass
						elif full_block[sj_pos-1] < left_ref_sj < right_ref_sj < full_block[sj_pos+2]:
							raw_read_attribute.correct_site.append([left_sj, right_sj])
							left_sj, right_sj = left_ref_sj, right_ref_sj

						if left_sj in self.chrand_left_sj_set and right_sj in self.chrand_right_sj_set:
							if (left_sj, right_sj) in self.chrand_ref_junction:
								raw_read_attribute.splicing_tag.append('CM')
							else:
								raw_read_attribute.splicing_tag.append('CKC')

				#checking unintended small intron overlap with reference exons
				if len(raw_read_attribute.splicing_tag) == idx and right_sj - left_sj <= self.mis_intron_length:
					if self.chrand_ref_exon:
						overlapped_ref_exon = tuple(self.chrand_ref_exon.find((left_sj, right_sj)))
					else:
						overlapped_ref_exon = ()
					# compare the unintended intron with the ref_exon from multi-exon transcripts
					if idx == 0 and overlapped_ref_exon:
						for exon in overlapped_ref_exon:
							if (len(full_block) == 4 and exon[1] >= right_sj and exon[0] <= left_sj):
								raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
								raw_read_attribute.multi_exon = False
								break
							elif (exon[1] == full_block[sj_pos+2] and exon[0] <= left_sj):
								raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
								break

					elif idx == int(len(raw_splicing)/2 - 1) and overlapped_ref_exon:
						for exon in overlapped_ref_exon:
							if exon[1] >= right_sj and exon[0] == full_block[sj_pos-1]:
								raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
								break
							
					elif 0 < idx < int(len(raw_splicing)/2 -1) and overlapped_ref_exon:
						for exon in overlapped_ref_exon:
							if exon[1] == full_block[sj_pos+2] and exon[0] == full_block[sj_pos-1]:
								raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
								break

					# compare the unintended intron with the exon from single-exon transcripts
					elif not overlapped_ref_exon and len(full_block) == 4:
						if self.chrand_ref_single_exon_trans:
							overlapped_ref_exon = tuple(self.chrand_ref_single_exon_trans.find((left_sj, right_sj)))

						if overlapped_ref_exon:
							for exon in overlapped_ref_exon:
								if exon[1] >= right_sj and exon[0] <= left_sj:
									raw_read_attribute.merge_gap.append([full_block[sj_pos], full_block[sj_pos+1]])
									raw_read_attribute.multi_exon = False
									break
								
				corrected_read_splicing.extend([left_sj, right_sj])

				if [full_block[sj_pos], full_block[sj_pos+1]] in raw_read_attribute.merge_gap:
					raw_read_attribute.splicing_tag.append('UI')
					del corrected_read_splicing[-2:]

				elif len(raw_read_attribute.splicing_tag) == idx:
					raw_read_attribute.splicing_tag.append('NC')
					if junction_coverage == 1:
						raw_read_attribute.lowCredit_junction[idx+1] = junction_motif

			corrected_read_splicing = tuple(corrected_read_splicing)
			if corrected_read_splicing in self.chrand_ref_mutple_exon_trans:
				raw_read_attribute.fsm = True
				raw_read_attribute.rss_dis, raw_read_attribute.res_dis = self.RTS_refrence_distance(raw_read_attribute.start, raw_read_attribute.end, corrected_read_splicing)

		return corrected_read_splicing, raw_read_attribute

	def multi_exon_read_collapse(self, corrected_read_splicing, raw_read_attribute, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx):
		"""collapsing multi-exon read
		"""
		prefix = 'POS' if self.strand == '+' else 'NEG'
		if raw_read_attribute.lowCredit_junction:
			pass
		elif corrected_read_splicing:
			if corrected_read_splicing not in multi_exon_read:
				collapsed_idx += 1
				raw_read_attribute.collapsed_ID = f'{self.chrom}_{prefix}.{collapsed_idx}'
				multi_exon_read[corrected_read_splicing] = [[raw_read_attribute.start], [raw_read_attribute.end], [raw_read_attribute.polyaed], 1, raw_read_attribute.fsm, raw_read_attribute.collapsed_ID]
			else:
				multi_exon_read[corrected_read_splicing][0].insert(0, raw_read_attribute.start)
				multi_exon_read[corrected_read_splicing][1].insert(0, raw_read_attribute.end)
				multi_exon_read[corrected_read_splicing][2].insert(0,raw_read_attribute.polyaed)
				multi_exon_read[corrected_read_splicing][3] += 1
				raw_read_attribute.collapsed_ID = multi_exon_read[corrected_read_splicing][5]

		if raw_read_attribute.rss_dis:
			rss_dis_lst.append(raw_read_attribute.rss_dis)
			res_dis_lst.append(raw_read_attribute.res_dis)

		return multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx, raw_read_attribute

	def coco_operation (self, collapsed_idx = 0):
		"""main function for correcting splicing junction and collpasing reads
		"""
		corrected_read = defaultdict(dict)
		correction_log = defaultdict(dict)
		single_exon_read = defaultdict(dict)
		multi_exon_read = defaultdict(dict)
		rss_dis_lst = []
		res_dis_lst = []

		for read_id, full_block in tqdm(self.chrand_processed_read.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {self.chrom} {self.strand}'):

			# single-exon read collapsing
			if len(full_block) == 2: 
				single_exon_read = self.single_exon_read_collapse(full_block, single_exon_read)
				corrected_read[read_id] = full_block
			# multi-exon read correction and collapsing
			else:
				corrected_read_splicing, raw_read_attribute = self.multi_exon_read_correction(read_id, full_block)
				multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx, raw_read_attribute = self.multi_exon_read_collapse(corrected_read_splicing, raw_read_attribute, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx)

				if corrected_read_splicing:
					corrected_read[read_id] = [raw_read_attribute.start, *list(corrected_read_splicing), raw_read_attribute.end]
				else:
					corrected_read[read_id] = [raw_read_attribute.start, raw_read_attribute.end]
				correction_log[read_id] = raw_read_attribute

		return self.chrom, self.strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst
gtf = '/expt/zjzace/Nanopore_subcellular/Reference/gencode.v38.primary_assembly.annotation.sorted.gtf'
ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = RefProcessWrapper(gtf, 16).result_collection()
bed = '/expt/zjzace/Nanopore_subcellular/Analysis/Assembly/LAFITE/GLRA_tmp/A549_Cyto_LAFITE_tmp/bam.bed'
fa = '/NFS/mnemosyne2/expt/zjzace/GenomeRef/GRCh38.primary_assembly.genome.fa'
junction_dict, processed_read = read_grouping(bed, fa)
polya_dict = polya_signal_import('/expt/zjzace/Nanopore_subcellular/Analysis/Nanopolish/A549_Cyto_PolyA.res')
chrom, strand = 'chr19', '-'
chrand_processed_read = processed_read[(chrom, strand)]
chrand_ref_exon = ref_exon[(chrom, strand)]
chrand_ref_junction = ref_junction[(chrom, strand)]
chrand_ref_single_exon_trans = ref_single_exon_trans[(chrom, strand)]
chrand_ref_mutple_exon_trans = ref_mutple_exon_trans[(chrom,strand)]
chrand_left_sj_set = left_sj_set[(chrom,strand)]
chrand_right_sj_set = right_sj_set[(chrom,strand)]
chrand_junction_dict = junction_dict[(chrom,strand)]
sj_correction_window = 40
mis_intron_length = 150
corExcept_dis=4
polya_dict=polya_dict
chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = ReadCorrectionColappse(chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, polya_dict, corExcept_dis).coco_operation()
2022-04-13 12:04:51: Collapsing raw reads from chr19 -: 100%|██████████| 48521/48521 [00:05<00:00, 9070.96it/s] 

class CoCoWrapper[source]

CoCoWrapper(thread, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, sj_correction_window, polya_dict, mis_intron_length, tmp_dir, corExcept_dis=0)

class CoCoWrapper:
	def __init__(self, thread, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, sj_correction_window, polya_dict, mis_intron_length, tmp_dir,corExcept_dis=0):
		self.thread = thread
		self.processed_read = processed_read
		self.ref_exon = ref_exon
		self.ref_junction = ref_junction
		self.ref_single_exon_trans = ref_single_exon_trans
		self.ref_mutple_exon_trans = ref_mutple_exon_trans
		self.left_sj_set = left_sj_set
		self.right_sj_set = right_sj_set
		self.junction_dict = junction_dict
		self.sj_correction_window = sj_correction_window
		self.polya_dict = polya_dict
		self.mis_intron_length = mis_intron_length
		self.tmp_dir = tmp_dir
		self.corExcept_dis = corExcept_dis
		

	def job_compute(self):

		job = []
		for branch in self.processed_read:
			chrom, strand = branch
			chrand_processed_read = self.processed_read[branch]
			chrand_ref_exon = self.ref_exon[branch]
			chrand_ref_junction = self.ref_junction[branch]
			chrand_ref_single_exon_trans = self.ref_single_exon_trans[branch]
			chrand_ref_mutple_exon_trans = self.ref_mutple_exon_trans[branch]
			chrand_left_sj_set = self.left_sj_set[branch]
			chrand_right_sj_set = self.right_sj_set[branch]
			chrand_junction_dict = self.junction_dict[branch]
			job.append(ReadCorrectionColappse(chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, self.sj_correction_window, self.mis_intron_length, self.polya_dict, self.corExcept_dis))
		p = Pool(processes = self.thread)
		result = [p.apply_async(i.coco_operation, args=()) for i in job]
		p.close()
		p.join()

		return result

	def result_collection(self):
		collected_single_exon_read = Vividict()
		collected_multi_exon_read = Vividict()
		collected_rss = []
		collected_res = []
		path_to_log = f'{self.tmp_dir}/read_correction.log'
		path_to_corrected_bed = f'{self.tmp_dir}/Corrected_reads.bed'

		result = self.job_compute()
		with open(path_to_log, 'w') as flog, open(path_to_corrected_bed, 'w') as fbed:
			for res in result:
				chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = res.get()
				
				collected_single_exon_read[(chrom, strand)] = single_exon_read
				collected_multi_exon_read[(chrom, strand)] = multi_exon_read
				collected_rss.extend(rss_dis_lst)
				collected_res.extend(res_dis_lst)
				for read_id, raw_read_attribute in correction_log.items():
					raw_read_attribute.name = raw_read_attribute.name.split('_', 1)[1]
					attributes = '\t'.join('{}: {}'.format(key, value) for key, value in raw_read_attribute.__dict__.items())
					flog.write(f'{attributes}\n')
				
				for read_id, full_block in corrected_read.items():
					read_name = read_id.split('_', 1)[1]
					bed_block = splicing_to_bed_block(chrom, strand, read_name, full_block)
					fbed.write(f'{bed_block}\n')
		return collected_single_exon_read, collected_multi_exon_read, collected_rss, collected_res
from LAFITE.reference_processing import RefProcessWrapper, short_reads_sj_import
from LAFITE.preprocessing import read_grouping, polya_signal_import, PolyAFinder
from LAFITE.utils import temp_dir_creation, bam2bed, keep_tmp_file
gtf = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/Raw_data/SIRV_isoforms_multi-fasta-annotation_C_170612a.gtf'
bed = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/bam/SRR6058584.sorted.bed'
fa = '/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/Raw_data/SIRV_isoforms_multi-fasta_170612a.fasta'
junction_dict, processed_read = read_grouping(bed, fa)
ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = RefProcessWrapper(gtf, 16).result_collection()

# polya_dict = polya_signal_import('/expt/zjzace/Nanopore_subcellular/SIRV/SIRV_Set1/bam/SRR6058584.polya.res')
polya_dict = PolyAFinder(processed_read, fa, '/home/zjzace/software/SQANTI3-4.1/data/polyA_motifs/mouse_and_human.polyA_motif.txt').polya_estimation()
collected_single_exon_read, collected_multi_exon_read, tss, tes = CoCoWrapper(16, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, sj_correction_window=40, polya_dict=polya_dict, mis_intron_length = 150, tmp_dir='.',corExcept_dis=4).result_collection()
2022-04-12 19:18:26: Collapsing raw reads from SIRV4 +: 100%|██████████| 2661/2661 [00:00<00:00, 91686.87it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV5 -: 100%|██████████| 590/590 [00:00<00:00, 12759.49it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV6 -: 100%|██████████| 1364/1364 [00:00<00:00, 10556.48it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV2 +: 100%|██████████| 4088/4088 [00:00<00:00, 6357.57it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV1 +: 100%|██████████| 3615/3615 [00:01<00:00, 2790.63it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV2 -: 100%|██████████| 6835/6835 [00:01<00:00, 5238.81it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV4 -: 100%|██████████| 7002/7002 [00:01<00:00, 5559.16it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV1 -: 100%|██████████| 7764/7764 [00:01<00:00, 5626.42it/s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV7 -: 100%|██████████| 7413/7413 [00:01<00:00, 6001.32it/s]s]
2022-04-12 19:18:26: Collapsing raw reads from SIRV3 -: 100%|██████████| 2899/2899 [00:01<00:00, 1956.22it/s]s]
2022-04-12 19:18:28: Collapsing raw reads from SIRV3 +: 100%|██████████| 10834/10834 [00:00<00:00, 79545.93it/s]
2022-04-12 19:18:27: Collapsing raw reads from SIRV5 +: 100%|██████████| 20402/20402 [00:00<00:00, 29591.10it/s]
2022-04-12 19:18:28: Collapsing raw reads from SIRV6 +: 100%|██████████| 32687/32687 [00:01<00:00, 30373.87it/s]
def single_exon_read_collapse(read, chrand_ref_exon, overhang, single_exon_read):
	"""remove the single-exon reads overlaped with exon from reference multi-exon transcript and collapse

	Args:
		read (list): start and end position of the single-exon read (1 base, [1,100])
		chrand_ref_exon (interlap data): exon from reference multi-exon transcript
		overhang (int): tolerance distance
		single_exon_read (dict): returned collapsed single-exon reads
	"""
	
	start, end = read
	if chrand_ref_exon:
		overlapped_ref_exon = tuple(chrand_ref_exon.find(read))
		if overlapped_ref_exon:
			for exon in overlapped_ref_exon:
				if exon[0] <= start+overhang and exon[1] >= end-overhang:
					read = []
					break
	if read:
		counter=Counter(range(start,end+1))
		for i in counter:
			if i in single_exon_read:
				single_exon_read[i] += counter[i]
			else:
				single_exon_read[i] = counter[i]
	
	return single_exon_read
def RTS_refrence_distance(strand, start, end, read_splicing, chrand_ref_mutple_exon_trans):
	"""calculate the distance between read start/end site to the reference transcript start/end site for FSM reads

	Args:
		strand (str): strand information
		start (int): genomic start position regardless of strand
		end (int): genomic end position regardless of strand
		read_splicing (tuple): corrected read splicing
		chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
	"""
	if strand == '+':
		rss_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][0]-start)
		res_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][1]-end)
	else:
		rss_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][0]-end)
		res_dis = abs(chrand_ref_mutple_exon_trans[read_splicing][1]-start)
	
	return rss_dis, res_dis

	
def multi_exon_read_correction(chrom, strand, name, full_block, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, corExcept_dis, polya_dict):
	"""splicing junction correction and collaspsing for multi-exon read

	Args:
		chrom (str): chromosome
		strand (str): strand information
		name (str): read name
		full_block (list): start, end position and all splicing site of multi-exon read (1 base, [1,20,40, 100])
		chrand_ref_exon (interlap data): exon from reference multi-exon transcript
		chrand_ref_junction (list): reference splicing junction
		chrand_ref_single_exon_trans (interlap data): reference single-exon transcript
		chrand_ref_mutple_exon_trans (dict): reference multi-exon transcript
		chrand_left_sj_set (list): reference left splicing site
		chrand_right_sj_set (list): reference right splicing site
		chrand_junction_dict (dict): splicing junction detected from long read
		sj_correction_window (int): tolerance distance for splicing site correction
		mis_intron_length (int): unintended small intron gap that should be filled
		corExcept_dis (int, optional): edit distance to the reference splicing site
		polya_dict (dict, optional): raw long reads Polyadenylation event
	"""

	raw_splicing= tuple(full_block[1:-1])
	start, end = full_block[0], full_block[-1]
	corrected_read_splicing = []
	tag_dict={'reference_match':False, 'multi-exon':True, 'correct_site':[], 'merge_site':[], 'polya_signal':False , 'lowCredit_junction':{}, 'splicing_tag':[], 'rss_dis':None, 'res_dis':None}

	# polya event checking
	try:
		if polya_dict[name]: tag_dict['polya_signal'] = True
	except:
		pass
	
	if raw_splicing in chrand_ref_mutple_exon_trans: # uncorrected read_splicing matching with the reference
		tag_dict['reference_match'] = True
		tag_dict['splicing_tag'] = 'FSM'
		corrected_read_splicing = raw_splicing
		tag_dict['rss_dis'], tag_dict['res_dis'] = RTS_refrence_distance(strand, start, end, corrected_read_splicing, chrand_ref_mutple_exon_trans)

	else: # splicing site correction
		itered_raw_splicing = iter(raw_splicing)
		for idx, (left_sj, right_sj) in enumerate(zip(itered_raw_splicing, itered_raw_splicing)):

			# check splicing coverage and motif in raw data
			sj_pos = full_block.index(left_sj)
			tmp_sj = (chrom,strand,left_sj,right_sj)
			junction_coverage = chrand_junction_dict[tmp_sj][0]
			junction_motif = chrand_junction_dict[tmp_sj][2]
			
			if left_sj in chrand_left_sj_set and right_sj in chrand_right_sj_set:
				if (left_sj, right_sj) in chrand_ref_junction:
					tag_dict['splicing_tag'].append('M')
				else:
					tag_dict['splicing_tag'].append('KC')

			elif chrand_left_sj_set:
				left_dis, left_ref_sj = loc_distance(chrand_left_sj_set, left_sj)
				right_dis, right_ref_sj = loc_distance(chrand_right_sj_set, right_sj)

				# do not correct the splicing site once the edit distance > sj_correction_window
				if left_dis > sj_correction_window: left_ref_sj = left_sj
				if right_dis > sj_correction_window: right_ref_sj = right_sj

				# correction exception, splicing junction with edit distance less than 4 for both sides
				if left_dis <= corExcept_dis and right_dis <= corExcept_dis and junction_coverage > 1 and junction_motif == 'canonical':
					tag_dict['splicing_tag'].append('EXC')
					
				# splicing site correction
				elif left_dis <= sj_correction_window or right_dis <= sj_correction_window:
					if [left_sj, right_sj] == [left_ref_sj, right_ref_sj]:
						pass
					elif full_block[sj_pos-1] < left_ref_sj < right_ref_sj < full_block[sj_pos+2]:
						tag_dict['correct_site'].append([left_sj, right_sj])
						left_sj, right_sj = left_ref_sj, right_ref_sj

					if left_sj in chrand_left_sj_set and right_sj in chrand_right_sj_set:
						if (left_sj, right_sj) in chrand_ref_junction:
							tag_dict['splicing_tag'].append('CM')
						else:
							tag_dict['splicing_tag'].append('CKC')

			#checking unintended small intron overlap with reference exons
			if len(tag_dict['splicing_tag']) == idx and right_sj - left_sj <= mis_intron_length:
				if chrand_ref_exon:
					overlapped_ref_exon = tuple(chrand_ref_exon.find((left_sj, right_sj)))
				else:
					overlapped_ref_exon = ()
				# compare the unintended intron with the ref_exon from multi-exon transcripts
				if idx == 0 and overlapped_ref_exon:
					for exon in overlapped_ref_exon:
						if (len(full_block) == 4 and exon[1] >= right_sj and exon[0] <= left_sj):
							tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
							tag_dict['multi-exon'] = False
							break
						elif (exon[1] == full_block[sj_pos+2] and exon[0] <= left_sj):
							tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
							break
							
				elif idx == int(len(raw_splicing)/2 - 1) and overlapped_ref_exon:
					for exon in overlapped_ref_exon:
						if exon[1] >= right_sj and exon[0] == full_block[sj_pos-1]:
							tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
							break
				
				elif 0 < idx < int(len(raw_splicing)/2 -1) and overlapped_ref_exon:
					for exon in overlapped_ref_exon:
						if exon[1] == full_block[sj_pos+2]  and exon[0] == full_block[sj_pos-1]:
							tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
							break
							
				# compare the unintended intron with the exon from single-exon transcripts
				elif not overlapped_ref_exon and len(full_block) == 4:
					if chrand_ref_single_exon_trans:
						overlapped_ref_exon = tuple(chrand_ref_single_exon_trans.find((left_sj, right_sj)))
				
					if overlapped_ref_exon:
						for exon in overlapped_ref_exon:
							if exon[1] >= right_sj and exon[0] <= left_sj:
								tag_dict['merge_site'].append([full_block[sj_pos], full_block[sj_pos+1]])
								tag_dict['multi-exon'] = False
								break
			
			corrected_read_splicing.extend([left_sj, right_sj])

			if [full_block[sj_pos], full_block[sj_pos+1]] in tag_dict['merge_site']:
				tag_dict['splicing_tag'].append('UI')
				del corrected_read_splicing[-2:]

			elif len(tag_dict['splicing_tag']) == idx:
				tag_dict['splicing_tag'].append('NC')
				if junction_coverage == 1:
					tag_dict['lowCredit_junction'][idx+1] = junction_motif

		corrected_read_splicing = tuple(corrected_read_splicing)
		if corrected_read_splicing in chrand_ref_mutple_exon_trans:
			tag_dict['reference_match'] = True
			tag_dict['rss_dis'], tag_dict['res_dis'] = RTS_refrence_distance(strand, start, end, corrected_read_splicing, chrand_ref_mutple_exon_trans)
	
	return start, end, corrected_read_splicing, tag_dict
def multi_exon_read_collapse(read_id, start, end, corrected_read_splicing, tag_dict, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx):
	"""collapsing multi-exon read
	"""
	read_name = read_id.split('_', 1)[1]
	if tag_dict['lowCredit_junction']:
		pass
	elif corrected_read_splicing:
		if corrected_read_splicing not in multi_exon_read:
			collapsed_idx += 1
			multi_exon_read[corrected_read_splicing] = [[start], [end], [tag_dict['polya_signal']], 1, [read_name], tag_dict['reference_match'], f'collapsed.{collapsed_idx}']
		else:
			multi_exon_read[corrected_read_splicing][0].insert(0, start)
			multi_exon_read[corrected_read_splicing][1].insert(0, end)
			multi_exon_read[corrected_read_splicing][2].insert(0,tag_dict['polya_signal'])
			multi_exon_read[corrected_read_splicing][3] += 1
			multi_exon_read[corrected_read_splicing][4].insert(0,read_name)

	if tag_dict['rss_dis']:
		rss_dis_lst.append(tag_dict['rss_dis'])
		res_dis_lst.append(tag_dict['res_dis'])

	return multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx
def three_prime_exon_extraction(strand, multi_exon_read):
	three_prime_exon = defaultdict(set)
	for corrected_read_splicing, read_info in multi_exon_read.items():
		if strand == '+':
			max_end = max(read_info[1])
			exon = (corrected_read_splicing[-1], max_end)
		else:
			max_end = max(read_info[0])
			exon = (max_end, corrected_read_splicing[0])
		three_prime_exon.add(exon)
	
	return three_prime_exon
def coco_operation (chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window = 40, mis_intron_length = 150, corExcept_dis = 0, polya_dict=None, collapsed_idx = 0):
	"""main function for correcting splicing junction and collpasing reads
	"""
	corrected_read = defaultdict(dict)
	correction_log = defaultdict(dict)
	single_exon_read = defaultdict(dict)
	multi_exon_read = defaultdict(dict)
	rss_dis_lst = res_dis_lst = []

	for read_id, full_block in tqdm(chrand_processed_read.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {chrom} {strand}'):
		
		# single-exon read collapsing
		if len(full_block) == 2: 
			single_exon_read = single_exon_read_collapse(full_block, chrand_ref_exon, sj_correction_window, single_exon_read)
			corrected_read[read_id] = full_block
		# multi-exon read correction and collapsing
		else:
			start, end, corrected_read_splicing, tag_dict = multi_exon_read_correction(chrom, strand, read_id, full_block, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window, mis_intron_length, corExcept_dis, polya_dict)
			
			multi_exon_read, rss_dis_lst, res_dis_lst, collapsed_idx = multi_exon_read_collapse(read_id, start, end, corrected_read_splicing, tag_dict, rss_dis_lst, res_dis_lst, multi_exon_read, collapsed_idx)
			if corrected_read_splicing:
				corrected_read[read_id] = [start, *list(corrected_read_splicing), end]
			else:
				corrected_read[read_id] = [start, end]
			correction_log[read_id] = tag_dict
				
	return chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst

splicing_to_bed_block[source]

splicing_to_bed_block(chrom, strand, name, full_block)

def splicing_to_bed_block(chrom, strand, name, full_block):
	start, end = full_block[0], full_block[-1]
	full_block = iter(full_block)
	block_sizes = []
	block_starts = []
	for left_end, right_end in zip(full_block, full_block):
		block_starts.append(left_end - start)
		block_sizes.append(right_end - left_end + 1)
	block_count = len(block_sizes)
	block_sizes = ','.join([str(i) for i in block_sizes])
	block_starts = ','.join([str(i) for i in block_starts])
	bed_block = [chrom, start-1, end, name, '-', strand, start-1, end, '255,0,0', block_count, block_sizes, block_starts]
	bed_block = '\t'.join([str(i) for i in bed_block])

	return bed_block
class CoCoWrapper:
	def __init__(self, thread, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, tmp_dir, sj_correction_window, polya_dict, mis_intron_length, corExcept_dis=4):
		self.thread = thread
		self.processed_read = processed_read
		self.ref_exon = ref_exon
		self.ref_junction = ref_junction
		self.ref_single_exon_trans = ref_single_exon_trans
		self.ref_mutple_exon_trans = ref_mutple_exon_trans
		self.left_sj_set = left_sj_set
		self.right_sj_set = right_sj_set
		self.junction_dict = junction_dict
		self.sj_correction_window = sj_correction_window
		self.tmp_dir = tmp_dir
		self.polya_dict = polya_dict
		self.mis_intron_length = mis_intron_length
		self.corExcept_dis = corExcept_dis
		

	def job_compute(self):
		p = Pool(processes = self.thread)
		result = []
		for branch in self.processed_read:
			chrom, strand = branch
			chrand_processed_read = self.processed_read[branch]
			chrand_ref_exon = self.ref_exon[branch]
			chrand_ref_junction = self.ref_junction[branch]
			chrand_ref_single_exon_trans = self.ref_single_exon_trans[branch]
			chrand_ref_mutple_exon_trans = self.ref_mutple_exon_trans[branch]
			chrand_left_sj_set = self.left_sj_set[branch]
			chrand_right_sj_set = self.right_sj_set[branch]
			chrand_junction_dict = self.junction_dict[branch]
			result.append(p.apply_async(coco_operation, (chrom, strand, chrand_processed_read, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, self.sj_correction_window, self.mis_intron_length, self.corExcept_dis, self.polya_dict,)))

		p.close()
		p.join()

		return result

	def result_collection(self):
		collected_single_exon_read = Vividict()
		collected_multi_exon_read = Vividict()
		collected_rss = []
		collected_res = []
		path_to_log = f'{self.tmp_dir}/read_correction.log'
		path_to_corrected_bed = f'{self.tmp_dir}/Corrected_reads.bed'

		result = self.job_compute()
		with open(path_to_log, 'w') as flog, open(path_to_corrected_bed, 'w') as fbed:
			for res in result:
				chrom, strand, single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = res.get()
				collected_single_exon_read[(chrom, strand)] = single_exon_read
				collected_multi_exon_read[(chrom, strand)] = multi_exon_read
				collected_rss.extend(rss_dis_lst)
				collected_res.extend(res_dis_lst)
				for read_id, tag_dict in correction_log.items():
					read_name = read_id.split('_', 1)[1]
					attributes = '\t'.join('{}: {}'.format(key, value) for key, value in tag_dict.items())
					flog.write(f'{read_name}\t{attributes}\n')
				
				for read_id, full_block in corrected_read.items():
					read_name = read_id.split('_', 1)[1]
					bed_block = splicing_to_bed_block(chrom, strand, read_name, full_block)
					fbed.write(f'{bed_block}\n')
		return collected_single_exon_read, collected_multi_exon_read, collected_rss, collected_res
# chrand_read_list1['adcc854a-2c55-432c-9d65-11ca1d8c9eb4'] = chrand_read_list['adcc854a-2c55-432c-9d65-11ca1d8c9eb4']
# # chrand_read_list1['adcc854a-2c55-432c-9d65-11ca1d8c9eb4'] = [167500,167519, 167584, 167610]


# for (chrom, strand) in processed_read:
# 	chrand_read_list = processed_read[(chrom, strand)]
# 	chrand_ref_exon = ref_exon[(chrom, strand)]
# 	chrand_ref_junction = ref_junction[(chrom, strand)]
# 	chrand_ref_single_exon_trans = ref_single_exon_trans[(chrom, strand)]
# 	chrand_ref_mutple_exon_trans = ref_mutple_exon_trans[(chrom,strand)]
# 	chrand_left_sj_set = left_sj_set[(chrom,strand)]
# 	chrand_right_sj_set = right_sj_set[(chrom,strand)]
# 	chrand_junction_dict = junction_dict[(chrom,strand)]
# 	sj_correction_window = 40
# 	mis_intron_length = 150
# 	corExcept_dis=4
# 	polya_dict=None
# 	single_exon_read = {}
# 	tmp_dir='./'
# 	# with open ('test.out', 'w') as fw:
# 	# 	for read, splicing in tqdm(chrand_read_list.items(), desc = f'{strftime("%Y-%m-%d %H:%M:%S")}: Collapsing raw reads from {chrom} {strand}'):
# 	# 		if len(splicing) == 2:
# 	# 			single_exon_read_collapse(splicing, chrand_ref_exon, 40, single_exon_ead)
# 	# 		if len(splicing) > 2:

# 	single_exon_read, multi_exon_read, corrected_read, correction_log, rss_dis_lst, res_dis_lst = coco_operation(chrom, strand, chrand_read_list, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_junction_dict, sj_correction_window)
collected_single_exon_read, collected_multi_exon_read = CoCoWrapper(32, processed_read, ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, junction_dict, tmp_dir='.', sj_correction_window=40, mis_intron_length = 150, corExcept_dis=4, polya_dict=None).result_collection()