db.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package setup
  2. import (
  3. "encoding/json"
  4. "os"
  5. "pmail/config"
  6. "pmail/db"
  7. "pmail/models"
  8. "pmail/utils/array"
  9. "pmail/utils/context"
  10. "pmail/utils/errors"
  11. "pmail/utils/file"
  12. "pmail/utils/password"
  13. )
  14. func GetDatabaseSettings(ctx *context.Context) (string, string, error) {
  15. configData, err := ReadConfig()
  16. if err != nil {
  17. return "", "", errors.Wrap(err)
  18. }
  19. if configData.DbType == "" && configData.DbDSN == "" {
  20. return config.DBTypeSQLite, "./config/pmail.db", nil
  21. }
  22. return configData.DbType, configData.DbDSN, nil
  23. }
  24. func GetAdminPassword(ctx *context.Context) (string, error) {
  25. users := []*models.User{}
  26. err := db.Instance.Select(&users, "select * from user")
  27. if err != nil {
  28. return "", errors.Wrap(err)
  29. }
  30. if len(users) > 0 {
  31. return users[0].Account, nil
  32. }
  33. return "", nil
  34. }
  35. func SetAdminPassword(ctx *context.Context, account, pwd string) error {
  36. encodePwd := password.Encode(pwd)
  37. res, err := db.Instance.Exec(db.WithContext(ctx, "INSERT INTO user (account, name, password) VALUES (?, 'admin',?)"), account, encodePwd)
  38. if err != nil {
  39. return errors.Wrap(err)
  40. }
  41. id, err := res.LastInsertId()
  42. if err != nil {
  43. return errors.Wrap(err)
  44. }
  45. _, err = db.Instance.Exec(db.WithContext(ctx, "INSERT INTO user_auth (user_id, email_account) VALUES (?, '*')"), id)
  46. if err != nil {
  47. return errors.Wrap(err)
  48. }
  49. return nil
  50. }
  51. func SetDatabaseSettings(ctx *context.Context, dbType, dbDSN string) error {
  52. configData, err := ReadConfig()
  53. if err != nil {
  54. return errors.Wrap(err)
  55. }
  56. if !array.InArray(dbType, config.DBTypes) {
  57. return errors.New("dbtype error")
  58. }
  59. if dbDSN == "" {
  60. return errors.New("DSN error")
  61. }
  62. configData.DbType = dbType
  63. configData.DbDSN = dbDSN
  64. err = WriteConfig(configData)
  65. if err != nil {
  66. return errors.Wrap(err)
  67. }
  68. config.Init()
  69. // 检查数据库是否能正确连接
  70. err = db.Init()
  71. if err != nil {
  72. return errors.Wrap(err)
  73. }
  74. return nil
  75. }
  76. func WriteConfig(cfg *config.Config) error {
  77. bytes, _ := json.Marshal(cfg)
  78. err := os.WriteFile("./config/config.json", bytes, 0666)
  79. if err != nil {
  80. return errors.Wrap(err)
  81. }
  82. return nil
  83. }
  84. func ReadConfig() (*config.Config, error) {
  85. configData := config.Config{
  86. DkimPrivateKeyPath: "config/dkim/dkim.priv",
  87. SSLPrivateKeyPath: "config/ssl/private.key",
  88. SSLPublicKeyPath: "config/ssl/public.crt",
  89. }
  90. if !file.PathExist("./config/config.json") {
  91. bytes, _ := json.Marshal(configData)
  92. err := os.WriteFile("./config/config.json", bytes, 0666)
  93. if err != nil {
  94. return nil, errors.Wrap(err)
  95. }
  96. } else {
  97. cfgData, err := os.ReadFile("./config/config.json")
  98. if err != nil {
  99. return nil, errors.Wrap(err)
  100. }
  101. err = json.Unmarshal(cfgData, &configData)
  102. if err != nil {
  103. return nil, errors.Wrap(err)
  104. }
  105. }
  106. return &configData, nil
  107. }