1 // Package httpreaderat allows opening a URL as a io.ReaderAt. 2 package httpreaderat 3 4 import ( 5 "cmp" 6 "errors" 7 "fmt" 8 "io" 9 "io/fs" 10 "iter" 11 "mime" 12 "mime/multipart" 13 "net/http" 14 "strconv" 15 "strings" 16 17 "vimagination.zapto.org/cache" 18 ) 19 20 type block struct { 21 data string 22 prev, next *block 23 } 24 25 // Request represents an io.ReaderAt for an HTTP URL. 26 type Request struct { 27 url string 28 length int64 29 blockSize int64 30 cache *cache.LRU[int64, string] 31 } 32 33 // NewRequest creates a new Request object, for the given URL, that implements 34 // io.ReaderAt. 35 // 36 // Options can be passed in to modify how the maximum length is determined, and 37 // the characteristics of the block caching. 38 // 39 // By default, up to 256 4KB blocks will be cached. 40 func NewRequest(url string, opts ...Option) (*Request, error) { 41 r := &Request{url: url, length: -1, blockSize: 1 << 12} 42 43 for _, opt := range opts { 44 opt(r) 45 } 46 47 if r.length == -1 { 48 if err := r.getLength(); err != nil { 49 return nil, err 50 } 51 } 52 53 if r.cache == nil { 54 r.cache = cache.NewLRU[int64, string](256) 55 } 56 57 return r, nil 58 } 59 60 func (r *Request) getLength() error { 61 resp, err := http.Head(r.url) 62 if err != nil { 63 return err 64 } 65 66 if resp.Header.Get("Accept-Ranges") != "bytes" { 67 return ErrNoRange 68 } 69 70 if cl := resp.Header.Get("Content-Length"); cl != "" { 71 if r.length, err = strconv.ParseInt(cl, 10, 64); err != nil { 72 return fmt.Errorf("error parsing content-length: %w", err) 73 } 74 } 75 76 return nil 77 } 78 79 // ReadAt implements the io.ReaderAt interface. 80 func (r *Request) ReadAt(p []byte, n int64) (int, error) { 81 if r.length >= 0 && n > r.length { 82 return 0, io.EOF 83 } else if n < 0 { 84 return 0, fs.ErrInvalid 85 } 86 87 p = p[:min(int64(len(p)), r.length-n)] 88 89 blocks, err := r.getBlocks(n, int64(len(p))) 90 if err != nil { 91 return 0, err 92 } 93 94 blocks[0] = blocks[0][n%r.blockSize:] 95 96 l := len(p) 97 98 for _, block := range blocks { 99 p = p[copy(p, block):] 100 } 101 102 return l, nil 103 } 104 105 func (r *Request) Length() int64 { 106 return r.length 107 } 108 109 func (r *Request) getBlocks(start, length int64) ([]string, error) { 110 blocks, requests := r.getExistingBlocks(start, length) 111 if len(requests) == 0 { 112 return blocks, nil 113 } 114 115 if err := r.setNewBlocks(blocks, requests, start); err != nil { 116 return nil, err 117 } 118 119 return blocks, nil 120 } 121 122 type requests [][2]int64 123 124 func (r requests) Iter() iter.Seq[int64] { 125 return func(yield func(int64) bool) { 126 for _, request := range r { 127 for b := request[0]; b <= request[1]; b++ { 128 if !yield(b) { 129 return 130 } 131 } 132 } 133 } 134 } 135 136 func (r *Request) getExistingBlocks(start, length int64) ([]string, requests) { 137 var ( 138 blocks []string 139 requests requests 140 ) 141 142 startBlock := start / r.blockSize 143 144 for block := startBlock; block <= (start+length-1)/r.blockSize; block++ { 145 b, ok := r.cache.Get(block) 146 if ok { 147 blocks = append(blocks, b) 148 149 continue 150 } 151 152 blocks = append(blocks, "") 153 154 if len(requests) > 0 && requests[len(requests)-1][1]+1 == block { 155 requests[len(requests)-1][1] = block 156 } else { 157 requests = append(requests, [2]int64{block, block}) 158 } 159 } 160 161 return blocks, requests 162 } 163 164 func (r *Request) setNewBlocks(blocks []string, requests requests, start int64) error { 165 resp, err := r.requestBlocks(requests) 166 if err != nil { 167 return err 168 } 169 170 startBlock := start / r.blockSize 171 buf := make([]byte, r.blockSize) 172 173 rr, err := makeReader(resp) 174 if err != nil { 175 return err 176 } 177 178 for block := range requests.Iter() { 179 n, err := io.ReadFull(rr, buf[:cmp.Or(min(r.length, (block+1)*r.blockSize)%r.blockSize, r.blockSize)]) 180 if err != nil { 181 return err 182 } 183 184 data := string(buf[:n]) 185 blocks[block-startBlock] = data 186 187 r.cache.Set(block, data) 188 } 189 190 return resp.Body.Close() 191 } 192 193 func (r *Request) requestBlocks(requests requests) (*http.Response, error) { 194 req, err := http.NewRequest(http.MethodGet, r.url, nil) 195 if err != nil { 196 return nil, err 197 } 198 199 req.Header.Set("Range", r.makeByteRangeHeader(requests)) 200 201 resp, err := http.DefaultClient.Do(req) 202 if err != nil { 203 return nil, err 204 } 205 206 return resp, nil 207 } 208 209 func (r *Request) makeByteRangeHeader(requests requests) string { 210 var byteRange strings.Builder 211 212 byteRange.WriteString("bytes=") 213 214 for n, request := range requests { 215 if n > 0 { 216 byteRange.WriteByte(',') 217 } 218 219 fmt.Fprintf(&byteRange, "%d-%d", request[0]*r.blockSize, min((request[1]+1)*r.blockSize, r.length)-1) 220 } 221 222 return byteRange.String() 223 } 224 225 func makeReader(resp *http.Response) (io.Reader, error) { 226 var rr io.Reader = resp.Body 227 228 if mt, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")); err != nil { 229 return nil, err 230 } else if mt == "multipart/byteranges" { 231 mr := multipart.NewReader(resp.Body, params["boundary"]) 232 233 p, err := mr.NextPart() 234 if err != nil { 235 return nil, err 236 } 237 238 rr = &multipartReader{mr: mr, p: p} 239 } 240 241 return rr, nil 242 } 243 244 type multipartReader struct { 245 mr *multipart.Reader 246 p *multipart.Part 247 } 248 249 func (m *multipartReader) Read(p []byte) (int, error) { 250 n, err := m.p.Read(p) 251 if n == 0 && err == io.EOF { 252 var err error 253 254 if m.p, err = m.mr.NextPart(); err != nil { 255 return n, err 256 } 257 258 return m.Read(p) 259 } 260 261 return n, err 262 } 263 264 // Clear clears the block cache. 265 func (r *Request) Clear() { 266 r.cache.Clear() 267 } 268 269 // Errors. 270 var ErrNoRange = errors.New("no range header") 271