import math
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from joblib import Parallel, delayed
from operator import sub
from statistics import mean
from LAFITE.utils import loc_distance, Vividict

class AttributeCollection[source]

AttributeCollection(start:int, end:int, count:float, polya_count:int=None, fsm:bool=False, polyaed:bool=False, as_site:list=None, apa_site:list=None, name:str=None, rss_dis:int=inf, read_tag:str=None, processed:bool=False, chrand_ID:int=None, loci_ID:int=None)

AttributeCollection(start: int, end: int, count: float, polya_count: int = None, fsm: bool = False, polyaed: bool = False, as_site: list = None, apa_site: list = None, name: str = None, rss_dis: int = inf, read_tag: str = None, processed: bool = False, chrand_ID: int = None, loci_ID: int = None)

@ dataclass
class AttributeCollection:
	start: int
	end: int
	count: float
	polya_count: int = None
	fsm: bool = False
	polyaed: bool = False
	as_site: list = None
	apa_site: list = None
	name: str = None
	rss_dis: int = math.inf
	read_tag: str = None
	processed: bool = False
	chrand_ID: int = None
	loci_ID: int = None

class SingleExonReadRefine[source]

SingleExonReadRefine(chrom, strand, chrand_collected_single_exon_read, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_three_prime_exon, min_single_exon_coverage, min_single_exon_len, rs_tolerance=24, coverage_cliff=0)

class SingleExonReadRefine:
	def __init__(self, chrom, strand, chrand_collected_single_exon_read, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_three_prime_exon, min_single_exon_coverage, min_single_exon_len, rs_tolerance = 24, coverage_cliff = 0):
		self.chrom = chrom
		self.strand = strand
		self.chrand_collected_single_exon_read = chrand_collected_single_exon_read
		self.chrand_ref_single_exon_trans = chrand_ref_single_exon_trans
		self.chrand_ref_mutple_exon_trans = chrand_ref_mutple_exon_trans
		self.chrand_three_prime_exon = chrand_three_prime_exon
		self.min_single_exon_coverage = min_single_exon_coverage
		self.min_single_exon_len = min_single_exon_len
		self.rs_tolerance = rs_tolerance
		self.coverage_cliff = coverage_cliff

	def single_exon_refine(self, subread, position, coverage, refined_single_exon_read, i):
		if subread[1]-subread[0] >= self.min_single_exon_len:
			subread_attribute = AttributeCollection(position[subread[0]], position[subread[1]], mean(coverage[subread[0]:subread[1]]))
			if self.chrand_ref_single_exon_trans:
				overlap_ref = tuple(self.chrand_ref_single_exon_trans.find((position[subread[0]],position[subread[1]])))
				if overlap_ref:
					for ref_exon in overlap_ref:
						overlap = list(set(range(position[subread[0]], position[subread[1]]+1)).intersection(range(ref_exon[0],ref_exon[1]+1)))
						overlap.sort()
						if (overlap[-1] - overlap[0] + 1)/(ref_exon[-1] - ref_exon[0] + 1) >= 0.5:
							subread_attribute.processed = True
							subread_attribute.read_tag = 'Keep_ref'
							break
			if not subread_attribute.processed and self.chrand_three_prime_exon:
				overlap_ref = tuple(self.chrand_three_prime_exon.find((position[subread[0]],position[subread[1]])))
				if overlap_ref:
					for ref_exon in overlap_ref:
						if (self.strand == '+' and position[subread[0]] > ref_exon[0] - self.rs_tolerance) or (self.strand == '-' and position[subread[1]] < ref_exon[1] + self.rs_tolerance):
							subread_attribute.processed = True
							break
			if not subread_attribute.processed:
				if subread_attribute.count >= self.min_single_exon_coverage:
					subread_attribute.processed = True
					subread_attribute.read_tag = 'Keep_coverage'
			if subread_attribute.read_tag:
				refined_single_exon_read[(position[subread[0]], position[subread[1]])] = subread_attribute
		subread = [i+1,i+1]
		return subread, refined_single_exon_read

	def refine(self):
		refined_single_exon_read = defaultdict(dict)
		chrand_collected_single_exon_read = dict(OrderedDict(sorted(self.chrand_collected_single_exon_read.items())))
		position = list(chrand_collected_single_exon_read.keys())
		coverage = list(chrand_collected_single_exon_read.values())
		subread = [0,0]
		if not self.chrand_ref_mutple_exon_trans:
			self.coverage_cliff = 0.9
		for i in range(len(position)-1):
			if position[i+1] == position[i]+1 and ((self.strand == '+' and coverage[i+1]/coverage[i]>= self.coverage_cliff) or (self.strand == '-' and coverage[i]/coverage[i+1]>= self.coverage_cliff)):
				subread[1] = i+1
				if i == len(position) - 2:
					subread, refined_single_exon_read = self.single_exon_refine(subread, position, coverage, refined_single_exon_read, i)
			else:
				subread, refined_single_exon_read = self.single_exon_refine(subread, position, coverage, refined_single_exon_read, i)
		return self.chrom, self.strand, refined_single_exon_read
		
		

