Batch en Ejercicio LSTM

Batch en Ejercicio LSTM

de Ariel Michel Malowany Grossman -
Número de respuestas: 7

Hola, cómo están?

En la función get_batches generé una lista que contiene 28 tensores, uno por batch. 

No logro evaluar la red por batch, cómo puedo hacer para que la red entienda que el input es un batch?

Por ejemplo, el primer batch que debería pasar es train_word_batches[0] que tiene shape (64, 6).

Estuve viendo que una solución puede ser agregar una dimensión más, de forma que quede (BATCH_SIZE, SEQ_LENGTH, NUM_FEATURES), pero cuando hago eso me dice que el input debe ser 2D o 3D y no 4D.

Alguno sabe cómo resolverlo?

En respuesta a Ariel Michel Malowany Grossman

Re: Batch en Ejercicio LSTM

de Luis Chiruzzo -
Hola,
¿Le pusiste batch_first=True a la LSTM? Fijate si no es por eso.
Saludos,
Luis
En respuesta a Luis Chiruzzo

Re: Batch en Ejercicio LSTM

de Ariel Michel Malowany Grossman -

Si! Ya tengo los datos pasando por la red, el problema que tengo es la baja accuracy, me da 30% en el mejor de los casos. Hice un poco de debugging y me di cuenta que el problema está en que el clasificador solamente clasifica 3 clases. Para llegar a eso utilicé el argumento None en metrics.MulticlassAccuracy. 

Por ejemplo el accuracy por clase en test es: tensor([0.8898, nan, nan, 0.0193, nan, nan, nan, nan, nan, nan, nan, nan, nan, 0.9860, nan, nan, nan]). Las clases que detecta son None, ADV, PUNCT. 

Saben a qué se puede deber este comportamiento en el clasificador? 

En respuesta a Ariel Michel Malowany Grossman

Re: Batch en Ejercicio LSTM

de Luis Chiruzzo -
Hola,

Contanos cómo es la arquitectura de la red, cantidad de capas, tamaño de capa, etc. ¿Durante cuántas épocas la entrenaste? ¿Usaste early stopping?

Saludos,
Luis
En respuesta a Luis Chiruzzo

Re: Batch en Ejercicio LSTM

de Ariel Michel Malowany Grossman -
Hola!

La implementé con una sola capa oculta con 100 unidades. Entrené con learning rate de 0.1, en 10 epochs. No usé early stopping.

Saludos

Ariel
En respuesta a Ariel Michel Malowany Grossman

Re: Batch en Ejercicio LSTM

de Mathias Etcheverry -
Hola,
¿ese accuracy es para un batch o para todo el conjunto? Sí es para un batch podría ser que algunas clases no aparezcan en ese batch y por eso el NAN.
Igual llama la atención que las clases que sí aparecen sean solamente esas (None, ADV y PUNCT).

Por otro lado, probaría con un learning rate mas chico (ej. 0.01 o 0.001). Además puede ser util considerar gradient clipping por tratarse de una RNN.

¿Estás usando SGD? Si es así probaría con Adam.

Además puede ser útil mirar la gráfica del loss (y acc) en train, para tener más idea de como está aprendiendo el modelo en el transcurso de las epocas.

Saludos,
Mathias
En respuesta a Mathias Etcheverry

Re: Batch en Ejercicio LSTM

de Ariel Michel Malowany Grossman -
Hola!

Gracias a ambos, ya entendí cuales eran los problemas de mi implementación. Logré tener 84.24% en test.

En relación al optimizador, la performance con Adam y learning rate = 0.001 es muy buena.

En cuanto al accuracy, había mucho ruido por problemas de implementación. Estaba calculando accuracy por batch por ejemplo, al pedir accuracy por clase traía NAN cuando la clase no aparecía, y al sumar NAN el resultado acumulado en las iteraciones era NAN, las únicas clases presentes en todos los batches eran None, ADV, PUNCT.

Otra cosa que me llamó la atención de metrics.MulticlassAccuracy es la modalidad macro vs micro, por lo que investigué muy por arriba micro es mejor cuando hay desbalanceo de clases, que es lo que sucede en este dataset.

Saludos

Ariel
En respuesta a Ariel Michel Malowany Grossman

Re: Batch en Ejercicio LSTM

de Mathias Etcheverry -
Buenísimo!
Gracias por la explicación de los accuracy NAN. Está muy clara y le puede ser útil para otros grupos.
Respecto a micro vs macro, en macro el resultado de cada clase influye por igual independientemente de la cantidad de instancias (es el promedio entre clases). Micro es a nivel de instancia y se tiene que micro-precision=micro-recall=accuracy. Además siempre es útil tener en cuenta precision y recall por clase.
Saludos