from __future__ import annotations
import logging
import pathlib
from collections import Counter, namedtuple
from relion._parser.jobtype import JobType
logger = logging.getLogger("relion._parser.class3D")
Class3DParticleClass = namedtuple(
"Class3DParticleClass",
[
"particle_sum",
"reference_image",
"class_distribution",
"accuracy_rotations",
"accuracy_translations_angst",
"estimated_resolution",
"overall_fourier_completeness",
"initial_model_num_particles",
"job",
],
)
Class3DParticleClass.__doc__ = "3D Classification stage."
Class3DParticleClass.particle_sum.__doc__ = "Sum of all particles in the class. Gives a tuple with the class number first, then the particle sum."
Class3DParticleClass.reference_image.__doc__ = "Reference image."
Class3DParticleClass.class_distribution.__doc__ = (
"Class Distribution. Proportional to the number of particles per class."
)
Class3DParticleClass.accuracy_rotations.__doc__ = "Accuracy rotations."
Class3DParticleClass.accuracy_translations_angst.__doc__ = (
"Accuracy translations angst."
)
Class3DParticleClass.estimated_resolution.__doc__ = "Estimated resolution."
Class3DParticleClass.overall_fourier_completeness.__doc__ = (
"Overall Fourier completeness."
)
Class3DParticleClass.initial_model_num_particles.__doc__ = (
"The number of particles used to generate the initial model."
)
Class3DParticleClass.job.__doc__ = "Job number of the Class3D job."
class Class3D(JobType):
def __eq__(self, other):
if isinstance(other, Class3D): # check this
return self._basepath == other._basepath
return False
def __hash__(self):
return hash(("relion._parser.Class3D", self._basepath))
def __repr__(self):
return f"Class3D({repr(str(self._basepath))})"
def __str__(self):
return f"<Class3D parser at {self._basepath}>"
@property
def job_number(self):
jobs = sorted(x.name for x in self._basepath.iterdir())
return jobs
def _load_job_directory(self, jobdir):
try:
dfile, mfile = self._final_data_and_model(jobdir)
except ValueError as e:
logger.debug(
f"The exception {e} was caught while trying to get data and model files. Returning an empty list",
exc_info=True,
)
return []
try:
sdfile = self._read_star_file(jobdir, dfile)
smfile = self._read_star_file(jobdir, mfile)
except (FileNotFoundError, OSError, RuntimeError, ValueError):
logger.debug(
"gemmi could not open file while trying to get data and model files. Returning an empty list",
exc_info=True,
)
return []
info_table = self._find_table_from_column_name("_rlnClassDistribution", smfile)
if info_table is None:
logger.debug(f"_rlnClassDistribution not found in file {mfile}")
return []
class_distribution = self.parse_star_file(
"_rlnClassDistribution", smfile, info_table
)
accuracy_rotations = self.parse_star_file(
"_rlnAccuracyRotations", smfile, info_table
)
accuracy_translations_angst = self.parse_star_file(
"_rlnAccuracyTranslationsAngst", smfile, info_table
)
estimated_resolution = self.parse_star_file(
"_rlnEstimatedResolution", smfile, info_table
)
overall_fourier_completeness = self.parse_star_file(
"_rlnOverallFourierCompleteness", smfile, info_table
)
reference_image = self.parse_star_file("_rlnReferenceImage", smfile, info_table)
class_numbers = self.parse_star_file("_rlnClassNumber", sdfile, info_table)
particle_sum = self._sum_all_particles(class_numbers)
int_particle_sum = [(int(name), value) for name, value in particle_sum.items()]
# something probably went wrong with file reading if this is the case
# return empty list and hope to recover later
if len(int_particle_sum) == 0:
return []
try:
checked_particle_list = self._class_checker(
sorted(int_particle_sum), len(reference_image)
)
except IndexError:
logger.debug(
f"IndexErorr encountered in _class_checker for {jobdir}", exc_info=True
)
return []
try:
init_nodel_num_particles = self._get_init_model_num_particles(
jobdir, "job.star"
)
except (RuntimeError, FileNotFoundError, OSError, ValueError):
logger.debug(f"Encountered error trying to read {jobdir}/job.star")
return []
if len(reference_image) != len(checked_particle_list):
logger.debug(
f"Number of reference images did not match number of classes for {jobdir}"
)
particle_class_list = []
try:
for j in range(len(reference_image)):
particle_class_list.append(
Class3DParticleClass(
checked_particle_list[j],
str(self._basepath.parent / reference_image[j]),
float(class_distribution[j]),
accuracy_rotations[j],
accuracy_translations_angst[j],
estimated_resolution[j],
overall_fourier_completeness[j],
init_nodel_num_particles,
jobdir,
)
)
except IndexError:
logger.debug(
"An IndexError was encountered while collecting 3D classification data: there was possibly a mismatch between data from different files"
)
return particle_class_list
def _get_init_model_num_particles(self, jobdir, param_file_name):
paramfile = self._read_star_file(jobdir, param_file_name)
info_table = self._find_table_from_column_name(
"_rlnJobOptionVariable", paramfile
)
variables = [
p.strip("'")
for p in self.parse_star_file(
"_rlnJobOptionVariable", paramfile, info_table
)
]
ini_model_index = variables.index("fn_ref")
ini_model_path = pathlib.Path(
self.parse_star_file("_rlnJobOptionValue", paramfile, info_table)[
ini_model_index
].strip("'")
)
# this string maniuplation is bad, I'm sorry
model_file_class_split = str(ini_model_path.name).split("_")
model_file_class_split = [p.replace("'", "") for p in model_file_class_split]
for sindex, sect in enumerate(model_file_class_split):
if "class" in sect:
model_file_class = sect.split(".")[0].replace("class", "")
remainder = model_file_class_split[sindex + 1 :]
# drop suffix
try:
remainder[-1] = "".join(remainder[-1].split(".")[:-1])
except IndexError:
pass
break
else:
return
model_info_name = (
str(ini_model_path.name)
.replace(
"class" + model_file_class + "".join(["_" + r for r in remainder if r]),
"data",
)
.replace("mrc", "star")
)
model_info_file = self._read_star_file_from_proj_dir(
ini_model_path.parent, model_info_name
)
info_table = self._find_table_from_column_name(
"_rlnClassNumber", model_info_file
)
# this str(int()) thing strips the 0s off of model_file_class
# should be faster than converting everything in num_particles_in_class to int
# there's probably a better way
num_particles_in_class = self.parse_star_file(
"_rlnClassNumber", model_info_file, info_table
).count(str(int(model_file_class)))
return num_particles_in_class
def _final_data_and_model(self, job_path):
number_list = [
entry.stem[6:9]
for entry in (self._basepath / job_path).glob("run_it*.star")
]
last_iteration_number = max(
(int(n) for n in number_list if n.isnumeric()), default=0
)
if not last_iteration_number:
raise ValueError(f"No result files found in {job_path}")
data_file = f"run_it{last_iteration_number:03d}_data.star"
model_file = f"run_it{last_iteration_number:03d}_model.star"
for check_file in (
self._basepath / job_path / data_file,
self._basepath / job_path / model_file,
):
if not check_file.exists():
raise ValueError(f"File {check_file} missing from job directory")
return data_file, model_file
def _class_checker(
self, tuple_list, length
): # Makes sure every class has a number of associated particles
if not tuple_list:
raise IndexError
for i in range(1, length + 1):
try:
if i not in tuple_list[i - 1]:
tuple_list.insert(i - 1, (i, 0))
# print("No values found for class", i)
except IndexError:
tuple_list.insert(i - 1, (i, 0))
# print("No values found for class", i)
return tuple_list
def _count_all(self, list):
count = Counter(list)
return count
def _sum_all_particles(self, list):
counted = self._count_all(list)
return counted
@staticmethod
def db_unpack(particle_class):
res = [
{
"type": "3D",
"class_number": cl.particle_sum[0],
"particles_per_class": cl.particle_sum[1],
"rotation_accuracy": cl.accuracy_rotations,
"translation_accuracy": cl.accuracy_translations_angst,
"estimated_resolution": cl.estimated_resolution,
"overall_fourier_completeness": cl.overall_fourier_completeness,
"job_string": cl.job,
"class_distribution": cl.class_distribution,
"class_image_full_path": cl.reference_image,
}
for cl in particle_class
]
return res