casbin.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package service
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/casbin/casbin/v2"
  6. "github.com/casbin/casbin/v2/model"
  7. "github.com/casbin/casbin/v2/persist"
  8. "github.com/gogf/gf/v2/frame/g"
  9. "github.com/tiger1103/gfast/v3/internal/app/common/dao"
  10. "github.com/tiger1103/gfast/v3/internal/app/common/model/entity"
  11. )
  12. type cabinImpl struct{}
  13. type adapterCasbin struct {
  14. Enforcer *casbin.SyncedEnforcer
  15. EnforcerErr error
  16. ctx context.Context
  17. }
  18. var (
  19. cb = cabinImpl{}
  20. once sync.Once
  21. ac *adapterCasbin
  22. )
  23. // CasbinEnforcer 获取adapter单例对象
  24. func CasbinEnforcer(ctx context.Context) (enforcer *casbin.SyncedEnforcer, err error) {
  25. once.Do(func() {
  26. ac = cb.newAdapter(ctx)
  27. })
  28. enforcer = ac.Enforcer
  29. err = ac.EnforcerErr
  30. return
  31. }
  32. // 初始化adapter操作
  33. func (s *cabinImpl) newAdapter(ctx context.Context) (a *adapterCasbin) {
  34. a = new(adapterCasbin)
  35. a.initPolicy(ctx)
  36. a.ctx = ctx
  37. return
  38. }
  39. func (a *adapterCasbin) initPolicy(ctx context.Context) {
  40. // Because the DB is empty at first,
  41. // so we need to load the policy from the file adapter (.CSV) first.
  42. e, err := casbin.NewSyncedEnforcer(g.Cfg().MustGet(ctx, "casbin.modelFile").String(), a)
  43. if err != nil {
  44. a.EnforcerErr = err
  45. return
  46. }
  47. a.Enforcer = e
  48. }
  49. // SavePolicy saves policy to database.
  50. func (a *adapterCasbin) SavePolicy(model model.Model) (err error) {
  51. err = a.dropTable()
  52. if err != nil {
  53. return
  54. }
  55. err = a.createTable()
  56. if err != nil {
  57. return
  58. }
  59. for ptype, ast := range model["p"] {
  60. for _, rule := range ast.Policy {
  61. line := savePolicyLine(ptype, rule)
  62. _, err := dao.CasbinRule.Ctx(a.ctx).Data(line).Insert()
  63. if err != nil {
  64. return err
  65. }
  66. }
  67. }
  68. for ptype, ast := range model["g"] {
  69. for _, rule := range ast.Policy {
  70. line := savePolicyLine(ptype, rule)
  71. _, err := dao.CasbinRule.Ctx(a.ctx).Data(line).Insert()
  72. if err != nil {
  73. return err
  74. }
  75. }
  76. }
  77. return
  78. }
  79. func (a *adapterCasbin) dropTable() (err error) {
  80. return
  81. }
  82. func (a *adapterCasbin) createTable() (err error) {
  83. return
  84. }
  85. // LoadPolicy loads policy from database.
  86. func (a *adapterCasbin) LoadPolicy(model model.Model) error {
  87. var lines []*entity.CasbinRule
  88. if err := dao.CasbinRule.Ctx(a.ctx).Scan(&lines); err != nil {
  89. return err
  90. }
  91. for _, line := range lines {
  92. loadPolicyLine(line, model)
  93. }
  94. return nil
  95. }
  96. // AddPolicy adds a policy rule to the storage.
  97. func (a *adapterCasbin) AddPolicy(sec string, ptype string, rule []string) error {
  98. line := savePolicyLine(ptype, rule)
  99. _, err := dao.CasbinRule.Ctx(a.ctx).Data(line).Insert()
  100. return err
  101. }
  102. // RemovePolicy removes a policy rule from the storage.
  103. func (a *adapterCasbin) RemovePolicy(sec string, ptype string, rule []string) error {
  104. line := savePolicyLine(ptype, rule)
  105. err := rawDelete(a, line)
  106. return err
  107. }
  108. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  109. func (a *adapterCasbin) RemoveFilteredPolicy(sec string, ptype string,
  110. fieldIndex int, fieldValues ...string) error {
  111. line := &entity.CasbinRule{}
  112. line.Ptype = ptype
  113. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  114. line.V0 = fieldValues[0-fieldIndex]
  115. }
  116. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  117. line.V1 = fieldValues[1-fieldIndex]
  118. }
  119. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  120. line.V2 = fieldValues[2-fieldIndex]
  121. }
  122. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  123. line.V3 = fieldValues[3-fieldIndex]
  124. }
  125. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  126. line.V4 = fieldValues[4-fieldIndex]
  127. }
  128. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  129. line.V5 = fieldValues[5-fieldIndex]
  130. }
  131. err := rawDelete(a, line)
  132. return err
  133. }
  134. func loadPolicyLine(line *entity.CasbinRule, model model.Model) {
  135. lineText := line.Ptype
  136. if line.V0 != "" {
  137. lineText += ", " + line.V0
  138. }
  139. if line.V1 != "" {
  140. lineText += ", " + line.V1
  141. }
  142. if line.V2 != "" {
  143. lineText += ", " + line.V2
  144. }
  145. if line.V3 != "" {
  146. lineText += ", " + line.V3
  147. }
  148. if line.V4 != "" {
  149. lineText += ", " + line.V4
  150. }
  151. if line.V5 != "" {
  152. lineText += ", " + line.V5
  153. }
  154. persist.LoadPolicyLine(lineText, model)
  155. }
  156. func savePolicyLine(ptype string, rule []string) *entity.CasbinRule {
  157. line := &entity.CasbinRule{}
  158. line.Ptype = ptype
  159. if len(rule) > 0 {
  160. line.V0 = rule[0]
  161. }
  162. if len(rule) > 1 {
  163. line.V1 = rule[1]
  164. }
  165. if len(rule) > 2 {
  166. line.V2 = rule[2]
  167. }
  168. if len(rule) > 3 {
  169. line.V3 = rule[3]
  170. }
  171. if len(rule) > 4 {
  172. line.V4 = rule[4]
  173. }
  174. if len(rule) > 5 {
  175. line.V5 = rule[5]
  176. }
  177. return line
  178. }
  179. func rawDelete(a *adapterCasbin, line *entity.CasbinRule) error {
  180. db := dao.CasbinRule.Ctx(a.ctx).Where("ptype = ?", line.Ptype)
  181. if line.V0 != "" {
  182. db = db.Where("v0 = ?", line.V0)
  183. }
  184. if line.V1 != "" {
  185. db = db.Where("v1 = ?", line.V1)
  186. }
  187. if line.V2 != "" {
  188. db = db.Where("v2 = ?", line.V2)
  189. }
  190. if line.V3 != "" {
  191. db = db.Where("v3 = ?", line.V3)
  192. }
  193. if line.V4 != "" {
  194. db = db.Where("v4 = ?", line.V4)
  195. }
  196. if line.V5 != "" {
  197. db = db.Where("v5 = ?", line.V5)
  198. }
  199. _, err := db.Delete()
  200. return err
  201. }