import numpy as np
import matplotlib.pyplot as plt
import asyncio
import random
import js    
import ctypes
from pyscript import document

# Initialize variables for animation
animating = False
collisions_off = False
walls_off = False
sp_plot = False
en_plot = False
async_flag = True
paused = True
wide, high = 500, 500
t, dt = 0, 1
sp_plots = 20
en_plots = 10
plot_int = 1
speed_factor = 1

# Setup
def do_set(event): 
    global paused 

    if paused:
        initialize_parameters()                                  
        initialize_pos_vel()                                  
        clear_dots()                                                  
        plt.close()
        plt.close()                    # duplicate
        add_circles()                                                                      
        
# Run the animation
def do_run(event):                                      
    global paused, animating, async_flag

    if paused and not animating:
        animating = True
        paused = False
        animate(t)

    if async_flag:
        async_flag = False
        asyncio.ensure_future(main())            
        
# Pause the animation
def do_pause(event):
    global animating, paused

    if animating:
        animating = False
        paused = True

# Reverse the velocity direction of each dot
def do_reverse(event):
    global paused, animating, vel

    if animating:
        for i in range(n):
            vel[i] *= -1
        paused = False
        animate(t)

def initialize_parameters():
    global n, t, tmax, time               
    global pos, vel, precision, r_init, r_space, imax
    global step, coll, animating, grid_side, gridwidth, spacing, dot_size   
    global collisions_off, walls_off, speed    # , results      
    global sp_plot, en_plot, sp_plots, en_plots, plot_int
    global wide, high, large_grid, paused
    global c_energy, d_energy, c_emicro_t, d_emicro_t, c_smicro_t, d_smicro_t
    global s_micro, s_occup, s_alloc, c_emicro, d_emicro, c_level, d_level
    global d_speed, c_speed, d_smicro, c_smicro
    global s_micro_sum, d_emicro_sum, d_smicro_sum

    plt.close()                # AVW

    # Set container dimensions to match program constants                         
    document.getElementById("cont").style.width=wide                              
    document.getElementById("cont").style.height=high                             

    # Get the input value of n
    try:
        n = int(document.getElementById("numberMolecules").value )
    except Exception as e:
        print("A number should be entered:", e)

    # Get the energy precision (precision)
    try:    
        precision = int(document.getElementById("Precision").value )     
    except Exception as e:
        print("A number should be entered:", e)

    if precision == 0:
        precision = 1
    elif precision%10 != 0:
        if n/precision != int(n/precision):
            if not (n%10 == 0 and precision%4 == 0):
                precision = 1

    # Get the maximum time steps
    try:
        tmax = int(document.getElementById("numberTimeSteps").value )     
    except Exception as e:
        print("A number should be entered:", e)

    # Get the initial radius (r_init)
    try:    
        r_init = int(document.getElementById("Radius").value )     
    except Exception as e:
        print("A number should be entered:", e)

    if r_init > wide/2:
        r_init = wide/2

    # Calculate the number of space bands
    imax = int(((wide/2) / r_init)**2 + 1)
    r_space = np.zeros((imax, 2))
    for i in range(imax):
        r_space[i, 0] = np.sqrt(i)*r_init

    # Set the grid size
    grid_side = int(np.sqrt(n))
    n = grid_side**2
    large_grid = False

    try:
        gridwidth = int(document.getElementById("gridwidth").value )
    except Exception as e:
        print("A number should be entered:", e)

    if gridwidth <= 1:
        gridwidth = 1
    elif gridwidth >= wide:
        gridwidth = wide

    spacing = gridwidth / grid_side
		
    if (document.getElementById("largeGrid").checked):
        spacing = wide / grid_side
        large_grid = True
    else:
        large_grid = False
	
    show_message1('')

    t = 0
    step = 0
    coll = 0

    animating = True
    paused = True
    pos = np.ones((n, 2))
    vel = np.zeros((n, 2))
    time = np.zeros(tmax)    

    s_micro = np.zeros((tmax, imax))
    s_occup = np.zeros((tmax, imax-1)) 
    s_alloc = np.zeros((tmax, imax-1)) 
    s_micro_sum = np.zeros(imax)

    c_speed = np.zeros(n)
    d_speed = np.zeros(n)
    c_smicro = np.zeros((tmax, n))
    d_smicro = np.zeros((tmax, n))
    c_energy = np.zeros(n)
    d_energy = np.zeros(n)
    c_emicro = np.zeros((tmax, n))
    d_emicro = np.zeros((tmax, n))
    c_level = np.zeros(tmax) 
    d_level = np.zeros(tmax, dtype=int) 
    d_smicro_sum = np.zeros((n*precision, 2))
    d_emicro_sum = np.zeros((n*precision, 2))
    for i in range(n*precision):
        d_smicro_sum[i, 0] = (i + 0.5)/precision
        d_emicro_sum[i, 0] = (i + 0.5)/precision

    # Set the dot size to one-tenth of the box width
    dot_size = np.sqrt((0.1 * wide)**2 / n)

    # Set collisions OFF on the checkbox selection 
    if (document.getElementById("collisionsOff").checked):
        collisions_off = True
    else:
        collisions_off = False

    # Set walls OFF based on the checkbox selection
    if (document.getElementById("wallsOff").checked):
        walls_off = True
    else:
        walls_off = False

    # Select the mode based on the radio button selection
    if (document.getElementById("gas").checked):
        sp_plot = False
        en_plot = False
    elif(document.getElementById("energy").checked):
        plt.close('all')                # AVW
        sp_plot = False
        en_plot = True
    else:
        plt.close('all')                # AVW
        sp_plot = True
        en_plot = False

    # Set the microstate plot interval
    if sp_plot: 
        plot_int = round(tmax / sp_plots)
    elif en_plot:
        plot_int = round(tmax / en_plots)

    if plot_int < 1:
        plot_int = 1

    do_pause('')
    
