from collections import defaultdict
from interlap import InterLap
from joblib import Parallel, delayed
from LAFITE.utils import Vividict
def gtf_line_split (entry):
"""split gtf line
"""
lst = entry.rstrip().split('\t')
chrom, source, feature, start, end, score, strand, frame, attribute = lst
start = int(start)
end = int(end)
return chrom, source, feature, start, end, score, strand, frame, attribute
def gtf2splicing(gtf, transcript_ref = 'transcript_id', gene_ref = 'gene_id', keepAttribute=False, mergeMode =False):
"""preprocess the gtf file to gene and isoform level
"""
isoform_structure_dict = defaultdict(dict)
gene_structure_dict = defaultdict(dict)
with open (gtf) as f:
for line in f:
if not line.startswith('#'):
chrom, source, feature, start, end, score, strand, frame, attributes = gtf_line_split(line)
try:
attributes = [a.strip() for a in attributes.strip(';').split('; ') if len(a)>0]
if mergeMode:
attributes = [a for a in attributes if any(i in a for i in [gene_ref, transcript_ref])]
attributes = dict([a.replace('"','').split(' ',1)[0:2] for a in attributes])
except:
raise ValueError("Fatal: please check input GTF format!\n")
if keepAttribute:
if feature in ('transcript', 'mRNA'):
gene_id = attributes[gene_ref]
transcript_id = attributes[transcript_ref]
isoform_structure_dict[(chrom, strand, transcript_id, gene_id)] = [attributes]
if feature == "exon":
gene_id = attributes[gene_ref]
transcript_id = attributes[transcript_ref]
if (chrom, strand, gene_id) in gene_structure_dict:
gene_structure_dict[(chrom, strand, gene_id)].update([start,end])
else:
gene_structure_dict[(chrom, strand, gene_id)] = set([start,end])
if (chrom, strand, transcript_id, gene_id) in isoform_structure_dict:
isoform_structure_dict[(chrom, strand, transcript_id, gene_id)].extend([start,end])
else:
isoform_structure_dict[(chrom, strand, transcript_id, gene_id)] = [start,end]
return gene_structure_dict, isoform_structure_dict
def gtf2splicing(gtf, transcript_ref = 'transcript_id', keepAttribute=False):
"""preprocess the gtf file to gene and isoform level
"""
isoform_structure_dict = Vividict()
with open (gtf) as f:
for line in f:
if not line.startswith('#'):
chrom, source, feature, start, end, score, strand, frame, attributes = gtf_line_split(line)
try:
attributes = [a.strip().replace('"','') for a in attributes.strip(';').split('"; ') if len(a)>0]
attributes = dict([a.split(' ',1)[0:2] for a in attributes])
except:
raise ValueError("Fatal: please check input GTF format\n")
if keepAttribute:
if feature in ('transcript', 'mRNA'):
transcript_id = attributes[transcript_ref]
isoform_structure_dict[(chrom, strand)][transcript_id] = [attributes]
if feature == "exon":
transcript_id = attributes[transcript_ref]
if transcript_id in isoform_structure_dict[(chrom, strand)]:
isoform_structure_dict[(chrom, strand)][transcript_id].extend([start,end])
else:
isoform_structure_dict[(chrom, strand)][transcript_id] = [start,end]
return isoform_structure_dict
class RefAnnotationExtraction:
"""
extraction the splicing/exon information from the reference annotation
"""
def __init__(self, ref_gtf):
self.ref_gtf = ref_gtf
def annotation_extraction(self):
ref_exon = defaultdict(set) # reference exon start and end sites, the key will be chromosome, strand
ref_junction = defaultdict(set) # splicing junction from reference annotation
ref_mutple_exon_trans = Vividict() # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
ref_single_exon_trans = defaultdict(set) # splicing list for every chromosome, but only for single exon transcript
left_sj_set = defaultdict(set)
right_sj_set = defaultdict(set)
tss_dict = defaultdict(set) # transcript start sites set
gene_structure_dict, ref_trans_structure_dict = gtf2splicing(self.ref_gtf)
for isoform, splicing_lst in ref_trans_structure_dict.items():
chrom, strand, transcript_id, gene_id = isoform
splicing_lst.sort()
read_splicing = splicing_lst[1:-1]
start, end = splicing_lst[0], splicing_lst[-1]
# record the TSS
tss_dict[(chrom, strand)].add(start) if strand == '+' else tss_dict[(chrom, strand)].add(end)
if not read_splicing:
# record single exon trans start and end site
ref_single_exon_trans[(chrom, strand)].add((start,end))
else:
# record the splicing site, splicing junction, exon for multi exon trans
splicing_lst = iter(splicing_lst)
for idx, exon in enumerate(zip(splicing_lst,splicing_lst)):
ref_exon[(chrom, strand)].add(exon)
if idx > 0:
right_sj = exon[0]
left_sj_set[(chrom, strand)].add(left_sj)
right_sj_set[(chrom, strand)].add(right_sj)
ref_junction[(chrom, strand)].add((left_sj,right_sj))
left_sj = exon[1]
# record the splicing structure and the tss and tes for multi exon trans
if strand == "-":
# read_splicing.reverse()
se_site = [end, start]
else:
se_site = [start, end]
ref_mutple_exon_trans[(chrom, strand)][tuple(read_splicing)] = se_site
return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict
def annotation_sorting(self):
ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict = self.annotation_extraction()
for i in tss_dict:
tss_dict[i] = sorted(tss_dict[i])
for i in ref_mutple_exon_trans:
# sort multi exon transcript by splicing junction number
tmp_dict = {}
for k in sorted(ref_mutple_exon_trans[i], key=len, reverse=True):
tmp_dict[k] = ref_mutple_exon_trans[i][k]
ref_mutple_exon_trans[i] = tmp_dict
for i in left_sj_set:
left_sj_set[i] = sorted(left_sj_set[i])
for i in right_sj_set:
right_sj_set[i] = sorted(right_sj_set[i])
# covert to interlap data format
for i in ref_single_exon_trans:
t = list(ref_single_exon_trans[i])
ref_single_exon_trans[i] = InterLap()
ref_single_exon_trans[i].update(t)
for i in ref_exon:
t = list(ref_exon[i])
ref_exon[i] = InterLap()
ref_exon[i].update(t)
return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict
class RefAnnotationExtraction:
"""
extraction the splicing/exon information from the reference annotation
"""
def __init__(self, chrom, strand, chrand_ref_trans_structure_dict):
self.chrom = chrom
self.strand = strand
self.chrand_ref_trans_structure_dict = chrand_ref_trans_structure_dict
def annotation_extraction(self):
chrand_ref_exon = set() # reference exon start and end sites, the key will be chromosome, strand
chrand_ref_junction = set() # splicing junction from reference annotation
chrand_ref_mutple_exon_trans = defaultdict(dict) # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
chrand_ref_single_exon_trans = set() # splicing list for every chromosome, but only for single exon transcript
chrand_left_sj_set = set()
chrand_right_sj_set = set()
chrand_tss_dict = set() # transcript start sites set
for isoform, full_block in self.chrand_ref_trans_structure_dict.items():
full_block.sort()
iso_splicing = full_block[1:-1]
start, end = full_block[0], full_block[-1]
# record the TSS
chrand_tss_dict.add(start) if self.strand == '+' else chrand_tss_dict.add(end)
if not iso_splicing:
# record single exon trans start and end site
chrand_ref_single_exon_trans.add((start,end))
else:
# record the splicing site, splicing junction, exon for multi exon trans
full_block = iter(full_block)
for idx, exon in enumerate(zip(full_block,full_block)):
chrand_ref_exon.add(exon)
if idx > 0:
right_sj = exon[0]
chrand_left_sj_set.add(left_sj)
chrand_right_sj_set.add(right_sj)
chrand_ref_junction.add((left_sj,right_sj))
left_sj = exon[1]
# record the splicing structure and the tss and tes for multi exon trans
if self.strand == '-':
se_site = [end, start]
else:
se_site = [start, end]
chrand_ref_mutple_exon_trans[tuple(iso_splicing)] = se_site
return chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict
def annotation_sorting(self):
chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict = self.annotation_extraction()
chrand_tss_dict = sorted(chrand_tss_dict)
# sort multi exon transcript by splicing junction number
chrand_ref_mutple_exon_trans = dict(sorted(chrand_ref_mutple_exon_trans.items(), key=lambda d: len(d[0]), reverse=True))
chrand_left_sj_set = sorted(chrand_left_sj_set)
chrand_right_sj_set = sorted(chrand_right_sj_set)
# covert to interlap data format
if chrand_ref_single_exon_trans:
t = list(chrand_ref_single_exon_trans)
chrand_ref_single_exon_trans = InterLap()
chrand_ref_single_exon_trans.update(t)
if chrand_ref_exon:
t = list(chrand_ref_exon)
chrand_ref_exon = InterLap()
chrand_ref_exon.update(t)
return self.chrom, self.strand, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict
class RefProcessWrapper:
def __init__(self, ref_gtf, thread):
self.ref_gtf = ref_gtf
self.thread = thread
def process(self):
preprocess_lst = []
ref_trans_structure_dict = gtf2splicing(self.ref_gtf)
for (chrom, strand), chrand_ref_trans_structure_dict in ref_trans_structure_dict.items():
preprocess_lst.append(RefAnnotationExtraction(chrom, strand, chrand_ref_trans_structure_dict))
with Parallel(n_jobs = self.thread) as parallel:
results = parallel(delayed(lambda x:x.annotation_sorting())(job) for job in preprocess_lst)
return results
def result_collection(self):
results = self.process()
ref_exon = defaultdict(dict) # reference exon start and end sites, the key will be chromosome, strand
ref_junction = defaultdict(dict) # splicing junction from reference annotation
ref_mutple_exon_trans = Vividict() # splicing list for every chromosome, e.g. {'chr1,+':[[1,2,3,4,5],[111,223,44]], 'chr1,-':[[xxx,xxx,xxx],[xxx,xx,xxxx]]}
ref_single_exon_trans = defaultdict(dict) # splicing list for every chromosome, but only for single exon transcript
left_sj_set = defaultdict(dict)
right_sj_set = defaultdict(dict)
tss_dict = defaultdict(dict) # transcript start sites set
for result in results:
chrom, strand, chrand_ref_exon, chrand_ref_junction, chrand_ref_single_exon_trans, chrand_ref_mutple_exon_trans, chrand_left_sj_set, chrand_right_sj_set, chrand_tss_dict = result
tss_dict[(chrom,strand)] = chrand_tss_dict
if chrand_ref_exon:
ref_exon[(chrom,strand)] = chrand_ref_exon
if chrand_ref_junction:
ref_junction[(chrom,strand)] = chrand_ref_junction
if chrand_ref_mutple_exon_trans:
ref_mutple_exon_trans[(chrom,strand)] = chrand_ref_mutple_exon_trans
if chrand_ref_single_exon_trans:
ref_single_exon_trans[(chrom,strand)] = chrand_ref_single_exon_trans
if chrand_left_sj_set:
left_sj_set[(chrom,strand)] = chrand_left_sj_set
if chrand_right_sj_set:
right_sj_set[(chrom,strand)] = chrand_right_sj_set
return ref_exon, ref_junction, ref_single_exon_trans, ref_mutple_exon_trans, left_sj_set, right_sj_set, tss_dict
def short_reads_sj_import(sj_tab, left_sj_set, right_sj_set):
"""import the splicing junctions detected from short reads data (STAR SJ_tab)
"""
for file in sj_tab:
with open (file) as f:
for line in f:
line = line.strip('\n').split("\t")
chrom = line[0]
if line[3] in ['+', '1']:
strand = '+'
else:
strand = '-'
left_sj = int(line[1])-1
right_sj = int(line[2])+1
left_sj_set[(chrom, strand)].add(left_sj)
right_sj_set[(chrom, strand)].add(right_sj)
for i in left_sj_set:
left_sj_set[i] = sorted(left_sj_set[i])
for i in right_sj_set:
right_sj_set[i] = sorted(right_sj_set[i])
return left_sj_set, right_sj_set
def cage_tss_import(cage_tss, tss_dict):
"""import the splicing junctions detected from short reads data (STAR SJ_tab)
"""
for file in cage_tss:
with open (file) as f:
for line in f:
line = line.strip('\n').split("\t")
chrom, strand = line[0], line[3]
tss_center = round(abs(int(line[2])-int(line[1]))/2 + 1)
if tss_center not in tss_dict[(chrom, strand)]:
tss_dict[(chrom, strand)].add(tss_center)
for i in tss_dict:
tss_dict[i] = sorted(tss_dict[i])
return tss_dict