dbhelper.go - randomcrap - random crap programs of varying quality
 (HTM) git clone git://git.codemadness.org/randomcrap
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) README
 (DIR) LICENSE
       ---
       dbhelper.go (16767B)
       ---
            1 package dbhelper
            2 
            3 import (
            4         "database/sql"
            5         "errors"
            6         "fmt"
            7         "reflect"
            8         "strconv"
            9         "strings"
           10         "sync"
           11         "time"
           12 )
           13 
           14 type KeyValue struct {
           15         Key   string
           16         Value interface{}
           17 }
           18 
           19 type Result struct {
           20         Values []interface{}
           21 }
           22 
           23 type Results struct {
           24         Columns        []string
           25         ColumnOrdinals map[string]int
           26         Rows           chan *Result
           27         Count          int // amount of row items read so far.
           28 
           29         Err     error
           30         SqlRows *sql.Rows
           31         SqlDB   *sql.DB
           32         SqlStmt *sql.Stmt
           33 
           34         done chan bool    // send true to close channel (while reading).
           35         l    sync.RWMutex // lock results (read-write, close).
           36 }
           37 
           38 var err_interface_bool = errors.New("interface conversion, not bool")
           39 var err_interface_byteslice = errors.New("interface conversion, not []byte")
           40 var err_interface_float64 = errors.New("interface conversion, not float64")
           41 var err_interface_int64 = errors.New("interface conversion, not int64")
           42 var err_interface_string = errors.New("interface conversion, not string")
           43 var err_interface_time = errors.New("interface conversion, not time.Time")
           44 
           45 var type_bool = reflect.TypeOf(false)
           46 var type_string = reflect.TypeOf("")
           47 var type_int64 = reflect.TypeOf(int64(0))
           48 var type_byteslice = reflect.TypeOf([]byte{})
           49 var type_float64 = reflect.TypeOf(float64(0.0))
           50 var type_time = reflect.TypeOf(time.Time{})
           51 
           52 // Types to convert for unmarshal, make these once for faster lookup.
           53 var types = map[string]reflect.Type{
           54         // standard types, see: http://golang.org/pkg/database/sql/driver/
           55         "bool":      type_bool,
           56         "string":    type_string,
           57         "int64":     type_int64,
           58         "float64":   type_float64,
           59         "[]byte":    type_byteslice,
           60         "time.Time": type_time,
           61 }
           62 
           63 var timeconv = map[string]*time.Location{
           64         "gmt":   time.UTC, // alias
           65         "local": time.Local,
           66         "utc":   time.UTC,
           67 }
           68 
           69 // fix Location of time to use localtime without updating the time itself.
           70 // (ignores it's current Location).
           71 func fixtime(t time.Time, loc *time.Location) time.Time {
           72         return time.Date(t.Year(), t.Month(), t.Day(),
           73                 t.Hour(), t.Minute(), t.Second(), t.Nanosecond(),
           74                 loc)
           75 }
           76 
           77 func (s Result) Get(ordinal int) interface{} {
           78         return s.Values[ordinal]
           79 }
           80 
           81 func (s Result) GetString(ordinal int) (string, error) {
           82         if _, ok := (*(s.Values[ordinal]).(*interface{})).(string); ok {
           83                 return (*(s.Values[ordinal]).(*interface{})).(string), nil
           84         }
           85         return "", err_interface_string
           86 }
           87 
           88 func (s Result) GetBool(ordinal int) (bool, error) {
           89         if _, ok := (*(s.Values[ordinal]).(*interface{})).(bool); ok {
           90                 return (*(s.Values[ordinal]).(*interface{})).(bool), nil
           91         }
           92         return false, err_interface_bool
           93 }
           94 
           95 func (s Result) GetInt64(ordinal int) (int64, error) {
           96         if _, ok := (*(s.Values[ordinal]).(*interface{})).(int64); ok {
           97                 return (*(s.Values[ordinal]).(*interface{})).(int64), nil
           98         }
           99         return int64(0), err_interface_int64
          100 }
          101 
          102 func (s Result) GetFloat64(ordinal int) (float64, error) {
          103         if _, ok := (*(s.Values[ordinal]).(*interface{})).(float64); ok {
          104                 return (*(s.Values[ordinal]).(*interface{})).(float64), nil
          105         }
          106         return 0.0, err_interface_float64
          107 }
          108 
          109 func (s Result) GetTime(ordinal int) (time.Time, error) {
          110         if _, ok := (*(s.Values[ordinal]).(*interface{})).(time.Time); ok {
          111                 return (*(s.Values[ordinal]).(*interface{})).(time.Time), nil
          112         }
          113         return time.Time{}, err_interface_time
          114 }
          115 
          116 func (s Result) GetBytes(ordinal int) ([]byte, error) {
          117         if _, ok := (*(s.Values[ordinal]).(*interface{})).([]byte); ok {
          118                 return (*(s.Values[ordinal]).(*interface{})).([]byte), nil
          119         }
          120         return []byte{}, err_interface_byteslice
          121 }
          122 
          123 // probably MSSQL-specific, converts []byte to float64.
          124 func (s Result) GetDecimalToFloat(ordinal int) (float64, error) {
          125         b, err := s.GetBytes(ordinal)
          126         if err != nil {
          127                 return 0.0, err
          128         }
          129         f, err := strconv.ParseFloat(string(b), 64)
          130         if err != nil {
          131                 return 0.0, err
          132         }
          133         return f, nil
          134 }
          135 
          136 // Get field as string (standard supported database types, see database/sql).
          137 func (s Result) GetAsString(ordinal int) string {
          138         switch v := (*(s.Values[ordinal]).(*interface{})).(type) {
          139         case nil:
          140                 return ""
          141         case bool:
          142                 if v {
          143                         return "1"
          144                 } else {
          145                         return "0"
          146                 }
          147         case string:
          148                 return v
          149         // NOTE: for MSSQL driver: decimal(n,p) is returned as []byte.
          150         case []byte:
          151                 return string(v)
          152         case time.Time:
          153                 return v.Format(time.RFC3339)
          154         default: // includes int64, float64.
          155                 return fmt.Sprint(v)
          156         }
          157         // NOTREACHED
          158 }
          159 
          160 // Lookup name (case-insensitive), returns ordinal or -1 if not found.
          161 func (r *Results) GetOrdinal(name string) int {
          162         k := strings.ToLower(name)
          163         if _, ok := r.ColumnOrdinals[k]; ok {
          164                 return r.ColumnOrdinals[k]
          165         }
          166         return -1 // not found.
          167 }
          168 
          169 // for cleanup: close statement and rows descriptor.
          170 func (r *Results) Close() {
          171         // this mutex is needed because there is a deferred close after reading
          172         // all the results in go func, but also when *Results are manually closed.
          173         r.l.Lock()
          174         defer r.l.Unlock()
          175 
          176         if r.SqlRows != nil {
          177                 r.SqlRows.Close()
          178         }
          179         if r.SqlStmt != nil {
          180                 r.SqlStmt.Close()
          181         }
          182 }
          183 
          184 func PrepareProc(db *sql.DB, name string, kv []KeyValue) (*sql.Stmt, error) {
          185         // convenience: remove brackets from name, such as: [dbname].dbo.[GetBla]
          186         name = strings.Replace(name, "[", "", -1)
          187         name = strings.Replace(name, "]", "", -1)
          188         klen := len(kv)
          189         query := "Execute " + name + " "
          190         for i, v := range kv {
          191                 query += "@" + v.Key + " = ?"
          192                 if i != klen-1 {
          193                         query += ", "
          194                 }
          195                 i++
          196         }
          197         return db.Prepare(query)
          198 }
          199 
          200 func GetAffected(db *sql.DB, query string, kv []KeyValue) (int64, error) {
          201         st, err := db.Prepare(query)
          202         if err != nil {
          203                 return 0, err
          204         }
          205         defer st.Close()
          206         // slice to variadic arguments.
          207         result, err := st.Exec(keyvalue_to_val(kv)...)
          208         if err != nil {
          209                 return 0, err
          210         }
          211         return result.RowsAffected()
          212 }
          213 
          214 func GetProcAffected(db *sql.DB, name string, kv []KeyValue) (int64, error) {
          215         st, err := PrepareProc(db, name, kv)
          216         if err != nil {
          217                 return 0, err
          218         }
          219         defer st.Close()
          220         // slice to variadic arguments.
          221         result, err := st.Exec(keyvalue_to_val(kv)...)
          222         if err != nil {
          223                 return 0, err
          224         }
          225         return result.RowsAffected()
          226 }
          227 
          228 func GetRows(db *sql.DB, st *sql.Stmt, kv []KeyValue) (*sql.Rows, error) {
          229         // slice to variadic arguments.
          230         rows, err := st.Query(keyvalue_to_val(kv)...)
          231         if err != nil {
          232                 return nil, err
          233         }
          234         return rows, nil
          235 }
          236 
          237 func GetResults(db *sql.DB, query string, kv []KeyValue) (*Results, error) {
          238         st, err := db.Prepare(query)
          239         if err != nil {
          240                 return nil, err
          241         }
          242         // NOTE: cleanup is handled by Results.Close()
          243         return GetStmtResultsChan(db, st, kv)
          244 }
          245 
          246 func GetProcResults(db *sql.DB, name string, kv []KeyValue) (*Results, error) {
          247         st, err := PrepareProc(db, name, kv)
          248         if err != nil {
          249                 return nil, err
          250         }
          251         // NOTE: cleanup is handled by Results.Close()
          252         return GetStmtResultsChan(db, st, kv)
          253 }
          254 
          255 func GetStmtResultsChan(db *sql.DB, st *sql.Stmt, kv []KeyValue) (*Results, error) {
          256         // slice to variadic arguments.
          257         rows, err := st.Query(keyvalue_to_val(kv)...)
          258         if err != nil {
          259                 st.Close()
          260                 return nil, err
          261         }
          262         cols, err := rows.Columns()
          263         if err != nil {
          264                 st.Close()
          265                 rows.Close()
          266                 return nil, err
          267         }
          268         colslen := len(cols)
          269 
          270         r := &Results{}
          271         r.SqlRows = rows
          272         r.SqlDB = db
          273         r.SqlStmt = st
          274         r.Count = 0
          275         r.Columns = cols
          276         r.ColumnOrdinals = make(map[string]int) // NOTE: names are mapped lower-case.
          277         for i, k := range cols {
          278                 // names to ordinal are always mapped to lowercase.
          279                 k := strings.ToLower(k)
          280                 r.ColumnOrdinals[k] = i
          281         }
          282         r.Rows = make(chan *Result)
          283         r.done = make(chan bool, 1)
          284         r.Err = nil
          285 
          286         go func() {
          287                 for {
          288                         select {
          289                         case v, ok := <-r.done:
          290                                 // true value and channel not closed.
          291                                 if v && ok {
          292                                         return
          293                                 }
          294                         default:
          295                                 // allocate space for result.
          296                                 ptrs := make([]interface{}, colslen)
          297                                 for i := 0; i < colslen; i++ {
          298                                         ptrs[i] = new(interface{})
          299                                 }
          300 
          301                                 // read lock for r.SqlRows -> rows
          302                                 r.l.RLock()
          303                                 if rows.Next() {
          304                                         err := rows.Scan(ptrs...)
          305                                         if err != nil {
          306                                                 r.Err = err
          307                                                 r.l.RUnlock()
          308                                                 return
          309                                         }
          310                                         r.l.RUnlock()
          311                                 } else {
          312                                         r.l.RUnlock()
          313                                         close(r.Rows)
          314                                         return
          315                                 }
          316                                 r.Rows <- &Result{Values: ptrs}
          317                                 r.Count++
          318                         }
          319                 }
          320         }()
          321         return r, nil
          322 }
          323 
          324 // comfy function to convert KeyValue pairs to interface{} values.
          325 func keyvalue_to_val(kv []KeyValue) []interface{} {
          326         values := make([]interface{}, len(kv))
          327         for i, v := range kv {
          328                 values[i] = v.Value
          329         }
          330         return values
          331 }
          332 
          333 // Unmarshal results into a slice of structs using the reflect package.
          334 //
          335 // By default types are assigned from the standard database type (see database/sql)
          336 // to the type in the struct. These must be directly assignable.
          337 //
          338 // A special case is if a database field type is []byte and the struct type is float64
          339 // it will be converted (useful to convert the MSSQL decimal type to float64.
          340 //
          341 // If a type is specified in the field struct tag it will be converted (not assigned)
          342 // to the struct type. This is also useful to convert directly to your own custom types,
          343 // like:
          344 //
          345 // type CustomTime time.Time
          346 //
          347 // The following format is supported:
          348 //
          349 // type Example struct {
          350 //         Field CustomTime `db:"Field_Name,time.Time"`
          351 // }
          352 //
          353 // For times there is a third struct field tag to fix timezones in time.Time, it is used like:
          354 //
          355 // type Example struct {
          356 //         Field CustomTime `db:"Field_Name,time.Time,local"`
          357 // }
          358 //
          359 // Fields are matched by name case-insensitive; to be clear: the order of the field in
          360 // the struct won't matter.
          361 //
          362 // CAVEAT: channel result.Rows is not closed, do this manually yourself, usually: defer r.Close()
          363 func (r *Results) Unmarshal(v interface{}) (err error) {
          364         val := reflect.ValueOf(v)
          365         if val.Kind() != reflect.Ptr || val.IsNil() {
          366                 return errors.New("must be pointer to slice of structs and non-nil")
          367         }
          368         val = val.Elem() // dereference pointer.
          369         if val.Kind() != reflect.Slice {
          370                 return errors.New("must be pointer to slice of structs and non-nil")
          371         }
          372 
          373         err = nil
          374 
          375         slicetype := val.Type()
          376         vals := reflect.MakeSlice(slicetype, 0, 0)
          377         eltype := vals.Type().Elem()
          378 
          379         // map types (from, to) and names once per resultset.
          380         structnumfield := eltype.NumField()
          381         typesto := make([]reflect.Type, structnumfield)
          382         typesnames := make([]string, structnumfield)
          383         timeconvs := make([]string, structnumfield)
          384         names := make([]string, structnumfield)
          385         // field struct index to db field ordinal.
          386         ordinals := make([]int, structnumfield)
          387 
          388         embed := make([]int, structnumfield)
          389 
          390         for i := 0; i < structnumfield; i++ {
          391                 embed[i] = -1
          392                 field := eltype.Field(i)
          393                 // field name, type (optional) and time conversion tag (optional).
          394                 fieldtag := field.Tag.Get("db") // "db" namespace.
          395                 name := field.Name
          396                 _type := ""
          397                 _timeconv := ""
          398                 if fieldtag != "" {
          399                         fieldtags := strings.Split(fieldtag, ",")
          400                         n := len(fieldtags)
          401                         // time conversion.
          402                         if n > 2 && fieldtags[2] != "" {
          403                                 _timeconv = fieldtags[2]
          404                         }
          405                         // field type
          406                         if n > 1 && fieldtags[1] != "" {
          407                                 _type = fieldtags[1]
          408                         }
          409                         // field name.
          410                         if n > 0 && fieldtags[0] != "" {
          411                                 name = fieldtags[0]
          412                         }
          413                 }
          414 
          415                 ordinals[i] = r.GetOrdinal(name)
          416                 if ordinals[i] < 0 {
          417                         return errors.New("column ordinal not found: " + name)
          418                 }
          419                 names[i] = name
          420                 typesto[i] = field.Type
          421 
          422                 // field type: type specified but does not exist.
          423                 if _type != "" && types[_type] == nil {
          424                         return errors.New("unsupported type: " + _type)
          425                 }
          426                 typesnames[i] = _type
          427 
          428                 // check if type is set but not convertible, if so:
          429                 // search for Anonymous field with same type as types[_type].
          430                 t := field.Type
          431                 if _type != "" && types[_type] != nil &&
          432                         !t.ConvertibleTo(types[_type]) &&
          433                         t.Kind() == reflect.Struct {
          434                         n := t.NumField()
          435                         for j := 0; j < n; j++ {
          436                                 f := t.Field(j)
          437                                 if f.Anonymous && (f.Type == types[_type] || f.Type.ConvertibleTo(types[_type])) {
          438                                         // also update Type to type of embedded field in struct.
          439                                         embed[i] = j // index of embedded field.
          440                                         typesto[i] = f.Type
          441                                         break
          442                                 }
          443                         }
          444                         if embed[i] == -1 {
          445                                 return errors.New("type set but anonymous field / embedded field not found: " + _type)
          446                         }
          447                 }
          448                 // time conversation tag.
          449                 if _timeconv != "" && timeconv[_timeconv] == nil {
          450                         return errors.New("unsupported time conversation parameter: " + _timeconv)
          451                 }
          452                 timeconvs[i] = _timeconv
          453         }
          454 
          455         total := 0
          456         fieldtypes := make([]reflect.Type, structnumfield)
          457 
          458         defer func() {
          459                 r.done <- true
          460                 close(r.done)
          461         }()
          462 
          463         for row := range r.Rows {
          464                 // NOTE: !reflect.Value.IsValid() will be skipped so initial values
          465                 //       of struct will be used.
          466                 t := reflect.New(eltype).Elem() // dereference
          467 
          468                 // first row, get type of field, should be equal in whole resultset.
          469                 if total == 0 {
          470                         for i := 0; i < structnumfield; i++ {
          471                                 ordinal := ordinals[i]
          472 
          473                                 if row.Values[ordinal] == nil {
          474                                         return errors.New("field \"" + names[i] + "\" not found (nil)")
          475                                 }
          476 
          477                                 v := reflect.ValueOf(*(row.Values[ordinal]).(*interface{}))
          478                                 // if value is nil then skip detecting it's type, it will be tried
          479                                 // later to detect the type again, if that is also not possible
          480                                 // the zero value of the corresponding type is used.
          481                                 if !v.IsValid() {
          482                                         continue
          483                                 }
          484 
          485                                 var field reflect.Value
          486                                 vt := v.Type()
          487                                 if embed[i] >= 0 {
          488                                         field = t.Field(i).Field(embed[i])
          489                                 } else {
          490                                         field = t.Field(i)
          491                                 }
          492                                 // can not set value of this field (like an unexported field).
          493                                 if !field.CanSet() {
          494                                         return errors.New("cannot set value on field: " +
          495                                                 "name=\"" + names[i] + "\"" +
          496                                                 ", dbtype=\"" + vt.Name() + "\"" +
          497                                                 ", struct type=\"" + typesto[i].Name() + "\"" +
          498                                                 ", tag type=\"" + typesnames[i] + "\", unexported field?")
          499                                 }
          500 
          501                                 // NOTE: []byte to float64 is a special-case, this is not convertible
          502                                 // but we handle this ourself, see below.
          503                                 if vt == type_byteslice && (typesto[i] == type_float64 || (embed[i] >= 0 && typesnames[i] == "float64")) {
          504                                         // special case, handled later.
          505                                 } else if typesnames[i] == "" {
          506                                         // NOTE: direct type to type, must be assignable, not convertible.
          507                                         if !vt.AssignableTo(typesto[i]) {
          508                                                 return errors.New("non-assignable: " +
          509                                                         "name=\"" + names[i] + "\"" +
          510                                                         ", dbtype=\"" + vt.Name() + "\"" +
          511                                                         ", struct type=\"" + typesto[i].Name() + "\"" +
          512                                                         ", tag type=\"" + typesnames[i] + "\"")
          513                                         }
          514                                 } else if !(vt.ConvertibleTo(typesto[i])) {
          515                                         // can't convert to type.
          516                                         return errors.New("non-convertible: " +
          517                                                 "name=\"" + names[i] + "\"" +
          518                                                 ", dbtype=\"" + vt.Name() + "\"" +
          519                                                 ", struct type=\"" + typesto[i].Name() + "\"" +
          520                                                 ", tag type=\"" + typesnames[i] + "\"")
          521                                 }
          522                                 // convertible type, set it.
          523                                 fieldtypes[i] = vt
          524                         }
          525                 }
          526 
          527                 for i := 0; i < structnumfield; i++ {
          528                         ordinal := ordinals[i]
          529                         v := reflect.ValueOf(*(row.Values[ordinal]).(*interface{}))
          530                         if !v.IsValid() {
          531                                 continue
          532                         }
          533 
          534                         var field reflect.Value
          535                         if embed[i] >= 0 {
          536                                 field = t.Field(i).Field(embed[i])
          537                         } else {
          538                                 field = t.Field(i)
          539                         }
          540 
          541                         // if type was not known in first result row (for example when it is NULL)
          542                         // then set it anyway. NOTE that when a type is known at some point it must
          543                         // be the same across all rows in the resultset.
          544                         if fieldtypes[i] == nil {
          545                                 vt := v.Type()
          546                                 // can't convert to type.
          547                                 if
          548                                 // NOTE: []byte to float64 is a special-case, this is not convertible.
          549                                 (vt != type_byteslice || (typesto[i] != type_float64 && !(embed[i] >= 0 && typesnames[i] == "float64"))) &&
          550                                         !(vt.ConvertibleTo(typesto[i])) {
          551                                         return errors.New("non-convertible type: " +
          552                                                 "name=\"" + names[i] + "\"" +
          553                                                 ", dbtype=\"" + vt.Name() + "\"" +
          554                                                 ", struct type=\"" + typesto[i].Name() + "\"" +
          555                                                 ", tag type=\"" + typesnames[i] + "\"")
          556                                 }
          557                                 // convertible type, set it.
          558                                 fieldtypes[i] = vt
          559                         }
          560 
          561                         // used to convert MSSQL-types like money, decimal to float.
          562                         // these (money, decimal) types are returned as a []byte by the MSSQL driver.
          563                         if fieldtypes[i] == type_byteslice && (typesto[i] == type_float64 || (embed[i] >= 0 && typesnames[i] == "float64")) {
          564                                 f, err := row.GetDecimalToFloat(ordinal)
          565                                 if err != nil {
          566                                         return err
          567                                 }
          568                                 v := reflect.ValueOf(f)
          569                                 if typesnames[i] == "float64" {
          570                                         v = v.Convert(typesto[i])
          571                                 }
          572                                 if !v.IsValid() {
          573                                         continue
          574                                 }
          575                                 field.Set(v)
          576                                 continue
          577                         }
          578 
          579                         // fix time, only correct timezone from database without modifying the time itself.
          580                         if timeconvs[i] != "" && fieldtypes[i] == type_time {
          581                                 t, err := row.GetTime(ordinal)
          582                                 if err != nil {
          583                                         return err
          584                                 }
          585                                 t = fixtime(t, timeconv[timeconvs[i]])
          586                                 v = reflect.ValueOf(t)
          587                         }
          588 
          589                         // assign to type in struct.
          590                         if typesnames[i] == "" {
          591                                 field.Set(v)
          592                         } else {
          593                                 // try to convert to type specified in struct tag: `db:"name,type"` if valid.
          594                                 // if invalid leave as zero value for the type.
          595                                 v = v.Convert(typesto[i])
          596                                 if v.IsValid() {
          597                                         field.Set(v)
          598                                 }
          599                         }
          600                 }
          601                 vals = reflect.Append(vals, t)
          602                 total++
          603         }
          604         val.Set(vals)
          605 
          606         return
          607 }