SoFunction
Updated on 2024-11-14

Usage of loss, optimizer, metrics in keras

The next step after building a model architecture with keras is to perform a compile operation. When compiling, you often need to specify three parameters

loss

optimizer

metrics

There are two types of options for these three parameters:

Using Strings

Use identifiers, e.g., , functions under the metrics package

Example:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
(loss='categorical_crossentropy',
  optimizer=sgd,
  metrics=['accuracy'])

Because sometimes strings and sometimes identifiers can be used, it's interesting to see how this works. The following is a study of the acquisition of three objects: optimizer, loss, and metrics, respectively.

optimizer

A model can only have one optimizer, and only one optimizer can be specified when performing compilation.

In, there is a get function to get an instance of the optimizer based on the optimizer parameter passed in by the user:

def get(identifier):
 # If the backend is tensorflow and uses an instance of the optimizer that comes with tensorflow, you can use the tensorflow native optimizer directly
 if () == 'tensorflow':
 # Wrap TF optimizer instances
 if isinstance(identifier, ):
  return TFOptimizer(identifier)
 # If the optimizer is defined and parameterized as a json string
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif isinstance(identifier, six.string_types):
 # If the optimizer is specified as a string, then the optimizer's default configuration parameters are used
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 if isinstance(identifier, Optimizer):
 # Instances of Optimizer if using keras wrappers
 return identifier
 else:
 raise ValueError('Could not interpret optimizer identifier: ' +
    str(identifier))

Among other things, the deserilize(config) function serves to deserialize the optimizer to make an instance.

loss

function also has a get(identifier) method. One of the things to note is the following:

If identifier is the name of a function that can be called, that is, a customized loss function that returns a tensor. This makes it easy to implement a custom loss function. In addition to using str and dict identifiers, we can also use loss functions directly from the package.

def get(identifier):
 if identifier is None:
 return None
 if isinstance(identifier, six.string_types):
 identifier = str(identifier)
 return deserialize(identifier)
 if isinstance(identifier, dict):
 return deserialize(identifier)
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'loss function identifier:', identifier)

metrics

In the () function, both optimizer and loss are in singular form, and only metrics is in plural form. Because a model can only specify one optimizer and loss, but can specify multiple metrics. metrics is also the most complex of the three processing logic.

At the very core of keras there is the following function that handles metrics. This function actually does two things:

Find the function corresponding to a specific metric based on the input metric

Compute the metric tensor

There are two steps in finding the metric counterpart function:

Use string form to specify accuracy and cross entropy

Functions in use

def handle_metrics(metrics, weights=None):
 metric_name_prefix = 'weighted_' if weights is not None else ''

 for metric in metrics:
 # If the metrics are of the most common kind:accuracy, cross entropy
 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
  # custom handling of accuracy/crossentropy
  # (because of class mode duality)
  output_shape = K.int_shape([i])
  # If the output dimension is 1 or the loss function is a binary loss function, then it is a binary problem and you should use binary accuracy and binary cross-entropy.
  if (output_shape[-1] == 1 or
  self.loss_functions[i] == losses.binary_crossentropy):
  # case: binary accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.binary_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.binary_crossentropy
  # If the loss function is sparse_categorical_crossentropy, then the target y_input is not one-hot, so you need to use sparse's multicategorical quasi-degree and sparse's multicategorical cross entropy.
  elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
  # case: categorical accuracy/crossentropy
  # with sparse targets
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.sparse_categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.sparse_categorical_crossentropy
  else:
  # case: categorical accuracy/crossentropy
  if metric in ('accuracy', 'acc'):
   metric_fn = metrics_module.categorical_accuracy
  elif metric in ('crossentropy', 'ce'):
   metric_fn = metrics_module.categorical_crossentropy
  if metric in ('accuracy', 'acc'):
   suffix = 'acc'
  elif metric in ('crossentropy', 'ce'):
   suffix = 'ce'
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  metric_name = metric_name_prefix + suffix
 else:
  # If the input metric is not a string, then call the metrics module to get the
  metric_fn = metrics_module.get(metric)
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  # Get metric name as string
  if hasattr(metric_fn, 'name'):
  metric_name = metric_fn.name
  else:
  metric_name = metric_fn.__name__
  metric_name = metric_name_prefix + metric_name

 with K.name_scope(metric_name):
  metric_result = weighted_metric_fn(y_true, y_pred,
      weights=weights,
      mask=masks[i])

 # Append to self.metrics_names, self.metric_tensors,
 # self.stateful_metric_names
 if len(self.output_names) > 1:
  metric_name = self.output_names[i] + '_' + metric_name
 # Dedupe name
 j = 1
 base_metric_name = metric_name
 while metric_name in self.metrics_names:
  metric_name = base_metric_name + '_' + str(j)
  j += 1
 self.metrics_names.append(metric_name)
 self.metrics_tensors.append(metric_result)

 # Keep track of state updates created by
 # stateful metrics (. metrics layers).
 if isinstance(metric_fn, Layer) and metric_fn.stateful:
  self.stateful_metric_names.append(metric_name)
  self.stateful_metric_functions.append(metric_fn)
  self.metrics_updates += metric_fn.updates

No matter how metric is used, it will eventually become a function under the metrics package. When specifying accuracy and crossentropy in string form, keras is very intelligent in determining which function under the metrics package should be used. This is because those metric functions under the metrics package have different usage scenarios, for example:

Some deal with y_input in one-hot form (categories of data), some deal with y_input in non-one-hot form

Some deal with metrics for binary classification problems and some deal with metrics for multiclassification problems

When specifying a metric using the strings "accuracy" and "crossentropy", Keras determines which metric function should be used based on the loss function, the output layer's shape In any case, the metric function is used directly. In any case, using the function name directly under the metrics is never wrong.

There is also a get(identifier) function in the file for getting the metric function.

def get(identifier):
 if isinstance(identifier, dict):
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
 elif isinstance(identifier, six.string_types):
 return deserialize(str(identifier))
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'metric function identifier:', identifier)

If the identifier is a string or a dictionary, then a metric function is deserialized based on the identifier.

If the identifier itself is a function name, then the function name is returned directly. This approach provides a great convenience for customizing metrics.

The design philosophy in keras is perfect.

Above this keras in the loss, optimizer, metrics usage is all I have shared with you, I hope to give you a reference, and I hope you support me more.