from collections import defaultdict
from interlap import InterLap
import random
import numpy as np
import pandas as pd

class OutputAssembly[source]

OutputAssembly(collected_refined_isoforms, output, label, relative_abundance_threshold)

class OutputAssembly:
	def __init__(self, collected_refined_isoforms, output, label, relative_abundance_threshold):
		self.collected_refined_isoforms = collected_refined_isoforms
		self.label = label
		self.output = output
		self.relative_abundance_threshold = relative_abundance_threshold
		

	def idx_rearrange(self):
		loci_ID = 0
		for (chrom,strand), refined_isoforms in self.collected_refined_isoforms.items():
			chrand2loci = defaultdict(dict)
			exon2loci_ID = defaultdict(dict)
			exon_interlap = set()
			for full_block, read_attribute in refined_isoforms.items():
				if len(full_block) > 2:
					if read_attribute.chrand_ID in chrand2loci:
						read_attribute.loci_ID = chrand2loci[read_attribute.chrand_ID]
					else:
						loci_ID += 1
						read_attribute.loci_ID = loci_ID
						chrand2loci[read_attribute.chrand_ID] = loci_ID
					full_block = iter(full_block)
					for exon in zip(full_block, full_block):
						exon2loci_ID[exon] = read_attribute.loci_ID
						exon_interlap.add(exon)
				else:
					if exon_interlap:
						if isinstance(exon_interlap, set):
							t = list(exon_interlap)
							exon_interlap = InterLap()
							exon_interlap.update(t)

						overlap_ref = tuple(exon_interlap.find(full_block))
						if overlap_ref:
							read_attribute.loci_ID = exon2loci_ID[overlap_ref[0]]
					if not read_attribute.loci_ID:
						loci_ID += 1
						read_attribute.loci_ID = loci_ID

		return self.collected_refined_isoforms
	
	def write_out(self):
		collected_refined_isoforms = self.idx_rearrange()
		with open (self.output, 'w') as fw:
			for (chrom, strand), refined_isoforms in collected_refined_isoforms.items():
				df = pd.DataFrame([(k,) + tuple(v.__dict__.values()) for k, v in refined_isoforms.items()])
				df.columns = ['full_block'] + list(random.choice(list(refined_isoforms.values())).__dict__.keys())
				df = df.set_index('loci_ID')
				for loci in np.unique(df.index.values):
					tmp = df.loc[[loci]].copy()
					tmp['abundance'] = tmp['count']/tmp['count'].sum()
					tmp = tmp[(tmp['abundance'] > self.relative_abundance_threshold)|(tmp['read_tag'].str.contains("Reference"))]
					tmp.sort_values('start', ascending = True, inplace = True)

					isoform_idx = 0
					for index, row in tmp.iterrows():
						full_block, start, end, count, polya_count, fsm, polyaed, as_site, apa_site, collasped_name, rss_dis, read_tag, processed, chrand_ID, abundance  = list(row)
						if as_site:
							as_site = ','.join([str(x) for x in as_site])
							apa_site = ','.join([str(x) for x in apa_site])
						else:
							as_site = str(start)
							apa_site = str(end)
						isoform_idx += 1
						attribute = f'gene_id "{self.label}.{loci}"; transcript_id "{self.label}.{loci}.{isoform_idx}"; distance_to_nearset_TSS "{rss_dis}"; full_length_count "{count}"; full_length_PolyA_count "{polya_count}"; alternative_TSS "{as_site}"; alternative_TES "{apa_site}"; isoform_class "{read_tag}";'
						isoform_info = [chrom, 'LAFITE', 'transcript', str(full_block[0]), str(full_block[-1]), '.', strand, '.', attribute]
						fw.write('\t'.join(isoform_info)+'\n')

						exon_idx = 0
						full_block = iter(full_block)
						for exon in zip(full_block, full_block):
							exon_idx += 1
							attribute = f'gene_id "{self.label}.{loci}"; transcript_id "{self.label}.{loci}.{isoform_idx}"; exon_number "{exon_idx}";'
							exon_info = [chrom, 'LAFITE', 'exon', str(exon[0]), str(exon[1]), '.', strand, '.', attribute]
							fw.write('\t'.join(exon_info)+'\n')

