store - store.go
1 // Package store automatically configures a database to store structured information in an sql database
2 package store // import "vimagination.zapto.org/store"
3
4 import (
5 "database/sql"
6 "errors"
7 "reflect"
8 "strings"
9 "sync"
10 "time"
11
12 _ "github.com/mxk/go-sqlite/sqlite3"
13 )
14
15 const (
16 add = iota
17 get
18 update
19 remove
20 getPage
21 count
22 )
23
24 type field struct {
25 isStruct bool
26 pos int
27 name string
28 }
29
30 type typeInfo struct {
31 primary int
32 fields []field
33 statements []*sql.Stmt
34 }
35
36 type Store struct {
37 db *sql.DB
38 types map[string]typeInfo
39 mutex sync.Mutex
40 }
41
42 func New(dataSourceName string) (*Store, error) {
43 db, err := sql.Open("sqlite3", dataSourceName)
44 if err != nil {
45 return nil, err
46 }
47 return &Store{
48 db: db,
49 types: make(map[string]typeInfo),
50 }, nil
51 }
52
53 func (s *Store) Close() error {
54 err := s.db.Close()
55 s.db = nil
56 return err
57 }
58
59 func (s *Store) Register(is ...interface{}) error {
60 if s.db == nil {
61 return ErrDBClosed
62 }
63 s.mutex.Lock()
64 defer s.mutex.Unlock()
65 for _, i := range is {
66 if !isPointerStruct(i) {
67 return ErrNoPointerStruct
68 }
69 if err := s.defineType(i); err != nil {
70 return err
71 }
72 }
73 return nil
74 }
75
76 func (s *Store) defineType(i interface{}) error {
77 name := typeName(i)
78 if _, ok := s.types[name]; ok {
79 return nil
80 }
81
82 s.types[name] = typeInfo{}
83
84 v := reflect.ValueOf(i).Elem()
85 numFields := v.Type().NumField()
86 fields := make([]field, 0, numFields)
87 id := 0
88 idType := 0
89
90 for n := 0; n < numFields; n++ {
91 f := v.Type().Field(n)
92 if f.PkgPath != "" { // not exported
93 continue
94 }
95 fieldName := f.Name
96 if fn := f.Tag.Get("store"); fn != "" {
97 fieldName = fn
98 }
99 if fieldName == "-" { // Skip field
100 continue
101 }
102 tmp := strings.ToLower(fieldName)
103 for _, tf := range fields {
104 if strings.ToLower(tf.name) == tmp {
105 return ErrDuplicateColumn
106 }
107 }
108 isPointer := f.Type.Kind() == reflect.Ptr
109 var iface interface{}
110 if isPointer {
111 iface = v.Field(n).Interface()
112 } else {
113 iface = v.Field(n).Addr().Interface()
114 }
115 isStruct := false
116 if isPointerStruct(iface) {
117 if _, ok := iface.(*time.Time); !ok {
118 if err := s.defineType(iface); err != nil {
119 return err
120 }
121 isStruct = true
122 }
123 } else if !isValidType(iface) {
124 continue
125 }
126 if isValidKeyType(iface) {
127 if idType < 3 && f.Tag.Get("key") == "1" {
128 idType = 3
129 id = len(fields)
130 } else if idType < 2 && strings.ToLower(fieldName) == "id" {
131 idType = 2
132 id = len(fields)
133 } else if idType < 1 {
134 idType = 1
135 id = len(fields)
136 }
137 }
138 fields = append(fields, field{
139 isStruct,
140 n,
141 fieldName,
142 })
143 }
144 if idType == 0 {
145 return ErrNoKey
146 }
147 s.types[name] = typeInfo{
148 primary: id,
149 }
150
151 // create statements
152 var (
153 sqlVars, sqlParams, setSQLParams, tableVars string
154 doneFirst, doneFirstNonKey bool
155 )
156
157 for pos, f := range fields {
158 if doneFirst {
159 tableVars += ", "
160 } else {
161 doneFirst = true
162 }
163 if pos != id {
164 if doneFirstNonKey {
165 sqlVars += ", "
166 setSQLParams += ", "
167 sqlParams += ", "
168 } else {
169 doneFirstNonKey = true
170 }
171 }
172 var varType string
173 if f.isStruct {
174 varType = "INTEGER"
175 } else {
176 varType = getType(i, f.pos)
177 }
178 tableVars += "[" + f.name + "] " + varType
179 if pos == id {
180 tableVars += " PRIMARY KEY AUTOINCREMENT"
181 } else {
182 sqlVars += "[" + f.name + "]"
183 setSQLParams += "[" + f.name + "] = ?"
184 sqlParams += "?"
185 }
186 }
187
188 statements := make([]*sql.Stmt, 6)
189
190 sql := "CREATE TABLE IF NOT EXISTS [" + name + "](" + tableVars + ");"
191 _, err := s.db.Exec(sql)
192 if err != nil {
193 return err
194 }
195
196 sql = "INSERT INTO [" + name + "] (" + sqlVars + ") VALUES (" + sqlParams + ");"
197 stmt, err := s.db.Prepare(sql)
198 if err != nil {
199 return err
200 }
201 statements[add] = stmt
202
203 sql = "SELECT " + sqlVars + " FROM [" + name + "] WHERE [" + fields[id].name + "] = ? LIMIT 1;"
204 stmt, err = s.db.Prepare(sql)
205 if err != nil {
206 return err
207 }
208 statements[get] = stmt
209
210 sql = "UPDATE [" + name + "] SET " + setSQLParams + " WHERE [" + fields[id].name + "] = ?;"
211 stmt, err = s.db.Prepare(sql)
212 if err != nil {
213 return err
214 }
215 statements[update] = stmt
216
217 sql = "DELETE FROM [" + name + "] WHERE [" + fields[id].name + "] = ?;"
218 stmt, err = s.db.Prepare(sql)
219 if err != nil {
220 return err
221 }
222 statements[remove] = stmt
223
224 sql = "SELECT [" + fields[id].name + "] FROM [" + name + "] ORDER BY [" + fields[id].name + "] LIMIT ? OFFSET ?;"
225 stmt, err = s.db.Prepare(sql)
226 if err != nil {
227 return err
228 }
229 statements[getPage] = stmt
230
231 sql = "SELECT COUNT(1) FROM [" + name + "];"
232 stmt, err = s.db.Prepare(sql)
233 if err != nil {
234 return err
235 }
236 statements[count] = stmt
237
238 s.types[name] = typeInfo{
239 primary: id,
240 fields: fields,
241 statements: statements,
242 }
243 return nil
244 }
245
246 func (s *Store) Set(is ...interface{}) error {
247 s.mutex.Lock()
248 defer s.mutex.Unlock()
249 var toSet []interface{}
250 for _, i := range is {
251 t, ok := s.types[typeName(i)]
252 if !ok {
253 return ErrUnregisteredType
254 }
255 toSet = toSet[:0]
256 err := s.set(i, &t, &toSet)
257 if err != nil {
258 return err
259 }
260 }
261 return nil
262 }
263
264 func (s *Store) set(i interface{}, t *typeInfo, toSet *[]interface{}) error {
265 for _, oi := range *toSet {
266 if oi == i {
267 return nil
268 }
269 }
270 (*toSet) = append(*toSet, i)
271 id := t.GetID(i)
272 isUpdate := id != 0
273 vars := make([]interface{}, 0, len(t.fields))
274 for pos, f := range t.fields {
275 if pos == t.primary {
276 continue
277 }
278 if f.isStruct {
279 ni := getFieldPointer(i, f.pos)
280 nt := s.types[typeName(ni)]
281 err := s.set(ni, &nt, toSet)
282 if err != nil {
283 return err
284 }
285 vars = append(vars, getField(ni, nt.fields[nt.primary].pos))
286 } else {
287 vars = append(vars, getField(i, f.pos))
288 }
289 }
290 if isUpdate {
291 r, err := t.statements[update].Exec(append(vars, id)...)
292 if err != nil {
293 return err
294 }
295 if ra, err := r.RowsAffected(); err != nil {
296 return err
297 } else if ra > 0 {
298 return nil
299 }
300 // id wasn't found, so insert...
301 }
302 r, err := t.statements[add].Exec(vars...)
303 if err != nil {
304 return err
305 }
306 lid, err := r.LastInsertId()
307 if err != nil {
308 return err
309 }
310 t.SetID(i, lid)
311 return nil
312 }
313
314 func (s *Store) Get(is ...interface{}) error {
315 s.mutex.Lock()
316 defer s.mutex.Unlock()
317 return s.get(is...)
318 }
319 func (s *Store) get(is ...interface{}) error {
320 for _, i := range is {
321 t, ok := s.types[typeName(i)]
322 if !ok {
323 return ErrUnregisteredType
324 }
325 id := t.GetID(i)
326 if id == 0 {
327 continue
328 }
329 vars := make([]interface{}, 0, len(t.fields)-1)
330 var toGet []interface{}
331 for pos, f := range t.fields {
332 if pos == t.primary {
333 continue
334 }
335 if f.isStruct {
336 ni := getFieldPointer(i, f.pos)
337 nt := s.types[typeName(ni)]
338 toGet = append(toGet, ni)
339 vars = append(vars, getFieldPointer(ni, nt.fields[nt.primary].pos))
340 } else {
341 vars = append(vars, getFieldPointer(i, f.pos))
342 }
343 }
344 row := t.statements[get].QueryRow(id)
345 err := row.Scan(vars...)
346 if err == sql.ErrNoRows {
347 t.SetID(i, 0)
348 } else if err != nil {
349 return err
350 } else if len(toGet) > 0 {
351 if err = s.get(toGet...); err != nil {
352 return err
353 }
354 }
355 }
356 return nil
357 }
358
359 func (s *Store) GetPage(is []interface{}, offset int) (int, error) {
360 if len(is) == 0 {
361 return 0, nil
362 }
363 s.mutex.Lock()
364 defer s.mutex.Unlock()
365 t, ok := s.types[typeName(is[0])]
366 if !ok {
367 return 0, ErrInvalidType
368 }
369 rows, err := t.statements[getPage].Query(len(is), offset)
370 if err != nil {
371 return 0, err
372 }
373 defer rows.Close()
374 return s.getPage(is, rows)
375 }
376
377 func (s *Store) getPage(is []interface{}, rows *sql.Rows) (int, error) {
378 t := s.types[typeName(is[0])]
379 n := 0
380 for rows.Next() {
381 var id int64
382 if err := rows.Scan(&id); err != nil {
383 return 0, err
384 }
385 t.SetID(is[n], id)
386 n++
387 }
388 is = is[:n]
389 if err := rows.Err(); err == sql.ErrNoRows {
390 return 0, nil
391 } else if err != nil {
392 return 0, err
393 } else if err = s.get(is...); err != nil {
394 return 0, err
395 }
396 return n, nil
397 }
398
399 func (s *Store) Remove(is ...interface{}) error {
400 s.mutex.Lock()
401 defer s.mutex.Unlock()
402 for _, i := range is {
403 t, ok := s.types[typeName(i)]
404 if !ok {
405 return ErrUnregisteredType
406 }
407 _, err := t.statements[remove].Exec(t.GetID(i))
408 if err != nil {
409 return err
410 }
411 }
412 return nil
413 }
414
415 func (s *Store) Count(i interface{}) (int, error) {
416 s.mutex.Lock()
417 defer s.mutex.Unlock()
418 if !isPointerStruct(i) {
419 return 0, ErrNoPointerStruct
420 }
421 t, ok := s.types[typeName(i)]
422 if !ok {
423 return 0, ErrUnregisteredType
424 }
425 num := 0
426 err := t.statements[count].QueryRow().Scan(&num)
427 return num, err
428 }
429
430 // Errors
431
432 var (
433 ErrDBClosed = errors.New("database already closed")
434 ErrNoPointerStruct = errors.New("given variable is not a pointer to a struct")
435 ErrNoKey = errors.New("could not determine key")
436 ErrDuplicateColumn = errors.New("duplicate column name found")
437 ErrUnregisteredType = errors.New("type not registered")
438 ErrInvalidType = errors.New("invalid type")
439 )