# -*- coding: utf-8 -*- """ Created on Tue Jun 28 16:45:06 2022 @author: timof """ import math as m import numpy as np from numba import njit log2 = np.log2 @njit def digit(num, pos): return (num%(2**(pos)))//2**(pos-1) @njit def binstr(num,size): return np.array([digit(num,pos) for pos in range(size,0,-1)], dtype=np.bool_) @njit def binint(state): return int(sum(2**np.linspace(len(state)-1,0, len(state))*state)) @njit def binpart(state, num, part, part_comp): num_str = binstr(num, len(part_comp)) num_dict = {pos:num_str[numpos] for numpos, pos in enumerate(part_comp)} num_state = [sta if pos in part else num_dict[pos] for pos, sta in enumerate(state)] num_state = np.array(num_state) #print(binint(num_state)==sum([num_state[pos]*2**(len(state)-pos-1) for pos in range(len(state))])) return sum([num_state[pos]*2**(len(state)-pos-1) for pos in range(len(state))]) @njit def evolve(state, lenght): #lenght = len(state) layer = int(lenght/2) dim = int(m.sqrt(layer)) state = binstr(state,lenght) lay0 = state[:layer] lay1 = state[layer:] lay0 = lay0.reshape((dim,dim)) lay1 = lay1.reshape((dim,dim)) assert len(lay0)==len(lay1) assert lenght == 2*dim**2 new_lay0 = lay0.copy() new_lay1 = lay1.copy() for pos0, row in enumerate(lay0): for pos1, cell in enumerate(row): alive = sum([int(lay0[pos[0]][pos[1]]) for pos in [(pos0-1, pos1), (pos0+1-dim, pos1), (pos0, pos1-1), (pos0, pos1+1-dim)]]) alive += 1-int(lay1[pos0][pos1]) new_lay0[pos0][pos1] = alive > 2.5 for pos0, row in enumerate(lay1): for pos1, cell in enumerate(row): alive = sum([int(lay1[pos[0]][pos[1]]) for pos in [(pos0-1, pos1), (pos0+1-dim, pos1), (pos0, pos1-1), (pos0, pos1+1-dim)]]) alive += int(lay0[pos0][pos1]) new_lay1[pos0][pos1] = alive > 2.5 new_state = np.hstack((np.reshape(new_lay0,layer),np.reshape(new_lay1,layer))) return binint(new_state) @njit def update(dim): size = 2*dim**2 sig = 2**size return np.array([evolve(phi, size) for phi in range(sig)]) @njit def ei(transition, state, partition, lenght): partition = binstr(partition, lenght) state = binstr(state, lenght) # 2.33 µs ± 26.3 ns per loop parts = [[pos for pos, num in enumerate(partition) if not num], [pos for pos, num in enumerate(partition) if num] ] # 3.21 µs ± 93.7 ns per loop states = [[binpart(state, num, parts[0], parts[1]) for num in range(2**len(parts[1]))], [binpart(state, num, parts[1], parts[0]) for num in range(2**len(parts[0]))]] #print(parts, state, binpart(state,2,parts[0])) # 874 µs ± 9.06 µs per loop effects = [[transition[phi] for phi in states[0]], [transition[phi] for phi in states[1]]] # 899 µs ± 23.5 µs per loop ## does a faster calculation for probs exist. skipping potens? ## potens = [[ binint( np.array([val for pos, val in enumerate(binstr(phi,lenght)) if pos in parts[i]] )) for phi in effects[i] ] for i in range(2) ] # 925 µs ± 14.3 µs per loop probs = [ {eff:potens[i].count(eff)/len(potens[i]) for eff in set(potens[i]) } for i in range(2) ] # 925 µs ± 18.9 µs per loop entropies = [-sum([p*np.log2(p) for p in list(probs[i].values())]) for i in range(2)] return sum(entropies) @njit def MIP(Transition, State, lenght): candidate = (np.inf,-1) for part in range(1, 2**(lenght-1)): effinf = ei(Transition, State, part, lenght) if effinf < candidate[0]: candidate = (effinf, part) return candidate update3 = update(3) def print_MIPs(Transition, lenght): candidate = ((0,0),-1) for state in range(len(Transition)): MIP_c = MIP(Transition, state, lenght) if MIP_c[0] > candidate[0][0]: candidate = (MIP_c, state) print("MIP({}): {}".format(state,MIP_c)) return "MIP({}): {}".format(candidate[1],candidate[0]) print_MIPs(update3,8)