class MultiExonReadRefine[source]

MultiExonReadRefine(chrom, strand, chrand_processed_collected_multi_exon_read, chrand_ref_mutple_exon_trans, chrand_tss_dict, tss_cutoff, tes_cutoff, min_novel_trans_count, relative_fl_coverage=5, rs_tolerance=24, re_tolerance=50)

class MultiExonReadRefine():
	def __init__(self, chrom, strand, chrand_processed_collected_multi_exon_read, chrand_ref_mutple_exon_trans, chrand_tss_dict, tss_cutoff, tes_cutoff, min_novel_trans_count, relative_fl_coverage=5, rs_tolerance=24, re_tolerance=50):
		self.chrom = chrom
		self.strand = strand
		self.chrand_processed_collected_multi_exon_read = chrand_processed_collected_multi_exon_read
		self.chrand_ref_mutple_exon_trans = chrand_ref_mutple_exon_trans
		self.chrand_tss_dict = chrand_tss_dict
		self.tss_cutoff = tss_cutoff
		self.tes_cutoff = tes_cutoff
		self.rs_tolerance = rs_tolerance
		self.re_tolerance = re_tolerance
		self.min_novel_trans_count = min_novel_trans_count
		self.relative_fl_coverage = relative_fl_coverage

	def trucated_reads_filtering(self, start, end, corrected_read_splicing, trans_structure_pool, refined_multi_exon_read, total_count):
		for ref_iso_splicing in trans_structure_pool:
			tmp_tag = None
			if len(ref_iso_splicing) > len(corrected_read_splicing):
				if set(corrected_read_splicing).issubset(set(ref_iso_splicing)):
					idx_start = ref_iso_splicing.index(corrected_read_splicing[0])
					idx_end = ref_iso_splicing.index(corrected_read_splicing[-1])
					if corrected_read_splicing == ref_iso_splicing[idx_start:idx_end+1]:
						if self.strand == '+' and idx_start > 0:
							if idx_end == len(ref_iso_splicing) -1:
								if start >= ref_iso_splicing[idx_start-1] - self.rs_tolerance:
									tmp_tag = 'Disqualify_Trucated_ISM'
							else:
								if start >= ref_iso_splicing[idx_start-1] - self.rs_tolerance and end <= ref_iso_splicing[idx_end+1] + self.re_tolerance:
									tmp_tag = 'Disqualify_Trucated_ISM'

						elif self.strand == '-' and idx_end < len(ref_iso_splicing)-1:
							if idx_start == 0:
								if start <= ref_iso_splicing[idx_end+1] + self.rs_tolerance:
									tmp_tag = 'Disqualify_Trucated_ISM'
							else:
								if start <= ref_iso_splicing[idx_end+1] + self.rs_tolerance and end >= ref_iso_splicing[idx_start-1] - self.rs_tolerance:
									tmp_tag = 'Disqualify_Trucated_ISM'
					if tmp_tag:
						if ref_iso_splicing in refined_multi_exon_read:
							if total_count/refined_multi_exon_read[ref_iso_splicing].count >= self.relative_fl_coverage:
								tmp_tag = None
					if tmp_tag:
						break
		return tmp_tag

	def ISM_NIC_sorting(self, corrected_read_splicing, trans_structure_pool, rss_dis, total_count, start_pos = [], relative_intact = [], tmp_tag = None):
		for ref_iso_splicing in trans_structure_pool:
			if len(ref_iso_splicing) > len(corrected_read_splicing):
				if set(corrected_read_splicing).issubset(set(ref_iso_splicing)):
					idx_start = ref_iso_splicing.index(corrected_read_splicing[0])
					idx_end = ref_iso_splicing.index(corrected_read_splicing[-1])
					if (self.strand == '+' and idx_start == 0) or (self.strand == '-' and idx_end == len(ref_iso_splicing)-1):
						start_pos.append(1)
					else:
						start_pos.append(0)
					ref_exon_num = (len(ref_iso_splicing)+2)/2 
					if self.strand == '+':
						relative_intact.append((ref_exon_num - idx_start/2)/ref_exon_num)
					else:
						relative_intact.append(((idx_end+3)/2)/ref_exon_num)
		if start_pos:
			if all(start_pos) or rss_dis < self.tss_cutoff:
				tmp_tag = 'Subset'
			elif total_count >= self.min_novel_trans_count and mean(relative_intact) > 0.5:
				tmp_tag = 'Subset_coverage'
			else:
				tmp_tag = 'Disqualify_ISM_NIC'
		
		return tmp_tag

	
	def closest_ref_trans (self, corrected_read_splicing, trans_structure_pool):
		""" return the closest reference transcripts for the input read
		"""
		inter_SJ = ()
		cmp_trans = ()
		for ref_iso_splicing in trans_structure_pool:
			if self.strand == '-':
				ref_iso_splicing = tuple(reversed(ref_iso_splicing))
				corrected_read_splicing = tuple(reversed(corrected_read_splicing))

			tmp_inter = list(set(ref_iso_splicing).intersection(set(corrected_read_splicing)))
			tmp_inter.sort(reverse=True)
			if len(tmp_inter) > len(inter_SJ):
				inter_SJ, cmp_trans = tmp_inter, ref_iso_splicing
			elif len(tmp_inter) == len(inter_SJ) and len(tmp_inter) > 0:
				index1 = [i for i, val in enumerate(cmp_trans) if val in inter_SJ] 
				index2 = [i for i, val in enumerate(ref_iso_splicing) if val in tmp_inter] 
				index_sub = list(map(sub, index1, index2))
				if any(index_sub):
					if index_sub[next((i for i, x in enumerate(index_sub) if x!=0), None)] > 0:
						inter_SJ = tmp_inter
						cmp_trans = ref_iso_splicing
				elif len(ref_iso_splicing) < len(cmp_trans):
					inter_SJ = tmp_inter
					cmp_trans = ref_iso_splicing
		
		if self.strand == '-':
			cmp_trans = tuple(reversed(cmp_trans))

		return len(inter_SJ), cmp_trans
	
	def secondary_refine(self, corrected_read_splicing, read_attribute, refined_multi_exon_read, loci_idx):
		if 'Disqualify' not in read_attribute.read_tag:
			if read_attribute.read_tag != 'Reference':
				for ref_iso_splicing in refined_multi_exon_read:
					if set(corrected_read_splicing).issubset(set(ref_iso_splicing)):
						idx_start = ref_iso_splicing.index(corrected_read_splicing[0])
						idx_end = ref_iso_splicing.index(corrected_read_splicing[-1])
						if corrected_read_splicing == ref_iso_splicing[idx_start:idx_end+1]:
							if (self.strand == '+' and idx_end == len(ref_iso_splicing) -1) or (self.strand == '-' and idx_start == 0):
								if (self.strand == "+" and read_attribute.start >= ref_iso_splicing[idx_start-1] - self.rs_tolerance) or (self.strand == "-" and read_attribute.start <= ref_iso_splicing[idx_end+1] + self.rs_tolerance):
									if read_attribute.count/refined_multi_exon_read[ref_iso_splicing].count < self.relative_fl_coverage:
										read_attribute.read_tag = 'Disqualify_merged'
										break

		if 'Disqualify' not in read_attribute.read_tag:
			len_inter_SJ, cmp_trans = self.closest_ref_trans(corrected_read_splicing, refined_multi_exon_read)
			if len_inter_SJ > 0:
				read_attribute.chrand_ID = refined_multi_exon_read[cmp_trans].chrand_ID
			else:
				loci_idx += 1
				read_attribute.chrand_ID = loci_idx

			refined_multi_exon_read[corrected_read_splicing] = read_attribute

		return refined_multi_exon_read, read_attribute, loci_idx


	def main_refine(self, loci_idx = 0):
		refined_multi_exon_read = defaultdict(dict)
		refine_log = defaultdict(dict)
		for corrected_read_splicing, read_info in self.chrand_processed_collected_multi_exon_read.items():
			read_attribute = AttributeCollection(*read_info)
			trans_structure_pool = self.chrand_ref_mutple_exon_trans

			# calculate the distance between TSS of the collapsed read and nearest reference TSS
			read_attribute.rss_dis, _ = loc_distance(self.chrand_tss_dict, read_attribute.start)

			# check full splicing match reads whose rss and exactly matched reference transcripts 
			if read_attribute.fsm:
				if any(abs(x - trans_structure_pool[corrected_read_splicing][0]) <= self.tss_cutoff for x in read_attribute.as_site):
					if read_attribute.polyaed:
						read_attribute.read_tag = 'Reference'
						read_attribute.processed = True
					elif any(abs(x - trans_structure_pool[corrected_read_splicing][1]) <= self.tes_cutoff for x in read_attribute.apa_site):
						read_attribute.read_tag = 'Reference'
						read_attribute.processed = True

			if not read_attribute.processed:
				# check if read is trucated
				if trans_structure_pool:
					read_tag = self.trucated_reads_filtering(read_attribute.start, read_attribute.end, corrected_read_splicing, trans_structure_pool, refined_multi_exon_read, read_attribute.count)
					if read_tag:
						read_attribute.read_tag = read_tag
						read_attribute.processed = True

				if not read_attribute.processed and read_attribute.fsm:
					read_attribute.read_tag = 'Reference'
					read_attribute.processed = True
				
				if not read_attribute.processed and not read_attribute.polyaed:
					read_attribute.read_tag = 'Disqualify_No_ployA'
					read_attribute.processed = True
					
				elif not read_attribute.processed:
					if trans_structure_pool:
						read_tag = self.ISM_NIC_sorting(corrected_read_splicing, trans_structure_pool, read_attribute.rss_dis, read_attribute.count, start_pos = [], tmp_tag = None)
						if read_tag:
							read_attribute.read_tag = read_tag
							read_attribute.processed = True

					if not read_attribute.processed:
						len_inter_SJ, cmp_trans = self.closest_ref_trans(corrected_read_splicing, trans_structure_pool)
						read_attribute.processed = True
						if len_inter_SJ > 0:
							if (cmp_trans[0] == corrected_read_splicing[0]) or set(cmp_trans).issubset(set(corrected_read_splicing)) or read_attribute.rss_dis <= self.tss_cutoff:
								read_attribute.read_tag = 'Similar'
							else:
								read_attribute.read_tag = 'Disqualify_NNC'
						else:
							if read_attribute.rss_dis <= self.tss_cutoff or read_attribute.count  >= self.min_novel_trans_count:
								read_attribute.read_tag = 'Novel_loci'
							else:
								read_attribute.read_tag = "Disqualify_Other"
			
			refined_multi_exon_read, read_attribute, loci_idx = self.secondary_refine(corrected_read_splicing, read_attribute, refined_multi_exon_read, loci_idx)
			refine_log[corrected_read_splicing] = read_attribute
		return self.chrom, self.strand, refined_multi_exon_read, refine_log

