casbin_adapter.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package adapterUtils
  2. import (
  3. "fmt"
  4. "github.com/casbin/casbin/v2/model"
  5. "github.com/casbin/casbin/v2/persist"
  6. "github.com/gogf/gf/database/gdb"
  7. "runtime"
  8. )
  9. type CasbinRule struct {
  10. PType string `json:"ptype"`
  11. V0 string `json:"v0"`
  12. V1 string `json:"v1"`
  13. V2 string `json:"v2"`
  14. V3 string `json:"v3"`
  15. V4 string `json:"v4"`
  16. V5 string `json:"v5"`
  17. }
  18. // Adapter represents the gdb adapter for policy storage.
  19. type Adapter struct {
  20. DriverName string
  21. DataSourceName string
  22. TableName string
  23. Db gdb.DB
  24. }
  25. // finalizer is the destructor for Adapter.
  26. func finalizer(a *Adapter) {
  27. // 注意不用的时候不需要使用Close方法关闭数据库连接(并且gdb也没有提供Close方法),
  28. // 数据库引擎底层采用了链接池设计,当链接不再使用时会自动关闭
  29. a.Db = nil
  30. }
  31. // NewAdapter is the constructor for Adapter.
  32. func NewAdapter(driverName string, dataSourceName string) (*Adapter, error) {
  33. a := &Adapter{}
  34. a.DriverName = driverName
  35. a.DataSourceName = dataSourceName
  36. a.TableName = "casbin_rule"
  37. // Open the DB, create it if not existed.
  38. err := a.open()
  39. if err != nil {
  40. return nil, err
  41. }
  42. // Call the destructor when the object is released.
  43. runtime.SetFinalizer(a, finalizer)
  44. return a, nil
  45. }
  46. // NewAdapterFromOptions is the constructor for Adapter with existed connection
  47. func NewAdapterFromOptions(adapter *Adapter) (*Adapter, error) {
  48. if adapter.TableName == "" {
  49. adapter.TableName = "casbin_rule"
  50. }
  51. if adapter.Db == nil {
  52. err := adapter.open()
  53. if err != nil {
  54. return nil, err
  55. }
  56. runtime.SetFinalizer(adapter, finalizer)
  57. }
  58. return adapter, nil
  59. }
  60. func (a *Adapter) open() error {
  61. var err error
  62. var db gdb.DB
  63. gdb.SetConfig(gdb.Config{
  64. "casbin": gdb.ConfigGroup{
  65. gdb.ConfigNode{
  66. Type: a.DriverName,
  67. LinkInfo: a.DataSourceName,
  68. Role: "master",
  69. Weight: 100,
  70. },
  71. },
  72. })
  73. db, err = gdb.New("casbin")
  74. if err != nil {
  75. return err
  76. }
  77. a.Db = db
  78. return a.createTable()
  79. }
  80. func (a *Adapter) close() error {
  81. // 注意不用的时候不需要使用Close方法关闭数据库连接(并且gdb也没有提供Close方法),
  82. // 数据库引擎底层采用了链接池设计,当链接不再使用时会自动关闭
  83. a.Db = nil
  84. return nil
  85. }
  86. func (a *Adapter) createTable() error {
  87. _, err := a.Db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (ptype VARCHAR(10), v0 VARCHAR(256), v1 VARCHAR(256), v2 VARCHAR(256), v3 VARCHAR(256), v4 VARCHAR(256), v5 VARCHAR(256))", a.TableName))
  88. return err
  89. }
  90. func (a *Adapter) dropTable() error {
  91. _, err := a.Db.Exec(fmt.Sprintf("DROP TABLE %s", a.TableName))
  92. return err
  93. }
  94. func loadPolicyLine(line CasbinRule, model model.Model) {
  95. lineText := line.PType
  96. if line.V0 != "" {
  97. lineText += ", " + line.V0
  98. }
  99. if line.V1 != "" {
  100. lineText += ", " + line.V1
  101. }
  102. if line.V2 != "" {
  103. lineText += ", " + line.V2
  104. }
  105. if line.V3 != "" {
  106. lineText += ", " + line.V3
  107. }
  108. if line.V4 != "" {
  109. lineText += ", " + line.V4
  110. }
  111. if line.V5 != "" {
  112. lineText += ", " + line.V5
  113. }
  114. persist.LoadPolicyLine(lineText, model)
  115. }
  116. // LoadPolicy loads policy from database.
  117. func (a *Adapter) LoadPolicy(model model.Model) error {
  118. var lines []CasbinRule
  119. if err := a.Db.Table(a.TableName).Scan(&lines); err != nil {
  120. return err
  121. }
  122. for _, line := range lines {
  123. loadPolicyLine(line, model)
  124. }
  125. return nil
  126. }
  127. func savePolicyLine(ptype string, rule []string) CasbinRule {
  128. line := CasbinRule{}
  129. line.PType = ptype
  130. if len(rule) > 0 {
  131. line.V0 = rule[0]
  132. }
  133. if len(rule) > 1 {
  134. line.V1 = rule[1]
  135. }
  136. if len(rule) > 2 {
  137. line.V2 = rule[2]
  138. }
  139. if len(rule) > 3 {
  140. line.V3 = rule[3]
  141. }
  142. if len(rule) > 4 {
  143. line.V4 = rule[4]
  144. }
  145. if len(rule) > 5 {
  146. line.V5 = rule[5]
  147. }
  148. return line
  149. }
  150. // SavePolicy saves policy to database.
  151. func (a *Adapter) SavePolicy(model model.Model) error {
  152. err := a.dropTable()
  153. if err != nil {
  154. return err
  155. }
  156. err = a.createTable()
  157. if err != nil {
  158. return err
  159. }
  160. for ptype, ast := range model["p"] {
  161. for _, rule := range ast.Policy {
  162. line := savePolicyLine(ptype, rule)
  163. _, err := a.Db.Table(a.TableName).Data(&line).Insert()
  164. if err != nil {
  165. return err
  166. }
  167. }
  168. }
  169. for ptype, ast := range model["g"] {
  170. for _, rule := range ast.Policy {
  171. line := savePolicyLine(ptype, rule)
  172. _, err := a.Db.Table(a.TableName).Data(&line).Insert()
  173. if err != nil {
  174. return err
  175. }
  176. }
  177. }
  178. return nil
  179. }
  180. // AddPolicy adds a policy rule to the storage.
  181. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
  182. line := savePolicyLine(ptype, rule)
  183. _, err := a.Db.Table(a.TableName).Data(&line).Insert()
  184. return err
  185. }
  186. // RemovePolicy removes a policy rule from the storage.
  187. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
  188. line := savePolicyLine(ptype, rule)
  189. err := rawDelete(a, line)
  190. return err
  191. }
  192. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  193. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
  194. line := CasbinRule{}
  195. line.PType = ptype
  196. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  197. line.V0 = fieldValues[0-fieldIndex]
  198. }
  199. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  200. line.V1 = fieldValues[1-fieldIndex]
  201. }
  202. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  203. line.V2 = fieldValues[2-fieldIndex]
  204. }
  205. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  206. line.V3 = fieldValues[3-fieldIndex]
  207. }
  208. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  209. line.V4 = fieldValues[4-fieldIndex]
  210. }
  211. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  212. line.V5 = fieldValues[5-fieldIndex]
  213. }
  214. err := rawDelete(a, line)
  215. return err
  216. }
  217. func rawDelete(a *Adapter, line CasbinRule) error {
  218. db := a.Db.Table(a.TableName)
  219. db.Where("ptype = ?", line.PType)
  220. if line.V0 != "" {
  221. db.Where("v0 = ?", line.V0)
  222. }
  223. if line.V1 != "" {
  224. db.Where("v1 = ?", line.V1)
  225. }
  226. if line.V2 != "" {
  227. db.Where("v2 = ?", line.V2)
  228. }
  229. if line.V3 != "" {
  230. db.Where("v3 = ?", line.V3)
  231. }
  232. if line.V4 != "" {
  233. db.Where("v4 = ?", line.V4)
  234. }
  235. if line.V5 != "" {
  236. db.Where("v5 = ?", line.V5)
  237. }
  238. _, err := db.Delete()
  239. return err
  240. }