# -*- coding: utf-8 -*-
from vtkplotlib._get_vtk import vtk
import numpy as np
from pathlib import Path
from vtkplotlib.plots.BasePlot import ConstructedPlot
from vtkplotlib.plots.Lines import Lines
try:
from stl.mesh import Mesh as NumpyMesh
NUMPY_STL_AVAILABLE = True
except ImportError:
NumpyMesh = None
NUMPY_STL_AVAILABLE = False
MESH_DATA_TYPE_EX = lambda msg: ValueError("Invalid mesh_data argument. {}".
format(msg))
def vtk_read_stl(path):
from vtkplotlib.plots.polydata import PolyData
from vtkplotlib.unicode_paths import PathHandler
from vtkplotlib._vtk_errors import handler
with PathHandler(path) as path_handler:
reader = vtk.vtkSTLReader()
handler.attach(reader)
reader.SetFileName(path_handler.access_path)
# Normally Reader doesn't do any reading until it's been plotted.
# Update forces it to read.
reader.Update()
pd = PolyData(reader.GetOutput())
# For some reason VTK just doesn't like some files. There are some vague
# warnings in their docs - this could be what they're on about. If it
# doesn't work ``reader.GetOutput()`` gives an empty polydata.
if pd.vtk_polydata.GetNumberOfPoints() == 0:
raise RuntimeError(
"VTK's STLReader failed to read the STL file and no STL io backend "
"is installed. VTK's STLReader is rather patchy. To read this file "
"please ``pip install numpy-stl`` first.")
return pd
def set_from_path(self, path, ignore_numpystl=False):
# Ideally let numpy-stl open the file if it is installed.
if NUMPY_STL_AVAILABLE and not ignore_numpystl:
self.vectors = NumpyMesh.from_file(path).vectors
return
# Otherwise try vtk's STL reader - however it's not as reliable.
self.polydata = vtk_read_stl(path)
self.connect()
def set_vertices_index_pair(self, mesh_data):
vertices, args = mesh_data
if not isinstance(vertices, np.ndarray):
raise MESH_DATA_TYPE_EX("First argument is of invalid type {}".format(
type(vertices)))
if vertices.shape[1:] != (3,):
raise MESH_DATA_TYPE_EX("First argument has invalid shape {}. Should be"
" (..., 3).".format(vertices.shape))
if not isinstance(args, np.ndarray):
raise MESH_DATA_TYPE_EX("Second argument is of invalid type {}".format(
type(args)))
if args.shape[1:] != (3,):
raise MESH_DATA_TYPE_EX(
"Second argument has invalid shape {}. Should be"
" (n, 3).".format(args.shape))
if args.dtype.kind not in "iu":
raise MESH_DATA_TYPE_EX("Second argument must be an int dtype array")
self.vertices = vertices
self.indices = args
def normalise_mesh_type(self, mesh_data):
"""Try to support as many of the mesh libraries out there as possible
without having all of those libraries as dependencies.
"""
# If string or Path then read from file.
if isinstance(mesh_data, Path):
mesh_data = str(mesh_data)
if isinstance(mesh_data, str):
set_from_path(self, mesh_data)
return
# If in (vertices, indices) format.
if isinstance(mesh_data, tuple) and len(mesh_data) == 2:
set_vertices_index_pair(self, mesh_data)
return
# If already an array then great.
if isinstance(mesh_data, np.ndarray):
vectors = mesh_data
# If a mesh class that has the vectors in mesh.vectors as is conventional.
elif hasattr(mesh_data, "vectors"):
vectors = mesh_data.vectors
else:
raise MESH_DATA_TYPE_EX("")
# Check shapes
if vectors.shape[1:] != (3, 3):
# Sometimes there are extra entries. pymesh has them. No idea why.
vectors = vectors[:, :3, :3]
if vectors.shape[1:] != (3, 3):
raise MESH_DATA_TYPE_EX("mesh_data is invalid shape {}".format(
vectors.shape))
self.vectors = vectors
class MeshPlot(ConstructedPlot):
"""To plot STL files you will need some kind of STL reader library. If you don't
have one then get `numpy-stl`_. Their Mesh class can be passed
directly to `mesh_plot()`.
.. _numpy-stl: https://pypi.org/project/numpy-stl/
:param mesh_data: A mesh object to plot.
:param tri_scalars: Per-triangle scalar, texture-coordinates or RGB values, defaults to None.
:type tri_scalars: numpy.ndarray
:param scalars: Per-vertex scalar, texture-coordinates or RGB values, defaults to None.
:type scalars: numpy.ndarray
:param color: The color (see `colors.as_rgb_a()`) of the whole plot, ignored if scalars are used, defaults to white.
:type color: str or tuple or numpy.ndarray
:param opacity: The translucency of the plot. Ranges from ``0.0`` (invisible) to ``1.0`` (solid).
:type opacity: float
:param cmap: A colormap (see `vtkplotlib.colors.as_vtk_cmap()`) to convert scalars to colors, defaults to ``'rainbow'``.
:param fig: The figure to plot into, use `None` for no figure, defaults to the output of `vtkplotlib.gcf()`.
:type fig: :class:`~vtkplotlib.figure` or :class:`~vtkplotlib.QtFigure`
:param label: Give the plot a label to use in a `legend`.
:type label: str
:return: A mesh object.
:rtype: `vtkplotlib.mesh_plot`
The following example assumes you have installed `numpy-stl`_.
.. code-block:: python
import vtkplotlib as vpl
from stl.mesh import Mesh
# path = "if you have an STL file then put it's path here."
# Otherwise vtkplotlib comes with a small STL file for demos/testing.
path = vpl.data.get_rabbit_stl()
# Read the STL using numpy-stl
mesh = Mesh.from_file(path)
# Plot the mesh
vpl.mesh_plot(mesh)
# Show the figure
vpl.show()
Unfortunately there are far too many mesh/STL libraries/classes out there to
support them all. To overcome this as best we can, mesh_plot has a flexible
constructor which accepts any of the following.
1. A filename.
2. Some kind of mesh class that has form 3 stored in ``mesh.vectors``.
For example numpy-stl's ``stl.mesh.Mesh`` or pymesh's ``pymesh.stl.Stl``.
3. An `numpy.array` with shape ``(n, 3, 3)`` of the form:
.. code-block:: python
np.array([[[x, y, z], # corner 0 \\
[x, y, z], # corner 1 | triangle 0
[x, y, z]], # corner 2 /
...
[[x, y, z], # corner 0 \\
[x, y, z], # corner 1 | triangle n-1
[x, y, z]], # corner 2 /
])
Note it's not uncommon to have arrays of shape (n, 3, 4) or (n, 4, 3)
where the additional entries' meanings are usually irrelevant (often to
represent scalars but as STL has no color this is always uniform). Hence
to support mesh classes that have these, these arrays are allowed and the
extra entries are ignored.
4. An `numpy.array` with shape (k, 3) of (usually unique) vertices of the
form:
.. code-block:: python
np.array([[x, y, z],
[x, y, z],
...
[x, y, z],
[x, y, z]])
And a second argument of an `numpy.array` of integers with shape
``(n, 3)`` of point args in the form:
.. code-block:: python
np.array([[i, j, k], # triangle 0
...
[i, j, k], # triangle n-1
])
where i, j, k are the indices of the points (in the vertices array)
representing each corner of a triangle.
Note that this form can be easily converted to form 2) using
.. code-block:: python
vertices = unique_vertices[point_args]
Hopefully this will cover most of the cases. If you are using or have written
an STL library (or any other format) that you want supported then let me know.
If it's numpy based then it's probably only a few extra lines to support. Or
you can have a go at writing it yourself, either with `mesh_plot()` or
with the `vtkplotlib.PolyData` class.
**Mesh plotting with scalars:**
To create a heat map like image use the **scalars** or **tri_scalars** options.
Use the **scalars** option to assign a scalar value to each point/corner:
.. code-block:: python
import vtkplotlib as vpl
from stl.mesh import Mesh
# Open an STL as before
path = vpl.data.get_rabbit_stl()
mesh = Mesh.from_file(path)
# Plot it with the z values as the scalars. scalars is 'per vertex' or 1
# value for each corner of each triangle and should have shape (n, 3).
plot = vpl.mesh_plot(mesh, scalars=mesh.z)
# Optionally the plot created by mesh_plot can be passed to color_bar
vpl.color_bar(plot, "Heights")
vpl.show()
Use the **tri_scalars** option to assign a scalar value to each triangle:
.. code-block:: python
import vtkplotlib as vpl
from stl.mesh import Mesh
import numpy as np
# Open an STL as before
path = vpl.data.get_rabbit_stl()
mesh = Mesh.from_file(path)
# `tri_scalars` must have one value per triangle and have shape (n,) or (n, 1).
# Create some scalars showing "how upwards facing" each triangle is.
tri_scalars = np.inner(mesh.units, np.array([0, 0, 1]))
vpl.mesh_plot(mesh, tri_scalars=tri_scalars)
vpl.show()
.. note:: **scalars** and **tri_scalars** overwrite each other and can't be used simultaneously.
.. seealso::
Having per-triangle-edge scalars doesn't fit well with VTK. So it got
its own separate function `mesh_plot_with_edge_scalars()`.
"""
def __init__(self, mesh_data, tri_scalars=None, scalars=None, color=None,
opacity=None, cmap=None, fig="gcf", label=None):
super().__init__(fig)
self.connect()
self.shape = (0, 3, 3)
self._last_used_default_indices = False
self.set_mesh_data(mesh_data)
del mesh_data
self.__setstate__(locals())
set_mesh_data = normalise_mesh_type
@property
def vectors(self):
if self._last_used_default_indices:
return self.polydata.points.reshape(self.shape)
else:
return self.vertices[self.indices]
@vectors.setter
def vectors(self, vectors):
vectors = np.asarray(vectors)
self.polydata.points = vectors.reshape((-1, 3))
# Ideally try to avoid rewriting the indices table.
# ``self.vectors += translation`` shouldn't require a rewrite.
# This is only safe to do if the user isn't directly playing with
# self.indices. self._last_used_default_indices tests that.
if vectors.shape == self.shape and self._last_used_default_indices:
# If shape not changed, indices should be identical.
return
self.shape = vectors.shape
if len(vectors) < self.shape[0] and self._last_used_default_indices:
# If the mesh has been cropped, then the indices table can be
# cropped.
args = self.polydata.polygons[:len(vectors)]
else:
# Otherwise it has to be rewritten.
args = np.arange(np.prod(self.shape[:-1]), dtype=self.polydata.ID_ARRAY_DTYPE)\
.reshape((-1, self.shape[-2]))
self.polydata.polygons = args
self._last_used_default_indices = True
@property
def vertices(self):
return self.polydata.points
@vertices.setter
def vertices(self, v):
self.polydata.points = v
@property
def indices(self):
return self.polydata.polygons
@indices.setter
def indices(self, i):
self.polydata.polygons = i
self.shape = i.shape + (3,)
self._last_used_default_indices = False
scalars = Lines.color
@property
def tri_scalars(self):
"""Sets a scalar for each triangle for generating heatmaps.
tri_scalars should be an 1D np.array of length n.
Calls self.set_scalars. See set_scalars for implications.
"""
return self.polydata.polygon_colors.reshape((self.shape[0], -1))
@tri_scalars.setter
def tri_scalars(self, tri_scalars):
if tri_scalars is not None:
if len(tri_scalars) != self.shape[0]:
raise ValueError("`tri_scalars` should have the same length as "
"`self.vectors` or `self.args` to be one value"
" per triangle.")
reshaped = tri_scalars.reshape((self.shape[0], -1))
if 1 <= reshaped.shape[1] <= 3:
tri_scalars = reshaped
else:
raise ValueError("`tri_scalars` should have shape ({0},), "
"({0}, 1), ({0}, 2) or ({0}, 3). Received {1}"
.format(self.shape[0], tri_scalars.shape)) # yapf: disable
self.polydata.polygon_colors = tri_scalars
if not self._freeze_scalar_range:
self.scalar_range = Ellipsis
@tri_scalars.deleter
def tri_scalars(self):
del self.polydata.polygon_colors
[docs]def mesh_plot_with_edge_scalars(mesh_data, edge_scalars, centre_scalar="mean",
opacity=None, cmap=None, fig="gcf", label=None):
r"""Like `mesh_plot` but able to add scalars per triangle's edge. By default,
the scalar value at centre of each triangle is taken to be the mean of the
scalars of its edges, but it can be far more visually effective to use
``centre_scalar=fixed_value``.
:param mesh_data: The mesh to plot (see `mesh_plot()`).
:param edge_scalars: Per-edge scalar, texture-coordinates or RGB values.
:type edge_scalars: numpy.ndarray
:param centre_scalar: Scalar value(s) for the centre of each triangle, defaults to 'mean'.
:type centre_scalar: str
:param opacity: The translucency of the plot. Ranges from ``0.0`` (invisible) to ``1.0`` (solid).
:type opacity: float
:param cmap: A colormap (see `vtkplotlib.colors.as_vtk_cmap()`) to convert scalars to colors, defaults to ``'rainbow'``.
:param fig: The figure to plot into, use `None` for no figure, defaults to the output of `vtkplotlib.gcf()`.
:type fig: :class:`~vtkplotlib.figure` or :class:`~vtkplotlib.QtFigure`
:param label: Give the plot a label to use in a `legend`.
:type label: str
:return: A mesh plot object.
:rtype: `vtkplotlib.mesh_plot`
Edge scalars are very much not the way VTK likes it. In fact VTK doesn't
allow it. To overcome this, this function triple-ises each triangle. See
the diagram below to see how this is done:
.. code-block:: text
(The diagram's tacky, I know)
p1
//|\\ Double lines represent the original triangle.
// | \\ The single lines represent the division lines that
l0 // | \\ l1 split the triangle into three.
// / \ \\ The annotations show the order in which the
// / \ \\ scalar for each edge must be provided.
///~~~~~~~~~\\\
p0 ~~~~~~~~~~~~~~~ p2
l2
(reST doesn't like it either)
Here is a usage example:
.. code-block:: python
import vtkplotlib as vpl
from vtkplotlib import geometry
from stl.mesh import Mesh
import numpy as np
path = vpl.data.get_rabbit_stl()
mesh = Mesh.from_file(path)
# This is the length of each side of each triangle.
edge_scalars = geometry.distance(mesh.vectors[:, np.arange(1, 4) % 3] - mesh.vectors)
vpl.mesh_plot_with_edge_scalars(mesh, edge_scalars, centre_scalar=0, cmap="Greens")
vpl.show()
I wrote this originally to visualise curvature. The calculation is ugly, but
on the off chance someone needs it, here it is.
.. code-block:: python
import vtkplotlib as vpl
from vtkplotlib import geometry
from stl.mesh import Mesh
import numpy as np
path = vpl.data.get_rabbit_stl()
mesh = Mesh.from_file(path)
def astype(arr, dtype):
return np.frombuffer(arr.tobytes(), dtype)
def build_tri2tri_map(mesh):
'''This creates an (n, 3) array that maps each triangle to its 3
adjacent triangles. It takes advantage of each triangles vertices
being consistently ordered anti-clockwise. If triangle A shares an
edge with triangle B then both A and B have the edges ends as
vertices but in opposite order. Looking for this helps reduce the
complexity of the problem.
'''
# The most efficient way to make a pair of points hashable is to
# take its binary representation.
dtype = np.array(mesh.vectors[0, :2].tobytes()).dtype
vectors_rolled = mesh.vectors[:, np.arange(1, 4) % 3]
# Get all point pairs going one way round.
pairs = np.concatenate((mesh.vectors, vectors_rolled), -1)
# Get all point pairs going the other way round.
pairs_rev = np.concatenate((vectors_rolled, mesh.vectors), -1)
bin_pairs = astype(pairs, dtype).reshape(-1, 3)
bin_pairs_rev = astype(pairs_rev, dtype).reshape(-1, 3)
# Use a dictionary to find all the matching pairs.
mapp = dict(zip(bin_pairs.ravel(), np.arange(bin_pairs.size) // 3))
args = np.fromiter(map(mapp.get, bin_pairs_rev.flat), dtype=float, count=bin_pairs.size).reshape(-1, 3)
# Triangles with a missing adjacent edge come out as nans.
# Convert mapping to ints and nans to -1s.
mask = np.isfinite(args)
tri2tri_map = np.empty(args.shape, int)
tri2tri_map[mask] = args[mask]
tri2tri_map[~mask] = -1
return tri2tri_map
tri2tri_map = build_tri2tri_map(mesh)
tri_centres = np.mean(mesh.vectors, axis=1)
curves = np.cross(mesh.units[tri2tri_map], mesh.units[:, np.newaxis])
displacements = tri_centres[tri2tri_map] - tri_centres[:, np.newaxis]
curvatures = curves / geometry.distance(displacements, keepdims=True)
curvature_signs = np.sign(geometry.inner_product(mesh.units[:, np.newaxis],
displacements)) * -1
signed_curvatures = geometry.distance(curvatures) * curvature_signs
# And finally, to plot it.
plot = vpl.mesh_plot_with_edge_scalars(mesh, signed_curvatures)
# Curvature must be clipped to prevent anomalies overwidening the
# scalar range.
plot.scalar_range = -.1, .1
# Red represents an inside corner, blue represents an outside corner.
plot.cmap = "coolwarm_r"
vpl.show()
"""
self = MeshPlot(mesh_data, fig=fig)
vectors = self.vectors
tri_centres = np.mean(vectors, 1)
new_vectors = np.empty((len(vectors) * 3, 3, 3), vectors.dtype)
# new_vectors.fill(np.nan)
for i in range(3):
for j in range(2):
new_vectors[i::3, j % 3] = vectors[:, (i + j) % 3]
new_vectors[i::3, 2 % 3] = tri_centres
tri_scalars = edge_scalars.ravel()
if centre_scalar == "mean":
centre_scalars = np.mean(edge_scalars, 1)
else:
centre_scalars = centre_scalar
new_scalars = np.empty((len(tri_scalars), 3), dtype=tri_scalars.dtype)
new_scalars[:, 0] = new_scalars[:, 1] = tri_scalars
for i in range(3):
new_scalars[i::3, 2] = centre_scalars
self.vectors = new_vectors
self.scalars = new_scalars
self.opacity = opacity
self.fig = fig
self.label = label
self.cmap = cmap
return self