class RefineWrapper[source]

RefineWrapper(processed_collected_multi_exon_read, collected_single_exon_read, ref_mutple_exon_trans, ref_single_exon_trans, three_prime_exon, tss_dict, tss_cutoff, tes_cutoff, min_novel_trans_count, min_single_exon_coverage, min_single_exon_len, thread, tmp_dir)

class RefineWrapper:
	def __init__(self, processed_collected_multi_exon_read, collected_single_exon_read, ref_mutple_exon_trans, ref_single_exon_trans, three_prime_exon, tss_dict, tss_cutoff, tes_cutoff, min_novel_trans_count, min_single_exon_coverage, min_single_exon_len, thread, tmp_dir):
		self.processed_collected_multi_exon_read = processed_collected_multi_exon_read
		self.collected_single_exon_read = collected_single_exon_read
		self.ref_mutple_exon_trans = ref_mutple_exon_trans
		self.ref_single_exon_trans = ref_single_exon_trans
		self.three_prime_exon = three_prime_exon
		self.tss_dict = tss_dict
		self.tss_cutoff = tss_cutoff
		self.tes_cutoff = tes_cutoff
		self.min_novel_trans_count = min_novel_trans_count
		self.min_single_exon_coverage = min_single_exon_coverage
		self.min_single_exon_len = min_single_exon_len
		self.thread = thread
		self.tmp_dir = tmp_dir

	def run1(self):
		multi_precompute_list = []
		for (chrom, strand), chrand_processed_collected_multi_exon_read in self.processed_collected_multi_exon_read.items():
			multi_precompute_list.append(MultiExonReadRefine(chrom, strand, chrand_processed_collected_multi_exon_read, self.ref_mutple_exon_trans[(chrom, strand)], self.tss_dict[(chrom, strand)], self.tss_cutoff, self.tes_cutoff, self.min_novel_trans_count))
		with Parallel(n_jobs = self.thread) as parallel:
			multi_exon_results = parallel(delayed(lambda x:x.main_refine())(job) for job in multi_precompute_list)

		return multi_exon_results

	def run2(self):
		single_precompute_list = []
		for (chrom, strand), chrand_collected_single_exon_read in self.collected_single_exon_read.items():
			single_precompute_list.append(SingleExonReadRefine(chrom, strand, chrand_collected_single_exon_read, self.ref_single_exon_trans[(chrom,strand)], self.ref_mutple_exon_trans[(chrom,strand)], self.three_prime_exon[(chrom,strand)], self.min_single_exon_coverage, self.min_single_exon_len))
		with Parallel(n_jobs = self.thread) as parallel:
			single_exon_results = parallel(delayed(lambda x:x.refine())(job) for job in single_precompute_list)
		
		return single_exon_results


	def result_collection(self):
		multi_exon_results = self.run1()
		single_exon_results = self.run2()
		collected_refined_isoforms = Vividict()
		path_to_refine_log = f'{self.tmp_dir}/refine.log'
		for result in sorted(multi_exon_results):
			chrom, strand, refined_multi_exon_read, refine_log = result
			for corrected_read_splicing, read_attribute in refined_multi_exon_read.items():
				if strand == '-':
					corrected_read_splicing = (read_attribute.end,) + corrected_read_splicing + (read_attribute.start,)
				else:
					corrected_read_splicing = (read_attribute.start,) + corrected_read_splicing + (read_attribute.end,)
				collected_refined_isoforms[(chrom, strand)][corrected_read_splicing] = read_attribute
			with open(path_to_refine_log, 'a') as flog:
				for corrected_read_splicing, read_attribute in refine_log.items():
					corrected_read_splicing = ','.join([str(s) for s in corrected_read_splicing])
					read_attribute = read_attribute.__dict__
					read_attribute = '\t'.join('{}: {}'.format(key, str(value)) for key, value in read_attribute.items())
					flog.write(f'{corrected_read_splicing}\t{read_attribute}\n')
		for result in single_exon_results:
			chrom, strand, refined_single_exon_read = result
			for corrected_read_splicing, read_attribute in refined_single_exon_read.items():
				collected_refined_isoforms[(chrom, strand)][corrected_read_splicing] = read_attribute
		
		collected_refined_isoforms = dict(sorted(collected_refined_isoforms.items()))

		return collected_refined_isoforms