Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last Code Block Doesn't Work #6

Open
felixm3 opened this issue Oct 12, 2023 · 1 comment
Open

Last Code Block Doesn't Work #6

felixm3 opened this issue Oct 12, 2023 · 1 comment

Comments

@felixm3
Copy link

felixm3 commented Oct 12, 2023

Hello,

Thank you for sharing the primer.

Everything works except the last code block.

sequence_index = 1999  # You can change this to compute the gradient for a different example. But if so, change the coloring below as well.
sal = compute_salient_bases(model, input_features[sequence_index])

plt.figure(figsize=[16,5])
barlist = plt.bar(np.arange(len(sal)), sal)
[barlist[i].set_color('C1') for i in range(5,17)]  # Change the coloring here if you change the sequence index.
plt.xlabel('Bases')
plt.ylabel('Magnitude of saliency values')
plt.xticks(np.arange(len(sal)), list(sequences[sequence_index]));
plt.title('Saliency map for bases in one of the positive sequences'
          ' (green indicates the actual bases in motif)');

Running this on Google Colab returns the error below:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-10-b6400cc2276d>](https://localhost:8080/#) in <cell line: 2>()
      1 sequence_index = 1999  # You can change this to compute the gradient for a different example. But if so, change the coloring below as well.
----> 2 sal = compute_salient_bases(model, input_features[sequence_index])
      3 
      4 plt.figure(figsize=[16,5])
      5 barlist = plt.bar(np.arange(len(sal)), sal)

[<ipython-input-9-9666bf2dbd7f>](https://localhost:8080/#) in compute_salient_bases(model, x)
      3 def compute_salient_bases(model, x):
      4   input_tensors = [model.input]
----> 5   gradients = model.optimizer.get_gradients(model.output[0][1], model.input)
      6   compute_gradients = K.function(inputs = input_tensors, outputs = gradients)
      7 

AttributeError: 'Adam' object has no attribute 'get_gradients'

@felixm3
Copy link
Author

felixm3 commented Oct 12, 2023

Looks like it's an issue of TensorFlow 1.x vs 2.x

Changing the last two code blocks to the below works:


import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

def compute_salient_bases(model, x):
    @tf.function
    def compute_gradients(x):
        with tf.GradientTape() as tape:
            tape.watch(x)
            logits = model(x)
            prob = logits[:, 1]  # Assuming binary classification, change the index if needed
        gradients = tape.gradient(prob, x)
        return gradients

    x_value = np.expand_dims(x, axis=0)
    gradients = compute_gradients(x_value)
    gradients = tf.where(gradients == None, tf.zeros_like(gradients), gradients)
    sal = tf.reduce_sum(gradients * x, axis=2)
    sal = tf.clip_by_value(sal, clip_value_min=0, clip_value_max=tf.reduce_max(sal))
    return sal.numpy()

sequence_index = 1999  # You can change this to compute the gradient for a different example. But if so, change the coloring below as well.
sal = compute_salient_bases(model, input_features[sequence_index])

plt.figure(figsize=[16, 5])
barlist = plt.bar(np.arange(len(sal[0])), sal[0])
[barlist[i].set_color('C1') for i in range(5, 17)]  # Change the coloring here if you change the sequence index.
plt.xlabel('Bases')
plt.ylabel('Magnitude of saliency values')
plt.xticks(np.arange(len(sal[0])), list(sequences[sequence_index]))
plt.title('Saliency map for bases in one of the positive sequences'
          ' (green indicates the actual bases in the motif)')
plt.show()


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant