dbhelper.go - randomcrap - random crap programs of varying quality
(HTM) git clone git://git.codemadness.org/randomcrap
(DIR) Log
(DIR) Files
(DIR) Refs
(DIR) README
(DIR) LICENSE
---
dbhelper.go (16767B)
---
1 package dbhelper
2
3 import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "reflect"
8 "strconv"
9 "strings"
10 "sync"
11 "time"
12 )
13
14 type KeyValue struct {
15 Key string
16 Value interface{}
17 }
18
19 type Result struct {
20 Values []interface{}
21 }
22
23 type Results struct {
24 Columns []string
25 ColumnOrdinals map[string]int
26 Rows chan *Result
27 Count int // amount of row items read so far.
28
29 Err error
30 SqlRows *sql.Rows
31 SqlDB *sql.DB
32 SqlStmt *sql.Stmt
33
34 done chan bool // send true to close channel (while reading).
35 l sync.RWMutex // lock results (read-write, close).
36 }
37
38 var err_interface_bool = errors.New("interface conversion, not bool")
39 var err_interface_byteslice = errors.New("interface conversion, not []byte")
40 var err_interface_float64 = errors.New("interface conversion, not float64")
41 var err_interface_int64 = errors.New("interface conversion, not int64")
42 var err_interface_string = errors.New("interface conversion, not string")
43 var err_interface_time = errors.New("interface conversion, not time.Time")
44
45 var type_bool = reflect.TypeOf(false)
46 var type_string = reflect.TypeOf("")
47 var type_int64 = reflect.TypeOf(int64(0))
48 var type_byteslice = reflect.TypeOf([]byte{})
49 var type_float64 = reflect.TypeOf(float64(0.0))
50 var type_time = reflect.TypeOf(time.Time{})
51
52 // Types to convert for unmarshal, make these once for faster lookup.
53 var types = map[string]reflect.Type{
54 // standard types, see: http://golang.org/pkg/database/sql/driver/
55 "bool": type_bool,
56 "string": type_string,
57 "int64": type_int64,
58 "float64": type_float64,
59 "[]byte": type_byteslice,
60 "time.Time": type_time,
61 }
62
63 var timeconv = map[string]*time.Location{
64 "gmt": time.UTC, // alias
65 "local": time.Local,
66 "utc": time.UTC,
67 }
68
69 // fix Location of time to use localtime without updating the time itself.
70 // (ignores it's current Location).
71 func fixtime(t time.Time, loc *time.Location) time.Time {
72 return time.Date(t.Year(), t.Month(), t.Day(),
73 t.Hour(), t.Minute(), t.Second(), t.Nanosecond(),
74 loc)
75 }
76
77 func (s Result) Get(ordinal int) interface{} {
78 return s.Values[ordinal]
79 }
80
81 func (s Result) GetString(ordinal int) (string, error) {
82 if _, ok := (*(s.Values[ordinal]).(*interface{})).(string); ok {
83 return (*(s.Values[ordinal]).(*interface{})).(string), nil
84 }
85 return "", err_interface_string
86 }
87
88 func (s Result) GetBool(ordinal int) (bool, error) {
89 if _, ok := (*(s.Values[ordinal]).(*interface{})).(bool); ok {
90 return (*(s.Values[ordinal]).(*interface{})).(bool), nil
91 }
92 return false, err_interface_bool
93 }
94
95 func (s Result) GetInt64(ordinal int) (int64, error) {
96 if _, ok := (*(s.Values[ordinal]).(*interface{})).(int64); ok {
97 return (*(s.Values[ordinal]).(*interface{})).(int64), nil
98 }
99 return int64(0), err_interface_int64
100 }
101
102 func (s Result) GetFloat64(ordinal int) (float64, error) {
103 if _, ok := (*(s.Values[ordinal]).(*interface{})).(float64); ok {
104 return (*(s.Values[ordinal]).(*interface{})).(float64), nil
105 }
106 return 0.0, err_interface_float64
107 }
108
109 func (s Result) GetTime(ordinal int) (time.Time, error) {
110 if _, ok := (*(s.Values[ordinal]).(*interface{})).(time.Time); ok {
111 return (*(s.Values[ordinal]).(*interface{})).(time.Time), nil
112 }
113 return time.Time{}, err_interface_time
114 }
115
116 func (s Result) GetBytes(ordinal int) ([]byte, error) {
117 if _, ok := (*(s.Values[ordinal]).(*interface{})).([]byte); ok {
118 return (*(s.Values[ordinal]).(*interface{})).([]byte), nil
119 }
120 return []byte{}, err_interface_byteslice
121 }
122
123 // probably MSSQL-specific, converts []byte to float64.
124 func (s Result) GetDecimalToFloat(ordinal int) (float64, error) {
125 b, err := s.GetBytes(ordinal)
126 if err != nil {
127 return 0.0, err
128 }
129 f, err := strconv.ParseFloat(string(b), 64)
130 if err != nil {
131 return 0.0, err
132 }
133 return f, nil
134 }
135
136 // Get field as string (standard supported database types, see database/sql).
137 func (s Result) GetAsString(ordinal int) string {
138 switch v := (*(s.Values[ordinal]).(*interface{})).(type) {
139 case nil:
140 return ""
141 case bool:
142 if v {
143 return "1"
144 } else {
145 return "0"
146 }
147 case string:
148 return v
149 // NOTE: for MSSQL driver: decimal(n,p) is returned as []byte.
150 case []byte:
151 return string(v)
152 case time.Time:
153 return v.Format(time.RFC3339)
154 default: // includes int64, float64.
155 return fmt.Sprint(v)
156 }
157 // NOTREACHED
158 }
159
160 // Lookup name (case-insensitive), returns ordinal or -1 if not found.
161 func (r *Results) GetOrdinal(name string) int {
162 k := strings.ToLower(name)
163 if _, ok := r.ColumnOrdinals[k]; ok {
164 return r.ColumnOrdinals[k]
165 }
166 return -1 // not found.
167 }
168
169 // for cleanup: close statement and rows descriptor.
170 func (r *Results) Close() {
171 // this mutex is needed because there is a deferred close after reading
172 // all the results in go func, but also when *Results are manually closed.
173 r.l.Lock()
174 defer r.l.Unlock()
175
176 if r.SqlRows != nil {
177 r.SqlRows.Close()
178 }
179 if r.SqlStmt != nil {
180 r.SqlStmt.Close()
181 }
182 }
183
184 func PrepareProc(db *sql.DB, name string, kv []KeyValue) (*sql.Stmt, error) {
185 // convenience: remove brackets from name, such as: [dbname].dbo.[GetBla]
186 name = strings.Replace(name, "[", "", -1)
187 name = strings.Replace(name, "]", "", -1)
188 klen := len(kv)
189 query := "Execute " + name + " "
190 for i, v := range kv {
191 query += "@" + v.Key + " = ?"
192 if i != klen-1 {
193 query += ", "
194 }
195 i++
196 }
197 return db.Prepare(query)
198 }
199
200 func GetAffected(db *sql.DB, query string, kv []KeyValue) (int64, error) {
201 st, err := db.Prepare(query)
202 if err != nil {
203 return 0, err
204 }
205 defer st.Close()
206 // slice to variadic arguments.
207 result, err := st.Exec(keyvalue_to_val(kv)...)
208 if err != nil {
209 return 0, err
210 }
211 return result.RowsAffected()
212 }
213
214 func GetProcAffected(db *sql.DB, name string, kv []KeyValue) (int64, error) {
215 st, err := PrepareProc(db, name, kv)
216 if err != nil {
217 return 0, err
218 }
219 defer st.Close()
220 // slice to variadic arguments.
221 result, err := st.Exec(keyvalue_to_val(kv)...)
222 if err != nil {
223 return 0, err
224 }
225 return result.RowsAffected()
226 }
227
228 func GetRows(db *sql.DB, st *sql.Stmt, kv []KeyValue) (*sql.Rows, error) {
229 // slice to variadic arguments.
230 rows, err := st.Query(keyvalue_to_val(kv)...)
231 if err != nil {
232 return nil, err
233 }
234 return rows, nil
235 }
236
237 func GetResults(db *sql.DB, query string, kv []KeyValue) (*Results, error) {
238 st, err := db.Prepare(query)
239 if err != nil {
240 return nil, err
241 }
242 // NOTE: cleanup is handled by Results.Close()
243 return GetStmtResultsChan(db, st, kv)
244 }
245
246 func GetProcResults(db *sql.DB, name string, kv []KeyValue) (*Results, error) {
247 st, err := PrepareProc(db, name, kv)
248 if err != nil {
249 return nil, err
250 }
251 // NOTE: cleanup is handled by Results.Close()
252 return GetStmtResultsChan(db, st, kv)
253 }
254
255 func GetStmtResultsChan(db *sql.DB, st *sql.Stmt, kv []KeyValue) (*Results, error) {
256 // slice to variadic arguments.
257 rows, err := st.Query(keyvalue_to_val(kv)...)
258 if err != nil {
259 st.Close()
260 return nil, err
261 }
262 cols, err := rows.Columns()
263 if err != nil {
264 st.Close()
265 rows.Close()
266 return nil, err
267 }
268 colslen := len(cols)
269
270 r := &Results{}
271 r.SqlRows = rows
272 r.SqlDB = db
273 r.SqlStmt = st
274 r.Count = 0
275 r.Columns = cols
276 r.ColumnOrdinals = make(map[string]int) // NOTE: names are mapped lower-case.
277 for i, k := range cols {
278 // names to ordinal are always mapped to lowercase.
279 k := strings.ToLower(k)
280 r.ColumnOrdinals[k] = i
281 }
282 r.Rows = make(chan *Result)
283 r.done = make(chan bool, 1)
284 r.Err = nil
285
286 go func() {
287 for {
288 select {
289 case v, ok := <-r.done:
290 // true value and channel not closed.
291 if v && ok {
292 return
293 }
294 default:
295 // allocate space for result.
296 ptrs := make([]interface{}, colslen)
297 for i := 0; i < colslen; i++ {
298 ptrs[i] = new(interface{})
299 }
300
301 // read lock for r.SqlRows -> rows
302 r.l.RLock()
303 if rows.Next() {
304 err := rows.Scan(ptrs...)
305 if err != nil {
306 r.Err = err
307 r.l.RUnlock()
308 return
309 }
310 r.l.RUnlock()
311 } else {
312 r.l.RUnlock()
313 close(r.Rows)
314 return
315 }
316 r.Rows <- &Result{Values: ptrs}
317 r.Count++
318 }
319 }
320 }()
321 return r, nil
322 }
323
324 // comfy function to convert KeyValue pairs to interface{} values.
325 func keyvalue_to_val(kv []KeyValue) []interface{} {
326 values := make([]interface{}, len(kv))
327 for i, v := range kv {
328 values[i] = v.Value
329 }
330 return values
331 }
332
333 // Unmarshal results into a slice of structs using the reflect package.
334 //
335 // By default types are assigned from the standard database type (see database/sql)
336 // to the type in the struct. These must be directly assignable.
337 //
338 // A special case is if a database field type is []byte and the struct type is float64
339 // it will be converted (useful to convert the MSSQL decimal type to float64.
340 //
341 // If a type is specified in the field struct tag it will be converted (not assigned)
342 // to the struct type. This is also useful to convert directly to your own custom types,
343 // like:
344 //
345 // type CustomTime time.Time
346 //
347 // The following format is supported:
348 //
349 // type Example struct {
350 // Field CustomTime `db:"Field_Name,time.Time"`
351 // }
352 //
353 // For times there is a third struct field tag to fix timezones in time.Time, it is used like:
354 //
355 // type Example struct {
356 // Field CustomTime `db:"Field_Name,time.Time,local"`
357 // }
358 //
359 // Fields are matched by name case-insensitive; to be clear: the order of the field in
360 // the struct won't matter.
361 //
362 // CAVEAT: channel result.Rows is not closed, do this manually yourself, usually: defer r.Close()
363 func (r *Results) Unmarshal(v interface{}) (err error) {
364 val := reflect.ValueOf(v)
365 if val.Kind() != reflect.Ptr || val.IsNil() {
366 return errors.New("must be pointer to slice of structs and non-nil")
367 }
368 val = val.Elem() // dereference pointer.
369 if val.Kind() != reflect.Slice {
370 return errors.New("must be pointer to slice of structs and non-nil")
371 }
372
373 err = nil
374
375 slicetype := val.Type()
376 vals := reflect.MakeSlice(slicetype, 0, 0)
377 eltype := vals.Type().Elem()
378
379 // map types (from, to) and names once per resultset.
380 structnumfield := eltype.NumField()
381 typesto := make([]reflect.Type, structnumfield)
382 typesnames := make([]string, structnumfield)
383 timeconvs := make([]string, structnumfield)
384 names := make([]string, structnumfield)
385 // field struct index to db field ordinal.
386 ordinals := make([]int, structnumfield)
387
388 embed := make([]int, structnumfield)
389
390 for i := 0; i < structnumfield; i++ {
391 embed[i] = -1
392 field := eltype.Field(i)
393 // field name, type (optional) and time conversion tag (optional).
394 fieldtag := field.Tag.Get("db") // "db" namespace.
395 name := field.Name
396 _type := ""
397 _timeconv := ""
398 if fieldtag != "" {
399 fieldtags := strings.Split(fieldtag, ",")
400 n := len(fieldtags)
401 // time conversion.
402 if n > 2 && fieldtags[2] != "" {
403 _timeconv = fieldtags[2]
404 }
405 // field type
406 if n > 1 && fieldtags[1] != "" {
407 _type = fieldtags[1]
408 }
409 // field name.
410 if n > 0 && fieldtags[0] != "" {
411 name = fieldtags[0]
412 }
413 }
414
415 ordinals[i] = r.GetOrdinal(name)
416 if ordinals[i] < 0 {
417 return errors.New("column ordinal not found: " + name)
418 }
419 names[i] = name
420 typesto[i] = field.Type
421
422 // field type: type specified but does not exist.
423 if _type != "" && types[_type] == nil {
424 return errors.New("unsupported type: " + _type)
425 }
426 typesnames[i] = _type
427
428 // check if type is set but not convertible, if so:
429 // search for Anonymous field with same type as types[_type].
430 t := field.Type
431 if _type != "" && types[_type] != nil &&
432 !t.ConvertibleTo(types[_type]) &&
433 t.Kind() == reflect.Struct {
434 n := t.NumField()
435 for j := 0; j < n; j++ {
436 f := t.Field(j)
437 if f.Anonymous && (f.Type == types[_type] || f.Type.ConvertibleTo(types[_type])) {
438 // also update Type to type of embedded field in struct.
439 embed[i] = j // index of embedded field.
440 typesto[i] = f.Type
441 break
442 }
443 }
444 if embed[i] == -1 {
445 return errors.New("type set but anonymous field / embedded field not found: " + _type)
446 }
447 }
448 // time conversation tag.
449 if _timeconv != "" && timeconv[_timeconv] == nil {
450 return errors.New("unsupported time conversation parameter: " + _timeconv)
451 }
452 timeconvs[i] = _timeconv
453 }
454
455 total := 0
456 fieldtypes := make([]reflect.Type, structnumfield)
457
458 defer func() {
459 r.done <- true
460 close(r.done)
461 }()
462
463 for row := range r.Rows {
464 // NOTE: !reflect.Value.IsValid() will be skipped so initial values
465 // of struct will be used.
466 t := reflect.New(eltype).Elem() // dereference
467
468 // first row, get type of field, should be equal in whole resultset.
469 if total == 0 {
470 for i := 0; i < structnumfield; i++ {
471 ordinal := ordinals[i]
472
473 if row.Values[ordinal] == nil {
474 return errors.New("field \"" + names[i] + "\" not found (nil)")
475 }
476
477 v := reflect.ValueOf(*(row.Values[ordinal]).(*interface{}))
478 // if value is nil then skip detecting it's type, it will be tried
479 // later to detect the type again, if that is also not possible
480 // the zero value of the corresponding type is used.
481 if !v.IsValid() {
482 continue
483 }
484
485 var field reflect.Value
486 vt := v.Type()
487 if embed[i] >= 0 {
488 field = t.Field(i).Field(embed[i])
489 } else {
490 field = t.Field(i)
491 }
492 // can not set value of this field (like an unexported field).
493 if !field.CanSet() {
494 return errors.New("cannot set value on field: " +
495 "name=\"" + names[i] + "\"" +
496 ", dbtype=\"" + vt.Name() + "\"" +
497 ", struct type=\"" + typesto[i].Name() + "\"" +
498 ", tag type=\"" + typesnames[i] + "\", unexported field?")
499 }
500
501 // NOTE: []byte to float64 is a special-case, this is not convertible
502 // but we handle this ourself, see below.
503 if vt == type_byteslice && (typesto[i] == type_float64 || (embed[i] >= 0 && typesnames[i] == "float64")) {
504 // special case, handled later.
505 } else if typesnames[i] == "" {
506 // NOTE: direct type to type, must be assignable, not convertible.
507 if !vt.AssignableTo(typesto[i]) {
508 return errors.New("non-assignable: " +
509 "name=\"" + names[i] + "\"" +
510 ", dbtype=\"" + vt.Name() + "\"" +
511 ", struct type=\"" + typesto[i].Name() + "\"" +
512 ", tag type=\"" + typesnames[i] + "\"")
513 }
514 } else if !(vt.ConvertibleTo(typesto[i])) {
515 // can't convert to type.
516 return errors.New("non-convertible: " +
517 "name=\"" + names[i] + "\"" +
518 ", dbtype=\"" + vt.Name() + "\"" +
519 ", struct type=\"" + typesto[i].Name() + "\"" +
520 ", tag type=\"" + typesnames[i] + "\"")
521 }
522 // convertible type, set it.
523 fieldtypes[i] = vt
524 }
525 }
526
527 for i := 0; i < structnumfield; i++ {
528 ordinal := ordinals[i]
529 v := reflect.ValueOf(*(row.Values[ordinal]).(*interface{}))
530 if !v.IsValid() {
531 continue
532 }
533
534 var field reflect.Value
535 if embed[i] >= 0 {
536 field = t.Field(i).Field(embed[i])
537 } else {
538 field = t.Field(i)
539 }
540
541 // if type was not known in first result row (for example when it is NULL)
542 // then set it anyway. NOTE that when a type is known at some point it must
543 // be the same across all rows in the resultset.
544 if fieldtypes[i] == nil {
545 vt := v.Type()
546 // can't convert to type.
547 if
548 // NOTE: []byte to float64 is a special-case, this is not convertible.
549 (vt != type_byteslice || (typesto[i] != type_float64 && !(embed[i] >= 0 && typesnames[i] == "float64"))) &&
550 !(vt.ConvertibleTo(typesto[i])) {
551 return errors.New("non-convertible type: " +
552 "name=\"" + names[i] + "\"" +
553 ", dbtype=\"" + vt.Name() + "\"" +
554 ", struct type=\"" + typesto[i].Name() + "\"" +
555 ", tag type=\"" + typesnames[i] + "\"")
556 }
557 // convertible type, set it.
558 fieldtypes[i] = vt
559 }
560
561 // used to convert MSSQL-types like money, decimal to float.
562 // these (money, decimal) types are returned as a []byte by the MSSQL driver.
563 if fieldtypes[i] == type_byteslice && (typesto[i] == type_float64 || (embed[i] >= 0 && typesnames[i] == "float64")) {
564 f, err := row.GetDecimalToFloat(ordinal)
565 if err != nil {
566 return err
567 }
568 v := reflect.ValueOf(f)
569 if typesnames[i] == "float64" {
570 v = v.Convert(typesto[i])
571 }
572 if !v.IsValid() {
573 continue
574 }
575 field.Set(v)
576 continue
577 }
578
579 // fix time, only correct timezone from database without modifying the time itself.
580 if timeconvs[i] != "" && fieldtypes[i] == type_time {
581 t, err := row.GetTime(ordinal)
582 if err != nil {
583 return err
584 }
585 t = fixtime(t, timeconv[timeconvs[i]])
586 v = reflect.ValueOf(t)
587 }
588
589 // assign to type in struct.
590 if typesnames[i] == "" {
591 field.Set(v)
592 } else {
593 // try to convert to type specified in struct tag: `db:"name,type"` if valid.
594 // if invalid leave as zero value for the type.
595 v = v.Convert(typesto[i])
596 if v.IsValid() {
597 field.Set(v)
598 }
599 }
600 }
601 vals = reflect.Append(vals, t)
602 total++
603 }
604 val.Set(vals)
605
606 return
607 }