from collections import defaultdict
from interlap import InterLap
from joblib import Parallel, delayed
from LAFITE.utils import Vividict

gtf_line_split[source]

gtf_line_split(entry)

split gtf line

def gtf_line_split (entry):
	"""split gtf line
	"""
	lst = entry.rstrip().split('\t')
	chrom, source, feature, start, end, score, strand, frame, attribute = lst
	start = int(start)
	end = int(end)

	return chrom, source, feature, start, end, score, strand, frame, attribute
def gtf2splicing(gtf, transcript_ref = 'transcript_id', gene_ref = 'gene_id', keepAttribute=False, mergeMode =False):
	"""preprocess the gtf file to gene and isoform level
	"""
	isoform_structure_dict = defaultdict(dict)
	gene_structure_dict = defaultdict(dict)
	with open (gtf) as f:
		for line in f:
			if not line.startswith('#'):
				chrom, source, feature, start, end, score, strand, frame, attributes = gtf_line_split(line)
				
				try:
					attributes = [a.strip() for a in attributes.strip(';').split('; ') if len(a)>0]
					if mergeMode:
						attributes = [a for a in attributes if any(i in a for i in [gene_ref, transcript_ref])]
					attributes = dict([a.replace('"','').split(' ',1)[0:2] for a in attributes])

				except:
					raise ValueError("Fatal: please check input GTF format!\n")

				if keepAttribute:
					if feature in ('transcript', 'mRNA'):
						gene_id = attributes[gene_ref]
						transcript_id  = attributes[transcript_ref]
						isoform_structure_dict[(chrom, strand, transcript_id, gene_id)] = [attributes]

				if feature == "exon":
					gene_id = attributes[gene_ref]
					transcript_id  = attributes[transcript_ref]

					if (chrom, strand, gene_id) in gene_structure_dict:
						gene_structure_dict[(chrom, strand, gene_id)].update([start,end])
					else:
						gene_structure_dict[(chrom, strand, gene_id)] = set([start,end])

					if (chrom, strand, transcript_id, gene_id) in isoform_structure_dict:
						isoform_structure_dict[(chrom, strand, transcript_id, gene_id)].extend([start,end])
					else:
						isoform_structure_dict[(chrom, strand, transcript_id, gene_id)] = [start,end]

	return gene_structure_dict, isoform_structure_dict

gtf2splicing[source]

gtf2splicing(gtf, transcript_ref='transcript_id', keepAttribute=False)

preprocess the gtf file to gene and isoform level

def gtf2splicing(gtf, transcript_ref = 'transcript_id', keepAttribute=False):
	"""preprocess the gtf file to gene and isoform level
	"""
	isoform_structure_dict = Vividict()
	with open (gtf) as f:
		for line in f:
			if not line.startswith('#'):
				chrom, source, feature, start, end, score, strand, frame, attributes = gtf_line_split(line)
				
				try:
					attributes = [a.strip().replace('"','') for a in attributes.strip(';').split('"; ') if len(a)>0]
					attributes = dict([a.split(' ',1)[0:2] for a in attributes])

				except:
					raise ValueError("Fatal: please check input GTF format\n")

				if keepAttribute:
					if feature in ('transcript', 'mRNA'):
						transcript_id  = attributes[transcript_ref]
						isoform_structure_dict[(chrom, strand)][transcript_id] = [attributes]

				if feature == "exon":
					transcript_id  = attributes[transcript_ref]
					if transcript_id in isoform_structure_dict[(chrom, strand)]:
						isoform_structure_dict[(chrom, strand)][transcript_id].extend([start,end])
					else:
						isoform_structure_dict[(chrom, strand)][transcript_id] = [start,end]

	return isoform_structure_dict
