adapter.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package casbin_adapter_service
  2. import (
  3. "fmt"
  4. "gfast/app/model/admin/casbin_rule"
  5. "github.com/casbin/casbin/v2"
  6. "github.com/casbin/casbin/v2/model"
  7. "github.com/casbin/casbin/v2/persist"
  8. "github.com/gogf/gf/frame/g"
  9. "sync"
  10. )
  11. type Adapter struct{}
  12. var Enforcer *casbin.SyncedEnforcer
  13. var EnforcerErr error
  14. var once sync.Once
  15. //获取adapter单例对象
  16. func GetEnforcer() (*casbin.SyncedEnforcer, error) {
  17. once.Do(func() {
  18. _, EnforcerErr = newAdapter()
  19. })
  20. return Enforcer, EnforcerErr
  21. }
  22. //初始化adapter操作
  23. func newAdapter() (a *Adapter, err error) {
  24. a = new(Adapter)
  25. err = a.initPolicy()
  26. return
  27. }
  28. func (a *Adapter) initPolicy() error {
  29. // Because the DB is empty at first,
  30. // so we need to load the policy from the file adapter (.CSV) first.
  31. e, err := casbin.NewSyncedEnforcer(g.Cfg().GetString("casbin.modelFile"),
  32. g.Cfg().GetString("casbin.policyFile"))
  33. if err != nil {
  34. return err
  35. }
  36. // This is a trick to save the current policy to the DB.
  37. // We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
  38. // The current policy means the policy in the Casbin enforcer (aka in memory).
  39. //err = a.SavePolicy(e.GetModel())
  40. //if err != nil {
  41. // return err
  42. //}
  43. //set adapter
  44. e.SetAdapter(a)
  45. // Clear the current policy.
  46. e.ClearPolicy()
  47. Enforcer = e
  48. // Load the policy from DB.
  49. err = a.LoadPolicy(e.GetModel())
  50. if err != nil {
  51. return err
  52. }
  53. return nil
  54. }
  55. // SavePolicy saves policy to database.
  56. func (a *Adapter) SavePolicy(model model.Model) (err error) {
  57. err = a.dropTable()
  58. if err != nil {
  59. return
  60. }
  61. err = a.createTable()
  62. if err != nil {
  63. return
  64. }
  65. for ptype, ast := range model["p"] {
  66. for _, rule := range ast.Policy {
  67. line := savePolicyLine(ptype, rule)
  68. _, err := casbin_rule.Model.Data(&line).Insert()
  69. if err != nil {
  70. return err
  71. }
  72. }
  73. }
  74. for ptype, ast := range model["g"] {
  75. for _, rule := range ast.Policy {
  76. line := savePolicyLine(ptype, rule)
  77. _, err := casbin_rule.Model.Data(&line).Insert()
  78. if err != nil {
  79. return err
  80. }
  81. }
  82. }
  83. return
  84. }
  85. func (a *Adapter) dropTable() (err error) {
  86. _, err = g.DB("default").Exec(fmt.Sprintf("DROP TABLE %s", casbin_rule.Table))
  87. return
  88. }
  89. func (a *Adapter) createTable() (err error) {
  90. _, err = g.DB("default").Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (ptype VARCHAR(10), v0 VARCHAR(256), v1 VARCHAR(256), v2 VARCHAR(256), v3 VARCHAR(256), v4 VARCHAR(256), v5 VARCHAR(256))", casbin_rule.Table))
  91. return
  92. }
  93. // LoadPolicy loads policy from database.
  94. func (a *Adapter) LoadPolicy(model model.Model) error {
  95. var lines []casbin_rule.Entity
  96. if err := casbin_rule.Model.M.Scan(&lines); err != nil {
  97. return err
  98. }
  99. for _, line := range lines {
  100. loadPolicyLine(line, model)
  101. }
  102. return nil
  103. }
  104. // AddPolicy adds a policy rule to the storage.
  105. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
  106. line := savePolicyLine(ptype, rule)
  107. _, err := casbin_rule.Model.M.Data(&line).Insert()
  108. return err
  109. }
  110. // RemovePolicy removes a policy rule from the storage.
  111. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
  112. line := savePolicyLine(ptype, rule)
  113. err := rawDelete(a, line)
  114. return err
  115. }
  116. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  117. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
  118. line := casbin_rule.Entity{}
  119. line.Ptype = ptype
  120. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  121. line.V0 = fieldValues[0-fieldIndex]
  122. }
  123. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  124. line.V1 = fieldValues[1-fieldIndex]
  125. }
  126. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  127. line.V2 = fieldValues[2-fieldIndex]
  128. }
  129. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  130. line.V3 = fieldValues[3-fieldIndex]
  131. }
  132. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  133. line.V4 = fieldValues[4-fieldIndex]
  134. }
  135. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  136. line.V5 = fieldValues[5-fieldIndex]
  137. }
  138. err := rawDelete(a, line)
  139. return err
  140. }
  141. func loadPolicyLine(line casbin_rule.Entity, model model.Model) {
  142. lineText := line.Ptype
  143. if line.V0 != "" {
  144. lineText += ", " + line.V0
  145. }
  146. if line.V1 != "" {
  147. lineText += ", " + line.V1
  148. }
  149. if line.V2 != "" {
  150. lineText += ", " + line.V2
  151. }
  152. if line.V3 != "" {
  153. lineText += ", " + line.V3
  154. }
  155. if line.V4 != "" {
  156. lineText += ", " + line.V4
  157. }
  158. if line.V5 != "" {
  159. lineText += ", " + line.V5
  160. }
  161. persist.LoadPolicyLine(lineText, model)
  162. }
  163. func savePolicyLine(ptype string, rule []string) casbin_rule.Entity {
  164. line := casbin_rule.Entity{}
  165. line.Ptype = ptype
  166. if len(rule) > 0 {
  167. line.V0 = rule[0]
  168. }
  169. if len(rule) > 1 {
  170. line.V1 = rule[1]
  171. }
  172. if len(rule) > 2 {
  173. line.V2 = rule[2]
  174. }
  175. if len(rule) > 3 {
  176. line.V3 = rule[3]
  177. }
  178. if len(rule) > 4 {
  179. line.V4 = rule[4]
  180. }
  181. if len(rule) > 5 {
  182. line.V5 = rule[5]
  183. }
  184. return line
  185. }
  186. func rawDelete(a *Adapter, line casbin_rule.Entity) error {
  187. db := casbin_rule.Model
  188. db.Where("ptype = ?", line.Ptype)
  189. if line.V0 != "" {
  190. db = db.Where("v0 = ?", line.V0)
  191. }
  192. if line.V1 != "" {
  193. db = db.Where("v1 = ?", line.V1)
  194. }
  195. if line.V2 != "" {
  196. db = db.Where("v2 = ?", line.V2)
  197. }
  198. if line.V3 != "" {
  199. db = db.Where("v3 = ?", line.V3)
  200. }
  201. if line.V4 != "" {
  202. db = db.Where("v4 = ?", line.V4)
  203. }
  204. if line.V5 != "" {
  205. db = db.Where("v5 = ?", line.V5)
  206. }
  207. _, err := db.Delete()
  208. return err
  209. }