evoxels.profiler

Classes

JAXMemoryProfiler()

MemoryProfiler()

Base interface for tracking host and device memory usage.

TorchMemoryProfiler(device)

class evoxels.profiler.JAXMemoryProfiler
__init__()

Initialize the profiler for JAX.

print_memory_stats(start, end, iters)

Print usage statistics for the JAX backend.

class evoxels.profiler.MemoryProfiler

Base interface for tracking host and device memory usage.

__init__()
get_cuda_memory_from_nvidia_smi()

Return currently used CUDA memory in megabytes.

abstractmethod print_memory_stats(start: float, end: float, iters: int)

Print profiling summary after a simulation run.

update_memory_stats()
class evoxels.profiler.TorchMemoryProfiler(device)
__init__(device)

Initialize the profiler for a given torch device.

print_memory_stats(start, end, iters)

Print usage statistics for the Torch backend.