Commit 30c98dd5 authored by Giovanni La Mura's avatar Giovanni La Mura
Browse files

Enable OBJ model export with group-based color scaling

parent 9d1d7e4d
Loading
Loading
Loading
Loading
+43 −25
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ except ModuleNotFoundError as ex:
    print("WARNING: pyvista not found!")
    allow_3d = False

from pathlib import PurePath
from pathlib import Path
from sys import argv

## \brief Main execution code
@@ -69,7 +69,7 @@ def interpolate_constants(sconf):
    for i in range(sconf['configurations']):
        for j in range(sconf['nshl'][i]):
            file_idx = sconf['dielec_id'][i][j]
            dielec_path = PurePath(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
            dielec_path = Path(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
            file_name = str(dielec_path)
            dielec_file = open(file_name, 'r')
            wavelengths = []
@@ -149,7 +149,7 @@ def load_model(model_file):
            make_3d = False
        # Create the sconf dict
        sconf = {
            'out_file': PurePath(
            'out_file': Path(
                model['input_settings']['input_folder'],
                model['input_settings']['spheres_file']
            )
@@ -316,7 +316,7 @@ def load_model(model_file):
            print("ERROR: %s is not a recognized polarization state."%str_polar)
            return (None, None)
        gconf = {
            'out_file': PurePath(
            'out_file': Path(
                model['input_settings']['input_folder'],
                model['input_settings']['geometry_file']
            )
@@ -404,7 +404,7 @@ def match_grid(sconf):
            layers += 1
        for j in range(layers):
            file_idx = sconf['dielec_id'][i][j]
            dielec_path = PurePath(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
            dielec_path = Path(sconf['dielec_path'], sconf['dielec_file'][int(file_idx) - 1])
            file_name = str(dielec_path)
            dielec_file = open(file_name, 'r')
            wavelengths = []
@@ -784,6 +784,9 @@ def write_legacy_sconf(conf):
#  \param geometry: `dict` Geometry configuration dictionary (gets modified)
#  \param max_rad: `float` Maximum allowed radial extension of the aggregate
def write_obj(scatterer, geometry, max_rad):
    out_dir = scatterer['out_file'].absolute().parent
    out_model_path = Path(out_dir, "model.obj")
    out_material_path = Path(out_dir, "model.mtl")
    color_strings = [
        "1.0 1.0 1.0\n", # white
        "1.0 0.0 0.0\n", # red
@@ -793,9 +796,9 @@ def write_obj(scatterer, geometry, max_rad):
    color_names = [
        "white", "red", "blue", "green"
    ]
    mtl_file = open("model.mtl", "w")
    mtl_file = open(str(out_material_path), "w")
    for mi in range(len(color_strings)):
        mtl_line = "newmtl mtl{0:d}\n".format(mi)
        mtl_line = "newmtl "  + color_names[mi] + "\n"
        mtl_file.write(mtl_line)
        color_line = color_strings[mi]
        mtl_file.write("   Ka " + color_line)
@@ -808,28 +811,43 @@ def write_obj(scatterer, geometry, max_rad):
    pl = pv.Plotter()
    for si in range(scatterer['nsph']):
        sph_type_index = scatterer['vec_types'][si]
        color_by_name = color_names[sph_type_index]
        # color_index = 1 + (sph_type_index % (len(color_strings) - 1))
        # color_by_name = color_names[sph_type_index]
        radius = scatterer['ros'][sph_type_index - 1] / max_rad
        x = geometry['vec_sph_x'][si] / max_rad
        y = geometry['vec_sph_y'][si] / max_rad
        z = geometry['vec_sph_z'][si] / max_rad
        mesh = pv.Sphere(radius, (x, y, z))
        mesh.save("tmp_mesh.obj")
        pl.add_mesh(mesh) #, color=color_by_name)
        mesh_name = "sphere_{0:04d}.obj".format(si)
        in_obj_file = open("tmp_mesh.obj", "r")
        out_obj_file = open(mesh_name, "w")
        in_line = in_obj_file.readline()
        out_obj_file.write(in_line)
        out_obj_file.write("mtllib model.mtl\n")
        out_obj_file.write("usemtl mtl{0:d}\n".format(sph_type_index))
        while (in_line != ""):
            in_line = in_obj_file.readline()
            out_obj_file.write(in_line)
        in_obj_file.close()
        out_obj_file.close()
    pl.export_obj("model.obj")
    os.remove("tmp_mesh.obj")
        pl.add_mesh(mesh, color=None)
    pl.export_obj(str(Path(str(out_dir), "TMP_MODEL.obj")))
    tmp_model_file = open(str(Path(str(out_dir), "TMP_MODEL.obj")), "r")
    out_model_file = open(str(Path(str(out_dir), "model.obj")), "w")
    sph_index = 0
    sph_type_index = 0
    old_sph_type_index = 0
    str_line = tmp_model_file.readline()
    while (str_line != ""):
        if (str_line.startswith("mtllib")):
            str_line = "mtllib model.mtl\n"
        elif (str_line.startswith("g ")):
            sph_index += 1
            sph_type_index = scatterer['vec_types'][sph_index - 1]
            if (sph_type_index == old_sph_type_index):
                str_line = tmp_model_file.readline()
                str_line = tmp_model_file.readline()
            else:
                old_sph_type_index = sph_type_index
                color_index = sph_type_index % (len(color_names) - 1)
                str_line = "g grp{0:04d}\n".format(sph_type_index)
                out_model_file.write(str_line)
                str_line = tmp_model_file.readline()
                str_line = "usemtl {0:s}\n".format(color_names[color_index])
        out_model_file.write(str_line)
        str_line = tmp_model_file.readline()
    out_model_file.close()
    tmp_model_file.close()
    os.remove(str(Path(str(out_dir), "TMP_MODEL.obj")))
    os.remove(str(Path(str(out_dir), "TMP_MODEL.mtl")))

## \brief Exit code (0 for success)
exit_code = main()