train.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import retvec
  2. import datasets
  3. import os
  4. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # silence TF INFO messages
  5. import tensorflow as tf
  6. import numpy as np
  7. from tensorflow.keras import layers
  8. from retvec.tf import RETVecTokenizer
  9. NUM_CLASSES = 3
  10. def getData(folder_path):
  11. labels = []
  12. msgs = []
  13. # 遍历文件夹
  14. for root, dirs, files in os.walk(folder_path):
  15. # 遍历当前文件夹下的所有文件
  16. for filename in files:
  17. # 判断是否为csv文件
  18. if filename.endswith(".csv"):
  19. file_path = os.path.join(root, filename)
  20. # 读取csv文件内容
  21. with open(file_path, 'r', errors='ignore') as csv_file:
  22. for line in csv_file:
  23. labels.append([int(str.strip(line[0]))])
  24. msgs.append(line[3:])
  25. return np.array(msgs), np.array(labels)
  26. trainDataMsgs, trainDataLabels = getData("./trainData")
  27. testDataMsgs, testDataLabels = getData("./testData")
  28. # preparing data
  29. x_train = tf.constant(trainDataMsgs, dtype=tf.string)
  30. print(x_train.shape)
  31. y_train = np.zeros((len(x_train),NUM_CLASSES))
  32. for idx, ex in enumerate(trainDataLabels):
  33. for val in ex:
  34. y_train[idx][val] = 1
  35. # test data
  36. x_test = tf.constant(testDataMsgs, dtype=tf.string)
  37. y_test = np.zeros((len(x_test),NUM_CLASSES))
  38. for idx, ex in enumerate(testDataLabels):
  39. for val in ex:
  40. y_test[idx][val] = 1
  41. # using strings directly requires to put a shape of (1,) and dtype tf.string
  42. inputs = layers.Input(shape=(1,), name="token", dtype=tf.string)
  43. # add RETVec tokenizer layer with default settings -- this is all you have to do to build a model with RETVec!
  44. x = RETVecTokenizer(model='retvec-v1')(inputs)
  45. # standard two layer LSTM
  46. x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
  47. x = layers.Bidirectional(layers.LSTM(64))(x)
  48. outputs = layers.Dense(NUM_CLASSES, activation='sigmoid')(x)
  49. model = tf.keras.Model(inputs, outputs)
  50. model.summary()
  51. # compile and train the model
  52. batch_size = 256
  53. epochs = 2
  54. model.compile('adam', 'binary_crossentropy', ['acc'])
  55. history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size,
  56. validation_data=(x_test, y_test))
  57. # saving the model
  58. save_path = './emotion_model/1'
  59. model.save(save_path)