Skip to content

Commit

Permalink
implement VideoViz to record model runs in a video
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-boyu committed Nov 3, 2024
1 parent 527c023 commit ce2f0d5
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ dmypy.json
# JS dependencies
mesa/visualization/templates/external/
mesa/visualization/templates/js/external/

# Video
**/*.mp4
40 changes: 40 additions & 0 deletions mesa/examples/basic/schelling/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Example of using VideoViz with the Schelling model."""

from mesa.examples.basic.schelling.model import Schelling
from mesa.visualization.video_viz import (

Check warning on line 4 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L3-L4

Added lines #L3 - L4 were not covered by tests
VideoViz,
make_measure_component,
make_space_component,
)

# Create model
model = Schelling(10, 10)

Check warning on line 11 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L11

Added line #L11 was not covered by tests


def agent_portrayal(agent):

Check warning on line 14 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L14

Added line #L14 was not covered by tests
"""Portray agents based on their type."""
if agent is None:
return {}

Check warning on line 17 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L17

Added line #L17 was not covered by tests

portrayal = {

Check warning on line 19 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L19

Added line #L19 was not covered by tests
"color": "red" if agent.type == 0 else "blue",
"size": 25,
"marker": "s", # square marker
}
return portrayal

Check warning on line 24 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L24

Added line #L24 was not covered by tests


# Create visualization with space and some metrics
viz = VideoViz(

Check warning on line 28 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L28

Added line #L28 was not covered by tests
model,
[
make_space_component(agent_portrayal=agent_portrayal, save_format="svg"),
make_measure_component("happy", save_format="svg"),
],
title="Schelling's Segregation Model",
)

# Record simulation
if __name__ == "__main__":
video_path = viz.record(steps=50, filepath="schelling.mp4")
print(f"Video saved to: {video_path}")

Check warning on line 40 in mesa/examples/basic/schelling/video.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/schelling/video.py#L39-L40

Added lines #L39 - L40 were not covered by tests
184 changes: 184 additions & 0 deletions mesa/visualization/video_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Video recording components for Mesa model visualization."""

import shutil
from collections.abc import Callable, Sequence
from pathlib import Path

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

import mesa
from mesa.visualization.matplotlib_renderer import (
MatplotlibRenderer,
MeasureRendererMatplotlib,
SpaceRenderMatplotlib,
)


def make_space_component(
agent_portrayal: Callable | None = None,
propertylayer_portrayal: dict | None = None,
post_process: Callable | None = None,
**space_drawing_kwargs,
):
"""Create a Matplotlib-based space visualization component.
Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
backend: The backend to use for rendering the space. Can be "matplotlib" or "altair".
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
the functions for drawing the various spaces for further details.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
Returns:
SpaceRenderMatplotlib: A component for rendering the space.
"""
if agent_portrayal is None:

def agent_portrayal(a):
return {}

return SpaceRenderMatplotlib(
agent_portrayal,
propertylayer_portrayal,
post_process=post_process,
**space_drawing_kwargs,
)


def make_measure_component(
measure: Callable,
**kwargs,
):
"""Create a plotting function for a specified measure.
Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
kwargs: Additional keyword arguments to pass to the MeasureRendererMatplotlib constructor.
Returns:
MeasureRendererMatplotlib: A component for rendering the measure.
"""
return MeasureRendererMatplotlib(
measure,
**kwargs,
)


class VideoViz:
"""Create high-quality video recordings of model simulations."""

def __init__(
self,
model: mesa.Model,
components: Sequence[MatplotlibRenderer],
*,
title: str | None = None,
figsize: tuple[float, float] | None = None,
grid: tuple[int, int] | None = None,
):
"""Initialize video visualization configuration.
Args:
model: The model to simulate and record
components: Sequence of component objects defining what to visualize
title: Optional title for the video
figsize: Optional figure size in inches (width, height)
grid: Optional (rows, cols) for custom layout. Auto-calculated if None.
"""
# Check if FFmpeg is available
if not shutil.which("ffmpeg"):
raise RuntimeError(
"FFmpeg not found. Please install FFmpeg to save animations:\n"
" - macOS: brew install ffmpeg\n"
" - Linux: sudo apt-get install ffmpeg\n"
" - Windows: download from https://ffmpeg.org/download.html"
)
self.model = model
self.components = components
self.title = title
self.figsize = figsize
self.grid = grid or self._calculate_grid(len(components))

# Setup figure and axes
self.fig, self.axes = self._setup_figure()

def record(
self,
*,
steps: int,
filepath: str | Path,
dpi: int = 100,
fps: int = 10,
codec: str = "h264",
bitrate: int = 2000,
) -> Path:
"""Record model simulation to video file.
Args:
steps: Number of simulation steps to record
filepath: Where to save the video file
dpi: Resolution of the output video
fps: Frames per second in the output video
codec: Video codec to use
bitrate: Video bitrate in kbps (default: 2000)
Returns:
Path to the saved video file
Raises:
RuntimeError: If FFmpeg is not installed
"""
filepath = Path(filepath)

def update(frame_num):
# Update model state
self.model.step()

# Render all visualization frames
for component, ax in zip(self.components, self.axes):
ax.clear()
component.draw(self.model, ax)
return self.axes

# Create and save animation
anim = animation.FuncAnimation(
self.fig, update, frames=steps, interval=1000 / fps, blit=False
)

writer = animation.FFMpegWriter(
fps=fps,
codec=codec,
bitrate=bitrate, # Now passing as integer
)

anim.save(filepath, writer=writer, dpi=dpi)
return filepath

def _calculate_grid(self, n_frames: int) -> tuple[int, int]:
"""Calculate optimal grid layout for given number of frames."""
cols = min(3, n_frames) # Max 3 columns
rows = int(np.ceil(n_frames / cols))
return (rows, cols)

def _setup_figure(self):
"""Setup matplotlib figure and axes."""
if not self.figsize:
self.figsize = (5 * self.grid[1], 5 * self.grid[0])
fig = plt.figure(figsize=self.figsize)
axes = []

for i in range(len(self.components)):
ax = fig.add_subplot(self.grid[0], self.grid[1], i + 1)
axes.append(ax)

if self.title:
fig.suptitle(self.title, fontsize=16)
fig.tight_layout()
return fig, axes

0 comments on commit ce2f0d5

Please sign in to comment.