# Initialize position and velocity                                      
def initialize_pos_vel():
    global grid_side, pos, spacing, high, wide, vel                    

    k = 0                                                              
    for i in range(grid_side):                                         
        for j in range(grid_side):                                     
            pos[k,0] = float(i * spacing) + (spacing / 2) + (wide - grid_side * spacing)/ 2
            pos[k,1] = float(j * spacing*(high/wide)) + (spacing*(high/wide) / 2) + (high - grid_side * spacing)/ 2
            angle = random.uniform(0, 2 * np.pi)                     
            vel[k] = (np.array((np.cos(angle), np.sin(angle))))        
            k = k + 1

# Add the dots
def add_circles():              
    global n                          

    container = document.getElementById("cont")           # a container for the circles   
    k = 0                                                                               
    for i in range(grid_side):                                                         
        for j in range(grid_side):                                                     
            circle=document.createElementNS("http://www.w3.org/2000/svg", "circle")    
            idstr = "cir"+str(k)                          # make the id a function k 
            circle.setAttribute( "id", idstr)             # set the id                 
            xloc = pos[k,0]                               # initial x location         
            circle.setAttribute( "cx", xloc)                                           
            yloc = pos[k,1]                               # initial y location         
            circle.setAttribute( "cy", yloc)                                           
            circle.setAttribute( "r", dot_size / 2)       # set the radius             
            circle.setAttribute( "fill", "blue")          # try a different color      
            circle.setAttribute( "value", 1)              # create a value attribute   
            container.appendChild(circle)                 # add to container           
            k += 1                                        # k= 0, ... n-1              

# Loop through time generating microstates  
def animate(t):                                                                      
    global speed, speed_factor, tmax, paused, sp_plot, en_plot
    global async_flag
    
    if not animating:
        return

    # Set speed to the slider value
    speed = speed_factor * int(document.getElementById("animation_speed").value)     

    update_microstates(t)                                                                            
    do_collisions_walls(t)

    if sp_plot or en_plot and t > 1:    
        plot_chart_seq(t)                                                   

    if t == (tmax - 1):                                   
        plot_charts(t)                                                   
        show_messages()
        paused = True
        async_flag = True

