ProblemPython

Python-Programme
Anonymous
 Problem

Post by Anonymous »

Ich versuche derzeit, ein Customproblem aus der BaseProblem -Klasse in Tensorneat zu erstellen, die eine JAX -basierte Bibliothek ist. Bei dem Versuch, die Funktion dieser Klasse zu implementieren, verwende ich eine booleale Maske, aber ich habe Probleme, sie zum Laufen zu bringen. Mein Code führt zu Jax.Eror.nonconcretbooleanIndexError: Array Boolesche Indizes müssen konkret sein; Got Shapedarray (Bool [N, N]) , was meiner Meinung nach auf einige meiner Arrays zurückzuführen ist, die keine bestimmte Form haben. Wie umgab ich das?

Code: Select all

import numpy as np

ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)

ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)

a = (ran_int[ran_bool]>0).astype(int)
print(a)
< /code>
Es könnte eine Ausgabe wie folgt ergeben: < /p>
[[2 2]
[3 4]]
[[ True False]
[ True  True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!
Aber in JAX führt die gleiche Denkweise in den NonconcreteBooleIndexError Fehler, den ich bekommen habe. überschreiben ">

Code: Select all

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)  # should be shape (n, 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T  # shape (n, n)

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)

pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n)

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
Ich habe in Betracht gezogen, Jnp.where zu lösen, um das Problem zu lösen, aber die resultierenden paarweisen Pearwise_Labels und paarweise_Predictions haben eine andere Form als ich erwarte (erwarte (ich erwarte (ich erwarte nämlich (n, n) ), wie im folgenden Code zu sehen ist:

Code: Select all

#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)  # should be shape (n, 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T  # shape (n, n)

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)

pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) # shape (n, n)

pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) # shape (n, n)

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (n ,n)

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
Ich befürchte, dass die unterschiedlichen Formen von paarwiell_Predictions und paarwiell_labels nach der Verwendung von JNP.where zu einem anderen Verlust führen, als wenn ich gerade verwendet hätte, Die Boolesche Maske wie ich in NP . Es gibt auch die Tatsache, dass ich einen weiteren Fehler erhalte, der später in der Pipeline mit dem Ausgaberteil ValueError: max () iterable Argument auftritt >. This is curiously circumvented by changing pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold to pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold, which probably also results in some loss that ist falsch. Py PrettyPrint-Override ">

Code: Select all

from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem

def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))

# Define the custom Problem
class CustomProblem(BaseProblem):

jitable = True  # necessary

def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold

def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)  # should be shape (len(labels), 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
pairwise_predictions = predict - predict.T  # shape (len(inputs), len(inputs))

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
#pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
print(pairs_to_keep.shape)

pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape)

pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape)

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels)  # shape (len(labels), len(labels))

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss

@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)

@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)

def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)

loss = jnp.mean(jnp.square(predict - self.labels))

n_elements = 5
if n_elements >  len(self.inputs):
n_elements = len(self.inputs)

msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
msg += f"total loss: {loss}\n"
print(msg)

algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769,  # must at least be same as inputs and outputs
max_conns=768,  # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)

INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y

problem = CustomProblem(INPUTS, LABELS, 0.25)

print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
fitness_target=1,
seed=42,
)

state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post