idx_rearrange[source]

idx_rearrange(collected_refined_isoforms)

def idx_rearrange(collected_refined_isoforms):
	loci_ID = 0
	for (chrom,strand), refined_isoforms in collected_refined_isoforms.items():
		chrand2loci = defaultdict(dict)
		exon2loci_ID = defaultdict(dict)
		exon_interlap = set()
		for full_block, read_attribute in refined_isoforms.items():
			if len(full_block) > 2:
				if read_attribute.chrand_ID in chrand2loci:
					read_attribute.loci_ID = chrand2loci[read_attribute.chrand_ID]
				else:
					loci_ID += 1
					read_attribute.loci_ID = loci_ID
					chrand2loci[read_attribute.chrand_ID] = loci_ID
				full_block = iter(full_block)
				for exon in zip(full_block, full_block):
					exon2loci_ID[exon] = read_attribute.loci_ID
					exon_interlap.add(exon)
			else:
				if exon_interlap:
					if isinstance(exon_interlap, set):
						t = list(exon_interlap)
						exon_interlap = InterLap()
						exon_interlap.update(t)

					overlap_ref = tuple(exon_interlap.find(full_block))
					if overlap_ref:
						read_attribute.loci_ID = exon2loci_ID[overlap_ref[0]]
				if not read_attribute.loci_ID:
					loci_ID += 1
					read_attribute.loci_ID = loci_ID
	return collected_refined_isoforms

assembly_output[source]

assembly_output(collected_refined_isoforms, output, label, relative_abundance_threshold)

def assembly_output(collected_refined_isoforms, output, label, relative_abundance_threshold):
	for (chrom, strand), refined_isoforms in collected_refined_isoforms.items():
		df = pd.DataFrame([(k,) + tuple(v.__dict__.values()) for k, v in refined_isoforms.items()], columns = ['full_block'] + list(v.__dict__.keys()))
		df = df.set_index('loci_ID')
		for loci in np.unique(df.index.values):
			tmp = df.loc[[loci]].copy()
			tmp['abundance'] = tmp['count']/tmp['count'].sum()
			tmp = tmp[(tmp['abundance'] > relative_abundance_threshold)|(tmp['read_tag'].str.contains("Reference"))]
			tmp.sort_values('start', ascending = True, inplace = True)

			isoform_idx = 0
			for index, row in tmp.iterrows():
				full_block, start, end, count, polya_count, fsm, polyaed, as_site, apa_site, collasped_name, rss_dis, read_tag, processed, chrand_ID, abundance  = list(row)
				if as_site:
					as_site = ','.join([str(x) for x in as_site])
					apa_site = ','.join([str(x) for x in apa_site])
				else:
					as_site = str(start)
					apa_site = str(end)
				isoform_idx += 1
				attribute = f'gene_id "{label}.{loci}"; transcript_id "{label}.{loci}.{isoform_idx}"; distance_to_nearset_TSS "{rss_dis}"; full_length_count "{count}"; full_length_PolyA_count "{polya_count}"; alternative_TSS "{as_site}"; alternative_TES "{apa_site}"; isoform_class "{read_tag}";'
				isoform_info = [chrom, 'LAFITE', 'transcript', str(start), str(end), '.', strand, '.', attribute]
				print('\t'.join(isoform_info)+'\n')

				exon_idx = 0
				full_block = iter(full_block)
				for exon in zip(full_block, full_block):
					exon_idx += 1
					attribute = f'gene_id "{label}.{loci}"; transcript_id "{label}.{loci}.{isoform_idx}"; exon_number "{exon_idx}";'
					exon_info = [chrom, 'LAFITE', 'exon', str(exon[0]), str(exon[1]), '.', strand, '.', attribute]
					print('\t'.join(exon_info)+'\n')