# Update microstates 
def update_microstates(t):
    global n, tmax, time, speed, dt, wide, high                                  
    global pos, vel, precision, r_space, s_micro
    global c_energy, d_energy, c_emicro_t, d_emicro_t, c_smicro_t, d_smicro_t
    global s_micro, s_occup, s_alloc, c_emicro, d_emicro, c_level, d_level
    global d_speed, c_speed, d_smicro, c_smicro
    global s_micro_sum, d_emicro_sum, d_smicro_sum

    # Save the time
    if t < tmax:
        time[t] = t
    else:
        return

    pos += speed * vel * dt                 # Move the dots                                                                                                

    # Draw the dots
    draw_dots()

    # Calculate the spatial microstates                                  
    r_space[:, 1] = 0
    for i in range(n):
        radius = np.sqrt((pos[i, 0] - wide/2)**2 + (pos[i, 1] - high/2)**2)         
        jrange = len(r_space) - 1
        for j in range(jrange):
            if radius >= r_space[j, 0] and radius < r_space[j+1, 0]:
                r_space[j, 1] += 1
    s_micro[t] = r_space[:, 1]

    # Calculate the energy microstates                                  
    for i in range(n):
        c_speed[i] = np.sqrt((vel[i, 0]**2 + vel[i, 1]**2))         # Theoretical speed
        d_speed[i] = round(c_speed[i]*precision)/precision          # Approximate energy 
        c_energy[i] = (vel[i, 0]**2 + vel[i, 1]**2)/2               # Continuous energy
        d_energy[i] = round(c_energy[i]*precision)/precision        # Rounded energy 

    c_smicro[t] = c_speed
    d_smicro[t] = d_speed
    c_emicro[t] = c_energy
    d_emicro[t] = d_energy

    unique_vals, counts = np.unique(d_speed, return_counts=True)
    d_sarray = np.column_stack((unique_vals, counts))
    d_sarray[:, 0] = np.round(d_sarray[:, 0] + 0.5/precision, 5)    # Handles precision up to 1000
    for i in range(n*precision):
        for j in range(len(d_sarray)):
            if d_sarray[j, 0] == d_smicro_sum[i, 0]:
                d_smicro_sum[i, 1] = d_smicro_sum[i, 1] + d_sarray[j, 1]

    unique_vals, counts = np.unique(d_energy, return_counts=True)
    d_earray = np.column_stack((unique_vals, counts))
    d_earray[:, 0] = np.round(d_earray[:, 0] + 0.5/precision, 5)    # Handles precision up to 1000
    for i in range(n*precision):
        for j in range(len(d_earray)):
            if d_earray[j, 0] == d_emicro_sum[i, 0]:
                d_emicro_sum[i, 1] = d_emicro_sum[i, 1] + d_earray[j, 1]

# Calculate collisions with molecules and walls
def do_collisions_walls(t):
    global n, dt, random_numbers, dot_size, step, coll, pos, vel              
    global speed, collisions_off, walls_off, precision
    global c_energy, d_energy, c_emicro_t, d_emicro_t, c_smicro_t, d_smicro_t
    global c_level, d_level, s_micro, s_micro_t, s_alloc

    step += 1                                   

    for i in range(n):                                              
        # Do Collisions                                             
        if collisions_off == False:                                 
            for j in range(n):                                      
                if i != j:                                          
                    distance = ((pos[i, 0] - pos[j, 0])**2 + (pos[i, 1] - pos[j, 1])**2)**0.5
                    # Molecular collisions                                    
                    if distance <= dot_size:                        
                        coll += 1                                   
                        pos_i, vel_i = pos[i], vel[i]               
                        pos_j, vel_j = pos[j], vel[j]               
                        rel_pos, rel_vel = pos_i - pos_j, vel_i - vel_j     
                        r_rel = rel_pos @ rel_pos                           
                        v_rel = rel_vel @ rel_pos                           
                        v_rel = 2 * rel_pos * v_rel / r_rel - rel_vel       
                        v_cm = (vel_i + vel_j) / 2                          
                        vel_i = v_cm - v_rel/2                              
                        vel_j = v_cm + v_rel/2                              
                        vel[i] = vel_i                                      
                        vel[j] = vel_j                                      
                        pos[i] += speed * vel[i] * dt                       
                        pos[j] += speed * vel[j] * dt                       

        if walls_off == False:
            # Bounce off walls: ChatGPT code

            # Left/right walls
            hit_left = pos[:, 0] < dot_size
            hit_right = pos[:, 0] > wide - dot_size
            vel[hit_left | hit_right, 0] *= -1
            pos[hit_left, 0] = dot_size
            pos[hit_right, 0] = wide - dot_size

            # Bottom/top walls
            hit_bottom = pos[:, 1] < dot_size
            hit_top = pos[:, 1] > high - dot_size
            vel[hit_bottom | hit_top, 1] *= -1
            pos[hit_bottom, 1] = dot_size
            pos[hit_top, 1] = high - dot_size

    # Discrete space
    s_micro_t = s_micro[t]       
    s_micro_t = s_micro_t[:-1]
    s_alloc[t] = s_micro_t / sum(s_micro_t)

    # Discrete energy
    d_emicro_t = d_emicro[t]
    ones_count = 0
    for item in d_emicro_t:
        if item == d_emicro[0, 0]:
            ones_count += 1
    M = n - ones_count 
    if M < n:
        M = M + 1
    d_level[t] = M

    # Continuous energy
    c_emicro_t = c_emicro[t]
    ones_count = 0
    for item in c_emicro_t:
        if item == c_emicro[0, 0]:
            ones_count += 1
    M = n - ones_count 
    if M < n:
        M = M + 1
    c_level[t] = M

