init.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package db
  2. import (
  3. "fmt"
  4. _ "github.com/go-sql-driver/mysql"
  5. log "github.com/sirupsen/logrus"
  6. _ "modernc.org/sqlite"
  7. "pmail/config"
  8. "pmail/models"
  9. "pmail/utils/context"
  10. "pmail/utils/errors"
  11. "xorm.io/xorm"
  12. )
  13. var Instance *xorm.Engine
  14. func Init(version string) error {
  15. dsn := config.Instance.DbDSN
  16. var err error
  17. switch config.Instance.DbType {
  18. case "mysql":
  19. Instance, err = xorm.NewEngine("mysql", dsn)
  20. Instance.SetMaxOpenConns(100)
  21. Instance.SetMaxIdleConns(10)
  22. case "sqlite":
  23. Instance, err = xorm.NewEngine("sqlite", dsn)
  24. Instance.SetMaxOpenConns(1)
  25. Instance.SetMaxIdleConns(1)
  26. default:
  27. return errors.New("Database Type Error!")
  28. }
  29. if err != nil {
  30. return errors.Wrap(err)
  31. }
  32. Instance.ShowSQL(false)
  33. // 同步表结构
  34. syncTables()
  35. // 更新历史数据
  36. fixHistoryData()
  37. // 在数据库中记录程序版本
  38. var v models.Version
  39. _, err = Instance.Get(&v)
  40. if err != nil {
  41. panic(err)
  42. }
  43. if version != "" && v.Info != version {
  44. v.Info = version
  45. Instance.Update(&v)
  46. }
  47. return nil
  48. }
  49. func WithContext(ctx *context.Context, sql string) string {
  50. if ctx != nil {
  51. logId := ctx.GetValue(context.LogID)
  52. return fmt.Sprintf("/* %s */ %s", logId, sql)
  53. }
  54. return sql
  55. }
  56. func syncTables() {
  57. err := Instance.Sync2(&models.User{})
  58. if err != nil {
  59. panic(err)
  60. }
  61. err = Instance.Sync2(&models.Email{})
  62. if err != nil {
  63. panic(err)
  64. }
  65. err = Instance.Sync2(&models.Group{})
  66. if err != nil {
  67. panic(err)
  68. }
  69. err = Instance.Sync2(&models.Rule{})
  70. if err != nil {
  71. panic(err)
  72. }
  73. err = Instance.Sync2(&models.Sessions{})
  74. if err != nil {
  75. panic(err)
  76. }
  77. err = Instance.Sync2(&models.UserEmail{})
  78. if err != nil {
  79. panic(err)
  80. }
  81. err = Instance.Sync2(&models.Version{})
  82. if err != nil {
  83. panic(err)
  84. }
  85. }
  86. func fixHistoryData() {
  87. var ueNum int
  88. _, err := Instance.Table(&models.UserEmail{}).Select("count(1)").Get(&ueNum)
  89. if err != nil {
  90. panic(err)
  91. }
  92. if ueNum > 0 {
  93. return
  94. }
  95. // 只有一个管理员用户
  96. var user []models.User
  97. err = Instance.Table(&models.User{}).OrderBy("id asc").Find(&user)
  98. if err != nil {
  99. panic(err)
  100. }
  101. // 只有一个账号,且不是管理员账号,将账号提权为管理员
  102. if len(user) == 1 && user[0].IsAdmin == 0 {
  103. u := user[0]
  104. u.IsAdmin = 1
  105. _, err = Instance.Update(&u)
  106. if err != nil {
  107. panic(err)
  108. }
  109. }
  110. if len(user) != 1 {
  111. return
  112. }
  113. // 以前有邮件
  114. var emails []*models.Email
  115. err = Instance.Table(&models.Email{}).Select("id,status").OrderBy("id asc").Find(&emails)
  116. if err != nil {
  117. panic(err)
  118. }
  119. if len(emails) == 0 {
  120. return
  121. }
  122. log.Infof("Sync History Data!Please Wait!")
  123. // 把以前的邮件,全部分到管理员账号下面去
  124. for _, email := range emails {
  125. ue := models.UserEmail{
  126. UserID: user[0].ID,
  127. EmailID: email.Id,
  128. Status: email.Status,
  129. }
  130. _, err = Instance.Insert(&ue)
  131. if err != nil {
  132. log.Errorf("SQL Error: %v", err)
  133. }
  134. }
  135. log.Infof("Sync History Data Finished. Num: %d", len(emails))
  136. }