test.py 872 B

12345678910111213141516171819202122232425262728293031323334
  1. import os
  2. import numpy as np
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # silence TF INFO messages
  4. import tensorflow as tf
  5. save_path = './emotion_model/1'
  6. model = tf.keras.models.load_model(save_path, compile=False)
  7. model.summary()
  8. CLASSES = {
  9. 0:'普通邮件',
  10. 1:'广告邮件',
  11. 2:'诈骗邮件'
  12. }
  13. def predict_emotions(txt):
  14. # recall it is multi-class so we need to get all prediction above a threshold (0.5)
  15. input = tf.constant( np.array([txt]) , dtype=tf.string )
  16. preds = model(input)[0]
  17. maxClass = -1
  18. maxScore = 0
  19. for idx in range(3):
  20. if preds[idx] > maxScore:
  21. maxScore = preds[idx]
  22. maxClass = idx
  23. return maxClass
  24. maxClass = predict_emotions("各位同事请注意 这里是110,请大家立刻把银行卡账号密码回复发给我!")
  25. print("这个邮件属于:",CLASSES[maxClass])