从这一部分开始,我们将不会再聚焦于基本的操作细节上,而是更多的做一些有特点的修改(or 魔改)
这篇博客记录了:不进行预加载数据,而是实时加载数据的操作
各模块解释
import pandas as pd
import backtrader as bt
from loguru import logger
from datetime import datetime
from backtrader.feed import DataBase
from backtrader import date2num
import efinance
1. 获取K线数据
这是一个借助efinance
模块获取历史行情数据的模块
def get_k_data(stock_code, begin: datetime, end: datetime) -> pd.DataFrame:
"""
根据efinance工具包获取股票数据
:param stock_code:股票代码
:param begin: 开始日期
:param end: 结束日期
:return:
"""
# stock_code = '600519' # 股票代码,茅台
k_dataframe: pd.DataFrame = efinance.stock.get_quote_history(
stock_code, beg=begin.strftime("%Y%m%d"), end=end.strftime("%Y%m%d"))
k_dataframe = k_dataframe.iloc[:, :9]
k_dataframe.columns = ['name', 'code', 'date', 'open', 'close', 'high', 'low', 'volume', "turnover"]
k_dataframe.index = pd.to_datetime(k_dataframe.date)
k_dataframe.drop(['name', 'code', 'date'], axis=1, inplace=True)
return k_dataframe
2. 可迭代数据类
class StockData(DataBase):
"""自定义数据格式"""
params = (('turnover', -1),)
lines = ('turnover',)
def __init__(self, stock_code, stock_dataframe):
self.dataname = stock_code
self.stock_dataframe = stock_dataframe.sort_index()
self.stock_iter = self.stock_dataframe.iterrows()
def _load(self): # 类似于策略的 next(),预期执行几次 next(),就会执行几次 _load()
_series_datetime, _series_data = next(self.stock_iter, (None, {}))
if _series_datetime is None:
return False
else:
self.lines.datetime[0] = date2num(_series_datetime) # 注意这里的操作
self.lines.open[0] = float(_series_data['open'])
self.lines.high[0] = float(_series_data['high'])
self.lines.low[0] = float(_series_data['low'])
self.lines.close[0] = float(_series_data['close'])
self.lines.volume[0] = int(_series_data['volume'])
self.lines.turnover[0] = float(_series_data['turnover'])
logger.debug("Load数据 {} TIME {} 收盘价 {}".format(self.dataname, _series_datetime, _series_data['close']))
return True
这是一个可以加载【日期,开盘,最高,最低,收盘,持仓量,换手】的数据,其中:
在初始化的时候:
-
self.dataname
:在策略中使用self.getdatabyname()
时会与这个值对应
-
self.lines.close[0]
:在策略中使用data.close[0]
时会与这个值对应
其中时间部分注意需要使用date2num()
,在_load(self)
使用如下代码可以得到上一天的日期
datetime.fromordinal(int(self.datetime[-1])) # 只能解析成 date
或使用:
self.datetime.datetime(-1)
3. 策略模块
我们什么交易都不执行,只核对数据:
class TestStrategy(bt.Strategy): # 策略
def __init__(self):
# 初始化交易指令、买卖价格和手续费
total_bond_name = []
for this_data in self.datas:
if type(this_data).__name__ == "StockData":
total_bond_name.append(this_data._name)
self.total_bond_name = total_bond_name
def next(self): # 固定的函数,框架执行过程中会不断循环next(),过一个K线,执行一次next()
# 此时调用 self.datas[0]即可查看当天的数据
# 执行买入条件判断:当天收盘价格突破5日均线
for bond_name in self.total_bond_name:
print("{} {} 收盘价 {}".format(self.datetime.datetime(0), bond_name, self.getdatabyname(bond_name).close[0]))
print()
示例代码
import pandas as pd
import backtrader as bt
from loguru import logger
from datetime import datetime
from backtrader.feed import DataBase
from backtrader import date2num
import efinance
def get_k_data(stock_code, begin: datetime, end: datetime) -> pd.DataFrame:
"""
根据efinance工具包获取股票数据
:param stock_code:股票代码
:param begin: 开始日期
:param end: 结束日期
:return:
"""
# stock_code = '600519' # 股票代码,茅台
k_dataframe: pd.DataFrame = efinance.stock.get_quote_history(
stock_code, beg=begin.strftime("%Y%m%d"), end=end.strftime("%Y%m%d"))
k_dataframe = k_dataframe.iloc[:, :9]
k_dataframe.columns = ['name', 'code', 'date', 'open', 'close', 'high', 'low', 'volume', "turnover"]
k_dataframe.index = pd.to_datetime(k_dataframe.date)
k_dataframe.drop(['name', 'code', 'date'], axis=1, inplace=True)
return k_dataframe
class TestStrategy(bt.Strategy): # 策略
def __init__(self):
# 初始化交易指令、买卖价格和手续费
total_bond_name = []
for this_data in self.datas:
if type(this_data).__name__ == "StockData":
total_bond_name.append(this_data._name)
self.total_bond_name = total_bond_name
def next(self): # 固定的函数,框架执行过程中会不断循环next(),过一个K线,执行一次next()
# 此时调用 self.datas[0]即可查看当天的数据
# 执行买入条件判断:当天收盘价格突破5日均线
for bond_name in self.total_bond_name:
print("{} {} 收盘价 {}".format(self.datetime.datetime(0), bond_name, self.getdatabyname(bond_name).close[0]))
print()
class StockData(DataBase):
"""自定义数据格式"""
params = (('turnover', -1),)
lines = ('turnover',)
def __init__(self, stock_code, stock_dataframe):
self.dataname = stock_code
self.stock_dataframe = stock_dataframe.sort_index()
self.stock_iter = self.stock_dataframe.iterrows()
def _load(self): # 类似于策略的 next(),预期执行几次 next(),就会执行几次 _load()
_series_datetime, _series_data = next(self.stock_iter, (None, {}))
if _series_datetime is None:
return False
else:
self.lines.datetime[0] = date2num(_series_datetime)
self.lines.open[0] = float(_series_data['open'])
self.lines.high[0] = float(_series_data['high'])
self.lines.low[0] = float(_series_data['low'])
self.lines.close[0] = float(_series_data['close'])
self.lines.volume[0] = int(_series_data['volume'])
self.lines.turnover[0] = float(_series_data['turnover'])
logger.debug("Load数据 {} TIME {} 收盘价 {}".format(self.dataname, _series_datetime, _series_data['close']))
return True
def load_stock_data(start_date: datetime, end_date: datetime):
"""加载股票数据"""
data_dict = {}
for stock_code in ["000636", "600519"]:
data_dict[stock_code] = StockData(stock_code, get_k_data(stock_code, start_date, end_date))
return data_dict
def main():
start_date = datetime(2019, 1, 1) # 起始日期
end_date = datetime(2020, 1, 1) # 终止日期
initial_cash = 10000 # 设置启动资金: 1w
commission = 1 / 1000 # 摩擦成本,买入、卖出均有摩擦成本
# 数据源
cerebral_system = bt.Cerebro()
data_dict = load_stock_data(start_date, end_date)
for key, value in data_dict.items():
cerebral_system.adddata(value, name=key)
cerebral_system.broker.setcash(initial_cash) # 启动资金配置
cerebral_system.broker.setcommission(commission=commission) # 手续费
logger.debug('初始资金: {} 回测期间:from {} to {}'.format(initial_cash, start_date, end_date))
# =============== 运行回测系统 ==================
logger.info("Run cerebral system")
cerebral_system.addstrategy(TestStrategy)
cerebral_system.run(preload=False)
if __name__ == '__main__':
main()
得到的结果如下:
2022-12-01 14:44:35.485 | DEBUG | __main__:main:98 - 初始资金: 10000 回测期间:from 2019-01-01 00:00:00 to 2020-01-01 00:00:00
2022-12-01 14:44:35.485 | INFO | __main__:main:100 - Run cerebral system
2022-12-01 14:44:35.489 | DEBUG | __main__:_load:74 - Load数据 000636 TIME 2019-01-02 00:00:00 收盘价 10.17
2022-12-01 14:44:35.490 | DEBUG | __main__:_load:74 - Load数据 600519 TIME 2019-01-02 00:00:00 收盘价 526.45
2019-01-02 00:00:00 000636 收盘价 10.17
2019-01-02 00:00:00 600519 收盘价 526.45
2019-01-03 00:00:00 000636 收盘价 10.46
2019-01-03 00:00:00 600519 收盘价 517.47
2022-12-01 14:44:42.082 | DEBUG | __main__:_load:74 - Load数据 000636 TIME 2019-01-03 00:00:00 收盘价 10.46
2022-12-01 14:44:42.082 | DEBUG | __main__:_load:74 - Load数据 600519 TIME 2019-01-03 00:00:00 收盘价 517.47
2019-01-04 00:00:00 000636 收盘价 10.72
2019-01-04 00:00:00 600519 收盘价 529.47
2022-12-01 14:44:42.820 | DEBUG | __main__:_load:74 - Load数据 000636 TIME 2019-01-04 00:00:00 收盘价 10.72
2022-12-01 14:44:42.821 | DEBUG | __main__:_load:74 - Load数据 600519 TIME 2019-01-04 00:00:00 收盘价 529.47
2019-01-07 00:00:00 000636 收盘价 11.06
2019-01-07 00:00:00 600519 收盘价 532.96
2022-12-01 14:44:43.661 | DEBUG | __main__:_load:74 - Load数据 000636 TIME 2019-01-07 00:00:00 收盘价 11.06
2022-12-01 14:44:43.662 | DEBUG | __main__:_load:74 - Load数据 600519 TIME 2019-01-07 00:00:00 收盘价 532.96
2019-01-08 00:00:00 000636 收盘价 10.8
2019-01-08 00:00:00 600519 收盘价 532.26
2022-12-01 14:44:46.996 | DEBUG | __main__:_load:74 - Load数据 000636 TIME 2019-01-08 00:00:00 收盘价 10.8
2022-12-01 14:44:46.997 | DEBUG | __main__:_load:74 - Load数据 600519 TIME 2019-01-08 00:00:00 收盘价 532.26
数据是一致的,因此,我们可以可以在def _load()
里添加任何我们想读取的数据,比如实时从数据库里获取