train.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import retvec
  2. import datasets
  3. import os
  4. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # silence TF INFO messages
  5. import tensorflow as tf
  6. import numpy as np
  7. from tensorflow.keras import layers
  8. from retvec.tf import RETVecTokenizer
  9. NUM_CLASSES = 3
  10. def getData(folder_path):
  11. labels = []
  12. msgs = []
  13. # 遍历文件夹
  14. for root, dirs, files in os.walk(folder_path):
  15. # 遍历当前文件夹下的所有文件
  16. for filename in files:
  17. # 判断是否为csv文件
  18. if filename.endswith(".csv"):
  19. file_path = os.path.join(root, filename)
  20. # 读取csv文件内容
  21. with open(file_path, 'r', errors='ignore') as csv_file:
  22. for line in csv_file:
  23. if line[0] == '' or line[0]==' ':
  24. continue
  25. labels.append([int(str.strip(line[0]))])
  26. msgs.append(line[3:])
  27. return np.array(msgs), np.array(labels)
  28. trainDataMsgs, trainDataLabels = getData("./trainData")
  29. testDataMsgs, testDataLabels = getData("./testData")
  30. # preparing data
  31. x_train = tf.constant(trainDataMsgs, dtype=tf.string)
  32. print(x_train.shape)
  33. y_train = np.zeros((len(x_train),NUM_CLASSES))
  34. for idx, ex in enumerate(trainDataLabels):
  35. for val in ex:
  36. y_train[idx][val] = 1
  37. # test data
  38. x_test = tf.constant(testDataMsgs, dtype=tf.string)
  39. y_test = np.zeros((len(x_test),NUM_CLASSES))
  40. for idx, ex in enumerate(testDataLabels):
  41. for val in ex:
  42. y_test[idx][val] = 1
  43. # using strings directly requires to put a shape of (1,) and dtype tf.string
  44. inputs = layers.Input(shape=(1,), name="token", dtype=tf.string)
  45. # add RETVec tokenizer layer with default settings -- this is all you have to do to build a model with RETVec!
  46. x = RETVecTokenizer(model='retvec-v1')(inputs)
  47. # standard two layer LSTM
  48. x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
  49. x = layers.Bidirectional(layers.LSTM(64))(x)
  50. outputs = layers.Dense(NUM_CLASSES, activation='sigmoid')(x)
  51. model = tf.keras.Model(inputs, outputs)
  52. model.summary()
  53. # compile and train the model
  54. batch_size = 256
  55. epochs = 2
  56. model.compile('adam', 'binary_crossentropy', ['acc'])
  57. history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size,
  58. validation_data=(x_test, y_test))
  59. # saving the model
  60. save_path = './emotion_model/1'
  61. model.save(save_path)