class RefAnnotationExtraction:
	"""
	extraction the splicing/exon information from the reference annotation
	"""

	def __init__(self, ref_gtf):
		self.ref_gtf = ref_gtf

	def annotation_extraction(self):

		ref_exon = defaultdict(set) # reference exon start and end sites, the key will be chromosome, strand
		ref_junction = defaultdict(set) # splicing junction from reference annotation
		ref_mutple_exon_trans = Vividict() # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
		ref_single_exon_trans = defaultdict(set) # splicing list for every chromosome, but only for single exon transcript
		left_sj_set = defaultdict(set)
		right_sj_set = defaultdict(set)
		tss_dict = defaultdict(set) #  transcript start sites set

		gene_structure_dict, ref_trans_structure_dict = gtf2splicing(self.ref_gtf)

		for isoform, splicing_lst in ref_trans_structure_dict.items():
			chrom, strand, transcript_id, gene_id = isoform
			splicing_lst.sort()
			read_splicing = splicing_lst[1:-1]
			start, end = splicing_lst[0], splicing_lst[-1]

			# record the TSS
			tss_dict[(chrom, strand)].add(start) if strand == '+' else tss_dict[(chrom, strand)].add(end)

			if not read_splicing:
				# record single exon trans start and end site
				ref_single_exon_trans[(chrom, strand)].add((start,end))
			else:
				# record the splicing site, splicing junction, exon for multi exon trans
				splicing_lst = iter(splicing_lst)
				for idx, exon in enumerate(zip(splicing_lst,splicing_lst)):
					ref_exon[(chrom, strand)].add(exon)
					if idx > 0:
						right_sj = exon[0]
						left_sj_set[(chrom, strand)].add(left_sj)
						right_sj_set[(chrom, strand)].add(right_sj)
						ref_junction[(chrom, strand)].add((left_sj,right_sj))
					left_sj = exon[1]

				# record the splicing structure and the tss and tes for multi exon trans
				if strand == "-":
				# 	read_splicing.reverse()
					se_site = [end, start]
				else:
					se_site = [start, end]

				ref_mutple_exon_trans[(chrom, strand)][tuple(read_splicing)] = se_site

		return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict

	def annotation_sorting(self):

		ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = self.annotation_extraction()

		for i in tss_dict:
			tss_dict[i] = sorted(tss_dict[i])

		for i in ref_mutple_exon_trans:
			# sort multi exon transcript by splicing junction number
			tmp_dict = {}
			for k in sorted(ref_mutple_exon_trans[i], key=len, reverse=True):
				tmp_dict[k] = ref_mutple_exon_trans[i][k]
			ref_mutple_exon_trans[i] = tmp_dict

		for i in left_sj_set:
			left_sj_set[i] = sorted(left_sj_set[i])
		for i in right_sj_set:
			right_sj_set[i] = sorted(right_sj_set[i])

		# covert to interlap data format
		for i in ref_single_exon_trans:
			t = list(ref_single_exon_trans[i])
			ref_single_exon_trans[i] = InterLap()
			ref_single_exon_trans[i].update(t)

		for i in ref_exon:
			t = list(ref_exon[i])
			ref_exon[i] = InterLap()
			ref_exon[i].update(t)

		return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict

class RefAnnotationExtraction[source]

RefAnnotationExtraction(chrom, strand, chrand_ref_trans_structure_dict)

extraction the splicing/exon information from the reference annotation

class RefAnnotationExtraction:
	"""
	extraction the splicing/exon information from the reference annotation
	"""

	def __init__(self, chrom, strand, chrand_ref_trans_structure_dict):
		self.chrom = chrom
		self.strand = strand
		self.chrand_ref_trans_structure_dict = chrand_ref_trans_structure_dict

	def annotation_extraction(self):

		chrand_ref_exon = set() # reference exon start and end sites, the key will be chromosome, strand
		chrand_ref_junction = set() # splicing junction from reference annotation
		chrand_ref_mutple_exon_trans = defaultdict(dict) # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
		chrand_ref_single_exon_trans = set() # splicing list for every chromosome, but only for single exon transcript
		chrand_left_sj_set = set()
		chrand_right_sj_set = set()
		chrand_tss_dict = set() #  transcript start sites set

		for isoform, full_block in self.chrand_ref_trans_structure_dict.items():
			full_block.sort()
			iso_splicing = full_block[1:-1]
			start, end = full_block[0], full_block[-1]

			# record the TSS
			chrand_tss_dict.add(start) if self.strand == '+' else chrand_tss_dict.add(end)

			if not iso_splicing:
				# record single exon trans start and end site
				chrand_ref_single_exon_trans.add((start,end))
			else:
				# record the splicing site, splicing junction, exon for multi exon trans
				full_block = iter(full_block)
				for idx, exon in enumerate(zip(full_block,full_block)):
					chrand_ref_exon.add(exon)
					if idx > 0:
						right_sj = exon[0]
						chrand_left_sj_set.add(left_sj)
						chrand_right_sj_set.add(right_sj)
						chrand_ref_junction.add((left_sj,right_sj))
					left_sj = exon[1]

				# record the splicing structure and the tss and tes for multi exon trans
				if self.strand == '-':
					se_site = [end, start]
				else:
					se_site = [start, end]
				chrand_ref_mutple_exon_trans[tuple(iso_splicing)] = se_site

		return chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict

	def annotation_sorting(self):

		chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict = self.annotation_extraction()

		chrand_tss_dict = sorted(chrand_tss_dict)
		# sort multi exon transcript by splicing junction number
		chrand_ref_mutple_exon_trans = dict(sorted(chrand_ref_mutple_exon_trans.items(), key=lambda d: len(d[0]), reverse=True))

		chrand_left_sj_set = sorted(chrand_left_sj_set)
		chrand_right_sj_set = sorted(chrand_right_sj_set)

		# covert to interlap data format
		if chrand_ref_single_exon_trans:
			t = list(chrand_ref_single_exon_trans)
			chrand_ref_single_exon_trans = InterLap()
			chrand_ref_single_exon_trans.update(t)

		if chrand_ref_exon:
			t = list(chrand_ref_exon)
			chrand_ref_exon = InterLap()
			chrand_ref_exon.update(t)

		return self.chrom, self.strand, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict

class RefProcessWrapper[source]

RefProcessWrapper(ref_gtf, thread)

class RefProcessWrapper:
	def __init__(self, ref_gtf, thread):
		self.ref_gtf = ref_gtf
		self.thread = thread
	
	def process(self):
		preprocess_lst = []
		ref_trans_structure_dict = gtf2splicing(self.ref_gtf)
		for (chrom, strand), chrand_ref_trans_structure_dict in ref_trans_structure_dict.items():
			preprocess_lst.append(RefAnnotationExtraction(chrom, strand, chrand_ref_trans_structure_dict))
		with Parallel(n_jobs = self.thread) as parallel:
			results = parallel(delayed(lambda x:x.annotation_sorting())(job) for job in preprocess_lst)

		return results
	
	def result_collection(self):
		results = self.process()
		ref_exon = defaultdict(dict) # reference exon start and end sites, the key will be chromosome, strand
		ref_junction = defaultdict(dict) # splicing junction from reference annotation
		ref_mutple_exon_trans = Vividict() # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
		ref_single_exon_trans = defaultdict(dict) # splicing list for every chromosome, but only for single exon transcript
		left_sj_set = defaultdict(dict)
		right_sj_set = defaultdict(dict)
		tss_dict = defaultdict(dict) #  transcript start sites set

		for result in results:
			chrom, strand, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict = result
			
			tss_dict[(chrom,strand)] = chrand_tss_dict
			if chrand_ref_exon:
				ref_exon[(chrom,strand)] = chrand_ref_exon
			if chrand_ref_junction:
				ref_junction[(chrom,strand)] = chrand_ref_junction
			if chrand_ref_mutple_exon_trans:
				ref_mutple_exon_trans[(chrom,strand)] = chrand_ref_mutple_exon_trans
			if chrand_ref_single_exon_trans:
				ref_single_exon_trans[(chrom,strand)] = chrand_ref_single_exon_trans
			if chrand_left_sj_set:
				left_sj_set[(chrom,strand)] = chrand_left_sj_set
			if chrand_right_sj_set:
				right_sj_set[(chrom,strand)] = chrand_right_sj_set

		return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict

short_reads_sj_import[source]

short_reads_sj_import(sj_tab, left_sj_set, right_sj_set)

import the splicing junctions detected from short reads data (STAR SJ_tab)

def short_reads_sj_import(sj_tab, left_sj_set, right_sj_set):
	"""import the splicing junctions detected from short reads data (STAR SJ_tab)
	"""
	for file in sj_tab:
		with open (file) as f:
			for line in f:
				line = line.strip('\n').split("\t")
				chrom = line[0]
				if line[3] in ['+', '1']:
					strand = '+'
				else:
					strand = '-'
				left_sj = int(line[1])-1
				right_sj = int(line[2])+1
				left_sj_set[(chrom, strand)].add(left_sj)
				right_sj_set[(chrom, strand)].add(right_sj)

	for i in left_sj_set:
		left_sj_set[i] = sorted(left_sj_set[i])
	for i in right_sj_set:
		right_sj_set[i] = sorted(right_sj_set[i])
	
	return left_sj_set, right_sj_set

cage_tss_import[source]

cage_tss_import(cage_tss, tss_dict)

import the splicing junctions detected from short reads data (STAR SJ_tab)

def cage_tss_import(cage_tss, tss_dict):
	"""import the splicing junctions detected from short reads data (STAR SJ_tab)
	"""
	for file in cage_tss:
		with open (file) as f:
			for line in f:
				line = line.strip('\n').split("\t")
				chrom, strand = line[0], line[3]
				tss_center = round(abs(int(line[2])-int(line[1]))/2 + 1)
				if tss_center not in tss_dict[(chrom, strand)]:
					tss_dict[(chrom, strand)].add(tss_center)

	for i in tss_dict:
		tss_dict[i] = sorted(tss_dict[i])

	return tss_dict