# Draw the dots                 
def draw_dots():                                                 
    global n, pos                             

    if sp_plot == en_plot:
        
        for i in range(n):                                           
            idstr = "cir"+str(i)                              
            circle = document.getElementById(idstr)          
            xloc = pos[i,0]         # update x location          
            circle.setAttribute( "cx", xloc)                     
            yloc = pos[i,1]         # update y location          
            circle.setAttribute( "cy", yloc)                     
            circle.setAttribute( "value", i)  # update value     

# Clear the dots    
def clear_dots():
    container = document.getElementById("cont")
    while (container.hasChildNodes()):               
        container.removeChild(container.firstChild)  

# Plot charts
def plot_charts(t):
    global n, coll, speed, tmax, time, sp_plot, en_plot                                
    global pos, vel, precision, r_init, imax
    global c_energy, d_energy, c_emicro_t, d_emicro_t, c_smicro_t, d_smicro_t
    global s_micro, s_occup, s_alloc, c_emicro, d_emicro, c_level, d_level
    global s_micro_sum, d_speed, c_speed, d_smicro, c_smicro
    global d_smicro_sum, d_emicro_sum, message1

    message1 = 'Molecules: ' + str(n) + '   Molecular Collisions: ' + str(coll) + '   Speed: ' + str(speed) + '   Number of Time Steps: ' + str(tmax) + '   Energy Precision: ' + str(precision) + '   Initial Radius: ' + str(r_init)                     

    plt.close('all')                # AVW

    d_smicro_sum[:, 1] = d_smicro_sum[:, 1] / (n*tmax)      # Average speed allocation
    d_emicro_sum[:, 1] = d_emicro_sum[:, 1] / (n*tmax)      # Average energy probability
    s_micro_sum = s_micro_sum / tmax                        # Average discrete spatial distribution

    xscale_factor = np.sqrt(precision)            
    
    if en_plot or en_plot == sp_plot:        # or both flags false

        en_time = time[:tmax-1]
        """        
        # Plot speed
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        prob = np.zeros((n*precision, 2))
        for i in range(n*precision):
            Si = (i + 0.5) / precision
            maxwell = np.exp(-Si**2)            # From Maxwell (1860)
            prob[i, 0] = Si
            prob[i, 1] = Si * maxwell                
        sum_prob = sum(prob[:, 1])
        prob[:, 1] = prob[:, 1]/sum_prob                        
        
        ymax = max(max(d_smicro_sum[:, 1]), max(prob[:, 1]))

        # Plot average speed distribution
        axes[0].bar(d_smicro_sum[:, 0], d_smicro_sum[:, 1], width = .8/precision)   # Average speed distribution (Maxwell)
        axes[0].set_title('Average Speed Distribution n='+str(n)+' c='+str(precision)+' t='+str(tmax), fontsize=10)
        axes[0].set_xlabel('Observed Speed')
        axes[0].set_ylabel('Proportion of Molecules')
        axes[0].set_xlim(0, xscale_factor)       
        axes[0].set_ylim(0, 1.1 * ymax)         
        
        # Plot theoretical speed distribution
        axes[1].bar(prob[:, 0], prob[:, 1], width = .8/precision)                   # Theoretical speed distribution
        axes[1].set_title('Theoretical Speed Distribution n='+str(n)+' c='+str(precision)+' t='+str(tmax)+' 2D', fontsize=10)
        axes[1].set_xlabel('Theoretical Speed')
        axes[1].set_ylabel('Probability')
        axes[1].set_xlim(0, xscale_factor)      
        axes[1].set_ylim(0, 1.1 * ymax)         
        
        # Show speed distributions
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()

        # Plot energy
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        prob = np.zeros((n*precision, 2))
        for i in range(n*precision):
            Ei = (i + 0.5) / precision
            if i > 0:                   # corresponds to Range 0...
                degen = 1               # From Gibbs (1902) 2D
            else:
                degen = .5 
            boltz = np.exp(-2*Ei)      
            prob[i, 0] = Ei
            prob[i, 1] = degen * boltz
        sum_prob = sum(prob[:, 1])
        prob[:, 1] = prob[:, 1]/sum_prob                        

        ymax = max(max(d_emicro_sum[:, 1]), max(prob[:, 1]))

        # Plot average discrete energy distribution
        axes[0].bar(d_emicro_sum[:, 0], d_emicro_sum[:, 1], width = .8/precision)
        axes[0].set_title('Average Energy Distribution n='+str(n)+' c='+str(precision)+' t='+str(tmax), fontsize=10)
        axes[0].set_xlabel('Discrete Energy')
        axes[0].set_ylabel('Proportion of Molecules')
        axes[0].set_xlim(0, xscale_factor)        
        axes[0].set_ylim(0, 1.1 * ymax)         

        # Plot theoretical energy distribution
        axes[1].bar(prob[:, 0], prob[:, 1], width = .8/precision)   
        axes[1].set_title('Theoretical Energy Distribution n='+str(n)+' c='+str(precision)+' t='+str(tmax)+' 1D', fontsize=10)
        axes[1].set_xlabel('Theoretical Energy')
        axes[1].set_ylabel('Probability')
        axes[1].set_xlim(0, xscale_factor)        
        axes[1].set_ylim(0, 1.1 * ymax)         

        # Show energy distribution shapes
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()
        """
        # Plot discrete energy microstate allocation sequence
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        d_emicro = d_emicro[:tmax-1] / n
        axes[0].plot(en_time, d_emicro)
        axes[0].set_title('Observed Energy Microstates n='+str(n)+' c='+str(precision)+' t='+str(tmax), fontsize=10)
        axes[0].set_xlabel('Number of Time Steps')
        axes[0].set_ylabel('Proportion of Energy per Molecule')

        # Plot continous energy microstate allocation sequence
        c_emicro = c_emicro[:tmax-1] / n
        axes[1].plot(en_time, c_emicro)
        axes[1].set_title('Continuous Energy Microstates n='+str(n)+' t='+str(tmax), fontsize=10)
        axes[1].set_xlabel('Number of Time Steps')
        axes[1].set_ylabel('Proportion of Energy per Molecule')

        # Show energy microstate allocation sequence
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()
        """
        # Plot average discrete energy distribution
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        axes[0].bar(d_emicro_sum[:, 0], d_emicro_sum[:, 1], width = .8/precision)
        axes[0].set_title('Average Energy Distribution n='+str(n)+' c='+str(precision)+' t='+str(tmax), fontsize=10)
        axes[0].set_xlabel('Discrete Energy')
        axes[0].set_ylabel('Proportion of Molecules')
        axes[0].set_xlim(0, xscale_factor)        
        axes[0].set_ylim(0, 1.1 * ymax)         

        unique_rows, row_counts = np.unique(c_emicro_t, axis=0, return_counts=True)
        c_emicro_t = np.column_stack((unique_rows, row_counts))

        # Plot continous energy distribution
        axes[1].bar(c_emicro_t[:, 0], c_emicro_t[:, 1]/n, width = .08/precision) # Theoretical energy
        axes[1].set_title('Sample Continuous Energy Distribution n='+str(n)+' t='+str(tmax), fontsize=10)
        axes[1].set_xlabel('Continuous Energy')
        axes[1].set_ylabel('Probability')
        axes[1].set_xlim(0, xscale_factor)        
        axes[1].set_ylim(0, 1.1 * ymax)        

        # Show energy distributions
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()
        """
        # Plot discrete energy levels
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        d_level = d_level[:tmax-1]
        axes[0].plot(en_time, d_level, 'green')
        axes[0].set_title('Observed Energy Levels n='+str(n)+' c='+str(precision)+' t='+str(tmax), fontsize=10)
        axes[0].set_xlabel('Number of Time Steps')
        axes[0].set_ylabel('Number of Observed Energy Levels')
        axes[0].set_ylim(0, max(c_level) + 1)

        # Plot continuous energy levels
        c_level = c_level[:tmax-1]
        axes[1].plot(en_time, c_level, 'green')
        axes[1].set_title('Continuous Energy Levels n='+str(n)+' t='+str(tmax), fontsize=10)
        axes[1].set_xlabel('Number of Time Steps')
        axes[1].set_ylabel('Number of Theoretical Energy Levels')
        axes[1].set_ylim(0, max(c_level) + 1)

        # Show energy levels
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()

    if sp_plot or en_plot == sp_plot:        # or both flags false

        # Plot discrete spatial microstate sequence
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 3))

        axes[0].plot(time, s_alloc)
        axes[0].set_title('Proportion of Molecules per Band n='+str(n)+' r='+str(r_init), fontsize=10)
        axes[0].set_xlabel('Number of Time Steps')
        axes[0].set_ylabel('Proportion of Molecules per Band ')

        # Plot discrete spatial microstate distribution
        row = np.zeros(imax-1)      # Use (imax) for s_micro_sum
        for i in range(imax-1):     # Use (imax) for s_micro_sum
            row[i] = i + 1
        axes[1].bar(row, s_micro_t, color='brown')
        axes[1].set_title('Occupation Numbers per Band n='+str(n)+' r='+str(r_init), fontsize=10)
        axes[1].set_xlabel('Band Ranked by Radius')
        axes[1].set_ylabel('Occupation Number')
        axes[1].set_ylim(0, r_init*n/200)

        print('Num Spatial Bands ', imax-1)
        print('Max Occupation Num', int(max(s_occup[:, imax-2])))
        print('Max Spatial Alloc ', round(100*max(s_alloc[:, imax-2]), 2), '% of Total')

        # Show discrete spatial microstates
        plt.subplots_adjust(left= .075, right= .965, top=.93, bottom=.12)
        plt.show()

        return

