Skip to content

SquareGridArtist2D

SquareGridArtist2D

Bases: StructureArtist

A helper StructureArtist class for rendering 2D square grids.

Source code in pylattica/visualization/square_grid_artist_2D.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class SquareGridArtist2D(StructureArtist):
    """A helper StructureArtist class for rendering 2D square grids."""

    def _draw_image(self, state: SimulationState, **kwargs):
        label = kwargs.get("label", None)
        cell_size = kwargs.get("cell_size", 20)

        legend = self.cell_artist.get_legend(state)
        legend_order = sorted(legend.keys())
        state_size = int(self.structure.lattice.vec_lengths[0])
        width = state_size + 6

        legend_border_width = 5
        height = max(state_size, len(legend) + 1)
        img = Image.new(
            "RGB",
            (width * cell_size + legend_border_width, height * cell_size),
            "black",
        )  # Create a new black image

        pixels = img.load()
        draw = ImageDraw.Draw(img)

        for site in self.structure.sites():
            loc = site[LOCATION]
            cell_state = state.get_site_state(site[SITE_ID])
            cell_color = self.cell_artist.get_color_from_cell_state(cell_state)
            p_x_start = int((loc[0]) * cell_size)
            p_y_start = int((state_size - 1 - loc[1]) * cell_size)
            for p_x in range(p_x_start, p_x_start + cell_size):
                for p_y in range(p_y_start, p_y_start + cell_size):
                    pixels[p_x, p_y] = cell_color

        count = 0
        legend_hoffset = int(cell_size / 4)
        legend_voffset = int(cell_size / 4)

        for p_y in range(height * cell_size):
            for p_x in range(0, legend_border_width):
                x = state_size * cell_size + p_x
                pixels[x, p_y] = (255, 255, 255)

        for phase in legend_order:
            color = legend.get(phase)
            p_col_start = state_size * cell_size + legend_border_width + legend_hoffset
            p_row_start = count * cell_size + legend_voffset
            for p_x in range(p_col_start, p_col_start + cell_size):
                for p_y in range(p_row_start, p_row_start + cell_size):
                    pixels[p_x, p_y] = color

            legend_label_loc = (
                int(p_col_start + cell_size + cell_size / 4),
                int(p_row_start + cell_size / 4),
            )
            draw.text(legend_label_loc, phase, (255, 255, 255))
            count += 1

        if label is not None:
            draw.text((5, 5), label, (255, 255, 255))

        return img