def plot_chart_seq(t):
    global s_micro_t, d_emicro_t
      
    if sp_plot:
        if t == 1 or t%plot_int == 0:                

            plt.clf()

            # Plot discrete spatial microstate distribution
            row = np.zeros(imax-1)
            for i in range(imax-1):
                row[i] = i + 1
            plt.bar(row, s_micro_t, color='brown')
            plt.title('Occupation Numbers per Band n='+str(n)+' r='+str(r_init)+' t='+str(t), fontsize=10)
            plt.xlabel('Band Ranked by Radius')
            plt.ylabel('Occupation Numbers')
            plt.ylim(0, r_init*n/200)
            plt.ion()
            plt.pause(.01)
            plt.show()
            plt.ioff()
            
    elif en_plot:
        if t == 1 or t%plot_int == 0:                

            plt.clf()

            unique_rows, row_counts = np.unique(d_emicro_t, axis=0, return_counts=True)
            d_emicro_t = np.column_stack((unique_rows, row_counts))
            
            xscale_factor = np.sqrt(precision)            

            ymax = max(d_emicro_t[:, 1])/n

            # Plot discrete energy distribution
            plt.bar(d_emicro_t[:, 0], d_emicro_t[:, 1]/n, width = .8/precision, align='edge')
            plt.title('Molecular Energy Distribution n='+str(n)+' c='+str(precision)+' t='+str(t), fontsize=10)
            plt.xlabel('Discrete Energy')
            plt.ylabel('Proportion of Molecules')
            plt.xlim(0, xscale_factor)        
            plt.ylim(0, 1.1 * ymax)         
            plt.ion()
            plt.pause(1.5)
            plt.show()
            plt.ioff()

def show_message1(m: str):                                
    myCell = document.getElementById("message1")          
    myCell.innerText = m

def show_messages():
    global message1

    show_message1(message1)
 

# Function to start/continue the animation 
async def main():                          
    global animating, step, coll, tmax, async_flag    

    animating = True                       
    for t in range(tmax):                      
        animate(t)                                             
        document.getElementById("Step").innerText = str(step);
        document.getElementById("Collision").innerText = str(coll);
        await asyncio.sleep(.001)                              

    # get all tasks
    tasks = asyncio.all_tasks()

    # cancel all tasks
    for task in tasks:
        task.cancel()          

    async_flag = True
